From 3ccde3b443f0be26606c25f1994295dfc2c9b92b Mon Sep 17 00:00:00 2001 From: Engin Eren <engin.eren@desy.de> Date: Tue, 17 May 2022 14:58:30 +0200 Subject: [PATCH] forgot to add ECAL data into G.Penalty --- wganHCAL.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wganHCAL.py b/wganHCAL.py index 0988c4c..ced8009 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -22,7 +22,7 @@ from models.generatorFull import Hcal_ecalEMB from models.data_loaderFull import HDF5Dataset from models.criticFull import CriticEMB -def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize): +def calc_gradient_penalty(netD, real_dataECAL, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize): alpha = torch.rand(BATCH_SIZE, 1) alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous() @@ -36,7 +36,7 @@ def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, de interpolates = interpolates.to(device) interpolates.requires_grad_(True) - disc_interpolates = netD(interpolates.float(), real_label.float()) + disc_interpolates = netD(real_dataECAL, interpolates.float(), real_label.float()) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, @@ -72,7 +72,7 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e disc_real = aD(real_dataECAL.float(), real_dataHCAL.float(), real_label.float()) ## Calculate Gradient Penalty Term - gradient_penalty = calc_gradient_penalty(aD, real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30) + gradient_penalty = calc_gradient_penalty(aD, real_dataECAL.float(), real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30) ## Critic fwd pass on Fake disc_fake = aD(real_dataECAL, fake_dataHCAL.unsqueeze(1), real_label) -- GitLab