diff --git a/Dockerfile b/Dockerfile index dfcf4f974a21addec49afeae7422f16336c21739..23e5c4abe6f43a16aece4479d92e140577db9d42 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,8 +15,10 @@ RUN mkdir -p /opt/regressor && \ && pip install h5py pyflakes comet_ml && export MKL_SERVICE_FORCE_INTEL=1 WORKDIR /opt/regressor/src -ADD regressor.py /opt/regressor/src/regressor.py + ADD wgan.py /opt/regressor/src/wgan.py +ADD wganHCAL.py /opt/regressor/src/wganHCAL.py + COPY ./models/* /opt/regressor/src/models/ COPY docker/krb5.conf /etc/krb5.conf diff --git a/models/criticFull.py b/models/criticFull.py new file mode 100644 index 0000000000000000000000000000000000000000..16333c33a13a65beeda8b45ec3c5cfbdd1108150 --- /dev/null +++ b/models/criticFull.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.nn.functional as F + + +class CriticEMB(nn.Module): + def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64): + super(CriticEMB, self).__init__() + self.ndf = ndf + self.isize_1 = isize_1 + self.isize_2 = isize_2 + self.nc = nc + self.size_embed = 16 + self.conv1_bias = False + + + + # ECAL component of convolutions + # Designed for input 30*30*30 + self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False) + self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29]) + self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False) + self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14]) + self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) + # output [batch_size, ndf, 7, 7, 7] + + # HCAL component of convolutions + # Designed for input 48*30*30 + self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=(5,0,0), bias=False) + self.ln_HCAL_1 = torch.nn.LayerNorm([29,29,29]) + self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False) + self.ln_HCAL_2 = torch.nn.LayerNorm([14,14,14]) + self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) + # output [batch_size, ndf, 7, 7, 7] + + # alternative structure for 48*25*25 HCAL + #self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=0, bias=False) + #self.ln_HCAL_1 = torch.nn.LayerNorm([24,24,24]) + #self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False) + #self.ln_HCAL_2 = torch.nn.LayerNorm([12,12,12]) + #self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) + + + self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) + self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, 64) + + self.econd_lin = torch.nn.Linear(1, 64) # label embedding + + self.fc1 = torch.nn.Linear(64*3, 128) # 3 components after cat + self.fc2 = torch.nn.Linear(128, 64) + self.fc3 = torch.nn.Linear(64, 1) + + + def forward(self, img_ECAL, img_HCAL, E_true): + batch_size = img_ECAL.size(0) + # input: img_ECAL = [batch_size, 1, 30, 30, 30] + # img_HCAL = [batch_size, 1, 48, 30, 30] + + # ECAL + x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2) + x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2) + x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2) + x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7) + x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2) + + # HCAL + x_HCAL = F.leaky_relu(self.ln_HCAL_1(self.conv_HCAL_1(img_HCAL)), 0.2) + x_HCAL = F.leaky_relu(self.ln_HCAL_2(self.conv_HCAL_2(x_HCAL)), 0.2) + x_HCAL = F.leaky_relu(self.conv_HCAL_3(x_HCAL), 0.2) + x_HCAL = x_HCAL.view(-1, self.ndf*7*7*7) + x_HCAL = F.leaky_relu(self.conv_lin_HCAL(x_HCAL), 0.2) + + x_E = F.leaky_relu(self.econd_lin(E_true), 0.2) + + xa = torch.cat((x_ECAL, x_HCAL, x_E), 1) + + xa = F.leaky_relu(self.fc1(xa), 0.2) + xa = F.leaky_relu(self.fc2(xa), 0.2) + xa = self.fc3(xa) + + return xa ### flattens \ No newline at end of file diff --git a/models/data_loaderFull.py b/models/data_loaderFull.py index 17f9d2f01e00df7ebec2a923843ff86103a528c8..e157bf1043c18b1196c523f847bda4f4f40c8044 100644 --- a/models/data_loaderFull.py +++ b/models/data_loaderFull.py @@ -1,7 +1,6 @@ import numpy as np import torch -from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler import os import h5py from torch.utils import data diff --git a/models/generatorFull.py b/models/generatorFull.py new file mode 100644 index 0000000000000000000000000000000000000000..029b553f80091ce3edc29d105e183b44da07f409 --- /dev/null +++ b/models/generatorFull.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.nn.functional as F + + + +class Hcal_ecalEMB(nn.Module): + """ + generator component of WGAN + """ + def __init__(self, ngf, ndf, nz, emb_size): + super(Hcal_ecalEMB, self).__init__() + + + self.ndf = ndf + self.emb_size = emb_size + # ECAL component of convolutions + # Designed for input 30*30*30 + + self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False) + self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29]) + self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False) + self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14]) + self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) + + self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) + + self.econd_lin = torch.nn.Linear(1, 64) # label embedding + + self.fc1 = torch.nn.Linear(64*2, 128) # 2 components after cat + self.fc2 = torch.nn.Linear(128, 64) + self.fc3 = torch.nn.Linear(64, emb_size) + + + ## HCAL component of convolutions + self.ngf = ngf + self.nz = nz + + kernel = 4 + + + self.conv1 = nn.ConvTranspose3d(emb_size + nz + 1, ngf, kernel, 1, 0, bias=False) + ##torch.Size([100, 64, 4, 4, 4]) + + + # outs from first convolutions concatenate state size [ ngf*8 x 4 x 4] + # and going into main convolutional part of Generator + self.main_conv = nn.Sequential( + + nn.ConvTranspose3d(ngf, ngf*4, kernel_size=(4,2,2), stride=2, padding=1, bias=False), + nn.LayerNorm([8, 6, 6]), + nn.ReLU(), + # state shape [ (ndf*4) x 6 x 6 ] + + nn.ConvTranspose3d(ngf*4, ngf*2, kernel_size=(4,2,2), stride=2, padding=1, bias=False), + nn.LayerNorm([16, 10, 10]), + nn.ReLU(), + # state shape [ (ndf*2) x 10 x 10 ] + + nn.ConvTranspose3d(ngf*2, ngf, kernel_size=(4,4,4), stride=(2,1,1), padding=1, bias=False), + nn.LayerNorm([32, 11, 11]), + nn.ReLU(), + # state shape [ (ndf) x 11 x 11 ] + + nn.ConvTranspose3d(ngf, ngf, kernel_size=(10,4,4), stride=1, padding=1, bias=False), + nn.LayerNorm([39, 12, 12]), + nn.ReLU(), + # state shape [ ch=10 x 12 x 12 ] + + nn.ConvTranspose3d(ngf, 5, kernel_size=(8,3,3), stride=(1,2,2), padding=1, bias=False), + nn.LayerNorm([44, 23, 23]), + nn.ReLU(), + + # state shape [ ch=5 x 23 x 23 ] + + nn.ConvTranspose3d(5, 1, kernel_size=(7,10,10), stride=1, padding=1, bias=False), + nn.ReLU() + + ## final output ---> [48 x 25 x 25] + ) + + def forward(self, noise, energy, img_ECAL): + + + batch_size = img_ECAL.size(0) + # input: img_ECAL = [batch_size, 1, 30, 30, 30] + # + + # ECAL + x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2) + x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2) + x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2) + x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7) + x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2) + + x_E = F.leaky_relu(self.econd_lin(energy), 0.2) + + xa = torch.cat((x_ECAL, x_E), 1) + + xa = F.leaky_relu(self.fc1(xa), 0.2) + xa = F.leaky_relu(self.fc2(xa), 0.2) + xa = self.fc3(xa) + + + xzm = torch.cat((xa, energy, noise), 1) + xzm = xzm.view(xzm.size(0), xzm.size(1), 1, 1, 1) + + + inpt = self.conv1(xzm) + + x = self.main_conv(inpt) + + x = x.view(-1, 48, 30, 30) + + return x diff --git a/pytorch_job_wganHCAL_nccl.yaml b/pytorch_job_wganHCAL_nccl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcdf0407205c813895c65c8207c729516fb1932b --- /dev/null +++ b/pytorch_job_wganHCAL_nccl.yaml @@ -0,0 +1,94 @@ +apiVersion: "kubeflow.org/v1" +kind: "PyTorchJob" +metadata: + name: "pytorch-dist-wganHCAL-nccl" +spec: + pytorchReplicaSpecs: + Master: + replicas: 1 + restartPolicy: OnFailure + template: + metadata: + annotations: + sidecar.istio.io/inject: "false" + spec: + volumes: + - name: eos + hostPath: + path: /var/eos + - name: krb-secret-vol + secret: + secretName: krb-secret + - name: nvidia-driver + hostPath: + path: /opt/nvidia-driver + type: "" + containers: + - name: pytorch + image: gitlab-registry.cern.ch/eneren/pytorchjob:ddp + imagePullPolicy: Always + env: + - name: PATH + value: /opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/nvidia-driver/bin + - name: LD_LIBRARY_PATH + value: /opt/nvidia-driver/lib64 + - name: PYTHONUNBUFFERED + value: "1" + command: [sh, -c] + args: + - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 + && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001; + volumeMounts: + - name: eos + mountPath: /eos + - name: krb-secret-vol + mountPath: "/secret/krb-secret-vol" + - name: nvidia-driver + mountPath: /opt/nvidia-driver + resources: + limits: + nvidia.com/gpu: 1 + Worker: + replicas: 1 + restartPolicy: OnFailure + template: + metadata: + annotations: + sidecar.istio.io/inject: "false" + spec: + volumes: + - name: eos + hostPath: + path: /var/eos + - name: krb-secret-vol + secret: + secretName: krb-secret + - name: nvidia-driver + hostPath: + path: /opt/nvidia-driver + type: "" + containers: + - name: pytorch + image: gitlab-registry.cern.ch/eneren/pytorchjob:ddp + imagePullPolicy: Always + env: + - name: PATH + value: /opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/nvidia-driver/bin + - name: LD_LIBRARY_PATH + value: /opt/nvidia-driver/lib64 + - name: PYTHONUNBUFFERED + value: "1" + command: [sh, -c] + args: + - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 + && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001; + volumeMounts: + - name: eos + mountPath: /eos + - name: krb-secret-vol + mountPath: "/secret/krb-secret-vol" + - name: nvidia-driver + mountPath: /opt/nvidia-driver + resources: + limits: + nvidia.com/gpu: 1 diff --git a/wganHCAL.py b/wganHCAL.py index f3046bb49dbe36434173560c7c519b5e2701f611..0988c4c1271e793203b198a8adb3e94ba9604c02 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -12,3 +12,307 @@ from torch.utils.data.distributed import DistributedSampler from torch.utils.data import DataLoader from torch.autograd import Variable +os.environ['MKL_THREADING_LAYER'] = 'GNU' + +torch.autograd.set_detect_anomaly(True) + +sys.path.append('/opt/regressor/src') + +from models.generatorFull import Hcal_ecalEMB +from models.data_loaderFull import HDF5Dataset +from models.criticFull import CriticEMB + +def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize): + + alpha = torch.rand(BATCH_SIZE, 1) + alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous() + alpha = alpha.view(BATCH_SIZE, 1, layer, xsize, ysize) + alpha = alpha.to(device) + + + fake_data = fake_data.view(BATCH_SIZE, 1, layer, xsize, ysize) + interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach()) + + interpolates = interpolates.to(device) + interpolates.requires_grad_(True) + + disc_interpolates = netD(interpolates.float(), real_label.float()) + + + gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + + +def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, experiment): + + ### CRITIC TRAINING + aD.train() + aG.eval() + + Tensor = torch.cuda.FloatTensor + + for batch_idx, (dataE, dataH, energy) in enumerate(train_loader): + real_dataECAL = dataE.to(device).unsqueeze(1) + real_dataHCAL = dataH.to(device).unsqueeze(1) + real_label = energy.to(device) + + optimizer_d.zero_grad() + + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz)))) + + ## Generate Fake data + fake_dataHCAL = aG(z, real_label, real_dataECAL).detach() ## 48 x 30 x 30 + + ## Critic fwd pass on Real + disc_real = aD(real_dataECAL.float(), real_dataHCAL.float(), real_label.float()) + + ## Calculate Gradient Penalty Term + gradient_penalty = calc_gradient_penalty(aD, real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30) + + ## Critic fwd pass on Fake + disc_fake = aD(real_dataECAL, fake_dataHCAL.unsqueeze(1), real_label) + + + ## wasserstein-1 distace + w_dist = torch.mean(disc_fake) - torch.mean(disc_real) + # final disc cost + disc_cost = w_dist + args.lambd * gradient_penalty + + disc_cost.backward() + optimizer_d.step() + + if batch_idx % args.log_interval == 0: + print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), disc_cost.item())) + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_crit", disc_cost, step=niter) + experiment.log_metric("gradient_pen", gradient_penalty, step=niter) + experiment.log_metric("Wasserstein Dist", w_dist, step=niter) + experiment.log_metric("Critic Score (Real)", torch.mean(disc_real), step=niter) + experiment.log_metric("Critic Score (Fake)", torch.mean(disc_fake), step=niter) + + + ## training generator per ncrit + if (batch_idx % args.ncrit == 0) and (batch_idx != 0): + ## GENERATOR TRAINING + aD.eval() + aG.train() + + #print("Generator training started") + + optimizer_g.zero_grad() + + + ## generate fake data out of noise + fake_dataHCALG = aG(z, real_label, real_dataECAL) + + + ## Total loss function for generator + gen_cost = aD(real_dataECAL, fake_dataHCALG.unsqueeze(1), real_label) + g_cost = -torch.mean(gen_cost) + g_cost.backward() + optimizer_g.step() + + if batch_idx % args.log_interval == 0 : + print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), g_cost.item())) + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_Gen", g_cost, step=niter) + + +def is_distributed(): + return dist.is_available() and dist.is_initialized() + + +def parse_args(): + parser = argparse.ArgumentParser(description='WGAN training on hadron showers') + parser.add_argument('--batch-size', type=int, default=100, metavar='N', + help='input batch size for training (default: 100)') + + parser.add_argument('--nz', type=int, default=100, metavar='N', + help='latent space for generator (default: 100)') + + parser.add_argument('--lambd', type=int, default=15, metavar='N', + help='weight of gradient penalty (default: 15)') + + parser.add_argument('--kappa', type=float, default=0.001, metavar='N', + help='weight of label conditioning (default: 0.001)') + + parser.add_argument('--ndf', type=int, default=64, metavar='N', + help='n-feature of critic (default: 64)') + + parser.add_argument('--ngf', type=int, default=32, metavar='N', + help='n-feature of generator (default: 32)') + + parser.add_argument('--ncrit', type=int, default=10, metavar='N', + help='critic updates before generator one (default: 10)') + + parser.add_argument('--epochs', type=int, default=1, metavar='N', + help='number of epochs to train (default: 1)') + parser.add_argument('--lrCrit', type=float, default=0.00001, metavar='LR', + help='learning rate Critic (default: 0.00001)') + parser.add_argument('--lrGen', type=float, default=0.0001, metavar='LR', + help='learning rate Generator (default: 0.0001)') + + parser.add_argument('--chpt', action='store_true', default=False, + help='continue training from a saved model') + + parser.add_argument('--chpt_base', type=str, default='/eos/user/e/eneren/experiments/', + help='continue training from a saved model') + + parser.add_argument('--exp', type=str, default='dist_wgan', + help='name of the experiment') + + parser.add_argument('--chpt_eph', type=int, default=1, + help='continue checkpoint epoch') + + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status') + + + if dist.is_available(): + parser.add_argument('--backend', type=str, help='Distributed backend', + choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI], + default=dist.Backend.GLOO) + + parser.add_argument('--local_rank', type=int, default=0) + + args = parser.parse_args() + + + args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank)) + args.rank = int(os.environ.get('RANK')) + args.world_size = int(os.environ.get('WORLD_SIZE')) + + + # postprocess args + args.device = f'cuda:{args.local_rank}' # PytorchJob/launch.py + args.batch_size = max(args.batch_size, + args.world_size * 2) # min valid batchsize + return args + + + +def run(args): + # Training settings + + use_cuda = not args.no_cuda and torch.cuda.is_available() + if use_cuda: + print('Using CUDA') + + + experiment = Experiment(api_key="keGmeIz4GfKlQZlOP6cit4QOi", + project_name="ecal-hcal-shower", workspace="engineren", auto_output_logging="simple") + experiment.add_tag(args.exp) + + experiment.log_parameters( + { + "batch_size" : args.batch_size, + "latent": args.nz, + "lambda": args.lambd, + "ncrit" : args.ncrit, + "resN": args.ndf, + "ngf": args.ngf + } + ) + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + + + print('Using distributed PyTorch with {} backend'.format(args.backend)) + dist.init_process_group(backend=args.backend) + + print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size)) + + + + print ("loading data") + #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/40GeV40k.hdf5', transform=None, train_size=40000) + #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/60GeV20k.hdf5', transform=None, train_size=20000) + dataset = HDF5Dataset('/eos/user/e/eneren/scratch/4060GeV.hdf5', transform=None, train_size=60000) + + + sampler = DistributedSampler(dataset, shuffle=True) + train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=1, drop_last=True, pin_memory=False) + + + + mCrit = CriticEMB().to(device) + mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=32).to(device) + + + + if args.world_size > 1: + Distributor = nn.parallel.DistributedDataParallel if use_cuda \ + else nn.parallel.DistributedDataParallelCPU + mCrit = Distributor(mCrit, device_ids=[args.local_rank], output_device=args.local_rank ) + mGen = Distributor(mGen, device_ids=[args.local_rank], output_device=args.local_rank) + + + optimizerG = optim.Adam(mGen.parameters(), lr=args.lrGen, betas=(0.5, 0.9)) + optimizerD = optim.Adam(mCrit.parameters(), lr=args.lrCrit, betas=(0.5, 0.9)) + + if (args.chpt): + critic_checkpoint = torch.load(args.chpt_base + args.exp + "_critic_"+ str(args.chpt_eph) + ".pt") + gen_checkpoint = torch.load(args.chpt_base + args.exp + "_generator_"+ str(args.chpt_eph) + ".pt") + + mGen.load_state_dict(gen_checkpoint['model_state_dict']) + optimizerG.load_state_dict(gen_checkpoint['optimizer_state_dict']) + + mCrit.load_state_dict(critic_checkpoint['model_state_dict']) + optimizerD.load_state_dict(critic_checkpoint['optimizer_state_dict']) + + eph = gen_checkpoint['epoch'] + + else: + eph = 0 + print ("init models") + + + experiment.set_model_graph(str(mGen), overwrite=False) + + print ("starting training...") + for epoch in range(1, args.epochs + 1): + epoch += eph + train_loader.sampler.set_epoch(epoch) + train(args, mCrit, mGen, device, train_loader, optimizerD, optimizerG, epoch, experiment) + if args.rank == 0: + gPATH = args.chpt_base + args.exp + "_generator_"+ str(epoch) + ".pt" + cPATH = args.chpt_base + args.exp + "_critic_"+ str(epoch) + ".pt" + torch.save({ + 'epoch': epoch, + 'model_state_dict': mGen.state_dict(), + 'optimizer_state_dict': optimizerG.state_dict() + }, gPATH) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': mCrit.state_dict(), + 'optimizer_state_dict': optimizerD.state_dict() + }, cPATH) + + + print ("end training") + + + +def main(): + args = parse_args() + run(args) + +if __name__ == '__main__': + main() \ No newline at end of file