diff --git a/wganHCAL.py b/wganHCAL.py index 54d52f5f778ce7e7b23e1a4643d4d996178b1af7..a3dd6c195e7c5eae3a8ecbd562350b41fd063869 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -68,15 +68,9 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo label = energy.to(device).float() ### - ## Generate Fake ECAL - zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1)))) - fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)).detach() - fake_ecal = fake_ecal.unsqueeze(1) - #### - ## Generate Fake HCAL - z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz)))) + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False) fake_dataHCAL = aG(z, label, fake_ecal).detach() ## 48 x 30 x 30 @@ -86,8 +80,8 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo ## Calculate Gradient Penalty Term gradient_penalty = calc_gradient_penalty(aD, real_dataECAL, real_dataHCAL, fake_dataHCAL, label, args.batch_size, device, layer=48, xsize=30, ysize=30) - ## Critic fwd pass on Fake - disc_fake = aD(fake_ecal, fake_dataHCAL.unsqueeze(1), label) + ## Critic fwd pass on Fake HCAL + disc_fake = aD(real_dataECAL, fake_dataHCAL.unsqueeze(1), label) ## wasserstein-1 distace @@ -122,6 +116,13 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo optimizer_g.zero_grad() + ## Generate Fake ECAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)).detach() + fake_ecal = fake_ecal.unsqueeze(1) + #### + + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True) ## generate fake data out of noise fake_dataHCALG = aG(z, label, fake_ecal)