Skip to content
Snippets Groups Projects
Commit 3ccde3b4 authored by Engin Eren's avatar Engin Eren
Browse files

forgot to add ECAL data into G.Penalty

parent 4687816d
No related branches found
No related tags found
1 merge request!3Test
Pipeline #3983463 passed
......@@ -22,7 +22,7 @@ from models.generatorFull import Hcal_ecalEMB
from models.data_loaderFull import HDF5Dataset
from models.criticFull import CriticEMB
def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize):
def calc_gradient_penalty(netD, real_dataECAL, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize):
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous()
......@@ -36,7 +36,7 @@ def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, de
interpolates = interpolates.to(device)
interpolates.requires_grad_(True)
disc_interpolates = netD(interpolates.float(), real_label.float())
disc_interpolates = netD(real_dataECAL, interpolates.float(), real_label.float())
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
......@@ -72,7 +72,7 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e
disc_real = aD(real_dataECAL.float(), real_dataHCAL.float(), real_label.float())
## Calculate Gradient Penalty Term
gradient_penalty = calc_gradient_penalty(aD, real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30)
gradient_penalty = calc_gradient_penalty(aD, real_dataECAL.float(), real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30)
## Critic fwd pass on Fake
disc_fake = aD(real_dataECAL, fake_dataHCAL.unsqueeze(1), real_label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment