diff --git a/wgan_ECAL_HCAL_3crit.py b/wgan_ECAL_HCAL_3crit.py
index 0e6bc93fd6dae480faf7cd3e230a815d25549696..d3da95e137d3a0d29200b5337f510fa5539a33ae 100644
--- a/wgan_ECAL_HCAL_3crit.py
+++ b/wgan_ECAL_HCAL_3crit.py
@@ -180,10 +180,22 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
             g_E_cost.backward()
             optimizer_g_E.step()
 
+            if batch_idx % args.log_interval == 0 :
+                print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f}'.format(
+                epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), g_E_cost.item()))
+            
+                niter = epoch * len(train_loader) + batch_idx
+                experiment.log_metric("L_Gen_E", g_E_cost, step=niter)
+    
+
+           
 
         ## HCAL CRITIC TRAINING
         aDH.train()
         aGH.eval()
+        aGE.eval()
+
 
         # zero out critic gradients
         optimizer_d_H.zero_grad()
@@ -239,6 +251,15 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
             g_H_cost.backward()
             optimizer_g_H.step()
 
+            if batch_idx % args.log_interval == 0 :
+                print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGH={:.4f}'.format(
+                epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), g_H_cost.item()))
+            
+                niter = epoch * len(train_loader) + batch_idx                
+                experiment.log_metric("L_Gen_H", g_H_cost, step=niter)
+
+
 
         ## ECAL + HCAL CRITIC TRAINING
         aD_H_E.train()
@@ -279,7 +300,28 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
         optimizer_d_H_E.step()
 
 
-        
+        if batch_idx % args.log_interval == 0:
+            print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
+                epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), disc_cost_H_E.item()))
+            niter = epoch * len(train_loader) + batch_idx
+            experiment.log_metric("L_crit_E", disc_cost_E, step=niter)
+            experiment.log_metric("L_crit_H", disc_cost_H, step=niter)
+            experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter)
+            experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter)
+            experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter)
+            experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter)
+            experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter)
+            experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter)
+            experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter)
+            experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter)
+            experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter)
+            experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter)
+            experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter)
+            experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter)
+            experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter)
+
+
         ## training generator per ncrit 
         if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
             ## GENERATOR TRAINING
@@ -308,37 +350,17 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
             g_cost.backward()
             optimizer_g_H_E.step()
 
-        if batch_idx % args.log_interval == 0:
-            print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
-                epoch, batch_idx * len(dataH), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), disc_cost_H_E.item()))
-            niter = epoch * len(train_loader) + batch_idx
-            experiment.log_metric("L_crit_E", disc_cost_E, step=niter)
-            experiment.log_metric("L_crit_H", disc_cost_H, step=niter)
-            experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter)
-            experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter)
-            experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter)
-            experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter)
-            experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter)
-            experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter)
-            experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter)
-            experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter)
-            experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter)
-            experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter)
-            experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter)
-            experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter)
-            experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter)
-
-        if batch_idx % args.log_interval == 0 :
-            print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f} lossGHE={:.4f}'.format(
+            if batch_idx % args.log_interval == 0 :
+                print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\t lossGHE={:.4f}'.format(
                 epoch, batch_idx * len(dataH), len(train_loader.dataset),
-                100. * batch_idx / len(train_loader), g_E_cost.item(), g_H_cost.item(), g_cost.item()))
+                100. * batch_idx / len(train_loader), g_cost.item()))
+            
             niter = epoch * len(train_loader) + batch_idx
-            experiment.log_metric("L_Gen_E", g_E_cost, step=niter)
-            experiment.log_metric("L_Gen_H", g_H_cost, step=niter)
             experiment.log_metric("L_Gen_H_E", g_cost, step=niter)
 
 
+
+
 def is_distributed():
     return dist.is_available() and dist.is_initialized()