Skip to content
Snippets Groups Projects
Commit 83b4e3fd authored by Engin Eren's avatar Engin Eren
Browse files

save if rank=0

parent 81b28b4a
No related branches found
No related tags found
No related merge requests found
Pipeline #4560166 passed
...@@ -307,7 +307,7 @@ def run(args): ...@@ -307,7 +307,7 @@ def run(args):
eph = 0 eph = 0
print ("init models") print ("init models")
experiment.set_model_graph(str(Crit_E_H), overwrite=False) #experiment.set_model_graph(str(Crit_E_H), overwrite=False)
experiment.set_model_graph(str(Gen_E_H), overwrite=False) experiment.set_model_graph(str(Gen_E_H), overwrite=False)
print('starting training...') print('starting training...')
...@@ -315,25 +315,28 @@ def run(args): ...@@ -315,25 +315,28 @@ def run(args):
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
epoch += eph epoch += eph
train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment) if args.world_size > 1:
train_loader.sampler.set_epoch(epoch)
train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment)
# saving to checkpoints if args.rank == 0:
g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt" # saving to checkpoints
c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt" g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt"
c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt"
torch.save({
'epoch': epoch, torch.save({
'model_state_dict': Gen_E_H.state_dict(), 'epoch': epoch,
'optimizer_state_dict': optimizerG_E_H.state_dict() 'model_state_dict': Gen_E_H.state_dict(),
}, g_E_H_path) 'optimizer_state_dict': optimizerG_E_H.state_dict()
}, g_E_H_path)
torch.save({
'epoch': epoch, torch.save({
'model_state_dict': Crit_E_H.state_dict(), 'epoch': epoch,
'optimizer_state_dict': optimizerD_E_H.state_dict() 'model_state_dict': Crit_E_H.state_dict(),
}, c_E_H_path) 'optimizer_state_dict': optimizerD_E_H.state_dict()
}, c_E_H_path)
print('end training')
print('end training')
def main(): def main():
args = parse_args() args = parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment