From aa600a0e56694abaa5d5c149b674f3fc168920c3 Mon Sep 17 00:00:00 2001 From: Engin Eren <engin.eren@desy.de> Date: Thu, 14 Jul 2022 16:09:48 +0200 Subject: [PATCH] GP, typo --- wganHCAL.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wganHCAL.py b/wganHCAL.py index 5fdfb12..aecb849 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -33,11 +33,11 @@ def calc_gradient_penalty(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real alphaH = torch.rand(BATCH_SIZE, 1) alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous() - alphaH = alphaH.view(BATCH_SIZE, 1, layer, xsize, ysize) - alphaE = alphaH.to(device) + alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + alphaH = alphaH.to(device) - fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer, xsize, ysize) - fake_ecal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize) interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach()) interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach()) -- GitLab