diff --git a/wgan_ECAL_HCAL_3crit.py b/wgan_ECAL_HCAL_3crit.py index 39cb390a0fa1595f7016d7ac982529dc7a030f05..37a9a9d914fbe80b44fea375cb4ff9135e622c31 100644 --- a/wgan_ECAL_HCAL_3crit.py +++ b/wgan_ECAL_HCAL_3crit.py @@ -112,24 +112,22 @@ def calc_gradient_penalty_ECAL_HCAL(netD, real_ecal, real_hcal, fake_ecal, fake_ return gradient_penalty -def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H, optimizer_d_H_E, optimizer_g_E, optimizer_g_H, epoch, experiment): - - ### CRITIC TRAINING - aDE.train() - aDH.train() - aD_H_E.train() - aGH.eval() - aGE.eval() +def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H, optimizer_d_H_E, optimizer_g_E, optimizer_g_H, optimizer_g_H_E, epoch, experiment): + Tensor = torch.cuda.FloatTensor for batch_idx, (dataE, dataH, energy) in enumerate(train_loader): - - # ECAL critic + ## ECAL CRITC TRAINING + aDE.train() + aGE.eval() + + # zero out critic gradients optimizer_d_E.zero_grad() ## Get Real data real_dataECAL = dataE.to(device).unsqueeze(1).float() + real_dataHCAL = dataH.to(device).unsqueeze(1).float() label = energy.to(device).float() ### @@ -159,19 +157,103 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, disc_cost_E.backward() optimizer_d_E.step() + #print("Generator training started") + ## ECAL GENERATOR TRAINING + ## training generator per ncrit + if (batch_idx % args.ncrit == 0) and (batch_idx != 0): + ## GENERATOR TRAINING + aDE.eval() + aGE.train() - # ECAL + HCAL critic - optimizer_d_H_E.zero_grad() + # zero out generator gradients + optimizer_g_E.zero_grad() - ## Get Real data - real_dataECAL = dataE.to(device).unsqueeze(1).float() - real_dataHCAL = dataH.to(device).unsqueeze(1).float() - label = energy.to(device).float() - ### + ## Generate Fake ECAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)) + fake_ecal = fake_ecal.unsqueeze(1) + + ## Loss function for ECAL generator + gen_E_cost = aDE(fake_ecal, label) + g_E_cost = -torch.mean(gen_E_cost) + g_E_cost.backward() + optimizer_g_E.step() + + + ## HCAL CRITIC TRAINING + aDH.train() + aGH.eval() + + # zero out critic gradients + optimizer_d_H.zero_grad() + + # Generate Fake HCAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)) + fake_ecal = fake_ecal.unsqueeze(1).detach() + + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False) + fake_dataHCAL = aGH(z, label.view(-1, 1, 1, 1, 1), fake_ecal) ## 48 x 30 x 30 + fake_dataHCAL = fake_dataHCAL.unsqueeze(1).detach() + + ## Critic fwd pass on Real + disc_real_H = aDH(real_dataHCAL, label) + + ## Calculate gradient penalty term + gradient_penalty_H = calc_gradient_penalty_HCAL(aDH, real_dataHCAL, fake_dataHCAL, label, args.batch_size, device, layer=48, xsize=30, ysize=30) + + ## Critic fwd pass on fake data + disc_fake_H = aDH(fake_dataHCAL, label) + + w_dist_H = torch.mean(disc_fake_H) - torch.mean(disc_real_H) + + # final disc cost + disc_cost_H = w_dist_H + args.lambd * gradient_penalty_H + + disc_cost_H.backward() + optimizer_d_H.step() + + ## HCAL GENERATOR TRAINING + ## training generator per ncrit + if (batch_idx % args.ncrit == 0) and (batch_idx != 0): + ## GENERATOR TRAINING + aDH.eval() + aGH.train() + + # zero out generator gradients + optimizer_g_H.zero_grad() + + # Generate Fake HCAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)) + fake_ecal = fake_ecal.unsqueeze(1).detach() + + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True) + fake_dataHCAL = aGH(z, label.view(-1, 1, 1, 1, 1), fake_ecal) ## 48 x 30 x 30 + fake_dataHCAL = fake_dataHCAL.unsqueeze(1) + + ## Loss function for ECAL generator + gen_H_cost = aDH(fake_dataHCAL, label) + g_H_cost = -torch.mean(gen_H_cost) + g_H_cost.backward() + optimizer_g_H.step() + + + ## ECAL + HCAL CRITIC TRAINING + aD_H_E.train() + aGH.eval() + aGE.eval() + + # zero out critic gradients + optimizer_d_H_E.zero_grad() ## Get Fake ECAL - fake_ecal = fake_ecal_gen.clone().detach() + #fake_ecal = fake_ecal_gen.clone().detach() + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)) + fake_ecal = fake_ecal.unsqueeze(1).detach() + ## Generate Fake HCAL #z = zE.view(args.batch_size, args.nz) @@ -197,49 +279,14 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H_E.step() - if batch_idx % args.log_interval == 0: - print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( - epoch, batch_idx * len(dataH), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), disc_cost_H_E.item())) - niter = epoch * len(train_loader) + batch_idx - experiment.log_metric("L_crit_E", disc_cost_E, step=niter) - experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter) - experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter) - experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter) - experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter) - experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter) - experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter) - experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter) - experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter) - experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter) - - ## training generator per ncrit if (batch_idx % args.ncrit == 0) and (batch_idx != 0): ## GENERATOR TRAINING - aDE.eval() - aDH.eval() + aD_H_E.eval() aGH.train() aGE.train() - #print("Generator training started") - - # Optimize ECAL generator - optimizer_g_E.zero_grad() - - ## Generate Fake ECAL - zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True) - fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)) - fake_ecal = fake_ecal.unsqueeze(1) - - ## Loss function for ECAL generator - gen_E_cost = aDE(fake_ecal, label) - g_E_cost = -torch.mean(gen_E_cost) - g_E_cost.backward() - optimizer_g_E.step() - - # Optimize HCAL generator optimizer_g_H_E.zero_grad() @@ -253,21 +300,43 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, #z = zE.view(args.batch_size, args.nz) z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True) ## generate fake data out of noise - fake_dataHCALG = aGH(z, label, fake_ecal) + fake_dataHCAL = aGH(z, label, fake_ecal) ## Total loss function for generator - gen_cost = aDH(fake_ecal, fake_dataHCALG.unsqueeze(1), label) + gen_cost = aD_H_E(fake_ecal, fake_dataHCAL.unsqueeze(1), label) g_cost = -torch.mean(gen_cost) g_cost.backward() optimizer_g_H_E.step() - if batch_idx % args.log_interval == 0 : - print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f}'.format( - epoch, batch_idx * len(dataH), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), g_E_cost.item(), g_cost.item())) - niter = epoch * len(train_loader) + batch_idx - experiment.log_metric("L_Gen_E", g_E_cost, step=niter) - experiment.log_metric("L_Gen_H", g_cost, step=niter) + if batch_idx % args.log_interval == 0: + print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), disc_cost_H_E.item())) + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_crit_E", disc_cost_E, step=niter) + experiment.log_metric("L_crit_H", disc_cost_H, step=niter) + experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter) + experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter) + experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter) + experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter) + experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter) + experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter) + experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter) + experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter) + experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter) + experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter) + experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter) + experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter) + experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter) + + if batch_idx % args.log_interval == 0 : + print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f} lossGHE={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), g_E_cost.item(), g_H_cost.item(), g_cost.item())) + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_Gen_E", g_E_cost, step=niter) + experiment.log_metric("L_Gen_H", g_H_cost, step=niter) + experiment.log_metric("L_Gen_H_E", g_cost, step=niter) def is_distributed(): @@ -306,6 +375,9 @@ def parse_args(): parser.add_argument('--nworkers', type=int, default=1, metavar='N', help='number of epochs to train (default: 1)') + parser.add_argument('--lrCrit_H_E', type=float, default=0.00001, metavar='LR', + help='learning rate Critic_H_E (default: 0.00001)') + parser.add_argument('--lrCrit_H', type=float, default=0.00001, metavar='LR', help='learning rate CriticH (default: 0.00001)') @@ -315,6 +387,9 @@ def parse_args(): parser.add_argument('--lrGen_H_E', type=float, default=0.0001, metavar='LR', help='learning rate Generator_H_E (default: 0.0001)') + parser.add_argument('--lrGen_H', type=float, default=0.0001, metavar='LR', + help='learning rate Generator_H (default: 0.0001)') + parser.add_argument('--lrGen_E', type=float, default=0.0001, metavar='LR', help='learning rate GeneratorE (default: 0.0001)') @@ -379,7 +454,8 @@ def run(args): "latent": args.nz, "lambda": args.lambd, "ncrit" : args.ncrit, - "resN_H": args.ndf, + "resN_E_H": args.ndf, + "resN_H": args.dres, "resN_E": args.dres, "ngf": args.ngf } @@ -411,55 +487,69 @@ def run(args): train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False) + ## ECAL + HCAL Generator and critic + mCrit_H_E = CriticEMB().to(device) + ## HCAL Generator and critic - mCritH = CriticEMB().to(device) + mCritH = generate_model(args.dres).to(device) mGenH = Hcal_ecalEMB(args.ngf, 32, args.nz).to(device) ## ECAL GENERATOR and critic mGenE = DCGAN_G(args.ngf, args.nz).to(device) mCritE = generate_model(args.dres).to(device) - # ## Global Critic - # mCritGlob = if args.world_size > 1: Distributor = nn.parallel.DistributedDataParallel if use_cuda \ else nn.parallel.DistributedDataParallelCPU + mCrit_H_E = Distributor(mCrit_H_E, device_ids=[args.local_rank], output_device=args.local_rank ) mCritH = Distributor(mCritH, device_ids=[args.local_rank], output_device=args.local_rank ) mGenH = Distributor(mGenH, device_ids=[args.local_rank], output_device=args.local_rank) mGenE = Distributor(mGenE, device_ids=[args.local_rank], output_device=args.local_rank) mCritE = Distributor(mCritE, device_ids=[args.local_rank], output_device=args.local_rank) else: + mCrit_H_E = nn.parallel.DataParallel(mCrit_H_E) mGenH = nn.parallel.DataParallel(mGenH) mCritH = nn.parallel.DataParallel(mCritH) mGenE = nn.parallel.DataParallel(mGenE) mCritE = nn.parallel.DataParallel(mCritE) optimizerG_H_E = optim.Adam(list(mGenH.parameters())+list(mGenE.parameters()), lr=args.lrGen_H_E, betas=(0.5, 0.9)) - - optimizerG_E = optim.Adam(mGenE.parameters(), lr=args.lrGen_E, betas=(0.5, 0.9)) - optimizerD_H_E = optim.Adam(mCritH.parameters(), lr=args.lrCrit_H, betas=(0.5, 0.9)) + optimizerD_H_E = optim.Adam(mCrit_H_E.parameters(), lr=args.lrCrit_H_E, betas=(0.5, 0.9)) + + optimizerG_H = optim.Adam(mGenH.parameters(), lr=args.lrGen_H, betas=(0.5, 0.9)) + + optimizerD_H = optim.Adam(mCritH.parameters(), lr=args.lrCrit_H, betas=(0.5, 0.9)) + + optimizerG_E = optim.Adam(mGenE.parameters(), lr=args.lrGen_E, betas=(0.5, 0.9)) optimizerD_E = optim.Adam(mCritE.parameters(), lr=args.lrCrit_E, betas=(0.5, 0.9)) if (args.chpt): critic_E_checkpoint = torch.load(args.chpt_base + args.exp + "_criticE_"+ str(args.chpt_eph) + ".pt") - critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_"+ str(args.chpt_eph) + ".pt") + critic_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_"+ str(args.chpt_eph) + ".pt") + critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_E_"+ str(args.chpt_eph) + ".pt") gen_E_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorE_"+ str(args.chpt_eph) + ".pt") gen_H_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorH_"+ str(args.chpt_eph) + ".pt") + gen_H_E_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorH_E_"+ str(args.chpt_eph) + ".pt") + mGenE.load_state_dict(gen_E_checkpoint['model_state_dict']) optimizerG_E.load_state_dict(gen_E_checkpoint['optimizer_state_dict']) mGenH.load_state_dict(gen_H_checkpoint['model_state_dict']) - optimizerG_H_E.load_state_dict(gen_H_checkpoint['optimizer_state_dict']) + optimizerG_H.load_state_dict(gen_H_checkpoint['optimizer_state_dict']) + optimizerG_H_E.load_state_dict(gen_H_E_checkpoint['optimizer_state_dict']) mCritE.load_state_dict(critic_E_checkpoint['model_state_dict']) optimizerD_E.load_state_dict(critic_E_checkpoint['optimizer_state_dict']) - mCritH.load_state_dict(critic_E_H_checkpoint['model_state_dict']) + mCritH.load_state_dict(critic_H_checkpoint['model_state_dict']) + optimizerD_H.load_state_dict(critic_H_checkpoint['optimizer_state_dict']) + + mCrit_H_E.load_state_dict(critic_E_H_checkpoint['model_state_dict']) optimizerD_H_E.load_state_dict(critic_E_H_checkpoint['optimizer_state_dict']) eph = gen_H_checkpoint['epoch'] @@ -486,22 +576,34 @@ def run(args): if args.world_size > 1: train_loader.sampler.set_epoch(epoch) - - train(args, mCritE, mCritH, mGenE, mGenH, device, train_loader, optimizerD_E, optimizerD_H_E, optimizerG_E, optimizerG_H_E, epoch, experiment) + train(args, mCritE, mCritH, mCrit_H_E, mGenE, mGenH, device, train_loader, optimizerD_E, optimizerD_H, optimizerD_H_E, optimizerG_E, optimizerG_H, optimizerG_H_E, epoch, experiment) if args.rank == 0: - gPATH = args.chpt_base + args.exp + "_generatorH_"+ str(epoch) + ".pt" + gHPATH = args.chpt_base + args.exp + "_generatorH_"+ str(epoch) + ".pt" ePATH = args.chpt_base + args.exp + "_generatorE_"+ str(epoch) + ".pt" - cPATH = args.chpt_base + args.exp + "_criticH_"+ str(epoch) + ".pt" + gPATH = args.chpt_base + args.exp + "_generatorH_E_"+ str(epoch) + ".pt" + cPATH = args.chpt_base + args.exp + "_criticH_E_"+ str(epoch) + ".pt" + cHPATH = args.chpt_base + args.exp + "_criticH_"+ str(epoch) + ".pt" cePATH = args.chpt_base + args.exp + "_criticE_"+ str(epoch) + ".pt" torch.save({ 'epoch': epoch, 'model_state_dict': mGenH.state_dict(), + 'optimizer_state_dict': optimizerG_H.state_dict() + }, gHPATH) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': mCritH.state_dict(), + 'optimizer_state_dict': optimizerD_H.state_dict() + }, cHPATH) + + torch.save({ + 'epoch': epoch, 'optimizer_state_dict': optimizerG_H_E.state_dict() }, gPATH) - + torch.save({ 'epoch': epoch, - 'model_state_dict': mCritH.state_dict(), + 'model_state_dict': mCrit_H_E.state_dict(), 'optimizer_state_dict': optimizerD_H_E.state_dict() }, cPATH)