diff --git a/pytorch_job_wganHCAL_single.yaml b/pytorch_job_wganHCAL_single.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a940079a918a86f859543a7260e2cf741988a49c --- /dev/null +++ b/pytorch_job_wganHCAL_single.yaml @@ -0,0 +1,51 @@ +apiVersion: "kubeflow.org/v1" +kind: "PyTorchJob" +metadata: + name: "pytorch-dist-wganHCAL-nccl" +spec: + pytorchReplicaSpecs: + Master: + replicas: 1 + restartPolicy: OnFailure + template: + metadata: + annotations: + sidecar.istio.io/inject: "false" + spec: + volumes: + - name: eos + hostPath: + path: /var/eos + - name: krb-secret-vol + secret: + secretName: krb-secret + - name: nvidia-driver + hostPath: + path: /opt/nvidia-driver + type: "" + containers: + - name: pytorch + image: gitlab-registry.cern.ch/eneren/pytorchjob:ddp + imagePullPolicy: Always + env: + - name: PATH + value: /opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/nvidia-driver/bin + - name: LD_LIBRARY_PATH + value: /opt/nvidia-driver/lib64 + - name: PYTHONUNBUFFERED + value: "1" + command: [sh, -c] + args: + - cp /secret/krb-secret-vol/krb5cc_1000 /tmp/krb5cc_0 && chmod 600 /tmp/krb5cc_0 + && python -u wganHCAL.py --backend nccl --epochs 50 --exp wganHCALv1 --lrCrit 0.0001 --lrGen 0.00001; + volumeMounts: + - name: eos + mountPath: /eos + - name: krb-secret-vol + mountPath: "/secret/krb-secret-vol" + - name: nvidia-driver + mountPath: /opt/nvidia-driver + resources: + limits: + nvidia.com/gpu: 1 + diff --git a/wganHCAL.py b/wganHCAL.py index 822e2866667219aa133b806f78c06f7ff9d8506b..18e0bf135f95d334f0a3e457b867790a2f139b75 100644 --- a/wganHCAL.py +++ b/wganHCAL.py @@ -241,10 +241,10 @@ def run(args): device = torch.device("cuda" if use_cuda else "cpu") - - print('Using distributed PyTorch with {} backend'.format(args.backend)) - dist.init_process_group(backend=args.backend) - + if args.world_size > 1: + print('Using distributed PyTorch with {} backend'.format(args.backend)) + dist.init_process_group(backend=args.backend) + print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size)) @@ -255,8 +255,11 @@ def run(args): dataset = HDF5Dataset('/eos/user/e/eneren/scratch/4060GeV.hdf5', transform=None, train_size=60000) - sampler = DistributedSampler(dataset, shuffle=True) - train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False) + if args.world_size > 1: + sampler = DistributedSampler(dataset, shuffle=True) + train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False) + else: + train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False)