From aa600a0e56694abaa5d5c149b674f3fc168920c3 Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Thu, 14 Jul 2022 16:09:48 +0200
Subject: [PATCH] GP, typo

---
 wganHCAL.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/wganHCAL.py b/wganHCAL.py
index 5fdfb12..aecb849 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -33,11 +33,11 @@ def calc_gradient_penalty(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real
 
     alphaH = torch.rand(BATCH_SIZE, 1)
     alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous()
-    alphaH = alphaH.view(BATCH_SIZE, 1, layer, xsize, ysize)
-    alphaE = alphaH.to(device)
+    alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    alphaH = alphaH.to(device)
 
-    fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer, xsize, ysize)
-    fake_ecal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize)
     
     interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach())
     interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach())
-- 
GitLab