Skip to content
Snippets Groups Projects
Commit 47cf9a21 authored by Engin Eren's avatar Engin Eren
Browse files

reverting back and BS = 50

parent 271991f1
No related branches found
No related tags found
1 merge request!3Test
Pipeline #3983825 passed
......@@ -26,13 +26,13 @@ class HDF5Dataset(data.Dataset):
def __getitem__(self, index):
# get ECAL part
x = self.get_data(index)
#x = torch.from_numpy(x).float()
x = torch.from_numpy(x).float()
## get HCAL part
y = self.get_data_hcal(index)
#y = torch.from_numpy(y).float()
y = torch.from_numpy(y).float()
e = self.get_energy(index)
e = torch.from_numpy(self.get_energy(index))
return x, y, e
......
......@@ -57,14 +57,14 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e
Tensor = torch.cuda.FloatTensor
for batch_idx, (dataE, dataH, energy) in enumerate(train_loader):
#real_dataECAL = dataE.to(device).unsqueeze(1)
real_dataECAL = torch.from_numpy(dataE).to(device).unsqueeze(1).float()
real_dataECAL = dataE.to(device).unsqueeze(1).float()
#real_dataHCAL = dataH.to(device).unsqueeze(1)
real_dataHCAL = torch.from_numpy(dataH).to(device).unsqueeze(1).float()
#real_label = energy.to(device)
real_label = torch.from_numpy(energy).to(device).float()
real_dataHCAL = dataH.to(device).unsqueeze(1).float()
real_label = energy.to(device).float()
optimizer_d.zero_grad()
......@@ -138,8 +138,8 @@ def is_distributed():
def parse_args():
parser = argparse.ArgumentParser(description='WGAN training on hadron showers')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
help='input batch size for training (default: 100)')
parser.add_argument('--batch-size', type=int, default=50, metavar='N',
help='input batch size for training (default: 50)')
parser.add_argument('--nz', type=int, default=100, metavar='N',
help='latent space for generator (default: 100)')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment