diff --git a/.gitignore b/.gitignore
index 0fd481a5848990b0d291057d41bf247214afbd7a..d60834e3613e5f32dfb89da72b6ef61b9e8f629a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,8 @@ models/.ipynb_checkpoints
 models/.ipynb_checkpoints/*
 interactive/__pycache__/
 models/__pycache__/
+__pycache__/
+API_keys.py
 interactive/jsd/
 interactive/plots/
 interactive/*.png
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 0bf3e8f3bda6a7774980fc6093459f23ba37ff46..001cc8d20ae1965ba62c4f1730e4211688fa1497 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -3,7 +3,7 @@ build_kaniko_command:
     variables:
       # To push to a specific docker tag other than latest(the default), amend the --destination parameter, e.g. --destination $CI_REGISTRY_IMAGE:$CI_BUILD_REF_NAME
       # See https://docs.gitlab.com/ee/ci/variables/predefined_variables.html#variables-reference for available variables
-      IMAGE_DESTINATION: ${CI_REGISTRY_IMAGE}:Peter_image
+      IMAGE_DESTINATION: ${CI_REGISTRY_IMAGE}:SingleGen
     image: 
         # We recommend using the CERN version of the Kaniko image: gitlab-registry.cern.ch/ci-tools/docker-image-builder
         name: gitlab-registry.cern.ch/ci-tools/docker-image-builder
@@ -16,9 +16,3 @@ build_kaniko_command:
         # Print the full registry path of the pushed image
         - echo "Image pushed successfully to ${IMAGE_DESTINATION}"
 
-check_python_run:
-  stage: test
-  image: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime
-  script: 
-    - pip install pyflakes
-    - pyflakes $CI_PROJECT_DIR/wgan_ECAL_HCAL_3crit.py
diff --git a/Dockerfile b/Dockerfile
index f31165b449627cd83e471d7ad02b8a25accc2f9a..5279f8e093f67438bff4f268a338b157b586a146 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -24,10 +24,12 @@ RUN  mkdir -p ${HOME}/models \
 
 WORKDIR ${HOME}
 
+ADD wganSingleGen.py ${HOME}/wganSingleGen.py
 ADD wgan.py ${HOME}/wgan.py
 ADD wganHCAL.py ${HOME}/wganHCAL.py
 ADD wgan_ECAL_HCAL_3crit.py ${HOME}/wgan_ECAL_HCAL_3crit.py
 
+
 COPY ./models/* ${HOME}/models/
 COPY docker/krb5.conf /etc/krb5.conf
 
diff --git a/gitlab-ci.yml b/gitlab-ci.yml
deleted file mode 100644
index 613873bbfc02280b1ec2284b2af070063998e48c..0000000000000000000000000000000000000000
--- a/gitlab-ci.yml
+++ /dev/null
@@ -1,18 +0,0 @@
-build_kaniko_command:
-    stage: build
-    variables:
-      # To push to a specific docker tag other than latest(the default), amend the --destination parameter, e.g. --destination $CI_REGISTRY_IMAGE:$CI_BUILD_REF_NAME
-      # See https://docs.gitlab.com/ee/ci/variables/predefined_variables.html#variables-reference for available variables
-      IMAGE_DESTINATION: ${CI_REGISTRY_IMAGE}:latest
-    image: 
-        # We recommend using the CERN version of the Kaniko image: gitlab-registry.cern.ch/ci-tools/docker-image-builder
-        name: gitlab-registry.cern.ch/ci-tools/docker-image-builder
-        entrypoint: [""]
-    script:
-        # Prepare Kaniko configuration file
-        - echo "{\"auths\":{\"$CI_REGISTRY\":{\"username\":\"$CI_REGISTRY_USER\",\"password\":\"$CI_REGISTRY_PASSWORD\"}}}" > /kaniko/.docker/config.json
-        # Build and push the image from the Dockerfile at the root of the project.
-        - /kaniko/executor --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $IMAGE_DESTINATION
-        # Print the full registry path of the pushed image
-        - echo "Image pushed successfully to ${IMAGE_DESTINATION}"
-
diff --git a/interactive/model_testing.ipynb b/interactive/model_testing.ipynb
index a6219d4d1c6fd826b193c27ffbf009d531da3d91..448893be473c088cbc5ba5fca0033f1462aa5dc4 100644
--- a/interactive/model_testing.ipynb
+++ b/interactive/model_testing.ipynb
@@ -11,7 +11,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -23,75 +23,112 @@
     "from torch.utils import data\n",
     "\n",
     "      \n",
-    "class DCGAN_G(nn.Module):\n",
-    "    \"\"\" \n",
-    "        generator component of WGAN\n",
+    "class Combined_Generator(nn.Module):\n",
+    "    \"\"\"\n",
+    "    combined generator for WGAN that generates both ECAL and HCAL components of shower, with a branching structure\n",
+    "    \n",
     "    \"\"\"\n",
+    "    \n",
     "    def __init__(self, ngf, nz):\n",
-    "        super(DCGAN_G, self).__init__()\n",
+    "        super(Combined_Generator, self).__init__()\n",
     "        \n",
     "        self.ngf = ngf\n",
     "        self.nz = nz\n",
-    "\n",
-    "        kernel = 4\n",
     "        \n",
-    "        # input energy shape [batch x 1 x 1 x 1 ] going into convolutional\n",
-    "        self.conv1_1 = nn.ConvTranspose3d(1, ngf*4, kernel, 1, 0, bias=False)\n",
-    "        # state size [ ngf*4 x 4 x 4 x 4]\n",
+    "        # input: cat(z, E)\n",
     "        \n",
-    "        # input noise shape [batch x nz x 1 x 1] going into convolutional\n",
-    "        self.conv1_100 = nn.ConvTranspose3d(nz, ngf*4, kernel, 1, 0, bias=False)\n",
-    "        # state size [ ngf*4 x 4 x 4 x 4]\n",
+    "        # fully connected layers for label embedding\n",
+    "        self.cond1 = torch.nn.Linear(self.nz+1, int(self.nz*1.5), bias=True)   # +1 for both the energy label\n",
+    "        self.cond2 = torch.nn.Linear(int(self.nz*1.5), ngf*200, bias=True)\n",
+    "        self.cond3 = torch.nn.Linear(ngf*200, 500*ngf, bias=True)\n",
+    "        self.cond4 = torch.nn.Linear(ngf*500, 250*ngf, bias=True)\n",
     "        \n",
+    "        ####### ECAL Branch #############\n",
+    "        # ECAL transpose convolutions to desired shape\n",
+    "        # input [5, 5, 5]\n",
+    "        self.ECALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,5,5), stride=(2,2,2), padding=0, bias=False)\n",
+    "        self.ECALlndc1 = torch.nn.LayerNorm([13,13,13])\n",
+    "        self.ECALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(4,4,4), stride=(2,2,2), padding=0, bias=False)\n",
+    "        self.ECALlndc2 = torch.nn.LayerNorm([28,28,28])\n",
+    "        self.ECALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(3,3,3), stride=(1,1,1), padding=0, bias=False)\n",
+    "        self.ECALlndc3 = torch.nn.LayerNorm([30,30,30])\n",
+    "        \n",
+    "        # ECAL conv layers\n",
+    "        self.ECALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n",
+    "        self.ECALlnc1 = torch.nn.LayerNorm([30,30,30])\n",
+    "        self.ECALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n",
+    "        self.ECALlnc2 = torch.nn.LayerNorm([30,30,30])\n",
+    "        self.ECALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) \n",
+    "        self.ECALlnc3 = torch.nn.LayerNorm([30,30,30])\n",
+    "        self.ECALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n",
+    "        \n",
+    "        ####### HCAL Branch ############\n",
+    "        # HCAL transpose convolutions to deisred shape\n",
+    "        # input [5, 5, 5]\n",
+    "        self.HCALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(6,5,5), stride=(2,2,2), padding=0, bias=False)\n",
+    "        self.HCALlndc1 = torch.nn.LayerNorm([14,13,13])\n",
+    "        self.HCALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,4,4), stride=(3,2,2), padding=0, bias=False)\n",
+    "        self.HCALlndc2 = torch.nn.LayerNorm([44,28,28])\n",
+    "        self.HCALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,3,3), stride=(1,1,1), padding=0, bias=False)\n",
+    "        self.HCALlndc3 = torch.nn.LayerNorm([48,30,30])\n",
+    "        \n",
+    "        # HCAL conv layers\n",
+    "        self.HCALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n",
+    "        self.HCALlnc1 = torch.nn.LayerNorm([48,30,30])\n",
+    "        self.HCALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n",
+    "        self.HCALlnc2 = torch.nn.LayerNorm([48,30,30])\n",
+    "        self.HCALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) \n",
+    "        self.HCALlnc3 = torch.nn.LayerNorm([48,30,30])\n",
+    "        self.HCALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n",
     "        \n",
-    "        # outs from first convolutions concatenate state size [ ngf*8 x 4 x 4]\n",
-    "        # and going into main convolutional part of Generator\n",
-    "        self.main_conv = nn.Sequential(\n",
-    "            \n",
-    "            nn.ConvTranspose3d(ngf*8, ngf*4, kernel, 2, 1, bias=False),\n",
-    "            nn.LayerNorm([8, 8, 8]),\n",
-    "            nn.ReLU(),\n",
-    "            # state shape [ (ndf*4) x 8 x 8 ]\n",
-    "\n",
-    "            nn.ConvTranspose3d(ngf*4, ngf*2, kernel, 2, 1, bias=False),\n",
-    "            nn.LayerNorm([16, 16, 16]),\n",
-    "            nn.ReLU(),\n",
-    "            # state shape [ (ndf*2) x 16 x 16 ]\n",
-    "\n",
-    "            nn.ConvTranspose3d(ngf*2, ngf, kernel, 2, 1, bias=False),\n",
-    "            nn.LayerNorm([32, 32, 32]),\n",
-    "            nn.ReLU(),\n",
-    "            # state shape [ (ndf) x 32 x 32 ]\n",
-    "\n",
-    "            nn.ConvTranspose3d(ngf, 1, 3, 1, 2, bias=False),\n",
-    "            nn.ReLU()\n",
-    "            # state shape [ 1 x 30 x 30 x 30 ]\n",
-    "        )\n",
-    "\n",
     "    def forward(self, noise, energy):\n",
-    "        energy_trans = self.conv1_1(energy)\n",
-    "        noise_trans = self.conv1_100(noise)\n",
-    "        input = torch.cat((energy_trans, noise_trans), 1)\n",
-    "        x = self.main_conv(input)\n",
-    "        x = x.view(-1, 30, 30, 30)\n",
-    "        return x\n",
-    "\n",
-    "\n",
-    "def weights_init(m):\n",
-    "    classname = m.__class__.__name__\n",
-    "    if classname.find('Conv') != -1:\n",
-    "        nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
-    "    elif classname.find('LayerNorm') != -1:\n",
-    "        nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
-    "        nn.init.constant_(m.bias.data, 0)\n",
-    "    elif classname.find('Linear') != -1:\n",
-    "        m.weight.data.normal_(0.0, 0.02)\n",
-    "        m.bias.data.fill_(0)"
+    "        z = torch.cat((noise, energy), 1)\n",
+    "        z = z.view(-1, self.nz+1)\n",
+    "        \n",
+    "        # label embedding\n",
+    "        x = F.leaky_relu(self.cond1(z), 0.2, inplace=True)\n",
+    "        x = F.leaky_relu(self.cond2(x), 0.2, inplace=True)\n",
+    "        x = F.leaky_relu(self.cond3(x), 0.2, inplace=True)\n",
+    "        x = F.leaky_relu(self.cond4(x), 0.2, inplace=True)\n",
+    "        \n",
+    "        # split x to spread it between ECAL and HCAL generator branches\n",
+    "        x_ECAL, x_HCAL = torch.tensor_split(x, 2)\n",
+    "        x_ECAL = x_ECAL.view(-1, self.ngf, 5, 5, 5)\n",
+    "        x_HCAL = x_HCAL.view(-1, self.ngf, 5, 5, 5)\n",
+    "        \n",
+    "        ########## ECAL Branch #########\n",
+    "        # ECAL deconvolutions up to desired shape\n",
+    "        x_ECAL = F.leaky_relu(self.ECALlndc1(self.ECALdeconv1(x_ECAL)), 0.2, inplace=True)\n",
+    "        x_ECAL = F.leaky_relu(self.ECALlndc2(self.ECALdeconv2(x_ECAL)), 0.2, inplace=True)\n",
+    "        x_ECAL = F.leaky_relu(self.ECALlndc3(self.ECALdeconv3(x_ECAL)), 0.2, inplace=True)\n",
+    "        \n",
+    "        # ECAL convolutions\n",
+    "        x_ECAL = F.leaky_relu(self.ECALlnc1(self.ECALconv1(x_ECAL)), 0.2, inplace=True)\n",
+    "        x_ECAL = F.leaky_relu(self.ECALlnc2(self.ECALconv2(x_ECAL)), 0.2, inplace=True)\n",
+    "        x_ECAL = F.leaky_relu(self.ECALlnc3(self.ECALconv3(x_ECAL)), 0.2, inplace=True)\n",
+    "        x_ECAL = F.relu(self.ECALconv4(x_ECAL), inplace=True)\n",
+    "        \n",
+    "        ######### HCAL Branch ##########\n",
+    "        # HCAL deconvolutions up to desired shape\n",
+    "        x_HCAL = F.leaky_relu(self.HCALlndc1(self.HCALdeconv1(x_HCAL)), 0.2, inplace=True)\n",
+    "        x_HCAL = F.leaky_relu(self.HCALlndc2(self.HCALdeconv2(x_HCAL)), 0.2, inplace=True)\n",
+    "        x_HCAL = F.leaky_relu(self.HCALlndc3(self.HCALdeconv3(x_HCAL)), 0.2, inplace=True)\n",
+    "        \n",
+    "        # HCAL convolutions\n",
+    "        x_HCAL = F.leaky_relu(self.HCALlnc1(self.HCALconv1(x_HCAL)), 0.2, inplace=True)\n",
+    "        x_HCAL = F.leaky_relu(self.HCALlnc2(self.HCALconv2(x_HCAL)), 0.2, inplace=True)\n",
+    "        x_HCAL = F.leaky_relu(self.HCALlnc3(self.HCALconv3(x_HCAL)), 0.2, inplace=True)\n",
+    "        x_HCAL = F.relu(self.HCALconv4(x_HCAL), inplace=True)\n",
+    "        \n",
+    "        #x_ECAL = x_ECAL.view(-1, 30, 30, 30)\n",
+    "        #x_HCAL = x_HCAL.view(-1, 48, 30, 30)\n",
+    "        \n",
+    "        return x_ECAL, x_HCAL"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -100,7 +137,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -117,7 +154,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -128,33 +165,54 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
-    "aG = DCGAN_G(16, LATENT).to(device)\n",
+    "aG = Combined_Generator(8, LATENT).to(device)\n",
     "\n",
     "with torch.no_grad():\n",
     "    noisev = noise  # totally freeze G, training D\n",
-    "\n",
-    "fake_data = aG(noisev, real_label).detach()\n",
     "            "
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 17,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([10, 1, 30, 30, 30])"
+      ]
+     },
+     "execution_count": 17,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
-    "fake_data.shape"
+    "fake_ecal, fake_hcal = aG(noisev, real_label)\n",
+    "fake_ecal.shape"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 18,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "15309618"
+      ]
+     },
+     "execution_count": 18,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "sum(p.numel() for p in aG.parameters() if p.requires_grad)\n",
     "#aG"
@@ -263,12 +321,95 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Global_Discriminator(nn.Module):\n",
+    "    def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64):\n",
+    "        super(Global_Discriminator, self).__init__()    \n",
+    "        self.ndf = ndf\n",
+    "        self.isize_1 = isize_1\n",
+    "        self.isize_2 = isize_2\n",
+    "        self.nc = nc\n",
+    "        self.size_embed = 16\n",
+    "        self.conv1_bias = False\n",
+    "\n",
+    "      \n",
+    "        \n",
+    "        # ECAL component of convolutions\n",
+    "        # Designed for input 30*30*30\n",
+    "        self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False)\n",
+    "        self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29])\n",
+    "        self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)\n",
+    "        self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14])\n",
+    "        self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)\n",
+    "        # output [batch_size, ndf, 7, 7, 7]\n",
+    "        \n",
+    "        # HCAL component of convolutions\n",
+    "        # Designed for input 48*30*30\n",
+    "        self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=(5,0,0), bias=False)\n",
+    "        self.ln_HCAL_1 = torch.nn.LayerNorm([29,29,29])\n",
+    "        self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)\n",
+    "        self.ln_HCAL_2 = torch.nn.LayerNorm([14,14,14])\n",
+    "        self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)\n",
+    "        # output [batch_size, ndf, 7, 7, 7]\n",
+    "        \n",
+    "        # alternative structure for 48*25*25 HCAL\n",
+    "        #self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=0, bias=False)\n",
+    "        #self.ln_HCAL_1 = torch.nn.LayerNorm([24,24,24])\n",
+    "        #self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)\n",
+    "        #self.ln_HCAL_2 = torch.nn.LayerNorm([12,12,12])\n",
+    "        #self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)\n",
+    "        \n",
+    "\n",
+    "        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) \n",
+    "        self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, 64) \n",
+    "        \n",
+    "        self.econd_lin = torch.nn.Linear(1, 64) # label embedding\n",
+    "\n",
+    "        self.fc1 = torch.nn.Linear(64*3, 128)  # 3 components after cat\n",
+    "        self.fc2 = torch.nn.Linear(128,  64)\n",
+    "        self.fc3 = torch.nn.Linear(64, 1)\n",
+    "\n",
+    "\n",
+    "    def forward(self, img_ECAL, img_HCAL, E_true):\n",
+    "        batch_size = img_ECAL.size(0)\n",
+    "        # input: img_ECAL = [batch_size, 1, 30, 30, 30]\n",
+    "        #        img_HCAL = [batch_size, 1, 48, 30, 30] \n",
+    "        \n",
+    "        # ECAL \n",
+    "        x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2)\n",
+    "        x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2)\n",
+    "        x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2)\n",
+    "        x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7)\n",
+    "        x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2)\n",
+    "        \n",
+    "        # HCAL\n",
+    "        x_HCAL = F.leaky_relu(self.ln_HCAL_1(self.conv_HCAL_1(img_HCAL)), 0.2)\n",
+    "        x_HCAL = F.leaky_relu(self.ln_HCAL_2(self.conv_HCAL_2(x_HCAL)), 0.2)\n",
+    "        x_HCAL = F.leaky_relu(self.conv_HCAL_3(x_HCAL), 0.2)\n",
+    "        x_HCAL = x_HCAL.view(-1, self.ndf*7*7*7)\n",
+    "        x_HCAL = F.leaky_relu(self.conv_lin_HCAL(x_HCAL), 0.2)        \n",
+    "        \n",
+    "        x_E = F.leaky_relu(self.econd_lin(E_true), 0.2)\n",
+    "        \n",
+    "        xa = torch.cat((x_ECAL, x_HCAL, x_E), 1)\n",
+    "        \n",
+    "        xa = F.leaky_relu(self.fc1(xa), 0.2)\n",
+    "        xa = F.leaky_relu(self.fc2(xa), 0.2)\n",
+    "        xa = self.fc3(xa)\n",
+    "        \n",
+    "        return xa ### flattens"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
    "metadata": {},
    "outputs": [],
    "source": [
-    "from criticRes import *\n",
-    "#import criticRes"
+    "aD = Global_Discriminator(ndf=128)"
    ]
   },
   {
@@ -277,7 +418,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "aD = generate_model(34).to(device)\n",
+    "\n",
     "for batch_idx, (data, energy) in enumerate(train_loader):\n",
     "    #print (data.shape, energy.shape)\n",
     "    data = data.to(device).unsqueeze(1)  \n",
@@ -290,9 +431,20 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 25,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "8122869"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "sum(p.numel() for p in aD.parameters() if p.requires_grad)\n"
    ]
@@ -408,7 +560,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.6.9"
+   "version": "3.8.10"
   }
  },
  "nbformat": 4,
diff --git a/models/combined_generator_split.py b/models/combined_generator_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..2acac208656627fe461fbce61544e54c71aea2bf
--- /dev/null
+++ b/models/combined_generator_split.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.nn.functional as F
+from torch.autograd import Variable
+from torch.utils import data
+
+class Combined_Generator(nn.Module):
+    """
+    combined generator for WGAN that generates both ECAL and HCAL components of shower, with a branching structure
+    
+    """
+    
+    def __init__(self, ngf, nz):
+        super(Combined_Generator, self).__init__()
+        
+        self.ngf = ngf
+        self.nz = nz
+        
+        # input: cat(z, E)
+        
+        # fully connected layers for label embedding
+        self.cond1 = torch.nn.Linear(self.nz+1, int(self.nz*1.5), bias=True)   # +1 for both the energy label
+        self.cond2 = torch.nn.Linear(int(self.nz*1.5), ngf*200, bias=True)
+        self.cond3 = torch.nn.Linear(ngf*200, 500*ngf, bias=True)
+        self.cond4 = torch.nn.Linear(ngf*500, 250*ngf, bias=True)
+        
+        ####### ECAL Branch #############
+        # ECAL transpose convolutions to desired shape
+        # input [5, 5, 5]
+        self.ECALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,5,5), stride=(2,2,2), padding=0, bias=False)
+        self.ECALlndc1 = torch.nn.LayerNorm([13,13,13])
+        self.ECALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(4,4,4), stride=(2,2,2), padding=0, bias=False)
+        self.ECALlndc2 = torch.nn.LayerNorm([28,28,28])
+        self.ECALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(3,3,3), stride=(1,1,1), padding=0, bias=False)
+        self.ECALlndc3 = torch.nn.LayerNorm([30,30,30])
+        
+        # ECAL conv layers
+        self.ECALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)
+        self.ECALlnc1 = torch.nn.LayerNorm([30,30,30])
+        self.ECALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)
+        self.ECALlnc2 = torch.nn.LayerNorm([30,30,30])
+        self.ECALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) 
+        self.ECALlnc3 = torch.nn.LayerNorm([30,30,30])
+        self.ECALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)
+        
+        ####### HCAL Branch ############
+        # HCAL transpose convolutions to deisred shape
+        # input [5, 5, 5]
+        self.HCALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(6,5,5), stride=(2,2,2), padding=0, bias=False)
+        self.HCALlndc1 = torch.nn.LayerNorm([14,13,13])
+        self.HCALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,4,4), stride=(3,2,2), padding=0, bias=False)
+        self.HCALlndc2 = torch.nn.LayerNorm([44,28,28])
+        self.HCALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,3,3), stride=(1,1,1), padding=0, bias=False)
+        self.HCALlndc3 = torch.nn.LayerNorm([48,30,30])
+        
+        # HCAL conv layers
+        self.HCALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)
+        self.HCALlnc1 = torch.nn.LayerNorm([48,30,30])
+        self.HCALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)
+        self.HCALlnc2 = torch.nn.LayerNorm([48,30,30])
+        self.HCALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) 
+        self.HCALlnc3 = torch.nn.LayerNorm([48,30,30])
+        self.HCALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)
+        
+    def forward(self, noise, energy):
+        z = torch.cat((noise, energy), 1)
+        z = z.view(-1, self.nz+1)
+        
+        # label embedding
+        x = F.leaky_relu(self.cond1(z), 0.2, inplace=True)
+        x = F.leaky_relu(self.cond2(x), 0.2, inplace=True)
+        x = F.leaky_relu(self.cond3(x), 0.2, inplace=True)
+        x = F.leaky_relu(self.cond4(x), 0.2, inplace=True)
+        
+        # split x to spread it between ECAL and HCAL generator branches
+        x_ECAL, x_HCAL = torch.tensor_split(x, 2)
+        x_ECAL = x_ECAL.view(-1, self.ngf, 5, 5, 5)
+        x_HCAL = x_HCAL.view(-1, self.ngf, 5, 5, 5)
+        
+        ########## ECAL Branch #########
+        # ECAL deconvolutions up to desired shape
+        x_ECAL = F.leaky_relu(self.ECALlndc1(self.ECALdeconv1(x_ECAL)), 0.2, inplace=True)
+        x_ECAL = F.leaky_relu(self.ECALlndc2(self.ECALdeconv2(x_ECAL)), 0.2, inplace=True)
+        x_ECAL = F.leaky_relu(self.ECALlndc3(self.ECALdeconv3(x_ECAL)), 0.2, inplace=True)
+        
+        # ECAL convolutions
+        x_ECAL = F.leaky_relu(self.ECALlnc1(self.ECALconv1(x_ECAL)), 0.2, inplace=True)
+        x_ECAL = F.leaky_relu(self.ECALlnc2(self.ECALconv2(x_ECAL)), 0.2, inplace=True)
+        x_ECAL = F.leaky_relu(self.ECALlnc3(self.ECALconv3(x_ECAL)), 0.2, inplace=True)
+        x_ECAL = F.relu(self.ECALconv4(x_ECAL), inplace=True)
+        
+        ######### HCAL Branch ##########
+        # HCAL deconvolutions up to desired shape
+        x_HCAL = F.leaky_relu(self.HCALlndc1(self.HCALdeconv1(x_HCAL)), 0.2, inplace=True)
+        x_HCAL = F.leaky_relu(self.HCALlndc2(self.HCALdeconv2(x_HCAL)), 0.2, inplace=True)
+        x_HCAL = F.leaky_relu(self.HCALlndc3(self.HCALdeconv3(x_HCAL)), 0.2, inplace=True)
+        
+        # HCAL convolutions
+        x_HCAL = F.leaky_relu(self.HCALlnc1(self.HCALconv1(x_HCAL)), 0.2, inplace=True)
+        x_HCAL = F.leaky_relu(self.HCALlnc2(self.HCALconv2(x_HCAL)), 0.2, inplace=True)
+        x_HCAL = F.leaky_relu(self.HCALlnc3(self.HCALconv3(x_HCAL)), 0.2, inplace=True)
+        x_HCAL = F.relu(self.HCALconv4(x_HCAL), inplace=True)
+        
+        #x_ECAL = x_ECAL.view(-1, 30, 30, 30)
+        #x_HCAL = x_HCAL.view(-1, 48, 30, 30)
+        
+        return x_ECAL, x_HCAL
+        
\ No newline at end of file
diff --git a/models/global_disc.py b/models/global_disc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f370f5d57dd594cb2e8f900fe018ef7942807d0
--- /dev/null
+++ b/models/global_disc.py
@@ -0,0 +1,86 @@
+import numpy as np
+import argparse
+import torch
+import torch.utils.data
+from torch import nn, optim
+from torch.nn import functional as F
+from torchvision import datasets, transforms
+from torchvision.utils import save_image
+from torch import autograd  
+
+class Global_Discriminator(nn.Module):
+    def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64):
+        super(Global_Discriminator, self).__init__()    
+        self.ndf = ndf
+        self.isize_1 = isize_1
+        self.isize_2 = isize_2
+        self.nc = nc
+        self.size_embed = 16
+        self.conv1_bias = False
+
+      
+        
+        # ECAL component of convolutions
+        # Designed for input 30*30*30
+        self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False)
+        self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29])
+        self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)
+        self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14])
+        self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
+        # output [batch_size, ndf, 7, 7, 7]
+        
+        # HCAL component of convolutions
+        # Designed for input 48*30*30
+        self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=(5,0,0), bias=False)
+        self.ln_HCAL_1 = torch.nn.LayerNorm([29,29,29])
+        self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)
+        self.ln_HCAL_2 = torch.nn.LayerNorm([14,14,14])
+        self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
+        # output [batch_size, ndf, 7, 7, 7]
+        
+        # alternative structure for 48*25*25 HCAL
+        #self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=0, bias=False)
+        #self.ln_HCAL_1 = torch.nn.LayerNorm([24,24,24])
+        #self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)
+        #self.ln_HCAL_2 = torch.nn.LayerNorm([12,12,12])
+        #self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
+        
+
+        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) 
+        self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, 64) 
+        
+        self.econd_lin = torch.nn.Linear(1, 64) # label embedding
+
+        self.fc1 = torch.nn.Linear(64*3, 128)  # 3 components after cat
+        self.fc2 = torch.nn.Linear(128,  64)
+        self.fc3 = torch.nn.Linear(64, 1)
+
+
+    def forward(self, img_ECAL, img_HCAL, E_true):
+        batch_size = img_ECAL.size(0)
+        # input: img_ECAL = [batch_size, 1, 30, 30, 30]
+        #        img_HCAL = [batch_size, 1, 48, 30, 30] 
+        
+        # ECAL 
+        x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2)
+        x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2)
+        x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2)
+        x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7)
+        x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2)
+        
+        # HCAL
+        x_HCAL = F.leaky_relu(self.ln_HCAL_1(self.conv_HCAL_1(img_HCAL)), 0.2)
+        x_HCAL = F.leaky_relu(self.ln_HCAL_2(self.conv_HCAL_2(x_HCAL)), 0.2)
+        x_HCAL = F.leaky_relu(self.conv_HCAL_3(x_HCAL), 0.2)
+        x_HCAL = x_HCAL.view(-1, self.ndf*7*7*7)
+        x_HCAL = F.leaky_relu(self.conv_lin_HCAL(x_HCAL), 0.2)        
+        
+        x_E = F.leaky_relu(self.econd_lin(E_true), 0.2)
+        
+        xa = torch.cat((x_ECAL, x_HCAL, x_E), 1)
+        
+        xa = F.leaky_relu(self.fc1(xa), 0.2)
+        xa = F.leaky_relu(self.fc2(xa), 0.2)
+        xa = self.fc3(xa)
+        
+        return xa ### flattens
\ No newline at end of file
diff --git a/pytorch_job_wganSingleGen_ncc.yaml b/pytorch_job_wganSingleGen_ncc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b980cec96d582d2027f729e672cf0caaded7cfa3
--- /dev/null
+++ b/pytorch_job_wganSingleGen_ncc.yaml
@@ -0,0 +1,76 @@
+apiVersion: "kubeflow.org/v1"
+kind: "PyTorchJob"
+metadata:
+  name: "pytorch-dist-wganhcal-nccl"
+spec:
+  pytorchReplicaSpecs:
+    Master:
+      replicas: 1
+      restartPolicy: OnFailure
+      template:
+        metadata:
+          labels:
+            mount-kerberos-secret: "true"
+            mount-eos: "true"
+            mount-nvidia-driver: "true"
+          annotations:
+            sidecar.istio.io/inject: "false"
+        spec: 
+          affinity:
+              nodeAffinity:
+                requiredDuringSchedulingIgnoredDuringExecution:
+                  nodeSelectorTerms:
+                  - matchExpressions:
+                    - key: magnum.openstack.org/nodegroup
+                      operator: In
+                      values:
+                      - v100
+                      - v100s
+          containers:
+            - name: pytorch
+              image: gitlab-registry.cern.ch/eneren/pytorchjob:SingleGen
+              imagePullPolicy: Always
+              env:
+              - name: PYTHONUNBUFFERED
+                value: "1"
+              command: [sh, -c] 
+              args:  
+                - python -u wganSingleGen.py --backend nccl --epochs 50 --exp wganSingleGenV1 --batch-size 64 --ncrit 4  
+              resources: 
+                limits:
+                  nvidia.com/gpu: 1
+    Worker:
+      replicas: 4
+      restartPolicy: OnFailure
+      template:
+        metadata:
+          labels:
+            mount-kerberos-secret: "true"
+            mount-eos: "true"
+            mount-nvidia-driver: "true"
+          annotations:
+            sidecar.istio.io/inject: "false"
+        spec:
+          affinity:
+              nodeAffinity:
+                requiredDuringSchedulingIgnoredDuringExecution:
+                  nodeSelectorTerms:
+                  - matchExpressions:
+                    - key: magnum.openstack.org/nodegroup
+                      operator: In
+                      values:
+                      - v100
+                      - v100s
+          containers: 
+            - name: pytorch
+              image: gitlab-registry.cern.ch/eneren/pytorchjob:SingleGen
+              imagePullPolicy: Always
+              env:
+              - name: PYTHONUNBUFFERED
+                value: "1"
+              command: [sh, -c] 
+              args:  
+                - python -u wganSingleGen.py --backend nccl --epochs 50 --exp wganSingleGenV1 --batch-size 64 --ncrit 4
+              resources: 
+                limits:
+                  nvidia.com/gpu: 1
diff --git a/wganSingleGen.py b/wganSingleGen.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ece4d0cce35c0e75c722d2c632008b2b2aa8e47
--- /dev/null
+++ b/wganSingleGen.py
@@ -0,0 +1,344 @@
+from comet_ml import Experiment
+import argparse
+import os, sys
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.optim as optim
+from torch import autograd
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import DataLoader
+from torch.autograd import Variable
+
+from API_keys import api_key
+
+torch.autograd.set_detect_anomaly(True)
+
+os.environ['MKL_THREADING_LAYER'] = 'GNU'
+
+sys.path.append('/opt/regressor/src')
+
+
+
+#from models.combined_generator import Combined_Generator
+from models.combined_generator_split import Combined_Generator
+from models.global_disc import Global_Discriminator
+from models.data_loaderFull import HDF5Dataset
+
+
+def calc_gradient_penalty_ECAL_HCAL(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real_label, BATCH_SIZE, device, layer, layer_hcal, xsize, ysize):
+    
+    alphaE = torch.rand(BATCH_SIZE, 1)
+    alphaE = alphaE.expand(BATCH_SIZE, int(real_ecal.nelement()/BATCH_SIZE)).contiguous()
+    alphaE = alphaE.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    alphaE = alphaE.to(device)
+
+
+    alphaH = torch.rand(BATCH_SIZE, 1)
+    alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous()
+    alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    alphaH = alphaH.to(device)
+
+    fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize)
+    fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    
+    interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach())
+    interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach())
+
+
+    interpolatesHCAL = interpolatesHCAL.to(device)
+    interpolatesHCAL.requires_grad_(True)   
+
+    interpolatesECAL = interpolatesECAL.to(device)
+    interpolatesECAL.requires_grad_(True)   
+
+    disc_interpolates = netD(interpolatesECAL.float(), interpolatesHCAL.float(), real_label.float())
+
+    gradients = autograd.grad(outputs=disc_interpolates, inputs=[interpolatesECAL, interpolatesHCAL],
+                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+                              create_graph=True, retain_graph=True, only_inputs=True)[0]
+
+    gradients = gradients.view(gradients.size(0), -1)                              
+    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+    return gradient_penalty
+
+def train(args, netD_E_H, netG_E_H, device, train_loader, optimizer_D_E_H, optimizer_G_E_H, epoch, experiment):
+    Tensor = torch.cuda.FloatTensor
+    
+    for batch_idx, (dataE, dataH, energy) in enumerate(train_loader):
+        ## GLOBAL CRITIC TRAINING
+        netD_E_H.train()
+        netG_E_H.eval()
+        
+        # zero out critic gradients
+        optimizer_D_E_H.zero_grad()
+        
+        ## Get Real data
+        real_dataECAL = dataE.to(device).unsqueeze(1).float()
+        real_dataHCAL = dataH.to(device).unsqueeze(1).float()
+        label = energy.to(device).float()
+        
+        ## Generate fake ECAL and HCAL
+        z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False)
+        fake_ecal, fake_hcal = netG_E_H(z, label.view(-1, 1, 1, 1, 1))
+        fake_ecal = fake_ecal.detach()
+        fake_hcal = fake_hcal.detach()
+        
+        ## Global Critic forward pass on Real
+        disc_real_E_H = netD_E_H(real_dataECAL, real_dataHCAL, label)
+        
+        ## Calculate Gradient Penalty Term
+        gradient_penalty_E_H = calc_gradient_penalty_ECAL_HCAL(netD_E_H, real_dataECAL, real_dataHCAL, fake_ecal, fake_hcal, label, args.batch_size, device, layer=30, layer_hcal=48, xsize=30, ysize=30)
+        
+        ## Global Critic forward pass on fake data
+        disc_fake_E_H = netD_E_H(fake_ecal, fake_hcal, label)
+        
+        ## wasserstein-1 distace for critic
+        w_dist_E_H = torch.mean(disc_fake_E_H) - torch.mean(disc_real_E_H)
+        
+        ## final disc cost
+        disc_cost_E_H = w_dist_E_H + args.lambd * gradient_penalty_E_H
+        
+        disc_cost_E_H.backward()
+        optimizer_D_E_H.step()
+        
+        ## GENERATOR TRAINING
+        ## training generator every ncrit
+        if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
+            netD_E_H.eval()
+            netG_E_H.train()
+            
+            # zero out generator gradients
+            optimizer_G_E_H.zero_grad()
+
+            ## Generate fake ECAL and HCAL
+            z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True)
+            fake_ecal, fake_hcal = netG_E_H(z, label.view(-1, 1, 1, 1, 1))
+
+
+            ## Loss function for ECAL generator
+            gen_E_H_cost = netD_E_H(fake_ecal, fake_hcal, label)
+            g_E_H_cost = -torch.mean(gen_E_H_cost)
+            g_E_H_cost.backward()
+            optimizer_G_E_H. step()
+
+            if batch_idx % args.log_interval == 0 :
+                    print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f}'.format(
+                    epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                    100. * batch_idx / len(train_loader), g_E_H_cost.item()))
+
+                    niter = epoch * len(train_loader) + batch_idx
+                    experiment.log_metric("L_Gen_E_H", g_E_H_cost, step=niter)
+
+                    print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
+                    epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                    100. * batch_idx / len(train_loader), disc_cost_E_H.item()))
+                    experiment.log_metric("L_crit_H_E", disc_cost_E_H, step=niter)
+                    experiment.log_metric("gradient_pen_E_H", gradient_penalty_E_H, step=niter)
+                    experiment.log_metric("Wasserstein Dist E H", w_dist_E_H, step=niter)
+                    experiment.log_metric("Critic Score E H(Real)", torch.mean(disc_real_E_H), step=niter)
+                    experiment.log_metric("Critic Score E H (Fake)", torch.mean(disc_fake_E_H), step=niter)
+
+            
+def parse_args():
+    parser = argparse.ArgumentParser(description='WGAN training on hadron showers')
+    parser.add_argument('--batch-size', type=int, default=64*4, metavar='N',
+                        help='input batch size for training (default: 64*4)')
+    
+    parser.add_argument('--nz', type=int, default=100, metavar='N',
+                        help='latent space for generator (default: 100)')
+    
+    parser.add_argument('--lambd', type=int, default=15, metavar='N',
+                        help='weight of gradient penalty  (default: 15)')  
+
+    parser.add_argument('--ndf', type=int, default=128, metavar='N',
+                        help='n-feature of critic (default: 32)')
+
+    parser.add_argument('--ngf', type=int, default=8, metavar='N',  # 8
+                        help='n-feature of generator  (default: 8)')
+
+    parser.add_argument('--ncrit', type=int, default=10, metavar='N',
+                        help='critic updates before generator one  (default: 10)')
+
+    parser.add_argument('--epochs', type=int, default=100, metavar='N',
+                        help='number of epochs to train (default: 100)')
+    
+    parser.add_argument('--nworkers', type=int, default=1, metavar='N',
+                        help='number of epochs to train (default: 1)')
+
+    parser.add_argument('--lrCrit_E_H', type=float, default=0.00001, metavar='LR',
+                        help='learning rate Critic_H_E (default: 0.00001)')
+
+    parser.add_argument('--chpt_base', type=str, default='/eos/user/e/eneren/experiments/',
+                        help='continue training from a saved model')
+
+    parser.add_argument('--lrGen_E_H', type=float, default=0.0001, metavar='LR',
+                        help='learning rate Generator_H_E (default: 0.0001)')
+  
+    parser.add_argument('--chpt', action='store_true', default=False,
+                        help='continue training from a saved model')
+ 
+    parser.add_argument('--exp', type=str, default='dist_wgan',
+                        help='name of the experiment')
+
+    parser.add_argument('--chpt_eph', type=int, default=1,
+                        help='continue checkpoint epoch')
+
+    parser.add_argument('--no-cuda', action='store_true', default=False,
+                        help='disables CUDA training')
+    parser.add_argument('--seed', type=int, default=1, metavar='S',
+                        help='random seed (default: 1)')
+    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
+                        help='how many batches to wait before logging training status')
+
+    if dist.is_available():
+        parser.add_argument('--backend', type=str, help='Distributed backend',
+                            choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
+                            default=dist.Backend.GLOO)
+    
+    parser.add_argument('--local_rank', type=int, default=0)
+
+    args = parser.parse_args()
+
+
+    args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank))
+    args.rank = int(os.environ.get('RANK'))
+    args.world_size = int(os.environ.get('WORLD_SIZE'))
+
+
+    # postprocess args
+    args.device = 'cuda:{}'.format(args.local_rank)  # PytorchJob/launch.py
+    args.batch_size = max(args.batch_size,
+                          args.world_size * 2)  # min valid batchsize
+    return args
+    
+
+def run(args):
+    # Training settings
+    
+    use_cuda = not args.no_cuda and torch.cuda.is_available()
+    if use_cuda:
+        print('Using CUDA')
+
+    # set up COMET ML experiment for logging
+    experiment = Experiment(api_key=api_key,
+                        project_name="ecal-hcal-shower", workspace="engineren", auto_output_logging="simple")
+    experiment.add_tag(args.exp)
+
+    experiment.log_parameters(
+        {
+        "batch_size" : args.batch_size,
+        "latent": args.nz,
+        "lambda": args.lambd,
+        "ncrit" : args.ncrit,
+        "ngf": args.ngf,
+        "ndf": args.ndf
+        }
+    )
+
+    torch.manual_seed(args.seed)
+
+    device = torch.device("cuda" if use_cuda else "cpu")
+
+    if args.world_size > 1:
+        print('Using distributed PyTorch with {} backend'.format(args.backend))
+        dist.init_process_group(backend=args.backend)
+    
+    print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size))
+
+
+    print ("loading data")
+
+    dataset = HDF5Dataset('/eos/user/e/eneren/scratch/50GeV75k.hdf5', transform=None, train_size=75000)
+
+
+
+    if args.world_size > 1:
+        sampler = DistributedSampler(dataset, shuffle=True)    
+        train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False)
+    else:
+        train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False)
+
+
+    ## Global critic
+    Crit_E_H = Global_Discriminator(args.ndf)
+    
+    ## Combined generator
+    Gen_E_H = Combined_Generator(args.ngf, args.nz)
+    
+        
+    Crit_E_H = Crit_E_H.to(device)
+    Gen_E_H = Gen_E_H.to(device)
+    
+    print('Critic trainable params:', sum(p.numel() for p in Crit_E_H.parameters() if p.requires_grad))
+    print('Generator trainable params:', sum(p.numel() for p in Gen_E_H.parameters() if p.requires_grad))
+    
+     if args.world_size > 1: 
+        Distributor = nn.parallel.DistributedDataParallel if use_cuda \
+            else nn.parallel.DistributedDataParallelCPU
+        Crit_E_H = Distributor(Crit_E_H, device_ids=[args.local_rank], output_device=args.local_rank )
+        Gen_E_H = Distributor(Gen_E_H, device_ids=[args.local_rank], output_device=args.local_rank )
+       
+    else:
+        Crit_E_H = nn.parallel.DataParallel(Crit_E_H)
+        Gen_E_H = nn.parallel.DataParallel(Gen_E_H)
+        
+    
+    # instantiate opimizers
+    optimizerG_E_H = optim.Adam(Gen_E_H.parameters(), lr=args.lrGen_E_H, betas=(0.5, 0.9))
+    optimizerD_E_H = optim.Adam(Crit_E_H.parameters(), lr=args.lrCrit_E_H, betas=(0.5, 0.9))
+    
+    
+    
+    # load from checkpoint if desired
+    if (args.chpt):
+        critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticE_"+ str(args.chpt_eph) + ".pt")
+        gen_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorE_H_"+ str(args.chpt_eph) + ".pt")
+
+        Crit_E_H.load_state_dict(critic_E_H_checkpoint['model_state_dict'])
+        optimizerD_E_H.load_state_dict(critic_E_H_checkpoint['optimizer_state_dict'])
+
+        Gen_E_H.load_state_dict(gen_E_H_checkpoint['model_state_dict'])
+        optimizerG_E_H.load_state_dict(gen_H_E_checkpoint['optimizer_state_dict'])
+
+        eph = gen_E_H_checkpoint['epoch']
+
+    else:
+        eph = 0
+        print ("init models")
+        
+    experiment.set_model_graph(str(Crit_E_H), overwrite=False)
+    experiment.set_model_graph(str(Gen_E_H), overwrite=False)
+    
+    print('starting training...')
+    
+    for epoch in range(1, args.epochs + 1):
+        epoch += eph
+        
+        train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment)
+        
+        # saving to checkpoints
+        g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt"
+        c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt"
+        
+        torch.save({
+            'epoch': epoch,
+            'model_state_dict': Gen_E_H.state_dict(),
+            'optimizer_state_dict': optimizerG_E_H.state_dict()
+            }, g_E_H_path)
+
+        torch.save({
+            'epoch': epoch,
+            'model_state_dict': Crit_E_H.state_dict(),
+            'optimizer_state_dict': optimizerD_E_H.state_dict()
+            }, c_E_H_path)
+
+        print('end training')
+        
+def main():
+    args = parse_args()
+    run(args)
+    
+if __name__ == '__main__':
+    main()
\ No newline at end of file