diff --git a/wganHCAL.py b/wganHCAL.py index a3dd6c195e7c5eae3a8ecbd562350b41fd063869..262c4580ff9f04d1fbdc88498002253fc9974687 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -68,10 +68,13 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo label = energy.to(device).float() ### - + ## Generate Fake ECAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)).detach() + fake_ecal = fake_ecal.unsqueeze(1) + ## Generate Fake HCAL z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False) - fake_dataHCAL = aG(z, label, fake_ecal).detach() ## 48 x 30 x 30 ## Critic fwd pass on Real