diff --git a/wgan_ECAL_HCAL_3crit.py b/wgan_ECAL_HCAL_3crit.py
index 39cb390a0fa1595f7016d7ac982529dc7a030f05..37a9a9d914fbe80b44fea375cb4ff9135e622c31 100644
--- a/wgan_ECAL_HCAL_3crit.py
+++ b/wgan_ECAL_HCAL_3crit.py
@@ -112,24 +112,22 @@ def calc_gradient_penalty_ECAL_HCAL(netD, real_ecal, real_hcal, fake_ecal, fake_
     return gradient_penalty
 
 
-def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H, optimizer_d_H_E, optimizer_g_E, optimizer_g_H, epoch, experiment):
-    
-    ### CRITIC TRAINING
-    aDE.train()
-    aDH.train()
-    aD_H_E.train()
-    aGH.eval()
-    aGE.eval()
+def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H, optimizer_d_H_E, optimizer_g_E, optimizer_g_H, optimizer_g_H_E, epoch, experiment):
+
 
     Tensor = torch.cuda.FloatTensor 
    
     for batch_idx, (dataE, dataH, energy) in enumerate(train_loader):
-        
-	    # ECAL critic
+        ## ECAL CRITC TRAINING
+        aDE.train()
+        aGE.eval()
+
+        # zero out critic gradients
         optimizer_d_E.zero_grad()
 
         ## Get Real data
         real_dataECAL = dataE.to(device).unsqueeze(1).float()
+        real_dataHCAL = dataH.to(device).unsqueeze(1).float()
         label = energy.to(device).float()
         ###
 
@@ -159,19 +157,103 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
         disc_cost_E.backward()
         optimizer_d_E.step()
 
+        #print("Generator training started")
 
+	    ## ECAL GENERATOR TRAINING
+        ## training generator per ncrit
+        if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
+            ## GENERATOR TRAINING
+            aDE.eval()
+            aGE.train()
 
-        # ECAL + HCAL critic
-        optimizer_d_H_E.zero_grad()
+            # zero out generator gradients
+            optimizer_g_E.zero_grad()
 
-        ## Get Real data
-        real_dataECAL = dataE.to(device).unsqueeze(1).float()
-        real_dataHCAL = dataH.to(device).unsqueeze(1).float()
-        label = energy.to(device).float()
-        ###
+            ## Generate Fake ECAL
+            zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True)
+            fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1))
+            fake_ecal = fake_ecal.unsqueeze(1)
+
+            ## Loss function for ECAL generator
+            gen_E_cost = aDE(fake_ecal, label)
+            g_E_cost = -torch.mean(gen_E_cost)
+            g_E_cost.backward()
+            optimizer_g_E.step()
+
+
+        ## HCAL CRITIC TRAINING
+        aDH.train()
+        aGH.eval()
+
+        # zero out critic gradients
+        optimizer_d_H.zero_grad()
+
+        # Generate Fake HCAL
+        zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False)
+        fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1))
+        fake_ecal = fake_ecal.unsqueeze(1).detach()
+
+        z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=False)
+        fake_dataHCAL = aGH(z, label.view(-1, 1, 1, 1, 1), fake_ecal) ## 48 x 30 x 30
+        fake_dataHCAL = fake_dataHCAL.unsqueeze(1).detach()
+
+        ## Critic fwd pass on Real
+        disc_real_H = aDH(real_dataHCAL, label)
+
+        ## Calculate gradient penalty term
+        gradient_penalty_H = calc_gradient_penalty_HCAL(aDH, real_dataHCAL, fake_dataHCAL, label, args.batch_size, device, layer=48, xsize=30, ysize=30)
+
+        ## Critic fwd pass on fake data
+        disc_fake_H = aDH(fake_dataHCAL, label)
+
+        w_dist_H = torch.mean(disc_fake_H) - torch.mean(disc_real_H)
+
+        # final disc cost
+        disc_cost_H = w_dist_H + args.lambd * gradient_penalty_H
+
+        disc_cost_H.backward()
+        optimizer_d_H.step()
+
+        ## HCAL GENERATOR TRAINING
+        ## training generator per ncrit
+        if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
+            ## GENERATOR TRAINING
+            aDH.eval()
+            aGH.train()
+
+            # zero out generator gradients
+            optimizer_g_H.zero_grad()
+
+            # Generate Fake HCAL
+            zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False)
+            fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1))
+            fake_ecal = fake_ecal.unsqueeze(1).detach()
+
+            z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True)
+            fake_dataHCAL = aGH(z, label.view(-1, 1, 1, 1, 1), fake_ecal) ## 48 x 30 x 30
+            fake_dataHCAL = fake_dataHCAL.unsqueeze(1)
+
+            ## Loss function for ECAL generator
+            gen_H_cost = aDH(fake_dataHCAL, label)
+            g_H_cost = -torch.mean(gen_H_cost)
+            g_H_cost.backward()
+            optimizer_g_H.step()
+
+
+        ## ECAL + HCAL CRITIC TRAINING
+        aD_H_E.train()
+        aGH.eval()
+        aGE.eval()
+
+        # zero out critic gradients
+        optimizer_d_H_E.zero_grad()
 
         ## Get Fake ECAL
