Skip to content
Snippets Groups Projects
Commit 61830702 authored by Engin Eren's avatar Engin Eren
Browse files

new combinations of models

parent f963a38d
No related branches found
No related tags found
1 merge request!43crit peter
Pipeline #4224728 passed
...@@ -43,13 +43,13 @@ class CriticEMB(nn.Module): ...@@ -43,13 +43,13 @@ class CriticEMB(nn.Module):
self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, size_embed) self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, size_embed)
self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, size_embed*3) self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, size_embed)
self.econd_lin = torch.nn.Linear(1, size_embed) # label embedding self.econd_lin = torch.nn.Linear(1, size_embed) # label embedding
self.fc1 = torch.nn.Linear(size_embed*5, size_embed) # 3 components after cat self.fc1 = torch.nn.Linear(size_embed*3, size_embed) # 3 components after cat
self.fc2 = torch.nn.Linear(size_embed, size_embed - 8) self.fc2 = torch.nn.Linear(size_embed, size_embed - 32)
self.fc3 = torch.nn.Linear(size_embed - 8, 1) self.fc3 = torch.nn.Linear(size_embed - 32, 1)
def forward(self, img_ECAL, img_HCAL, E_true): def forward(self, img_ECAL, img_HCAL, E_true):
......
...@@ -24,13 +24,13 @@ class Hcal_ecalEMB(nn.Module): ...@@ -24,13 +24,13 @@ class Hcal_ecalEMB(nn.Module):
self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14]) self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14])
self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 8) self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64)
self.econd_lin = torch.nn.Linear(1, 16) # label embedding self.econd_lin = torch.nn.Linear(1, 8) # label embedding
self.fc1 = torch.nn.Linear(24, 16) # 2 components after cat self.fc1 = torch.nn.Linear(72, 64) # 2 components after cat
self.fc2 = torch.nn.Linear(16, 16) self.fc2 = torch.nn.Linear(64, 32)
self.fc3 = torch.nn.Linear(16, emb_size) self.fc3 = torch.nn.Linear(32, emb_size)
## HCAL component of convolutions ## HCAL component of convolutions
......
...@@ -274,8 +274,8 @@ def run(args): ...@@ -274,8 +274,8 @@ def run(args):
## HCAL Generator and critic ## HCAL Generator and critic
mCrit = CriticEMB(args.ndf).to(device) mCrit = CriticEMB(args.ndf, emb_size=64).to(device)
mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=16).to(device) mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=32).to(device)
## ECAL GENERATOR ## ECAL GENERATOR
mGenE = DCGAN_G(args.ngf, args.nz).to(device) mGenE = DCGAN_G(args.ngf, args.nz).to(device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment