From 83b4e3fd826dae7aeff2e432aa9b12e9e480ed28 Mon Sep 17 00:00:00 2001 From: eneren <engin.eren@cern.ch> Date: Fri, 30 Sep 2022 14:33:48 +0000 Subject: [PATCH] save if rank=0 --- wganSingleGen.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/wganSingleGen.py b/wganSingleGen.py index 4a92e3a..96e11b2 100644 --- a/wganSingleGen.py +++ b/wganSingleGen.py @@ -307,7 +307,7 @@ def run(args): eph = 0 print ("init models") - experiment.set_model_graph(str(Crit_E_H), overwrite=False) + #experiment.set_model_graph(str(Crit_E_H), overwrite=False) experiment.set_model_graph(str(Gen_E_H), overwrite=False) print('starting training...') @@ -315,25 +315,28 @@ def run(args): for epoch in range(1, args.epochs + 1): epoch += eph - train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment) + if args.world_size > 1: + train_loader.sampler.set_epoch(epoch) + train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment) - # saving to checkpoints - g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt" - c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt" - - torch.save({ - 'epoch': epoch, - 'model_state_dict': Gen_E_H.state_dict(), - 'optimizer_state_dict': optimizerG_E_H.state_dict() - }, g_E_H_path) - - torch.save({ - 'epoch': epoch, - 'model_state_dict': Crit_E_H.state_dict(), - 'optimizer_state_dict': optimizerD_E_H.state_dict() - }, c_E_H_path) - - print('end training') + if args.rank == 0: + # saving to checkpoints + g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt" + c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt" + + torch.save({ + 'epoch': epoch, + 'model_state_dict': Gen_E_H.state_dict(), + 'optimizer_state_dict': optimizerG_E_H.state_dict() + }, g_E_H_path) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': Crit_E_H.state_dict(), + 'optimizer_state_dict': optimizerD_E_H.state_dict() + }, c_E_H_path) + + print('end training') def main(): args = parse_args() -- GitLab