-        fake_ecal = fake_ecal_gen.clone().detach()
+        #fake_ecal = fake_ecal_gen.clone().detach()
+        zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False)
+        fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1))
+        fake_ecal = fake_ecal.unsqueeze(1).detach()
+
 
         ## Generate Fake HCAL
         #z = zE.view(args.batch_size, args.nz)
@@ -197,49 +279,14 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
         optimizer_d_H_E.step()
 
 
-        if batch_idx % args.log_interval == 0:
-            print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
-                epoch, batch_idx * len(dataH), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), disc_cost_H_E.item()))
-            niter = epoch * len(train_loader) + batch_idx
-            experiment.log_metric("L_crit_E", disc_cost_E, step=niter)
-            experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter)
-            experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter)
-            experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter)
-            experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter)
-            experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter)
-            experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter)
-            experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter)
-            experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter)
-            experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter)
-
-
         
         ## training generator per ncrit 
         if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
             ## GENERATOR TRAINING
-            aDE.eval()
-            aDH.eval()
+            aD_H_E.eval()
             aGH.train()
             aGE.train()
 
-            #print("Generator training started")
-	    
-	        # Optimize ECAL generator
-            optimizer_g_E.zero_grad()
-
-            ## Generate Fake ECAL
-            zE = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True)
-            fake_ecal = aGE(zE, label.view(-1, 1, 1, 1, 1))
-            fake_ecal = fake_ecal.unsqueeze(1)
-
-            ## Loss function for ECAL generator
-            gen_E_cost = aDE(fake_ecal, label)
-            g_E_cost = -torch.mean(gen_E_cost)
-            g_E_cost.backward()
-            optimizer_g_E.step() 
-
-
             # Optimize HCAL generator
             optimizer_g_H_E.zero_grad()
 
@@ -253,21 +300,43 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
             #z = zE.view(args.batch_size, args.nz)
             z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))), requires_grad=True)
             ## generate fake data out of noise
-            fake_dataHCALG = aGH(z, label, fake_ecal)
+            fake_dataHCAL = aGH(z, label, fake_ecal)
             
             ## Total loss function for generator
-            gen_cost = aDH(fake_ecal, fake_dataHCALG.unsqueeze(1), label)
+            gen_cost = aD_H_E(fake_ecal, fake_dataHCAL.unsqueeze(1), label)
             g_cost = -torch.mean(gen_cost) 
             g_cost.backward()
             optimizer_g_H_E.step()
 
-            if batch_idx % args.log_interval == 0 :
-                print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f}'.format(
-                    epoch, batch_idx * len(dataH), len(train_loader.dataset),
-                    100. * batch_idx / len(train_loader), g_E_cost.item(),  g_cost.item()))
-                niter = epoch * len(train_loader) + batch_idx
-                experiment.log_metric("L_Gen_E", g_E_cost, step=niter)
-                experiment.log_metric("L_Gen_H", g_cost, step=niter)
+        if batch_idx % args.log_interval == 0:
+            print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
+                epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), disc_cost_H_E.item()))
+            niter = epoch * len(train_loader) + batch_idx
+            experiment.log_metric("L_crit_E", disc_cost_E, step=niter)
+            experiment.log_metric("L_crit_H", disc_cost_H, step=niter)
+            experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter)
+            experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter)
+            experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter)
+            experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter)
+            experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter)
+            experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter)
+            experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter)
+            experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter)
+            experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter)
+            experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter)
+            experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter)
+            experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter)
+            experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter)
+
+        if batch_idx % args.log_interval == 0 :
+            print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f} lossGHE={:.4f}'.format(
+                epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), g_E_cost.item(), g_H_cost.item(), g_cost.item()))
+            niter = epoch * len(train_loader) + batch_idx
+            experiment.log_metric("L_Gen_E", g_E_cost, step=niter)
+            experiment.log_metric("L_Gen_H", g_H_cost, step=niter)
+            experiment.log_metric("L_Gen_H_E", g_cost, step=niter)
 
 
 def is_distributed():
