diff --git a/pytorch_job_wganHCAL_single.yaml b/pytorch_job_wganHCAL_single.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a940079a918a86f859543a7260e2cf741988a49c
--- /dev/null
+++ b/pytorch_job_wganHCAL_single.yaml
@@ -0,0 +1,51 @@
+apiVersion: "kubeflow.org/v1"
+kind: "PyTorchJob"
+metadata:
+  name: "pytorch-dist-wganHCAL-nccl"
+spec:
+  pytorchReplicaSpecs:
+    Master:
+      replicas: 1
+      restartPolicy: OnFailure
+      template:
+        metadata:
+          annotations:
+            sidecar.istio.io/inject: "false"
+        spec:
+          volumes:
+          - name: eos
+            hostPath:
+              path: /var/eos
+          - name: krb-secret-vol
+            secret:
+              secretName: krb-secret
+          - name: nvidia-driver
+            hostPath:
+              path: /opt/nvidia-driver
+              type: ""
+          containers:
+            - name: pytorch
+              image: gitlab-registry.cern.ch/eneren/pytorchjob:ddp
+              imagePullPolicy: Always
+              env:
+              - name: PATH
+                value: /opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/nvidia-driver/bin
+              - name: LD_LIBRARY_PATH
+                value: /opt/nvidia-driver/lib64
+              - name: PYTHONUNBUFFERED
+                value: "1"
+              command: [sh, -c] 
+              args:  
+                - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 
+                  && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001;
+              volumeMounts:
+              - name: eos
+                mountPath: /eos
+              - name: krb-secret-vol
+                mountPath: "/secret/krb-secret-vol"
+              - name: nvidia-driver
+                mountPath: /opt/nvidia-driver
+              resources: 
+                limits:
+                  nvidia.com/gpu: 1
+   
diff --git a/wganHCAL.py b/wganHCAL.py
index 822e2866667219aa133b806f78c06f7ff9d8506b..18e0bf135f95d334f0a3e457b867790a2f139b75 100644
--- a/wganHCAL.py
+++ b/wganHCAL.py
@@ -241,10 +241,10 @@ def run(args):
     device = torch.device("cuda" if use_cuda else "cpu")
     
 
-
-    print('Using distributed PyTorch with {} backend'.format(args.backend))
-    dist.init_process_group(backend=args.backend)
-
+    if args.world_size > 1:
+        print('Using distributed PyTorch with {} backend'.format(args.backend))
+        dist.init_process_group(backend=args.backend)
+    
     print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size))
 
 
@@ -255,8 +255,11 @@ def run(args):
     dataset = HDF5Dataset('/eos/user/e/eneren/scratch/4060GeV.hdf5', transform=None, train_size=60000)
 
 
-    sampler = DistributedSampler(dataset, shuffle=True)    
-    train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False)
+    if args.world_size > 1:
+        sampler = DistributedSampler(dataset, shuffle=True)    
+        train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False)
+    else:
+        train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False)