diff --git a/wganHCAL.py b/wganHCAL.py index a61ab5a0079ba15bfcd6fda5c3a2b0dad46732ab..9835cc800dc25d2235414fba1dd116009ce7515e 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -85,7 +85,8 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo fake_ecal = fake_ecal.unsqueeze(1) ## Generate Fake HCAL - z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False) + z = zE.view(args.batch_size, args.nz) + #z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False) fake_dataHCAL = aG(z, label, fake_ecal).detach() ## 48 x 30 x 30 ## Critic fwd pass on Real @@ -136,7 +137,8 @@ def train(args, aD, aG, aGE, device, train_loader, optimizer_d, optimizer_g, epo fake_ecal = fake_ecal.unsqueeze(1) #### - z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True) + z = zE.view(args.batch_size, args.nz) + #z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True) ## generate fake data out of noise fake_dataHCALG = aG(z, label, fake_ecal)