@@ -306,6 +375,9 @@ def parse_args():
     parser.add_argument('--nworkers', type=int, default=1, metavar='N',
                         help='number of epochs to train (default: 1)')
 
+    parser.add_argument('--lrCrit_H_E', type=float, default=0.00001, metavar='LR',
+                        help='learning rate Critic_H_E (default: 0.00001)')
+
     parser.add_argument('--lrCrit_H', type=float, default=0.00001, metavar='LR',
                         help='learning rate CriticH (default: 0.00001)')
 
@@ -315,6 +387,9 @@ def parse_args():
     parser.add_argument('--lrGen_H_E', type=float, default=0.0001, metavar='LR',
                         help='learning rate Generator_H_E (default: 0.0001)')
 
+    parser.add_argument('--lrGen_H', type=float, default=0.0001, metavar='LR',
+                        help='learning rate Generator_H (default: 0.0001)')
+
     parser.add_argument('--lrGen_E', type=float, default=0.0001, metavar='LR',
                         help='learning rate GeneratorE (default: 0.0001)')
 
@@ -379,7 +454,8 @@ def run(args):
         "latent": args.nz,
         "lambda": args.lambd,
         "ncrit" : args.ncrit,
-        "resN_H": args.ndf,
+        "resN_E_H": args.ndf,
+        "resN_H": args.dres,
         "resN_E": args.dres,
         "ngf": args.ngf
         }
@@ -411,55 +487,69 @@ def run(args):
         train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False)
 
 
+    ## ECAL + HCAL Generator and critic
+    mCrit_H_E = CriticEMB().to(device)
+
     ## HCAL Generator and critic
-    mCritH = CriticEMB().to(device)
+    mCritH = generate_model(args.dres).to(device)
     mGenH = Hcal_ecalEMB(args.ngf, 32, args.nz).to(device)
     
     ## ECAL GENERATOR and critic
     mGenE = DCGAN_G(args.ngf, args.nz).to(device)
     mCritE = generate_model(args.dres).to(device)
 
-   # ## Global Critic
-   # mCritGlob = 
 
     if args.world_size > 1: 
         Distributor = nn.parallel.DistributedDataParallel if use_cuda \
             else nn.parallel.DistributedDataParallelCPU
+        mCrit_H_E = Distributor(mCrit_H_E, device_ids=[args.local_rank], output_device=args.local_rank )
         mCritH = Distributor(mCritH, device_ids=[args.local_rank], output_device=args.local_rank )
         mGenH = Distributor(mGenH, device_ids=[args.local_rank], output_device=args.local_rank)
         mGenE = Distributor(mGenE, device_ids=[args.local_rank], output_device=args.local_rank)
         mCritE = Distributor(mCritE, device_ids=[args.local_rank], output_device=args.local_rank)
     else:
+        mCrit_H_E = nn.parallel.DataParallel(mCrit_H_E)
         mGenH = nn.parallel.DataParallel(mGenH)
         mCritH = nn.parallel.DataParallel(mCritH)
         mGenE = nn.parallel.DataParallel(mGenE)
         mCritE = nn.parallel.DataParallel(mCritE)
     
     optimizerG_H_E = optim.Adam(list(mGenH.parameters())+list(mGenE.parameters()), lr=args.lrGen_H_E, betas=(0.5, 0.9))
-    
-    optimizerG_E = optim.Adam(mGenE.parameters(), lr=args.lrGen_E, betas=(0.5, 0.9))
 
-    optimizerD_H_E = optim.Adam(mCritH.parameters(), lr=args.lrCrit_H, betas=(0.5, 0.9))
+    optimizerD_H_E = optim.Adam(mCrit_H_E.parameters(), lr=args.lrCrit_H_E, betas=(0.5, 0.9))
+
+    optimizerG_H = optim.Adam(mGenH.parameters(), lr=args.lrGen_H, betas=(0.5, 0.9))
+
+    optimizerD_H = optim.Adam(mCritH.parameters(), lr=args.lrCrit_H, betas=(0.5, 0.9))
+
+    optimizerG_E = optim.Adam(mGenE.parameters(), lr=args.lrGen_E, betas=(0.5, 0.9))
 
     optimizerD_E = optim.Adam(mCritE.parameters(), lr=args.lrCrit_E, betas=(0.5, 0.9))
 
 
     if (args.chpt):
         critic_E_checkpoint = torch.load(args.chpt_base + args.exp + "_criticE_"+ str(args.chpt_eph) + ".pt")
