diff --git a/wgan_ECAL_HCAL_3crit.py b/wgan_ECAL_HCAL_3crit.py
new file mode 100644
index 0000000000000000000000000000000000000000..39cb390a0fa1595f7016d7ac982529dc7a030f05
--- /dev/null
+++ b/wgan_ECAL_HCAL_3crit.py
@@ -0,0 +1,530 @@
+from __future__ import print_function
+from comet_ml import Experiment
+import argparse
+import os, sys
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.optim as optim
+from torch import autograd
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import DataLoader
+from torch.autograd import Variable
+
+os.environ['MKL_THREADING_LAYER'] = 'GNU'
+
+torch.autograd.set_detect_anomaly(True)
+
+sys.path.append('/opt/regressor/src')
+
+from models.generatorFull import Hcal_ecalEMB
+from models.data_loaderFull import HDF5Dataset
+from models.criticFull import CriticEMB
+from models.generator import DCGAN_G
+from models.criticRes import generate_model
+
+def calc_gradient_penalty_ECAL(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize):
+    
+    alpha = torch.rand(BATCH_SIZE, 1)
+    alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous()
+    alpha = alpha.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    alpha = alpha.to(device)
+
+
+    fake_data = fake_data.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
+
+    interpolates = interpolates.to(device)
+    interpolates.requires_grad_(True)   
+
+    disc_interpolates = netD(interpolates.float(), real_label.float())
+
+
+    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
+                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+                              create_graph=True, retain_graph=True, only_inputs=True)[0]
+
+    gradients = gradients.view(gradients.size(0), -1)                              
+    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+    return gradient_penalty
+
+def calc_gradient_penalty_HCAL(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize):
+
+    alpha = torch.rand(BATCH_SIZE, 1)
+    alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous()
+    alpha = alpha.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    alpha = alpha.to(device)
+
+
+    fake_data = fake_data.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
+
+    interpolates = interpolates.to(device)
+    interpolates.requires_grad_(True)
+
+    disc_interpolates = netD(interpolates.float(), real_label.float())
+
+
+    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
+                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+                              create_graph=True, retain_graph=True, only_inputs=True)[0]
+
+    gradients = gradients.view(gradients.size(0), -1)
+    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+    return gradient_penalty
+
+
+def calc_gradient_penalty_ECAL_HCAL(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real_label, BATCH_SIZE, device, layer, layer_hcal, xsize, ysize):
+    
+    alphaE = torch.rand(BATCH_SIZE, 1)
+    alphaE = alphaE.expand(BATCH_SIZE, int(real_ecal.nelement()/BATCH_SIZE)).contiguous()
+    alphaE = alphaE.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    alphaE = alphaE.to(device)
+
+
+    alphaH = torch.rand(BATCH_SIZE, 1)
+    alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous()
+    alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    alphaH = alphaH.to(device)
+
+    fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    
+    interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach())
+    interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach())
+
+
+    interpolatesHCAL = interpolatesHCAL.to(device)
+    interpolatesHCAL.requires_grad_(True)   
+
+    interpolatesECAL = interpolatesECAL.to(device)
+    interpolatesECAL.requires_grad_(True)   
+
+    disc_interpolates = netD(interpolatesECAL.float(), interpolatesHCAL.float(), real_label.float())
+
+    gradients = autograd.grad(outputs=disc_interpolates, inputs=[interpolatesECAL, interpolatesHCAL],
+                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+                              create_graph=True, retain_graph=True, only_inputs=True)[0]
+
+    gradients = gradients.view(gradients.size(0), -1)                              
+    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+    return gradient_penalty
+
+
+def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H, optimizer_d_H_E, optimizer_g_E, optimizer_g_H, epoch, experiment):
+    
+    ### CRITIC TRAINING
+    aDE.train()
+    aDH.train()
+    aD_H_E.train()
+    aGH.eval()
+    aGE.eval()
+
+    Tensor = torch.cuda.FloatTensor 
+   
+    for batch_idx, (dataE, dataH, energy) in enumerate(train_loader):
+        
+	    # ECAL critic
+        optimizer_d_E.zero_grad()
+
+        ## Get Real data
+        real_dataECAL = dataE.to(device).unsqueeze(1).float()
+        label = energy.to(device).float()
+        ###
+
+        ## Generate Fake ECAL
+        zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False)
+        fake_ecal_gen = aGE(zE, label.view(-1, 1, 1, 1, 1)).detach()
+        fake_ecal_gen = fake_ecal_gen.unsqueeze(1)         
+
+        fake_ecal = fake_ecal_gen.clone().detach()
+
+        ## Critic fwd pass on Real
+        disc_real_E = aDE(real_dataECAL, label)
+
+        ## Calculate Gradient Penalty Term
+        gradient_penalty_E = calc_gradient_penalty_ECAL(aDE, real_dataECAL, fake_ecal, label, args.batch_size, device, layer=30, xsize=30, ysize=30)
+
+        ## Critic fwd pass on fake data
+        disc_fake_E = aDE(fake_ecal, label)
+
+
+        ## wasserstein-1 distace for critic
+        w_dist_E = torch.mean(disc_fake_E) - torch.mean(disc_real_E)
+
+        # final disc cost
+        disc_cost_E = w_dist_E + args.lambd * gradient_penalty_E
+
+        disc_cost_E.backward()
+        optimizer_d_E.step()
+
+
+
+        # ECAL + HCAL critic
+        optimizer_d_H_E.zero_grad()
+
+        ## Get Real data
+        real_dataECAL = dataE.to(device).unsqueeze(1).float()
+        real_dataHCAL = dataH.to(device).unsqueeze(1).float()
+        label = energy.to(device).float()
+        ###
+
+        ## Get Fake ECAL
+        fake_ecal = fake_ecal_gen.clone().detach()
+
+        ## Generate Fake HCAL
+        #z = zE.view(args.batch_size, args.nz)
+        z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False)
+        fake_dataHCAL = aGH(z, label, fake_ecal).detach() ## 48 x 30 x 30	
+
+        ## Critic fwd pass on Real
+        disc_real_H_E = aD_H_E(real_dataECAL, real_dataHCAL, label)
+
+        ## Calculate Gradient Penalty Term
+        gradient_penalty_H_E = calc_gradient_penalty_ECAL_HCAL(aD_H_E, real_dataECAL, real_dataHCAL, fake_ecal, fake_dataHCAL, label, args.batch_size, device, layer=30, layer_hcal=48, xsize=30, ysize=30)
+
+        ## Critic fwd pass on fake data
+        disc_fake_H_E = aD_H_E(fake_ecal, fake_dataHCAL.unsqueeze(1), label)
+
+        ## wasserstein-1 distace for critic
+        w_dist_H_E = torch.mean(disc_fake_H_E) - torch.mean(disc_real_H_E)
+
+        # final disc cost
+        disc_cost_H_E = w_dist_H_E + args.lambd * gradient_penalty_H_E
+
+        disc_cost_H_E.backward()
+        optimizer_d_H_E.step()
+
+
+        if batch_idx % args.log_interval == 0:
+            print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
+                epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), disc_cost_H_E.item()))
+            niter = epoch * len(train_loader) + batch_idx
+            experiment.log_metric("L_crit_E", disc_cost_E, step=niter)
+            experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter)
+            experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter)
+            experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter)
+            experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter)
+            experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter)
+            experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter)
+            experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter)
+            experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter)
+            experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter)
+
+
+        
+        ## training generator per ncrit 
+        if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
+            ## GENERATOR TRAINING
+            aDE.eval()
+            aDH.eval()
+            aGH.train()
+            aGE.train()
+
+            #print("Generator training started")
+	    
+	        # Optimize ECAL generator
+            optimizer_g_E.zero_grad()
+
+            ## Generate Fake ECAL
+            zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True)
+            fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1))
+            fake_ecal = fake_ecal.unsqueeze(1)
+
+            ## Loss function for ECAL generator
+            gen_E_cost = aDE(fake_ecal, label)
+            g_E_cost = -torch.mean(gen_E_cost)
+            g_E_cost.backward()
+            optimizer_g_E.step() 
+
+
+            # Optimize HCAL generator
+            optimizer_g_H_E.zero_grad()
+
+
+            ## Generate Fake ECAL
+            zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True)
+            fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1))
+            fake_ecal = fake_ecal.unsqueeze(1) 
+            ####
+
+            #z = zE.view(args.batch_size, args.nz)
+            z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True)
+            ## generate fake data out of noise
+            fake_dataHCALG = aGH(z, label, fake_ecal)
+            
+            ## Total loss function for generator
+            gen_cost = aDH(fake_ecal, fake_dataHCALG.unsqueeze(1), label)
+            g_cost = -torch.mean(gen_cost) 
+            g_cost.backward()
+            optimizer_g_H_E.step()
+
+            if batch_idx % args.log_interval == 0 :
+                print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f}'.format(
+                    epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                    100. * batch_idx / len(train_loader), g_E_cost.item(),  g_cost.item()))
+                niter = epoch * len(train_loader) + batch_idx
+                experiment.log_metric("L_Gen_E", g_E_cost, step=niter)
+                experiment.log_metric("L_Gen_H", g_cost, step=niter)
+
+
+def is_distributed():
+    return dist.is_available() and dist.is_initialized()
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='WGAN training on hadron showers')
+    parser.add_argument('--batch-size', type=int, default=50, metavar='N',
+                        help='input batch size for training (default: 50)')
+    
+    parser.add_argument('--nz', type=int, default=100, metavar='N',
+                        help='latent space for generator (default: 100)')
+    
+    parser.add_argument('--lambd', type=int, default=15, metavar='N',
+                        help='weight of gradient penalty  (default: 15)')
+
+    parser.add_argument('--kappa', type=float, default=0.001, metavar='N',
+                        help='weight of label conditioning  (default: 0.001)')
+ 
+    parser.add_argument('--dres', type=int, default=34, metavar='N',
+                        help='depth of Residual critic (default: 34)')    
+
+    parser.add_argument('--ndf', type=int, default=32, metavar='N',
+                        help='n-feature of critic (default: 32)')
+
+    parser.add_argument('--ngf', type=int, default=32, metavar='N',
+                        help='n-feature of generator  (default: 32)')
+
+    parser.add_argument('--ncrit', type=int, default=10, metavar='N',
+                        help='critic updates before generator one  (default: 10)')
+
+    parser.add_argument('--epochs', type=int, default=1, metavar='N',
+                        help='number of epochs to train (default: 1)')
+    
+    parser.add_argument('--nworkers', type=int, default=1, metavar='N',
+                        help='number of epochs to train (default: 1)')
+
+    parser.add_argument('--lrCrit_H', type=float, default=0.00001, metavar='LR',
+                        help='learning rate CriticH (default: 0.00001)')
+
+    parser.add_argument('--lrCrit_E', type=float, default=0.00001, metavar='LR',
+                        help='learning rate CriticE (default: 0.00001)')
+
+    parser.add_argument('--lrGen_H_E', type=float, default=0.0001, metavar='LR',
+                        help='learning rate Generator_H_E (default: 0.0001)')
+
+    parser.add_argument('--lrGen_E', type=float, default=0.0001, metavar='LR',
+                        help='learning rate GeneratorE (default: 0.0001)')
+
+    parser.add_argument('--chpt', action='store_true', default=False,
+                        help='continue training from a saved model')
+
+    parser.add_argument('--chpt_base', type=str, default='/eos/user/e/eneren/experiments/',
+                        help='continue training from a saved model')
+
+    parser.add_argument('--exp', type=str, default='dist_wgan',
+                        help='name of the experiment')
+
+    parser.add_argument('--chpt_eph', type=int, default=1,
+                        help='continue checkpoint epoch')
+
+    parser.add_argument('--no-cuda', action='store_true', default=False,
+                        help='disables CUDA training')
+    parser.add_argument('--seed', type=int, default=1, metavar='S',
+                        help='random seed (default: 1)')
+    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
+                        help='how many batches to wait before logging training status')
+   
+
+    if dist.is_available():
+        parser.add_argument('--backend', type=str, help='Distributed backend',
+                            choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
+                            default=dist.Backend.GLOO)
+    
+    parser.add_argument('--local_rank', type=int, default=0)
+
+    args = parser.parse_args()
+
+
+    args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank))
+    args.rank = int(os.environ.get('RANK'))
+    args.world_size = int(os.environ.get('WORLD_SIZE'))
+
+
+    # postprocess args
+    args.device = 'cuda:{}'.format(args.local_rank)  # PytorchJob/launch.py
+    args.batch_size = max(args.batch_size,
+                          args.world_size * 2)  # min valid batchsize
+    return args
+
+
+
+def run(args):
+    # Training settings
+
+    use_cuda = not args.no_cuda and torch.cuda.is_available()
+    if use_cuda:
+        print('Using CUDA')
+
+    
+    experiment = Experiment(api_key="keGmeIz4GfKlQZlOP6cit4QOi",
+                        project_name="ecal-hcal-shower", workspace="engineren", auto_output_logging="simple")
+    experiment.add_tag(args.exp)
+
+    experiment.log_parameters(
+        {
+        "batch_size" : args.batch_size,
+        "latent": args.nz,
+        "lambda": args.lambd,
+        "ncrit" : args.ncrit,
+        "resN_H": args.ndf,
+        "resN_E": args.dres,
+        "ngf": args.ngf
+        }
+    )
+
+    torch.manual_seed(args.seed)
+
+    device = torch.device("cuda" if use_cuda else "cpu")
+    
+
+    if args.world_size > 1:
+        print('Using distributed PyTorch with {} backend'.format(args.backend))
+        dist.init_process_group(backend=args.backend)
+    
+    print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size))
+
+
+
+    print ("loading data")
+    #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/40GeV40k.hdf5', transform=None, train_size=40000)
+    dataset = HDF5Dataset('/eos/user/e/eneren/scratch/50GeV75k.hdf5', transform=None, train_size=75000)
+    #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/4060GeV.hdf5', transform=None, train_size=60000)
+
+
+    if args.world_size > 1:
+        sampler = DistributedSampler(dataset, shuffle=True)    
+        train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False)
+    else:
+        train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False)
+
+
+    ## HCAL Generator and critic
+    mCritH = CriticEMB().to(device)
+    mGenH = Hcal_ecalEMB(args.ngf, 32, args.nz).to(device)
+    
+    ## ECAL GENERATOR and critic
+    mGenE = DCGAN_G(args.ngf, args.nz).to(device)
+    mCritE = generate_model(args.dres).to(device)
+
+   # ## Global Critic
+   # mCritGlob = 
+
+    if args.world_size > 1: 
+        Distributor = nn.parallel.DistributedDataParallel if use_cuda \
+            else nn.parallel.DistributedDataParallelCPU
+        mCritH = Distributor(mCritH, device_ids=[args.local_rank], output_device=args.local_rank )
+        mGenH = Distributor(mGenH, device_ids=[args.local_rank], output_device=args.local_rank)
+        mGenE = Distributor(mGenE, device_ids=[args.local_rank], output_device=args.local_rank)
+        mCritE = Distributor(mCritE, device_ids=[args.local_rank], output_device=args.local_rank)
+    else:
+        mGenH = nn.parallel.DataParallel(mGenH)
+        mCritH = nn.parallel.DataParallel(mCritH)
+        mGenE = nn.parallel.DataParallel(mGenE)
+        mCritE = nn.parallel.DataParallel(mCritE)
+    
+    optimizerG_H_E = optim.Adam(list(mGenH.parameters())+list(mGenE.parameters()), lr=args.lrGen_H_E, betas=(0.5, 0.9))
+    
+    optimizerG_E = optim.Adam(mGenE.parameters(), lr=args.lrGen_E, betas=(0.5, 0.9))
+
+    optimizerD_H_E = optim.Adam(mCritH.parameters(), lr=args.lrCrit_H, betas=(0.5, 0.9))
+
+    optimizerD_E = optim.Adam(mCritE.parameters(), lr=args.lrCrit_E, betas=(0.5, 0.9))
+
+
+    if (args.chpt):
+        critic_E_checkpoint = torch.load(args.chpt_base + args.exp + "_criticE_"+ str(args.chpt_eph) + ".pt")
+        critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_"+ str(args.chpt_eph) + ".pt")
+        gen_E_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorE_"+ str(args.chpt_eph) + ".pt")
+        gen_H_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorH_"+ str(args.chpt_eph) + ".pt")
+        
+        mGenE.load_state_dict(gen_E_checkpoint['model_state_dict'])
+        optimizerG_E.load_state_dict(gen_E_checkpoint['optimizer_state_dict'])
+
+        mGenH.load_state_dict(gen_H_checkpoint['model_state_dict'])
+        optimizerG_H_E.load_state_dict(gen_H_checkpoint['optimizer_state_dict'])
+
+        mCritE.load_state_dict(critic_E_checkpoint['model_state_dict'])
+        optimizerD_E.load_state_dict(critic_E_checkpoint['optimizer_state_dict'])
+        
+        mCritH.load_state_dict(critic_E_H_checkpoint['model_state_dict'])
+        optimizerD_H_E.load_state_dict(critic_E_H_checkpoint['optimizer_state_dict'])
+
+        eph = gen_H_checkpoint['epoch']
+    
+    else: 
+        eph = 0
+        gen_E_checkpoint = torch.load("/eos/user/e/eneren/experiments/wganv1_generator_694.pt", map_location=torch.device('cuda'))
+        critic_E_checkpoint = torch.load("/eos/user/e/eneren/experiments/wganv1_critic_694.pt", map_location=torch.device('cuda'))
+
+        mGenE.load_state_dict(gen_E_checkpoint['model_state_dict'])
+        optimizerG_E.load_state_dict(gen_E_checkpoint['optimizer_state_dict'])
+        
+        mCritE.load_state_dict(critic_E_checkpoint['model_state_dict'])
+        optimizerD_E.load_state_dict(critic_E_checkpoint['optimizer_state_dict'])
+        
+        print ("init models")
+
+        
+    experiment.set_model_graph(str(mCritH), overwrite=False)
+
+    print ("starting training...")
+    for epoch in range(1, args.epochs + 1):
+        epoch += eph
+        
+        if args.world_size > 1: 
+            train_loader.sampler.set_epoch(epoch)
+
+        train(args, mCritE, mCritH, mGenE, mGenH, device, train_loader, optimizerD_E, optimizerD_H_E, optimizerG_E, optimizerG_H_E, epoch, experiment)
+        if args.rank == 0:
+            gPATH = args.chpt_base + args.exp + "_generatorH_"+ str(epoch) + ".pt"
+            ePATH = args.chpt_base + args.exp + "_generatorE_"+ str(epoch) + ".pt"
+            cPATH = args.chpt_base + args.exp + "_criticH_"+ str(epoch) + ".pt"
+            cePATH = args.chpt_base + args.exp + "_criticE_"+ str(epoch) + ".pt"
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': mGenH.state_dict(),
+                'optimizer_state_dict': optimizerG_H_E.state_dict()
+                }, gPATH)
+            
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': mCritH.state_dict(),
+                'optimizer_state_dict': optimizerD_H_E.state_dict()
+                }, cPATH)
+
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': mGenE.state_dict(),
+                'optimizer_state_dict': optimizerG_E.state_dict()
+                }, ePATH)
+
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': mCritE.state_dict(),
+                'optimizer_state_dict': optimizerD_E.state_dict()
+            }, cePATH)
+
+
+    print ("end training")
+        
+
+
+def main():
+    args = parse_args()
+    run(args)
+            
+if __name__ == '__main__':
+    main()