diff --git a/wganSingleGen.py b/wganSingleGen.py index 8509436c04c28a757bb31234b92b5d65f17bc6f3..af5ec5b4c25fb8b10e45985c6be502f065b369b7 100644 --- a/wganSingleGen.py +++ b/wganSingleGen.py @@ -10,6 +10,7 @@ from torch import autograd from torch.utils.data.distributed import DistributedSampler from torch.utils.data import DataLoader from torch.autograd import Variable +import time from API_keys import api_key @@ -317,8 +318,11 @@ def run(args): if args.world_size > 1: train_loader.sampler.set_epoch(epoch) + t0 = time.time() train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment) - + dt1 = (time.time() - t0) + print('Epoch {:03d} had training time(s): {:.4f}'.format(epoch, dt1)) + if args.rank == 0: # saving to checkpoints g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt"