diff --git a/wgan_ECAL_HCAL_3crit.py b/wgan_ECAL_HCAL_3crit.py new file mode 100644 index 0000000000000000000000000000000000000000..39cb390a0fa1595f7016d7ac982529dc7a030f05 --- /dev/null +++ b/wgan_ECAL_HCAL_3crit.py @@ -0,0 +1,530 @@ +from __future__ import print_function +from comet_ml import Experiment +import argparse +import os, sys +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +from torch import autograd +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from torch.autograd import Variable + +os.environ['MKL_THREADING_LAYER'] = 'GNU' + +torch.autograd.set_detect_anomaly(True) + +sys.path.append('/opt/regressor/src') + +from models.generatorFull import Hcal_ecalEMB +from models.data_loaderFull import HDF5Dataset +from models.criticFull import CriticEMB +from models.generator import DCGAN_G +from models.criticRes import generate_model + +def calc_gradient_penalty_ECAL(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize): + + alpha = torch.rand(BATCH_SIZE, 1) + alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous() + alpha = alpha.view(BATCH_SIZE, 1, layer, xsize, ysize) + alpha = alpha.to(device) + + + fake_data = fake_data.view(BATCH_SIZE, 1, layer, xsize, ysize) + interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach()) + + interpolates = interpolates.to(device) + interpolates.requires_grad_(True) + + disc_interpolates = netD(interpolates.float(), real_label.float()) + + + gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + +def calc_gradient_penalty_HCAL(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize): + + alpha = torch.rand(BATCH_SIZE, 1) + alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous() + alpha = alpha.view(BATCH_SIZE, 1, layer, xsize, ysize) + alpha = alpha.to(device) + + + fake_data = fake_data.view(BATCH_SIZE, 1, layer, xsize, ysize) + interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach()) + + interpolates = interpolates.to(device) + interpolates.requires_grad_(True) + + disc_interpolates = netD(interpolates.float(), real_label.float()) + + + gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + +def calc_gradient_penalty_ECAL_HCAL(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real_label, BATCH_SIZE, device, layer, layer_hcal, xsize, ysize): + + alphaE = torch.rand(BATCH_SIZE, 1) + alphaE = alphaE.expand(BATCH_SIZE, int(real_ecal.nelement()/BATCH_SIZE)).contiguous() + alphaE = alphaE.view(BATCH_SIZE, 1, layer, xsize, ysize) + alphaE = alphaE.to(device) + + + alphaH = torch.rand(BATCH_SIZE, 1) + alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous() + alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + alphaH = alphaH.to(device) + + fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize) + + interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach()) + interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach()) + + + interpolatesHCAL = interpolatesHCAL.to(device) + interpolatesHCAL.requires_grad_(True) + + interpolatesECAL = interpolatesECAL.to(device) + interpolatesECAL.requires_grad_(True) + + disc_interpolates = netD(interpolatesECAL.float(), interpolatesHCAL.float(), real_label.float()) + + gradients = autograd.grad(outputs=disc_interpolates, inputs=[interpolatesECAL, interpolatesHCAL], + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + +def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H, optimizer_d_H_E, optimizer_g_E, optimizer_g_H, epoch, experiment): + + ### CRITIC TRAINING + aDE.train() + aDH.train() + aD_H_E.train() + aGH.eval() + aGE.eval() + + Tensor = torch.cuda.FloatTensor + + for batch_idx, (dataE, dataH, energy) in enumerate(train_loader): + + # ECAL critic + optimizer_d_E.zero_grad() + + ## Get Real data + real_dataECAL = dataE.to(device).unsqueeze(1).float() + label = energy.to(device).float() + ### + + ## Generate Fake ECAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False) + fake_ecal_gen = aGE(zE, label.view(-1, 1, 1, 1, 1)).detach() + fake_ecal_gen = fake_ecal_gen.unsqueeze(1) + + fake_ecal = fake_ecal_gen.clone().detach() + + ## Critic fwd pass on Real + disc_real_E = aDE(real_dataECAL, label) + + ## Calculate Gradient Penalty Term + gradient_penalty_E = calc_gradient_penalty_ECAL(aDE, real_dataECAL, fake_ecal, label, args.batch_size, device, layer=30, xsize=30, ysize=30) + + ## Critic fwd pass on fake data + disc_fake_E = aDE(fake_ecal, label) + + + ## wasserstein-1 distace for critic + w_dist_E = torch.mean(disc_fake_E) - torch.mean(disc_real_E) + + # final disc cost + disc_cost_E = w_dist_E + args.lambd * gradient_penalty_E + + disc_cost_E.backward() + optimizer_d_E.step() + + + + # ECAL + HCAL critic + optimizer_d_H_E.zero_grad() + + ## Get Real data + real_dataECAL = dataE.to(device).unsqueeze(1).float() + real_dataHCAL = dataH.to(device).unsqueeze(1).float() + label = energy.to(device).float() + ### + + ## Get Fake ECAL + fake_ecal = fake_ecal_gen.clone().detach() + + ## Generate Fake HCAL + #z = zE.view(args.batch_size, args.nz) + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False) + fake_dataHCAL = aGH(z, label, fake_ecal).detach() ## 48 x 30 x 30 + + ## Critic fwd pass on Real + disc_real_H_E = aD_H_E(real_dataECAL, real_dataHCAL, label) + + ## Calculate Gradient Penalty Term + gradient_penalty_H_E = calc_gradient_penalty_ECAL_HCAL(aD_H_E, real_dataECAL, real_dataHCAL, fake_ecal, fake_dataHCAL, label, args.batch_size, device, layer=30, layer_hcal=48, xsize=30, ysize=30) + + ## Critic fwd pass on fake data + disc_fake_H_E = aD_H_E(fake_ecal, fake_dataHCAL.unsqueeze(1), label) + + ## wasserstein-1 distace for critic + w_dist_H_E = torch.mean(disc_fake_H_E) - torch.mean(disc_real_H_E) + + # final disc cost + disc_cost_H_E = w_dist_H_E + args.lambd * gradient_penalty_H_E + + disc_cost_H_E.backward() + optimizer_d_H_E.step() + + + if batch_idx % args.log_interval == 0: + print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), disc_cost_H_E.item())) + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_crit_E", disc_cost_E, step=niter) + experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter) + experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter) + experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter) + experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter) + experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter) + experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter) + experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter) + experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter) + experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter) + + + + ## training generator per ncrit + if (batch_idx % args.ncrit == 0) and (batch_idx != 0): + ## GENERATOR TRAINING + aDE.eval() + aDH.eval() + aGH.train() + aGE.train() + + #print("Generator training started") + + # Optimize ECAL generator + optimizer_g_E.zero_grad() + + ## Generate Fake ECAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)) + fake_ecal = fake_ecal.unsqueeze(1) + + ## Loss function for ECAL generator + gen_E_cost = aDE(fake_ecal, label) + g_E_cost = -torch.mean(gen_E_cost) + g_E_cost.backward() + optimizer_g_E.step() + + + # Optimize HCAL generator + optimizer_g_H_E.zero_grad() + + + ## Generate Fake ECAL + zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True) + fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1)) + fake_ecal = fake_ecal.unsqueeze(1) + #### + + #z = zE.view(args.batch_size, args.nz) + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True) + ## generate fake data out of noise + fake_dataHCALG = aGH(z, label, fake_ecal) + + ## Total loss function for generator + gen_cost = aDH(fake_ecal, fake_dataHCALG.unsqueeze(1), label) + g_cost = -torch.mean(gen_cost) + g_cost.backward() + optimizer_g_H_E.step() + + if batch_idx % args.log_interval == 0 : + print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), g_E_cost.item(), g_cost.item())) + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_Gen_E", g_E_cost, step=niter) + experiment.log_metric("L_Gen_H", g_cost, step=niter) + + +def is_distributed(): + return dist.is_available() and dist.is_initialized() + + +def parse_args(): + parser = argparse.ArgumentParser(description='WGAN training on hadron showers') + parser.add_argument('--batch-size', type=int, default=50, metavar='N', + help='input batch size for training (default: 50)') + + parser.add_argument('--nz', type=int, default=100, metavar='N', + help='latent space for generator (default: 100)') + + parser.add_argument('--lambd', type=int, default=15, metavar='N', + help='weight of gradient penalty (default: 15)') + + parser.add_argument('--kappa', type=float, default=0.001, metavar='N', + help='weight of label conditioning (default: 0.001)') + + parser.add_argument('--dres', type=int, default=34, metavar='N', + help='depth of Residual critic (default: 34)') + + parser.add_argument('--ndf', type=int, default=32, metavar='N', + help='n-feature of critic (default: 32)') + + parser.add_argument('--ngf', type=int, default=32, metavar='N', + help='n-feature of generator (default: 32)') + + parser.add_argument('--ncrit', type=int, default=10, metavar='N', + help='critic updates before generator one (default: 10)') + + parser.add_argument('--epochs', type=int, default=1, metavar='N', + help='number of epochs to train (default: 1)') + + parser.add_argument('--nworkers', type=int, default=1, metavar='N', + help='number of epochs to train (default: 1)') + + parser.add_argument('--lrCrit_H', type=float, default=0.00001, metavar='LR', + help='learning rate CriticH (default: 0.00001)') + + parser.add_argument('--lrCrit_E', type=float, default=0.00001, metavar='LR', + help='learning rate CriticE (default: 0.00001)') + + parser.add_argument('--lrGen_H_E', type=float, default=0.0001, metavar='LR', + help='learning rate Generator_H_E (default: 0.0001)') + + parser.add_argument('--lrGen_E', type=float, default=0.0001, metavar='LR', + help='learning rate GeneratorE (default: 0.0001)') + + parser.add_argument('--chpt', action='store_true', default=False, + help='continue training from a saved model') + + parser.add_argument('--chpt_base', type=str, default='/eos/user/e/eneren/experiments/', + help='continue training from a saved model') + + parser.add_argument('--exp', type=str, default='dist_wgan', + help='name of the experiment') + + parser.add_argument('--chpt_eph', type=int, default=1, + help='continue checkpoint epoch') + + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status') + + + if dist.is_available(): + parser.add_argument('--backend', type=str, help='Distributed backend', + choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI], + default=dist.Backend.GLOO) + + parser.add_argument('--local_rank', type=int, default=0) + + args = parser.parse_args() + + + args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank)) + args.rank = int(os.environ.get('RANK')) + args.world_size = int(os.environ.get('WORLD_SIZE')) + + + # postprocess args + args.device = 'cuda:{}'.format(args.local_rank) # PytorchJob/launch.py + args.batch_size = max(args.batch_size, + args.world_size * 2) # min valid batchsize + return args + + + +def run(args): + # Training settings + + use_cuda = not args.no_cuda and torch.cuda.is_available() + if use_cuda: + print('Using CUDA') + + + experiment = Experiment(api_key="keGmeIz4GfKlQZlOP6cit4QOi", + project_name="ecal-hcal-shower", workspace="engineren", auto_output_logging="simple") + experiment.add_tag(args.exp) + + experiment.log_parameters( + { + "batch_size" : args.batch_size, + "latent": args.nz, + "lambda": args.lambd, + "ncrit" : args.ncrit, + "resN_H": args.ndf, + "resN_E": args.dres, + "ngf": args.ngf + } + ) + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + + if args.world_size > 1: + print('Using distributed PyTorch with {} backend'.format(args.backend)) + dist.init_process_group(backend=args.backend) + + print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size)) + + + + print ("loading data") + #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/40GeV40k.hdf5', transform=None, train_size=40000) + dataset = HDF5Dataset('/eos/user/e/eneren/scratch/50GeV75k.hdf5', transform=None, train_size=75000) + #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/4060GeV.hdf5', transform=None, train_size=60000) + + + if args.world_size > 1: + sampler = DistributedSampler(dataset, shuffle=True) + train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False) + else: + train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False) + + + ## HCAL Generator and critic + mCritH = CriticEMB().to(device) + mGenH = Hcal_ecalEMB(args.ngf, 32, args.nz).to(device) + + ## ECAL GENERATOR and critic + mGenE = DCGAN_G(args.ngf, args.nz).to(device) + mCritE = generate_model(args.dres).to(device) + + # ## Global Critic + # mCritGlob = + + if args.world_size > 1: + Distributor = nn.parallel.DistributedDataParallel if use_cuda \ + else nn.parallel.DistributedDataParallelCPU + mCritH = Distributor(mCritH, device_ids=[args.local_rank], output_device=args.local_rank ) + mGenH = Distributor(mGenH, device_ids=[args.local_rank], output_device=args.local_rank) + mGenE = Distributor(mGenE, device_ids=[args.local_rank], output_device=args.local_rank) + mCritE = Distributor(mCritE, device_ids=[args.local_rank], output_device=args.local_rank) + else: + mGenH = nn.parallel.DataParallel(mGenH) + mCritH = nn.parallel.DataParallel(mCritH) + mGenE = nn.parallel.DataParallel(mGenE) + mCritE = nn.parallel.DataParallel(mCritE) + + optimizerG_H_E = optim.Adam(list(mGenH.parameters())+list(mGenE.parameters()), lr=args.lrGen_H_E, betas=(0.5, 0.9)) + + optimizerG_E = optim.Adam(mGenE.parameters(), lr=args.lrGen_E, betas=(0.5, 0.9)) + + optimizerD_H_E = optim.Adam(mCritH.parameters(), lr=args.lrCrit_H, betas=(0.5, 0.9)) + + optimizerD_E = optim.Adam(mCritE.parameters(), lr=args.lrCrit_E, betas=(0.5, 0.9)) + + + if (args.chpt): + critic_E_checkpoint = torch.load(args.chpt_base + args.exp + "_criticE_"+ str(args.chpt_eph) + ".pt") + critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_"+ str(args.chpt_eph) + ".pt") + gen_E_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorE_"+ str(args.chpt_eph) + ".pt") + gen_H_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorH_"+ str(args.chpt_eph) + ".pt") + + mGenE.load_state_dict(gen_E_checkpoint['model_state_dict']) + optimizerG_E.load_state_dict(gen_E_checkpoint['optimizer_state_dict']) + + mGenH.load_state_dict(gen_H_checkpoint['model_state_dict']) + optimizerG_H_E.load_state_dict(gen_H_checkpoint['optimizer_state_dict']) + + mCritE.load_state_dict(critic_E_checkpoint['model_state_dict']) + optimizerD_E.load_state_dict(critic_E_checkpoint['optimizer_state_dict']) + + mCritH.load_state_dict(critic_E_H_checkpoint['model_state_dict']) + optimizerD_H_E.load_state_dict(critic_E_H_checkpoint['optimizer_state_dict']) + + eph = gen_H_checkpoint['epoch'] + + else: + eph = 0 + gen_E_checkpoint = torch.load("/eos/user/e/eneren/experiments/wganv1_generator_694.pt", map_location=torch.device('cuda')) + critic_E_checkpoint = torch.load("/eos/user/e/eneren/experiments/wganv1_critic_694.pt", map_location=torch.device('cuda')) + + mGenE.load_state_dict(gen_E_checkpoint['model_state_dict']) + optimizerG_E.load_state_dict(gen_E_checkpoint['optimizer_state_dict']) + + mCritE.load_state_dict(critic_E_checkpoint['model_state_dict']) + optimizerD_E.load_state_dict(critic_E_checkpoint['optimizer_state_dict']) + + print ("init models") + + + experiment.set_model_graph(str(mCritH), overwrite=False) + + print ("starting training...") + for epoch in range(1, args.epochs + 1): + epoch += eph + + if args.world_size > 1: + train_loader.sampler.set_epoch(epoch) + + train(args, mCritE, mCritH, mGenE, mGenH, device, train_loader, optimizerD_E, optimizerD_H_E, optimizerG_E, optimizerG_H_E, epoch, experiment) + if args.rank == 0: + gPATH = args.chpt_base + args.exp + "_generatorH_"+ str(epoch) + ".pt" + ePATH = args.chpt_base + args.exp + "_generatorE_"+ str(epoch) + ".pt" + cPATH = args.chpt_base + args.exp + "_criticH_"+ str(epoch) + ".pt" + cePATH = args.chpt_base + args.exp + "_criticE_"+ str(epoch) + ".pt" + torch.save({ + 'epoch': epoch, + 'model_state_dict': mGenH.state_dict(), + 'optimizer_state_dict': optimizerG_H_E.state_dict() + }, gPATH) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': mCritH.state_dict(), + 'optimizer_state_dict': optimizerD_H_E.state_dict() + }, cPATH) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': mGenE.state_dict(), + 'optimizer_state_dict': optimizerG_E.state_dict() + }, ePATH) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': mCritE.state_dict(), + 'optimizer_state_dict': optimizerD_E.state_dict() + }, cePATH) + + + print ("end training") + + + +def main(): + args = parse_args() + run(args) + +if __name__ == '__main__': + main()