diff --git a/wganSingleGen.py b/wganSingleGen.py index 4a92e3a5d503c488cac692ee8762cc0cb1dce63d..96e11b2e0dc1c932a6c88aa6c8ac5d6646936e30 100644 --- a/wganSingleGen.py +++ b/wganSingleGen.py @@ -307,7 +307,7 @@ def run(args): eph = 0 print ("init models") - experiment.set_model_graph(str(Crit_E_H), overwrite=False) + #experiment.set_model_graph(str(Crit_E_H), overwrite=False) experiment.set_model_graph(str(Gen_E_H), overwrite=False) print('starting training...') @@ -315,25 +315,28 @@ def run(args): for epoch in range(1, args.epochs + 1): epoch += eph - train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment) + if args.world_size > 1: + train_loader.sampler.set_epoch(epoch) + train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment) - # saving to checkpoints - g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt" - c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt" - - torch.save({ - 'epoch': epoch, - 'model_state_dict': Gen_E_H.state_dict(), - 'optimizer_state_dict': optimizerG_E_H.state_dict() - }, g_E_H_path) - - torch.save({ - 'epoch': epoch, - 'model_state_dict': Crit_E_H.state_dict(), - 'optimizer_state_dict': optimizerD_E_H.state_dict() - }, c_E_H_path) - - print('end training') + if args.rank == 0: + # saving to checkpoints + g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt" + c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt" + + torch.save({ + 'epoch': epoch, + 'model_state_dict': Gen_E_H.state_dict(), + 'optimizer_state_dict': optimizerG_E_H.state_dict() + }, g_E_H_path) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': Crit_E_H.state_dict(), + 'optimizer_state_dict': optimizerD_E_H.state_dict() + }, c_E_H_path) + + print('end training') def main(): args = parse_args()