diff --git a/pytorch_job_wganHCAL_single.yaml b/pytorch_job_wganHCAL_single.yaml index a940079a918a86f859543a7260e2cf741988a49c..e12c916204ab025203afc22c87281acdb04618e0 100644 --- a/pytorch_job_wganHCAL_single.yaml +++ b/pytorch_job_wganHCAL_single.yaml @@ -37,7 +37,7 @@ spec: command: [sh, -c] args: - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 - && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001; + && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --nworkers 2 --chpt --chpt_eph 63 --lrGen 0.00001 --ncrit 5; volumeMounts: - name: eos mountPath: /eos diff --git a/wganHCAL.py b/wganHCAL.py index 18e0bf135f95d334f0a3e457b867790a2f139b75..bddc2fa4ac50561df113ba08eeb7ab1eeb880f35 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -273,6 +273,9 @@ def run(args): else nn.parallel.DistributedDataParallelCPU mCrit = Distributor(mCrit, device_ids=[args.local_rank], output_device=args.local_rank ) mGen = Distributor(mGen, device_ids=[args.local_rank], output_device=args.local_rank) + else: + mGen = nn.parallel.DataParallel(mGen) + mCrit = nn.parallel.DataParallel(mCrit) optimizerG = optim.Adam(mGen.parameters(), lr=args.lrGen, betas=(0.5, 0.9))