From 3661db87f9d294434bf7896b09411128cd6f0f69 Mon Sep 17 00:00:00 2001 From: Engin Eren <engin.eren@desy.de> Date: Thu, 14 Jul 2022 10:53:06 +0200 Subject: [PATCH] restoring original arch. --- models/criticFull.py | 8 ++++---- models/generatorFull.py | 10 +++++----- wganHCAL.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/models/criticFull.py b/models/criticFull.py index 438eee1..382112d 100644 --- a/models/criticFull.py +++ b/models/criticFull.py @@ -5,7 +5,7 @@ import torch.nn.functional as F class CriticEMB(nn.Module): - def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=32, size_embed=16): + def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64, size_embed=64): super(CriticEMB, self).__init__() self.ndf = ndf self.isize_1 = isize_1 @@ -47,9 +47,9 @@ class CriticEMB(nn.Module): self.econd_lin = torch.nn.Linear(1, size_embed) # label embedding - self.fc1 = torch.nn.Linear(size_embed*3, size_embed) # 3 components after cat - self.fc2 = torch.nn.Linear(size_embed, size_embed - 32) - self.fc3 = torch.nn.Linear(size_embed - 32, 1) + self.fc1 = torch.nn.Linear(size_embed*3, size_embed*2) # 3 components after cat + self.fc2 = torch.nn.Linear(size_embed*2, size_embed) + self.fc3 = torch.nn.Linear(size_embed, 1) def forward(self, img_ECAL, img_HCAL, E_true): diff --git a/models/generatorFull.py b/models/generatorFull.py index 6178d92..d8c2d7d 100644 --- a/models/generatorFull.py +++ b/models/generatorFull.py @@ -9,7 +9,7 @@ class Hcal_ecalEMB(nn.Module): """ generator component of WGAN """ - def __init__(self, ngf, ndf, nz, emb_size): + def __init__(self, ngf, ndf, nz, emb_size=16): super(Hcal_ecalEMB, self).__init__() @@ -26,11 +26,11 @@ class Hcal_ecalEMB(nn.Module): self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) - self.econd_lin = torch.nn.Linear(1, 8) # label embedding + self.econd_lin = torch.nn.Linear(1, 64) # label embedding - self.fc1 = torch.nn.Linear(72, 64) # 2 components after cat - self.fc2 = torch.nn.Linear(64, 32) - self.fc3 = torch.nn.Linear(32, emb_size) + self.fc1 = torch.nn.Linear(64*2, 128) # 2 components after cat + self.fc2 = torch.nn.Linear(128, 64) + self.fc3 = torch.nn.Linear(64, emb_size) ## HCAL component of convolutions diff --git a/wganHCAL.py b/wganHCAL.py index 683a756..c32b557 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -274,8 +274,8 @@ def run(args): ## HCAL Generator and critic - mCrit = CriticEMB(args.ndf, size_embed=64).to(device) - mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=32).to(device) + mCrit = CriticEMB().to(device) + mGen = Hcal_ecalEMB(args.ngf, 32, args.nz).to(device) ## ECAL GENERATOR mGenE = DCGAN_G(args.ngf, args.nz).to(device) -- GitLab