diff --git a/wganHCAL.py b/wganHCAL.py index 08b9096b94b7dab5544ff24562b2e6d2c9b5503c..822e2866667219aa133b806f78c06f7ff9d8506b 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -161,6 +161,10 @@ def parse_args(): parser.add_argument('--epochs', type=int, default=1, metavar='N', help='number of epochs to train (default: 1)') + + parser.add_argument('--nworkers', type=int, default=1, metavar='N', + help='number of epochs to train (default: 1)') + parser.add_argument('--lrCrit', type=float, default=0.00001, metavar='LR', help='learning rate Critic (default: 0.00001)') parser.add_argument('--lrGen', type=float, default=0.0001, metavar='LR', @@ -252,7 +256,7 @@ def run(args): sampler = DistributedSampler(dataset, shuffle=True) - train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=1, drop_last=True, pin_memory=False) + train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False)