From 47cf9a21a90afc999508a7dd9753dfa479b0e61a Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Tue, 17 May 2022 15:56:46 +0200
Subject: [PATCH] reverting back and BS = 50

---
 models/data_loaderFull.py |  6 +++---
 wganHCAL.py               | 16 ++++++++--------
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/models/data_loaderFull.py b/models/data_loaderFull.py
index bc8d038..0302b3a 100644
--- a/models/data_loaderFull.py
+++ b/models/data_loaderFull.py
@@ -26,13 +26,13 @@ class HDF5Dataset(data.Dataset):
     def __getitem__(self, index):
         # get ECAL part
         x = self.get_data(index)
-        #x = torch.from_numpy(x).float()
+        x = torch.from_numpy(x).float()
         
         ## get HCAL part
         y = self.get_data_hcal(index)
-        #y = torch.from_numpy(y).float()
+        y = torch.from_numpy(y).float()
         
-        e = self.get_energy(index)
+        e = torch.from_numpy(self.get_energy(index))
         
         return x, y, e
     
diff --git a/wganHCAL.py b/wganHCAL.py
index f08fd6d..08b9096 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -57,14 +57,14 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, e
     Tensor = torch.cuda.FloatTensor 
    
     for batch_idx, (dataE, dataH, energy) in enumerate(train_loader):
-        #real_dataECAL = dataE.to(device).unsqueeze(1)
-        real_dataECAL = torch.from_numpy(dataE).to(device).unsqueeze(1).float()
+        real_dataECAL = dataE.to(device).unsqueeze(1).float()
         
-        #real_dataHCAL = dataH.to(device).unsqueeze(1)
-        real_dataHCAL = torch.from_numpy(dataH).to(device).unsqueeze(1).float()
         
-        #real_label = energy.to(device)
-        real_label = torch.from_numpy(energy).to(device).float()
+        real_dataHCAL = dataH.to(device).unsqueeze(1).float()
+        
+        
+        real_label = energy.to(device).float()
+        
         
         optimizer_d.zero_grad()
         
@@ -138,8 +138,8 @@ def is_distributed():
 
 def parse_args():
     parser = argparse.ArgumentParser(description='WGAN training on hadron showers')
-    parser.add_argument('--batch-size', type=int, default=100, metavar='N',
-                        help='input batch size for training (default: 100)')
+    parser.add_argument('--batch-size', type=int, default=50, metavar='N',
+                        help='input batch size for training (default: 50)')
     
     parser.add_argument('--nz', type=int, default=100, metavar='N',
                         help='latent space for generator (default: 100)')
-- 
GitLab