diff --git a/wgan.py b/wgan.py index b3676f24ad190cb06407163c4cada819f10e95d4..7322b66b1e8f9600959114a9eff98aaa6004b3e4 100644 --- a/wgan.py +++ b/wgan.py @@ -262,7 +262,7 @@ def run(args): print ("loading data") #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/40GeV40k.hdf5', transform=None, train_size=40000) - dataset = HDF5Dataset('/eos/user/e/eneren/scratch/50GeV75k.hdf5', transform=None, train_size=75000) + dataset = HDF5Dataset('/eos/user/e/eneren/scratch/60GeV20k.hdf5', transform=None, train_size=20000) sampler = DistributedSampler(dataset, shuffle=True) train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=1, pin_memory=False)