-        critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_"+ str(args.chpt_eph) + ".pt")
+        critic_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_"+ str(args.chpt_eph) + ".pt")
+        critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticH_E_"+ str(args.chpt_eph) + ".pt")
         gen_E_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorE_"+ str(args.chpt_eph) + ".pt")
         gen_H_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorH_"+ str(args.chpt_eph) + ".pt")
+        gen_H_E_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorH_E_"+ str(args.chpt_eph) + ".pt")
+
         
         mGenE.load_state_dict(gen_E_checkpoint['model_state_dict'])
         optimizerG_E.load_state_dict(gen_E_checkpoint['optimizer_state_dict'])
 
         mGenH.load_state_dict(gen_H_checkpoint['model_state_dict'])
-        optimizerG_H_E.load_state_dict(gen_H_checkpoint['optimizer_state_dict'])
+        optimizerG_H.load_state_dict(gen_H_checkpoint['optimizer_state_dict'])
+        optimizerG_H_E.load_state_dict(gen_H_E_checkpoint['optimizer_state_dict'])
 
         mCritE.load_state_dict(critic_E_checkpoint['model_state_dict'])
         optimizerD_E.load_state_dict(critic_E_checkpoint['optimizer_state_dict'])
         
-        mCritH.load_state_dict(critic_E_H_checkpoint['model_state_dict'])
+        mCritH.load_state_dict(critic_H_checkpoint['model_state_dict'])
+        optimizerD_H.load_state_dict(critic_H_checkpoint['optimizer_state_dict'])
+
+        mCrit_H_E.load_state_dict(critic_E_H_checkpoint['model_state_dict'])
         optimizerD_H_E.load_state_dict(critic_E_H_checkpoint['optimizer_state_dict'])
 
         eph = gen_H_checkpoint['epoch']
@@ -486,22 +576,34 @@ def run(args):
         
         if args.world_size > 1: 
             train_loader.sampler.set_epoch(epoch)
-
-        train(args, mCritE, mCritH, mGenE, mGenH, device, train_loader, optimizerD_E, optimizerD_H_E, optimizerG_E, optimizerG_H_E, epoch, experiment)
+            train(args, mCritE, mCritH, mCrit_H_E, mGenE, mGenH, device, train_loader, optimizerD_E, optimizerD_H, optimizerD_H_E, optimizerG_E, optimizerG_H, optimizerG_H_E, epoch, experiment)
         if args.rank == 0:
-            gPATH = args.chpt_base + args.exp + "_generatorH_"+ str(epoch) + ".pt"
+            gHPATH = args.chpt_base + args.exp + "_generatorH_"+ str(epoch) + ".pt"
             ePATH = args.chpt_base + args.exp + "_generatorE_"+ str(epoch) + ".pt"
-            cPATH = args.chpt_base + args.exp + "_criticH_"+ str(epoch) + ".pt"
+            gPATH = args.chpt_base + args.exp + "_generatorH_E_"+ str(epoch) + ".pt"
+            cPATH = args.chpt_base + args.exp + "_criticH_E_"+ str(epoch) + ".pt"
+            cHPATH = args.chpt_base + args.exp + "_criticH_"+ str(epoch) + ".pt"
             cePATH = args.chpt_base + args.exp + "_criticE_"+ str(epoch) + ".pt"
             torch.save({
                 'epoch': epoch,
                 'model_state_dict': mGenH.state_dict(),
+                'optimizer_state_dict': optimizerG_H.state_dict()
+                }, gHPATH)
+
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': mCritH.state_dict(),
+                'optimizer_state_dict': optimizerD_H.state_dict()
+                }, cHPATH)
+
+            torch.save({
+                'epoch': epoch,
                 'optimizer_state_dict': optimizerG_H_E.state_dict()
                 }, gPATH)
-            
+
             torch.save({
                 'epoch': epoch,
-                'model_state_dict': mCritH.state_dict(),
+                'model_state_dict': mCrit_H_E.state_dict(),
                 'optimizer_state_dict': optimizerD_H_E.state_dict()
                 }, cPATH)