From f963a38d40e4403b01782f89f4959fa65702e958 Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Tue, 12 Jul 2022 10:37:46 +0200
Subject: [PATCH] now Fake-Fake-Fake combination

---
 wganHCAL.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/wganHCAL.py b/wganHCAL.py
index 9526a71..1435ef1 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -63,7 +63,7 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo
         optimizer_d.zero_grad()
 
         ## Get Real data
-        real_dataECAL = dataE.to(device).unsqueeze(1).float()
+        #real_dataECAL = dataE.to(device).unsqueeze(1).float()
         real_dataHCAL = dataH.to(device).unsqueeze(1).float()
         label = energy.to(device).float()
         ###
@@ -78,10 +78,10 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo
         fake_dataHCAL = aG(z, label, fake_ecal).detach() ## 48 x 30 x 30        
 
         ## Critic fwd pass on Real
-        disc_real = aD(real_dataECAL, real_dataHCAL, label) 
+        disc_real = aD(fake_ecal, real_dataHCAL, label) 
 
         ## Calculate Gradient Penalty Term
-        gradient_penalty = calc_gradient_penalty(aD, real_dataECAL, real_dataHCAL, fake_dataHCAL, label, args.batch_size, device, layer=48, xsize=30, ysize=30)
+        gradient_penalty = calc_gradient_penalty(aD, fake_ecal, real_dataHCAL, fake_dataHCAL, label, args.batch_size, device, layer=48, xsize=30, ysize=30)
 
         ## Critic fwd pass on Fake HCAL
         disc_fake = aD(fake_ecal, fake_dataHCAL.unsqueeze(1), label)
-- 
GitLab