From 4b9756a23c0e67d9c07bb39629a7983abd2e5cf2 Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Mon, 23 May 2022 10:57:10 +0200
Subject: [PATCH] forgot to add nn.parallel.DataParallel

---
 pytorch_job_wganHCAL_single.yaml | 2 +-
 wganHCAL.py                      | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/pytorch_job_wganHCAL_single.yaml b/pytorch_job_wganHCAL_single.yaml
index a940079..e12c916 100644
--- a/pytorch_job_wganHCAL_single.yaml
+++ b/pytorch_job_wganHCAL_single.yaml
@@ -37,7 +37,7 @@ spec:
               command: [sh, -c] 
               args:  
                 - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 
-                  && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001;
+                  && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --nworkers 2 --chpt --chpt_eph 63 --lrGen 0.00001 --ncrit 5;
               volumeMounts:
               - name: eos
                 mountPath: /eos
diff --git a/wganHCAL.py b/wganHCAL.py
index 18e0bf1..bddc2fa 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -273,6 +273,9 @@ def run(args):
             else nn.parallel.DistributedDataParallelCPU
         mCrit = Distributor(mCrit, device_ids=[args.local_rank], output_device=args.local_rank )
         mGen = Distributor(mGen, device_ids=[args.local_rank], output_device=args.local_rank)
+    else:
+        mGen = nn.parallel.DataParallel(mGen)
+        mCrit = nn.parallel.DataParallel(mCrit)
 
     
     optimizerG = optim.Adam(mGen.parameters(), lr=args.lrGen, betas=(0.5, 0.9))
-- 
GitLab