From a116159f36c91fb1b05d43e904d061db557fae92 Mon Sep 17 00:00:00 2001 From: Engin Eren <engin.eren@desy.de> Date: Mon, 23 May 2022 11:33:38 +0200 Subject: [PATCH] train loader set epoch --- pytorch_job_wganHCAL_single.yaml | 2 +- wganHCAL.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_job_wganHCAL_single.yaml b/pytorch_job_wganHCAL_single.yaml index e12c916..a6d5cba 100644 --- a/pytorch_job_wganHCAL_single.yaml +++ b/pytorch_job_wganHCAL_single.yaml @@ -37,7 +37,7 @@ spec: command: [sh, -c] args: - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 - && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --nworkers 2 --chpt --chpt_eph 63 --lrGen 0.00001 --ncrit 5; + && python -u wganHCAL.py --backend nccl --epochs 20 --exp wganHCALv1 --nworkers 2 --chpt --chpt_eph 63 --lrGen 0.00001 --ncrit 5; volumeMounts: - name: eos mountPath: /eos diff --git a/wganHCAL.py b/wganHCAL.py index bddc2fa..f6ceaaa 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -303,7 +303,12 @@ def run(args): print ("starting training...") for epoch in range(1, args.epochs + 1): epoch += eph - train_loader.sampler.set_epoch(epoch) + + if args.world_size > 1: + train_loader.sampler.set_epoch(epoch) + else: + train_loader.set_epoch(epoch) + train(args, mCrit, mGen, device, train_loader, optimizerD, optimizerG, epoch, experiment) if args.rank == 0: gPATH = args.chpt_base + args.exp + "_generator_"+ str(epoch) + ".pt" -- GitLab