From a116159f36c91fb1b05d43e904d061db557fae92 Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Mon, 23 May 2022 11:33:38 +0200
Subject: [PATCH] train loader set epoch

---
 pytorch_job_wganHCAL_single.yaml | 2 +-
 wganHCAL.py                      | 7 ++++++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/pytorch_job_wganHCAL_single.yaml b/pytorch_job_wganHCAL_single.yaml
index e12c916..a6d5cba 100644
--- a/pytorch_job_wganHCAL_single.yaml
+++ b/pytorch_job_wganHCAL_single.yaml
@@ -37,7 +37,7 @@ spec:
               command: [sh, -c] 
               args:  
                 - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 
-                  && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --nworkers 2 --chpt --chpt_eph 63 --lrGen 0.00001 --ncrit 5;
+                  && python -u wganHCAL.py --backend nccl --epochs 20 --exp wganHCALv1 --nworkers 2 --chpt --chpt_eph 63 --lrGen 0.00001 --ncrit 5;
               volumeMounts:
               - name: eos
                 mountPath: /eos
diff --git a/wganHCAL.py b/wganHCAL.py
index bddc2fa..f6ceaaa 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -303,7 +303,12 @@ def run(args):
     print ("starting training...")
     for epoch in range(1, args.epochs + 1):
         epoch += eph
-        train_loader.sampler.set_epoch(epoch)
+        
+        if args.world_size > 1: 
+            train_loader.sampler.set_epoch(epoch)
+        else:
+            train_loader.set_epoch(epoch)
+
         train(args, mCrit, mGen, device, train_loader, optimizerD, optimizerG, epoch, experiment)
         if args.rank == 0:
             gPATH = args.chpt_base + args.exp + "_generator_"+ str(epoch) + ".pt"
-- 
GitLab