From 3ccde3b443f0be26606c25f1994295dfc2c9b92b Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Tue, 17 May 2022 14:58:30 +0200
Subject: [PATCH] forgot to add ECAL data into G.Penalty

---
 wganHCAL.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/wganHCAL.py b/wganHCAL.py
index 0988c4c..ced8009 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -22,7 +22,7 @@ from models.generatorFull import Hcal_ecalEMB
 from models.data_loaderFull import HDF5Dataset
 from models.criticFull import CriticEMB
 
-def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize):
+def calc_gradient_penalty(netD, real_dataECAL, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize):
     
     alpha = torch.rand(BATCH_SIZE, 1)
     alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous()
@@ -36,7 +36,7 @@ def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, de
     interpolates = interpolates.to(device)
     interpolates.requires_grad_(True)   
 
-    disc_interpolates = netD(interpolates.float(), real_label.float())
+    disc_interpolates = netD(real_dataECAL, interpolates.float(), real_label.float())
 
 
     gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
@@ -72,7 +72,7 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e
         disc_real = aD(real_dataECAL.float(), real_dataHCAL.float(), real_label.float()) 
 
         ## Calculate Gradient Penalty Term
-        gradient_penalty = calc_gradient_penalty(aD, real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30)
+        gradient_penalty = calc_gradient_penalty(aD, real_dataECAL.float(), real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30)
 
         ## Critic fwd pass on Fake 
         disc_fake = aD(real_dataECAL, fake_dataHCAL.unsqueeze(1), real_label)
-- 
GitLab