From c3a08a4a2e3dd9104aa6df31b87472774fdad659 Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Mon, 11 Jul 2022 14:06:09 +0200
Subject: [PATCH] Still RFF but with different emb size config

---
 models/criticFull.py    | 16 ++++++++--------
 models/generatorFull.py | 10 +++++-----
 wganHCAL.py             |  6 +++---
 3 files changed, 16 insertions(+), 16 deletions(-)

diff --git a/models/criticFull.py b/models/criticFull.py
index 16333c3..4051212 100644
--- a/models/criticFull.py
+++ b/models/criticFull.py
@@ -5,13 +5,13 @@ import torch.nn.functional as F
 
 
 class CriticEMB(nn.Module):
-    def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64):
+    def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=32, size_embed=16):
         super(CriticEMB, self).__init__()    
         self.ndf = ndf
         self.isize_1 = isize_1
         self.isize_2 = isize_2
         self.nc = nc
-        self.size_embed = 16
+        self.size_embed = size_embed
         self.conv1_bias = False
 
       
@@ -42,14 +42,14 @@ class CriticEMB(nn.Module):
         #self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
         
 
-        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) 
-        self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, 64) 
+        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, size_embed) 
+        self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, size_embed*3) 
         
-        self.econd_lin = torch.nn.Linear(1, 64) # label embedding
+        self.econd_lin = torch.nn.Linear(1, size_embed) # label embedding
 
-        self.fc1 = torch.nn.Linear(64*3, 128)  # 3 components after cat
-        self.fc2 = torch.nn.Linear(128,  64)
-        self.fc3 = torch.nn.Linear(64, 1)
+        self.fc1 = torch.nn.Linear(size_embed*5, size_embed)  # 3 components after cat
+        self.fc2 = torch.nn.Linear(size_embed,  size_embed / 2)
+        self.fc3 = torch.nn.Linear(size_embed / 2, 1)
 
 
     def forward(self, img_ECAL, img_HCAL, E_true):
diff --git a/models/generatorFull.py b/models/generatorFull.py
index 029b553..b1ab6da 100644
--- a/models/generatorFull.py
+++ b/models/generatorFull.py
@@ -24,13 +24,13 @@ class Hcal_ecalEMB(nn.Module):
         self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14])
         self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
         
-        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) 
+        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 8) 
         
-        self.econd_lin = torch.nn.Linear(1, 64) # label embedding
+        self.econd_lin = torch.nn.Linear(1, 16) # label embedding
 
-        self.fc1 = torch.nn.Linear(64*2, 128)  # 2 components after cat
-        self.fc2 = torch.nn.Linear(128,  64)
-        self.fc3 = torch.nn.Linear(64, emb_size)
+        self.fc1 = torch.nn.Linear(24, 16)  # 2 components after cat
+        self.fc2 = torch.nn.Linear(16,  16)
+        self.fc3 = torch.nn.Linear(16, emb_size)
         
         
         ## HCAL component of convolutions
diff --git a/wganHCAL.py b/wganHCAL.py
index 8f35c47..9526a71 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -161,8 +161,8 @@ def parse_args():
     parser.add_argument('--kappa', type=float, default=0.001, metavar='N',
                         help='weight of label conditioning  (default: 0.001)')
 
-    parser.add_argument('--ndf', type=int, default=64, metavar='N',
-                        help='n-feature of critic (default: 64)')
+    parser.add_argument('--ndf', type=int, default=32, metavar='N',
+                        help='n-feature of critic (default: 32)')
 
     parser.add_argument('--ngf', type=int, default=32, metavar='N',
                         help='n-feature of generator  (default: 32)')
@@ -274,7 +274,7 @@ def run(args):
 
 
     ## HCAL Generator and critic
-    mCrit = CriticEMB().to(device)
+    mCrit = CriticEMB(args.ndf).to(device)
     mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=16).to(device)
     
     ## ECAL GENERATOR
-- 
GitLab