diff --git a/models/criticFull.py b/models/criticFull.py index de1c9b98202329010ce814d2849a7d6b1d7e28fa..438eee1baa4355e148367a75527ed698befa7802 100644 --- a/models/criticFull.py +++ b/models/criticFull.py @@ -43,13 +43,13 @@ class CriticEMB(nn.Module): self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, size_embed) - self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, size_embed*3) + self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, size_embed) self.econd_lin = torch.nn.Linear(1, size_embed) # label embedding - self.fc1 = torch.nn.Linear(size_embed*5, size_embed) # 3 components after cat - self.fc2 = torch.nn.Linear(size_embed, size_embed - 8) - self.fc3 = torch.nn.Linear(size_embed - 8, 1) + self.fc1 = torch.nn.Linear(size_embed*3, size_embed) # 3 components after cat + self.fc2 = torch.nn.Linear(size_embed, size_embed - 32) + self.fc3 = torch.nn.Linear(size_embed - 32, 1) def forward(self, img_ECAL, img_HCAL, E_true): diff --git a/models/generatorFull.py b/models/generatorFull.py index b1ab6da8ace07ca086ca35e434bce44fe6aecb1a..6178d92d0fb8840e910fb60e59845444be01a569 100644 --- a/models/generatorFull.py +++ b/models/generatorFull.py @@ -24,13 +24,13 @@ class Hcal_ecalEMB(nn.Module): self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14]) self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) - self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 8) + self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) - self.econd_lin = torch.nn.Linear(1, 16) # label embedding + self.econd_lin = torch.nn.Linear(1, 8) # label embedding - self.fc1 = torch.nn.Linear(24, 16) # 2 components after cat - self.fc2 = torch.nn.Linear(16, 16) - self.fc3 = torch.nn.Linear(16, emb_size) + self.fc1 = torch.nn.Linear(72, 64) # 2 components after cat + self.fc2 = torch.nn.Linear(64, 32) + self.fc3 = torch.nn.Linear(32, emb_size) ## HCAL component of convolutions diff --git a/wganHCAL.py b/wganHCAL.py index 1435ef161315efa9dd4a1990865be0dca93b1694..3696feb2627f6bf7357ada02de918cac6533c598 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -274,8 +274,8 @@ def run(args): ## HCAL Generator and critic - mCrit = CriticEMB(args.ndf).to(device) - mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=16).to(device) + mCrit = CriticEMB(args.ndf, emb_size=64).to(device) + mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=32).to(device) ## ECAL GENERATOR mGenE = DCGAN_G(args.ngf, args.nz).to(device)