From f963a38d40e4403b01782f89f4959fa65702e958 Mon Sep 17 00:00:00 2001 From: Engin Eren <engin.eren@desy.de> Date: Tue, 12 Jul 2022 10:37:46 +0200 Subject: [PATCH] now Fake-Fake-Fake combination --- wganHCAL.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wganHCAL.py b/wganHCAL.py index 9526a71..1435ef1 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -63,7 +63,7 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo optimizer_d.zero_grad() ## Get Real data - real_dataECAL = dataE.to(device).unsqueeze(1).float() + #real_dataECAL = dataE.to(device).unsqueeze(1).float() real_dataHCAL = dataH.to(device).unsqueeze(1).float() label = energy.to(device).float() ### @@ -78,10 +78,10 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo fake_dataHCAL = aG(z, label, fake_ecal).detach() ## 48 x 30 x 30 ## Critic fwd pass on Real - disc_real = aD(real_dataECAL, real_dataHCAL, label) + disc_real = aD(fake_ecal, real_dataHCAL, label) ## Calculate Gradient Penalty Term - gradient_penalty = calc_gradient_penalty(aD, real_dataECAL, real_dataHCAL, fake_dataHCAL, label, args.batch_size, device, layer=48, xsize=30, ysize=30) + gradient_penalty = calc_gradient_penalty(aD, fake_ecal, real_dataHCAL, fake_dataHCAL, label, args.batch_size, device, layer=48, xsize=30, ysize=30) ## Critic fwd pass on Fake HCAL disc_fake = aD(fake_ecal, fake_dataHCAL.unsqueeze(1), label) -- GitLab