diff --git a/wganSingleGen.py b/wganSingleGen.py
index 4a92e3a5d503c488cac692ee8762cc0cb1dce63d..96e11b2e0dc1c932a6c88aa6c8ac5d6646936e30 100644
--- a/wganSingleGen.py
+++ b/wganSingleGen.py
@@ -307,7 +307,7 @@ def run(args):
         eph = 0
         print ("init models")
         
-    experiment.set_model_graph(str(Crit_E_H), overwrite=False)
+    #experiment.set_model_graph(str(Crit_E_H), overwrite=False)
     experiment.set_model_graph(str(Gen_E_H), overwrite=False)
     
     print('starting training...')
@@ -315,25 +315,28 @@ def run(args):
     for epoch in range(1, args.epochs + 1):
         epoch += eph
         
-        train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment)
+        if args.world_size > 1: 
+            train_loader.sampler.set_epoch(epoch)
+            train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment)
         
-        # saving to checkpoints
-        g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt"
-        c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt"
-        
-        torch.save({
-            'epoch': epoch,
-            'model_state_dict': Gen_E_H.state_dict(),
-            'optimizer_state_dict': optimizerG_E_H.state_dict()
-            }, g_E_H_path)
-
-        torch.save({
-            'epoch': epoch,
-            'model_state_dict': Crit_E_H.state_dict(),
-            'optimizer_state_dict': optimizerD_E_H.state_dict()
-            }, c_E_H_path)
-
-        print('end training')
+        if args.rank == 0:
+            # saving to checkpoints
+            g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt"
+            c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt"
+
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': Gen_E_H.state_dict(),
+                'optimizer_state_dict': optimizerG_E_H.state_dict()
+                }, g_E_H_path)
+
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': Crit_E_H.state_dict(),
+                'optimizer_state_dict': optimizerD_E_H.state_dict()
+                }, c_E_H_path)
+
+    print('end training')
         
 def main():
     args = parse_args()