diff --git a/models/data_loaderFull.py b/models/data_loaderFull.py index e157bf1043c18b1196c523f847bda4f4f40c8044..b95449d6596687d41b98189980fb8a41acc2c978 100644 --- a/models/data_loaderFull.py +++ b/models/data_loaderFull.py @@ -26,20 +26,14 @@ class HDF5Dataset(data.Dataset): def __getitem__(self, index): # get ECAL part x = self.get_data(index) - if self.transform: - x = torch.from_numpy(self.transform(x)).float() - else: - x = torch.from_numpy(x).float() + #x = torch.from_numpy(x).float() ## get HCAL part y = self.get_data_hcal(index) - if self.transform: - y = torch.from_numpy(self.transform(y)).float() - else: - y = torch.from_numpy(y).float() + #y = torch.from_numpy(y).float() - e = torch.from_numpy(self.get_energy(index)) + e = self.get_energy(index) if torch.sum(x) != torch.sum(x): #checks for NANs diff --git a/wganHCAL.py b/wganHCAL.py index ced8009df9bd765bdd2f23dc1cb48aa009e06f32..f08fd6d1962fb3c4a24f0e627da0fab59ff7d2b9 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -57,9 +57,14 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e Tensor = torch.cuda.FloatTensor for batch_idx, (dataE, dataH, energy) in enumerate(train_loader): - real_dataECAL = dataE.to(device).unsqueeze(1) - real_dataHCAL = dataH.to(device).unsqueeze(1) - real_label = energy.to(device) + #real_dataECAL = dataE.to(device).unsqueeze(1) + real_dataECAL = torch.from_numpy(dataE).to(device).unsqueeze(1).float() + + #real_dataHCAL = dataH.to(device).unsqueeze(1) + real_dataHCAL = torch.from_numpy(dataH).to(device).unsqueeze(1).float() + + #real_label = energy.to(device) + real_label = torch.from_numpy(energy).to(device).float() optimizer_d.zero_grad() @@ -69,10 +74,10 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e fake_dataHCAL = aG(z, real_label, real_dataECAL).detach() ## 48 x 30 x 30 ## Critic fwd pass on Real - disc_real = aD(real_dataECAL.float(), real_dataHCAL.float(), real_label.float()) + disc_real = aD(real_dataECAL, real_dataHCAL, real_label) ## Calculate Gradient Penalty Term - gradient_penalty = calc_gradient_penalty(aD, real_dataECAL.float(), real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30) + gradient_penalty = calc_gradient_penalty(aD, real_dataECAL, real_dataHCAL, fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30) ## Critic fwd pass on Fake disc_fake = aD(real_dataECAL, fake_dataHCAL.unsqueeze(1), real_label)