diff --git a/wganHCAL.py b/wganHCAL.py index 5fdfb1229aa51e2c0cebdd49013a7819af8334ac..aecb849c81a356b140650a35fce981e738516576 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -33,11 +33,11 @@ def calc_gradient_penalty(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real alphaH = torch.rand(BATCH_SIZE, 1) alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous() - alphaH = alphaH.view(BATCH_SIZE, 1, layer, xsize, ysize) - alphaE = alphaH.to(device) + alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + alphaH = alphaH.to(device) - fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer, xsize, ysize) - fake_ecal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize) interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach()) interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach())