From 4b9756a23c0e67d9c07bb39629a7983abd2e5cf2 Mon Sep 17 00:00:00 2001 From: Engin Eren <engin.eren@desy.de> Date: Mon, 23 May 2022 10:57:10 +0200 Subject: [PATCH] forgot to add nn.parallel.DataParallel --- pytorch_job_wganHCAL_single.yaml | 2 +- wganHCAL.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_job_wganHCAL_single.yaml b/pytorch_job_wganHCAL_single.yaml index a940079..e12c916 100644 --- a/pytorch_job_wganHCAL_single.yaml +++ b/pytorch_job_wganHCAL_single.yaml @@ -37,7 +37,7 @@ spec: command: [sh, -c] args: - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 - && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001; + && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --nworkers 2 --chpt --chpt_eph 63 --lrGen 0.00001 --ncrit 5; volumeMounts: - name: eos mountPath: /eos diff --git a/wganHCAL.py b/wganHCAL.py index 18e0bf1..bddc2fa 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -273,6 +273,9 @@ def run(args): else nn.parallel.DistributedDataParallelCPU mCrit = Distributor(mCrit, device_ids=[args.local_rank], output_device=args.local_rank ) mGen = Distributor(mGen, device_ids=[args.local_rank], output_device=args.local_rank) + else: + mGen = nn.parallel.DataParallel(mGen) + mCrit = nn.parallel.DataParallel(mCrit) optimizerG = optim.Adam(mGen.parameters(), lr=args.lrGen, betas=(0.5, 0.9)) -- GitLab