From 3661db87f9d294434bf7896b09411128cd6f0f69 Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Thu, 14 Jul 2022 10:53:06 +0200
Subject: [PATCH] restoring original arch.

---
 models/criticFull.py    |  8 ++++----
 models/generatorFull.py | 10 +++++-----
 wganHCAL.py             |  4 ++--
 3 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/models/criticFull.py b/models/criticFull.py
index 438eee1..382112d 100644
--- a/models/criticFull.py
+++ b/models/criticFull.py
@@ -5,7 +5,7 @@ import torch.nn.functional as F
 
 
 class CriticEMB(nn.Module):
-    def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=32, size_embed=16):
+    def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64, size_embed=64):
         super(CriticEMB, self).__init__()    
         self.ndf = ndf
         self.isize_1 = isize_1
@@ -47,9 +47,9 @@ class CriticEMB(nn.Module):
         
         self.econd_lin = torch.nn.Linear(1, size_embed) # label embedding
 
-        self.fc1 = torch.nn.Linear(size_embed*3, size_embed)  # 3 components after cat
-        self.fc2 = torch.nn.Linear(size_embed,  size_embed - 32)
-        self.fc3 = torch.nn.Linear(size_embed - 32, 1)
+        self.fc1 = torch.nn.Linear(size_embed*3, size_embed*2)  # 3 components after cat
+        self.fc2 = torch.nn.Linear(size_embed*2,  size_embed)
+        self.fc3 = torch.nn.Linear(size_embed, 1)
 
 
     def forward(self, img_ECAL, img_HCAL, E_true):
diff --git a/models/generatorFull.py b/models/generatorFull.py
index 6178d92..d8c2d7d 100644
--- a/models/generatorFull.py
+++ b/models/generatorFull.py
@@ -9,7 +9,7 @@ class Hcal_ecalEMB(nn.Module):
     """ 
         generator component of WGAN
     """
-    def __init__(self, ngf, ndf, nz, emb_size):
+    def __init__(self, ngf, ndf, nz, emb_size=16):
         super(Hcal_ecalEMB, self).__init__()
         
        
@@ -26,11 +26,11 @@ class Hcal_ecalEMB(nn.Module):
         
         self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) 
         
-        self.econd_lin = torch.nn.Linear(1, 8) # label embedding
+        self.econd_lin = torch.nn.Linear(1, 64) # label embedding
 
-        self.fc1 = torch.nn.Linear(72, 64)  # 2 components after cat
-        self.fc2 = torch.nn.Linear(64,  32)
-        self.fc3 = torch.nn.Linear(32, emb_size)
+        self.fc1 = torch.nn.Linear(64*2, 128)  # 2 components after cat
+        self.fc2 = torch.nn.Linear(128,  64)
+        self.fc3 = torch.nn.Linear(64, emb_size)
         
         
         ## HCAL component of convolutions
diff --git a/wganHCAL.py b/wganHCAL.py
index 683a756..c32b557 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -274,8 +274,8 @@ def run(args):
 
 
     ## HCAL Generator and critic
-    mCrit = CriticEMB(args.ndf, size_embed=64).to(device)
-    mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=32).to(device)
+    mCrit = CriticEMB().to(device)
+    mGen = Hcal_ecalEMB(args.ngf, 32, args.nz).to(device)
     
     ## ECAL GENERATOR
     mGenE = DCGAN_G(args.ngf, args.nz).to(device)
-- 
GitLab