diff --git a/Dockerfile b/Dockerfile
index dfcf4f974a21addec49afeae7422f16336c21739..23e5c4abe6f43a16aece4479d92e140577db9d42 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -15,8 +15,10 @@ RUN mkdir -p /opt/regressor && \
     && pip install h5py pyflakes comet_ml && export MKL_SERVICE_FORCE_INTEL=1
 
 WORKDIR /opt/regressor/src
-ADD regressor.py /opt/regressor/src/regressor.py
+
 ADD wgan.py /opt/regressor/src/wgan.py
+ADD wganHCAL.py /opt/regressor/src/wganHCAL.py
+
 COPY ./models/* /opt/regressor/src/models/
 COPY docker/krb5.conf /etc/krb5.conf
 
diff --git a/models/criticFull.py b/models/criticFull.py
new file mode 100644
index 0000000000000000000000000000000000000000..16333c33a13a65beeda8b45ec3c5cfbdd1108150
--- /dev/null
+++ b/models/criticFull.py
@@ -0,0 +1,82 @@
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.nn.functional as F
+
+
+class CriticEMB(nn.Module):
+    def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64):
+        super(CriticEMB, self).__init__()    
+        self.ndf = ndf
+        self.isize_1 = isize_1
+        self.isize_2 = isize_2
+        self.nc = nc
+        self.size_embed = 16
+        self.conv1_bias = False
+
+      
+        
+        # ECAL component of convolutions
+        # Designed for input 30*30*30
+        self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False)
+        self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29])
+        self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)
+        self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14])
+        self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
+        # output [batch_size, ndf, 7, 7, 7]
+        
+        # HCAL component of convolutions
+        # Designed for input 48*30*30
+        self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=(5,0,0), bias=False)
+        self.ln_HCAL_1 = torch.nn.LayerNorm([29,29,29])
+        self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)
+        self.ln_HCAL_2 = torch.nn.LayerNorm([14,14,14])
+        self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
+        # output [batch_size, ndf, 7, 7, 7]
+        
+        # alternative structure for 48*25*25 HCAL
+        #self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=0, bias=False)
+        #self.ln_HCAL_1 = torch.nn.LayerNorm([24,24,24])
+        #self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)
+        #self.ln_HCAL_2 = torch.nn.LayerNorm([12,12,12])
+        #self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
+        
+
+        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) 
+        self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, 64) 
+        
+        self.econd_lin = torch.nn.Linear(1, 64) # label embedding
+
+        self.fc1 = torch.nn.Linear(64*3, 128)  # 3 components after cat
+        self.fc2 = torch.nn.Linear(128,  64)
+        self.fc3 = torch.nn.Linear(64, 1)
+
+
+    def forward(self, img_ECAL, img_HCAL, E_true):
+        batch_size = img_ECAL.size(0)
+        # input: img_ECAL = [batch_size, 1, 30, 30, 30]
+        #        img_HCAL = [batch_size, 1, 48, 30, 30] 
+        
+        # ECAL 
+        x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2)
+        x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2)
+        x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2)
+        x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7)
+        x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2)
+        
+        # HCAL
+        x_HCAL = F.leaky_relu(self.ln_HCAL_1(self.conv_HCAL_1(img_HCAL)), 0.2)
+        x_HCAL = F.leaky_relu(self.ln_HCAL_2(self.conv_HCAL_2(x_HCAL)), 0.2)
+        x_HCAL = F.leaky_relu(self.conv_HCAL_3(x_HCAL), 0.2)
+        x_HCAL = x_HCAL.view(-1, self.ndf*7*7*7)
+        x_HCAL = F.leaky_relu(self.conv_lin_HCAL(x_HCAL), 0.2)        
+        
+        x_E = F.leaky_relu(self.econd_lin(E_true), 0.2)
+        
+        xa = torch.cat((x_ECAL, x_HCAL, x_E), 1)
+        
+        xa = F.leaky_relu(self.fc1(xa), 0.2)
+        xa = F.leaky_relu(self.fc2(xa), 0.2)
+        xa = self.fc3(xa)
+        
+        return xa ### flattens
\ No newline at end of file
diff --git a/models/data_loaderFull.py b/models/data_loaderFull.py
index 17f9d2f01e00df7ebec2a923843ff86103a528c8..e157bf1043c18b1196c523f847bda4f4f40c8044 100644
--- a/models/data_loaderFull.py
+++ b/models/data_loaderFull.py
@@ -1,7 +1,6 @@
 
 import numpy as np
 import torch
-from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
 import os
 import h5py
 from torch.utils import data
diff --git a/models/generatorFull.py b/models/generatorFull.py
new file mode 100644
index 0000000000000000000000000000000000000000..029b553f80091ce3edc29d105e183b44da07f409
--- /dev/null
+++ b/models/generatorFull.py
@@ -0,0 +1,116 @@
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.nn.functional as F
+
+
+
+class Hcal_ecalEMB(nn.Module):
+    """ 
+        generator component of WGAN
+    """
+    def __init__(self, ngf, ndf, nz, emb_size):
+        super(Hcal_ecalEMB, self).__init__()
+        
+       
+        self.ndf = ndf
+        self.emb_size = emb_size
+        # ECAL component of convolutions
+        # Designed for input 30*30*30
+        
+        self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False)
+        self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29])
+        self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)
+        self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14])
+        self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)
+        
+        self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) 
+        
+        self.econd_lin = torch.nn.Linear(1, 64) # label embedding
+
+        self.fc1 = torch.nn.Linear(64*2, 128)  # 2 components after cat
+        self.fc2 = torch.nn.Linear(128,  64)
+        self.fc3 = torch.nn.Linear(64, emb_size)
+        
+        
+        ## HCAL component of convolutions
+        self.ngf = ngf
+        self.nz = nz
+        
+        kernel = 4
+        
+
+        self.conv1 = nn.ConvTranspose3d(emb_size + nz + 1, ngf, kernel, 1, 0, bias=False)
+        ##torch.Size([100, 64, 4, 4, 4])
+    
+        
+        # outs from first convolutions concatenate state size [ ngf*8 x 4 x 4]
+        # and going into main convolutional part of Generator
+        self.main_conv = nn.Sequential(
+            
+            nn.ConvTranspose3d(ngf, ngf*4, kernel_size=(4,2,2), stride=2, padding=1, bias=False),
+            nn.LayerNorm([8, 6, 6]),
+            nn.ReLU(),
+            # state shape [ (ndf*4) x 6 x 6 ]
+
+            nn.ConvTranspose3d(ngf*4, ngf*2, kernel_size=(4,2,2), stride=2, padding=1, bias=False),
+            nn.LayerNorm([16, 10, 10]),
+            nn.ReLU(),
+            # state shape [ (ndf*2) x 10 x 10 ]
+
+            nn.ConvTranspose3d(ngf*2, ngf, kernel_size=(4,4,4), stride=(2,1,1), padding=1, bias=False),
+            nn.LayerNorm([32, 11, 11]),
+            nn.ReLU(),
+            # state shape [ (ndf) x 11 x 11 ]
+
+            nn.ConvTranspose3d(ngf, ngf, kernel_size=(10,4,4), stride=1, padding=1, bias=False),
+            nn.LayerNorm([39, 12, 12]),
+            nn.ReLU(),
+            # state shape [ ch=10 x 12 x 12 ]
+           
+            nn.ConvTranspose3d(ngf, 5, kernel_size=(8,3,3), stride=(1,2,2), padding=1, bias=False),
+            nn.LayerNorm([44, 23, 23]),
+            nn.ReLU(),
+            
+            # state shape [ ch=5 x 23 x 23  ]
+            
+            nn.ConvTranspose3d(5, 1, kernel_size=(7,10,10), stride=1, padding=1, bias=False),
+            nn.ReLU()
+            
+            ## final output ---> [48 x 25 x 25]
+        )
+
+    def forward(self, noise, energy, img_ECAL):
+        
+        
+        batch_size = img_ECAL.size(0)
+        # input: img_ECAL = [batch_size, 1, 30, 30, 30]
+        #        
+        
+        # ECAL 
+        x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2)
+        x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2)
+        x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2)
+        x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7)
+        x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2)
+        
+        x_E = F.leaky_relu(self.econd_lin(energy), 0.2)
+        
+        xa = torch.cat((x_ECAL, x_E), 1)
+        
+        xa = F.leaky_relu(self.fc1(xa), 0.2)
+        xa = F.leaky_relu(self.fc2(xa), 0.2)
+        xa = self.fc3(xa)
+        
+        
+        xzm = torch.cat((xa, energy, noise), 1) 
+        xzm = xzm.view(xzm.size(0), xzm.size(1), 1, 1, 1)
+            
+    
+        inpt = self.conv1(xzm)
+             
+        x = self.main_conv(inpt)
+        
+        x = x.view(-1, 48, 30, 30)
+        
+        return x
diff --git a/pytorch_job_wganHCAL_nccl.yaml b/pytorch_job_wganHCAL_nccl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fcdf0407205c813895c65c8207c729516fb1932b
--- /dev/null
+++ b/pytorch_job_wganHCAL_nccl.yaml
@@ -0,0 +1,94 @@
+apiVersion: "kubeflow.org/v1"
+kind: "PyTorchJob"
+metadata:
+  name: "pytorch-dist-wganHCAL-nccl"
+spec:
+  pytorchReplicaSpecs:
+    Master:
+      replicas: 1
+      restartPolicy: OnFailure
+      template:
+        metadata:
+          annotations:
+            sidecar.istio.io/inject: "false"
+        spec:
+          volumes:
+          - name: eos
+            hostPath:
+              path: /var/eos
+          - name: krb-secret-vol
+            secret:
+              secretName: krb-secret
+          - name: nvidia-driver
+            hostPath:
+              path: /opt/nvidia-driver
+              type: ""
+          containers:
+            - name: pytorch
+              image: gitlab-registry.cern.ch/eneren/pytorchjob:ddp
+              imagePullPolicy: Always
+              env:
+              - name: PATH
+                value: /opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/nvidia-driver/bin
+              - name: LD_LIBRARY_PATH
+                value: /opt/nvidia-driver/lib64
+              - name: PYTHONUNBUFFERED
+                value: "1"
+              command: [sh, -c] 
+              args:  
+                - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 
+                  && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001;
+              volumeMounts:
+              - name: eos
+                mountPath: /eos
+              - name: krb-secret-vol
+                mountPath: "/secret/krb-secret-vol"
+              - name: nvidia-driver
+                mountPath: /opt/nvidia-driver
+              resources: 
+                limits:
+                  nvidia.com/gpu: 1
+    Worker:
+      replicas: 1
+      restartPolicy: OnFailure
+      template:
+        metadata:
+          annotations:
+            sidecar.istio.io/inject: "false"
+        spec:
+          volumes:
+          - name: eos
+            hostPath:
+              path: /var/eos
+          - name: krb-secret-vol
+            secret:
+              secretName: krb-secret
+          - name: nvidia-driver
+            hostPath:
+              path: /opt/nvidia-driver
+              type: ""  
+          containers: 
+            - name: pytorch
+              image: gitlab-registry.cern.ch/eneren/pytorchjob:ddp
+              imagePullPolicy: Always
+              env:
+              - name: PATH
+                value: /opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/nvidia-driver/bin
+              - name: LD_LIBRARY_PATH
+                value: /opt/nvidia-driver/lib64
+              - name: PYTHONUNBUFFERED
+                value: "1"
+              command: [sh, -c] 
+              args:  
+                - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 
+                  && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001;
+              volumeMounts:
+              - name: eos
+                mountPath: /eos
+              - name: krb-secret-vol
+                mountPath: "/secret/krb-secret-vol"
+              - name: nvidia-driver
+                mountPath: /opt/nvidia-driver
+              resources: 
+                limits:
+                  nvidia.com/gpu: 1
diff --git a/wganHCAL.py b/wganHCAL.py
index f3046bb49dbe36434173560c7c519b5e2701f611..0988c4c1271e793203b198a8adb3e94ba9604c02 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -12,3 +12,307 @@ from torch.utils.data.distributed import DistributedSampler
 from torch.utils.data import DataLoader
 from torch.autograd import Variable
 
+os.environ['MKL_THREADING_LAYER'] = 'GNU'
+
+torch.autograd.set_detect_anomaly(True)
+
+sys.path.append('/opt/regressor/src')
+
+from models.generatorFull import Hcal_ecalEMB
+from models.data_loaderFull import HDF5Dataset
+from models.criticFull import CriticEMB
+
+def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, device, layer, xsize, ysize):
+    
+    alpha = torch.rand(BATCH_SIZE, 1)
+    alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous()
+    alpha = alpha.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    alpha = alpha.to(device)
+
+
+    fake_data = fake_data.view(BATCH_SIZE, 1, layer, xsize, ysize)
+    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
+
+    interpolates = interpolates.to(device)
+    interpolates.requires_grad_(True)   
+
+    disc_interpolates = netD(interpolates.float(), real_label.float())
+
+
+    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
+                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+                              create_graph=True, retain_graph=True, only_inputs=True)[0]
+
+    gradients = gradients.view(gradients.size(0), -1)                              
+    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
+    return gradient_penalty
+
+
+def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, experiment):
+    
+    ### CRITIC TRAINING
+    aD.train()
+    aG.eval()
+
+    Tensor = torch.cuda.FloatTensor 
+   
+    for batch_idx, (dataE, dataH, energy) in enumerate(train_loader):
+        real_dataECAL = dataE.to(device).unsqueeze(1)
+        real_dataHCAL = dataH.to(device).unsqueeze(1)
+        real_label = energy.to(device)
+        
+        optimizer_d.zero_grad()
+        
+        z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz))))
+
+        ## Generate Fake data
+        fake_dataHCAL = aG(z, real_label, real_dataECAL).detach() ## 48 x 30 x 30        
+
+        ## Critic fwd pass on Real
+        disc_real = aD(real_dataECAL.float(), real_dataHCAL.float(), real_label.float()) 
+
+        ## Calculate Gradient Penalty Term
+        gradient_penalty = calc_gradient_penalty(aD, real_dataHCAL.float(), fake_dataHCAL, real_label, args.batch_size, device, layer=48, xsize=30, ysize=30)
+
+        ## Critic fwd pass on Fake 
+        disc_fake = aD(real_dataECAL, fake_dataHCAL.unsqueeze(1), real_label)
+        
+
+        ## wasserstein-1 distace
+        w_dist = torch.mean(disc_fake) - torch.mean(disc_real)
+        # final disc cost
+        disc_cost = w_dist + args.lambd * gradient_penalty
+
+        disc_cost.backward()
+        optimizer_d.step()
+
+        if batch_idx % args.log_interval == 0:
+            print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
+                epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), disc_cost.item()))
+            niter = epoch * len(train_loader) + batch_idx
+            experiment.log_metric("L_crit", disc_cost, step=niter)
+            experiment.log_metric("gradient_pen", gradient_penalty, step=niter)
+            experiment.log_metric("Wasserstein Dist", w_dist, step=niter)
+            experiment.log_metric("Critic Score (Real)", torch.mean(disc_real), step=niter)
+            experiment.log_metric("Critic Score (Fake)", torch.mean(disc_fake), step=niter)
+
+        
+        ## training generator per ncrit 
+        if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
+            ## GENERATOR TRAINING
+            aD.eval()
+            aG.train()
+            
+            #print("Generator training started")
+
+            optimizer_g.zero_grad()
+    
+
+            ## generate fake data out of noise
+            fake_dataHCALG = aG(z, real_label, real_dataECAL)
+             
+            
+            ## Total loss function for generator
+            gen_cost = aD(real_dataECAL, fake_dataHCALG.unsqueeze(1), real_label)
+            g_cost = -torch.mean(gen_cost) 
+            g_cost.backward()
+            optimizer_g.step()
+
+            if batch_idx % args.log_interval == 0 :
+                print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
+                    epoch, batch_idx * len(dataH), len(train_loader.dataset),
+                    100. * batch_idx / len(train_loader), g_cost.item()))
+                niter = epoch * len(train_loader) + batch_idx
+                experiment.log_metric("L_Gen", g_cost, step=niter)
+
+
+def is_distributed():
+    return dist.is_available() and dist.is_initialized()
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='WGAN training on hadron showers')
+    parser.add_argument('--batch-size', type=int, default=100, metavar='N',
+                        help='input batch size for training (default: 100)')
+    
+    parser.add_argument('--nz', type=int, default=100, metavar='N',
+                        help='latent space for generator (default: 100)')
+    
+    parser.add_argument('--lambd', type=int, default=15, metavar='N',
+                        help='weight of gradient penalty  (default: 15)')
+
+    parser.add_argument('--kappa', type=float, default=0.001, metavar='N',
+                        help='weight of label conditioning  (default: 0.001)')
+
+    parser.add_argument('--ndf', type=int, default=64, metavar='N',
+                        help='n-feature of critic (default: 64)')
+
+    parser.add_argument('--ngf', type=int, default=32, metavar='N',
+                        help='n-feature of generator  (default: 32)')
+
+    parser.add_argument('--ncrit', type=int, default=10, metavar='N',
+                        help='critic updates before generator one  (default: 10)')
+
+    parser.add_argument('--epochs', type=int, default=1, metavar='N',
+                        help='number of epochs to train (default: 1)')
+    parser.add_argument('--lrCrit', type=float, default=0.00001, metavar='LR',
+                        help='learning rate Critic (default: 0.00001)')
+    parser.add_argument('--lrGen', type=float, default=0.0001, metavar='LR',
+                        help='learning rate Generator (default: 0.0001)')
+
+    parser.add_argument('--chpt', action='store_true', default=False,
+                        help='continue training from a saved model')
+
+    parser.add_argument('--chpt_base', type=str, default='/eos/user/e/eneren/experiments/',
+                        help='continue training from a saved model')
+
+    parser.add_argument('--exp', type=str, default='dist_wgan',
+                        help='name of the experiment')
+
+    parser.add_argument('--chpt_eph', type=int, default=1,
+                        help='continue checkpoint epoch')
+
+    parser.add_argument('--no-cuda', action='store_true', default=False,
+                        help='disables CUDA training')
+    parser.add_argument('--seed', type=int, default=1, metavar='S',
+                        help='random seed (default: 1)')
+    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
+                        help='how many batches to wait before logging training status')
+   
+
+    if dist.is_available():
+        parser.add_argument('--backend', type=str, help='Distributed backend',
+                            choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
+                            default=dist.Backend.GLOO)
+    
+    parser.add_argument('--local_rank', type=int, default=0)
+
+    args = parser.parse_args()
+
+
+    args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank))
+    args.rank = int(os.environ.get('RANK'))
+    args.world_size = int(os.environ.get('WORLD_SIZE'))
+
+
+    # postprocess args
+    args.device = f'cuda:{args.local_rank}'  # PytorchJob/launch.py
+    args.batch_size = max(args.batch_size,
+                          args.world_size * 2)  # min valid batchsize
+    return args
+
+
+
+def run(args):
+    # Training settings
+
+    use_cuda = not args.no_cuda and torch.cuda.is_available()
+    if use_cuda:
+        print('Using CUDA')
+
+    
+    experiment = Experiment(api_key="keGmeIz4GfKlQZlOP6cit4QOi",
+                        project_name="ecal-hcal-shower", workspace="engineren", auto_output_logging="simple")
+    experiment.add_tag(args.exp)
+
+    experiment.log_parameters(
+        {
+        "batch_size" : args.batch_size,
+        "latent": args.nz,
+        "lambda": args.lambd,
+        "ncrit" : args.ncrit,
+        "resN": args.ndf,
+        "ngf": args.ngf
+        }
+    )
+
+    torch.manual_seed(args.seed)
+
+    device = torch.device("cuda" if use_cuda else "cpu")
+    
+
+
+    print('Using distributed PyTorch with {} backend'.format(args.backend))
+    dist.init_process_group(backend=args.backend)
+
+    print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size))
+
+
+
+    print ("loading data")
+    #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/40GeV40k.hdf5', transform=None, train_size=40000)
+    #dataset = HDF5Dataset('/eos/user/e/eneren/scratch/60GeV20k.hdf5', transform=None, train_size=20000)
+    dataset = HDF5Dataset('/eos/user/e/eneren/scratch/4060GeV.hdf5', transform=None, train_size=60000)
+
+
+    sampler = DistributedSampler(dataset, shuffle=True)    
+    train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=1, drop_last=True, pin_memory=False)
+
+
+
+    mCrit = CriticEMB().to(device)
+    mGen = Hcal_ecalEMB(args.ngf, 32, args.nz, emb_size=32).to(device)
+    
+
+
+    if args.world_size > 1: 
+        Distributor = nn.parallel.DistributedDataParallel if use_cuda \
+            else nn.parallel.DistributedDataParallelCPU
+        mCrit = Distributor(mCrit, device_ids=[args.local_rank], output_device=args.local_rank )
+        mGen = Distributor(mGen, device_ids=[args.local_rank], output_device=args.local_rank)
+
+    
+    optimizerG = optim.Adam(mGen.parameters(), lr=args.lrGen, betas=(0.5, 0.9))
+    optimizerD = optim.Adam(mCrit.parameters(), lr=args.lrCrit, betas=(0.5, 0.9))
+
+    if (args.chpt):
+        critic_checkpoint = torch.load(args.chpt_base + args.exp + "_critic_"+ str(args.chpt_eph) + ".pt")
+        gen_checkpoint = torch.load(args.chpt_base + args.exp + "_generator_"+ str(args.chpt_eph) + ".pt")
+        
+        mGen.load_state_dict(gen_checkpoint['model_state_dict'])
+        optimizerG.load_state_dict(gen_checkpoint['optimizer_state_dict'])
+
+        mCrit.load_state_dict(critic_checkpoint['model_state_dict'])
+        optimizerD.load_state_dict(critic_checkpoint['optimizer_state_dict'])
+        
+        eph = gen_checkpoint['epoch']
+    
+    else: 
+        eph = 0
+        print ("init models")
+
+    
+    experiment.set_model_graph(str(mGen), overwrite=False)
+
+    print ("starting training...")
+    for epoch in range(1, args.epochs + 1):
+        epoch += eph
+        train_loader.sampler.set_epoch(epoch)
+        train(args, mCrit, mGen, device, train_loader, optimizerD, optimizerG, epoch, experiment)
+        if args.rank == 0:
+            gPATH = args.chpt_base + args.exp + "_generator_"+ str(epoch) + ".pt"
+            cPATH = args.chpt_base + args.exp + "_critic_"+ str(epoch) + ".pt"
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': mGen.state_dict(),
+                'optimizer_state_dict': optimizerG.state_dict()
+                }, gPATH)
+            
+            torch.save({
+                'epoch': epoch,
+                'model_state_dict': mCrit.state_dict(),
+                'optimizer_state_dict': optimizerD.state_dict()
+                }, cPATH)
+
+
+    print ("end training")
+        
+
+
+def main():
+    args = parse_args()
+    run(args)
+            
+if __name__ == '__main__':
+    main()
\ No newline at end of file