diff --git a/wganHCAL.py b/wganHCAL.py
index 5fdfb1229aa51e2c0cebdd49013a7819af8334ac..aecb849c81a356b140650a35fce981e738516576 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -33,11 +33,11 @@ def calc_gradient_penalty(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real
 
     alphaH = torch.rand(BATCH_SIZE, 1)
     alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous()
-    alphaH = alphaH.view(BATCH_SIZE, 1, layer, xsize, ysize)
-    alphaE = alphaH.to(device)
+    alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    alphaH = alphaH.to(device)
 
-    fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer, xsize, ysize)
-    fake_ecal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize)
     
     interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach())
     interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach())