From 47cf9a21a90afc999508a7dd9753dfa479b0e61a Mon Sep 17 00:00:00 2001 From: Engin Eren <engin.eren@desy.de> Date: Tue, 17 May 2022 15:56:46 +0200 Subject: [PATCH] reverting back and BS = 50 --- models/data_loaderFull.py | 6 +++--- wganHCAL.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/models/data_loaderFull.py b/models/data_loaderFull.py index bc8d038..0302b3a 100644 --- a/models/data_loaderFull.py +++ b/models/data_loaderFull.py @@ -26,13 +26,13 @@ class HDF5Dataset(data.Dataset): def __getitem__(self, index): # get ECAL part x = self.get_data(index) - #x = torch.from_numpy(x).float() + x = torch.from_numpy(x).float() ## get HCAL part y = self.get_data_hcal(index) - #y = torch.from_numpy(y).float() + y = torch.from_numpy(y).float() - e = self.get_energy(index) + e = torch.from_numpy(self.get_energy(index)) return x, y, e diff --git a/wganHCAL.py b/wganHCAL.py index f08fd6d..08b9096 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -57,14 +57,14 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e Tensor = torch.cuda.FloatTensor for batch_idx, (dataE, dataH, energy) in enumerate(train_loader): - #real_dataECAL = dataE.to(device).unsqueeze(1) - real_dataECAL = torch.from_numpy(dataE).to(device).unsqueeze(1).float() + real_dataECAL = dataE.to(device).unsqueeze(1).float() - #real_dataHCAL = dataH.to(device).unsqueeze(1) - real_dataHCAL = torch.from_numpy(dataH).to(device).unsqueeze(1).float() - #real_label = energy.to(device) - real_label = torch.from_numpy(energy).to(device).float() + real_dataHCAL = dataH.to(device).unsqueeze(1).float() + + + real_label = energy.to(device).float() + optimizer_d.zero_grad() @@ -138,8 +138,8 @@ def is_distributed(): def parse_args(): parser = argparse.ArgumentParser(description='WGAN training on hadron showers') - parser.add_argument('--batch-size', type=int, default=100, metavar='N', - help='input batch size for training (default: 100)') + parser.add_argument('--batch-size', type=int, default=50, metavar='N', + help='input batch size for training (default: 50)') parser.add_argument('--nz', type=int, default=100, metavar='N', help='latent space for generator (default: 100)') -- GitLab