diff --git a/.gitignore b/.gitignore index 0fd481a5848990b0d291057d41bf247214afbd7a..d60834e3613e5f32dfb89da72b6ef61b9e8f629a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ models/.ipynb_checkpoints models/.ipynb_checkpoints/* interactive/__pycache__/ models/__pycache__/ +__pycache__/ +API_keys.py interactive/jsd/ interactive/plots/ interactive/*.png diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 0bf3e8f3bda6a7774980fc6093459f23ba37ff46..001cc8d20ae1965ba62c4f1730e4211688fa1497 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -3,7 +3,7 @@ build_kaniko_command: variables: # To push to a specific docker tag other than latest(the default), amend the --destination parameter, e.g. --destination $CI_REGISTRY_IMAGE:$CI_BUILD_REF_NAME # See https://docs.gitlab.com/ee/ci/variables/predefined_variables.html#variables-reference for available variables - IMAGE_DESTINATION: ${CI_REGISTRY_IMAGE}:Peter_image + IMAGE_DESTINATION: ${CI_REGISTRY_IMAGE}:SingleGen image: # We recommend using the CERN version of the Kaniko image: gitlab-registry.cern.ch/ci-tools/docker-image-builder name: gitlab-registry.cern.ch/ci-tools/docker-image-builder @@ -16,9 +16,3 @@ build_kaniko_command: # Print the full registry path of the pushed image - echo "Image pushed successfully to ${IMAGE_DESTINATION}" -check_python_run: - stage: test - image: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime - script: - - pip install pyflakes - - pyflakes $CI_PROJECT_DIR/wgan_ECAL_HCAL_3crit.py diff --git a/Dockerfile b/Dockerfile index f31165b449627cd83e471d7ad02b8a25accc2f9a..5279f8e093f67438bff4f268a338b157b586a146 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,10 +24,12 @@ RUN mkdir -p ${HOME}/models \ WORKDIR ${HOME} +ADD wganSingleGen.py ${HOME}/wganSingleGen.py ADD wgan.py ${HOME}/wgan.py ADD wganHCAL.py ${HOME}/wganHCAL.py ADD wgan_ECAL_HCAL_3crit.py ${HOME}/wgan_ECAL_HCAL_3crit.py + COPY ./models/* ${HOME}/models/ COPY docker/krb5.conf /etc/krb5.conf diff --git a/gitlab-ci.yml b/gitlab-ci.yml deleted file mode 100644 index 613873bbfc02280b1ec2284b2af070063998e48c..0000000000000000000000000000000000000000 --- a/gitlab-ci.yml +++ /dev/null @@ -1,18 +0,0 @@ -build_kaniko_command: - stage: build - variables: - # To push to a specific docker tag other than latest(the default), amend the --destination parameter, e.g. --destination $CI_REGISTRY_IMAGE:$CI_BUILD_REF_NAME - # See https://docs.gitlab.com/ee/ci/variables/predefined_variables.html#variables-reference for available variables - IMAGE_DESTINATION: ${CI_REGISTRY_IMAGE}:latest - image: - # We recommend using the CERN version of the Kaniko image: gitlab-registry.cern.ch/ci-tools/docker-image-builder - name: gitlab-registry.cern.ch/ci-tools/docker-image-builder - entrypoint: [""] - script: - # Prepare Kaniko configuration file - - echo "{\"auths\":{\"$CI_REGISTRY\":{\"username\":\"$CI_REGISTRY_USER\",\"password\":\"$CI_REGISTRY_PASSWORD\"}}}" > /kaniko/.docker/config.json - # Build and push the image from the Dockerfile at the root of the project. - - /kaniko/executor --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $IMAGE_DESTINATION - # Print the full registry path of the pushed image - - echo "Image pushed successfully to ${IMAGE_DESTINATION}" - diff --git a/interactive/model_testing.ipynb b/interactive/model_testing.ipynb index a6219d4d1c6fd826b193c27ffbf009d531da3d91..448893be473c088cbc5ba5fca0033f1462aa5dc4 100644 --- a/interactive/model_testing.ipynb +++ b/interactive/model_testing.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -23,75 +23,112 @@ "from torch.utils import data\n", "\n", " \n", - "class DCGAN_G(nn.Module):\n", - " \"\"\" \n", - " generator component of WGAN\n", + "class Combined_Generator(nn.Module):\n", + " \"\"\"\n", + " combined generator for WGAN that generates both ECAL and HCAL components of shower, with a branching structure\n", + " \n", " \"\"\"\n", + " \n", " def __init__(self, ngf, nz):\n", - " super(DCGAN_G, self).__init__()\n", + " super(Combined_Generator, self).__init__()\n", " \n", " self.ngf = ngf\n", " self.nz = nz\n", - "\n", - " kernel = 4\n", " \n", - " # input energy shape [batch x 1 x 1 x 1 ] going into convolutional\n", - " self.conv1_1 = nn.ConvTranspose3d(1, ngf*4, kernel, 1, 0, bias=False)\n", - " # state size [ ngf*4 x 4 x 4 x 4]\n", + " # input: cat(z, E)\n", " \n", - " # input noise shape [batch x nz x 1 x 1] going into convolutional\n", - " self.conv1_100 = nn.ConvTranspose3d(nz, ngf*4, kernel, 1, 0, bias=False)\n", - " # state size [ ngf*4 x 4 x 4 x 4]\n", + " # fully connected layers for label embedding\n", + " self.cond1 = torch.nn.Linear(self.nz+1, int(self.nz*1.5), bias=True) # +1 for both the energy label\n", + " self.cond2 = torch.nn.Linear(int(self.nz*1.5), ngf*200, bias=True)\n", + " self.cond3 = torch.nn.Linear(ngf*200, 500*ngf, bias=True)\n", + " self.cond4 = torch.nn.Linear(ngf*500, 250*ngf, bias=True)\n", " \n", + " ####### ECAL Branch #############\n", + " # ECAL transpose convolutions to desired shape\n", + " # input [5, 5, 5]\n", + " self.ECALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,5,5), stride=(2,2,2), padding=0, bias=False)\n", + " self.ECALlndc1 = torch.nn.LayerNorm([13,13,13])\n", + " self.ECALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(4,4,4), stride=(2,2,2), padding=0, bias=False)\n", + " self.ECALlndc2 = torch.nn.LayerNorm([28,28,28])\n", + " self.ECALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(3,3,3), stride=(1,1,1), padding=0, bias=False)\n", + " self.ECALlndc3 = torch.nn.LayerNorm([30,30,30])\n", + " \n", + " # ECAL conv layers\n", + " self.ECALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n", + " self.ECALlnc1 = torch.nn.LayerNorm([30,30,30])\n", + " self.ECALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n", + " self.ECALlnc2 = torch.nn.LayerNorm([30,30,30])\n", + " self.ECALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) \n", + " self.ECALlnc3 = torch.nn.LayerNorm([30,30,30])\n", + " self.ECALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n", + " \n", + " ####### HCAL Branch ############\n", + " # HCAL transpose convolutions to deisred shape\n", + " # input [5, 5, 5]\n", + " self.HCALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(6,5,5), stride=(2,2,2), padding=0, bias=False)\n", + " self.HCALlndc1 = torch.nn.LayerNorm([14,13,13])\n", + " self.HCALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,4,4), stride=(3,2,2), padding=0, bias=False)\n", + " self.HCALlndc2 = torch.nn.LayerNorm([44,28,28])\n", + " self.HCALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,3,3), stride=(1,1,1), padding=0, bias=False)\n", + " self.HCALlndc3 = torch.nn.LayerNorm([48,30,30])\n", + " \n", + " # HCAL conv layers\n", + " self.HCALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n", + " self.HCALlnc1 = torch.nn.LayerNorm([48,30,30])\n", + " self.HCALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n", + " self.HCALlnc2 = torch.nn.LayerNorm([48,30,30])\n", + " self.HCALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) \n", + " self.HCALlnc3 = torch.nn.LayerNorm([48,30,30])\n", + " self.HCALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False)\n", " \n", - " # outs from first convolutions concatenate state size [ ngf*8 x 4 x 4]\n", - " # and going into main convolutional part of Generator\n", - " self.main_conv = nn.Sequential(\n", - " \n", - " nn.ConvTranspose3d(ngf*8, ngf*4, kernel, 2, 1, bias=False),\n", - " nn.LayerNorm([8, 8, 8]),\n", - " nn.ReLU(),\n", - " # state shape [ (ndf*4) x 8 x 8 ]\n", - "\n", - " nn.ConvTranspose3d(ngf*4, ngf*2, kernel, 2, 1, bias=False),\n", - " nn.LayerNorm([16, 16, 16]),\n", - " nn.ReLU(),\n", - " # state shape [ (ndf*2) x 16 x 16 ]\n", - "\n", - " nn.ConvTranspose3d(ngf*2, ngf, kernel, 2, 1, bias=False),\n", - " nn.LayerNorm([32, 32, 32]),\n", - " nn.ReLU(),\n", - " # state shape [ (ndf) x 32 x 32 ]\n", - "\n", - " nn.ConvTranspose3d(ngf, 1, 3, 1, 2, bias=False),\n", - " nn.ReLU()\n", - " # state shape [ 1 x 30 x 30 x 30 ]\n", - " )\n", - "\n", " def forward(self, noise, energy):\n", - " energy_trans = self.conv1_1(energy)\n", - " noise_trans = self.conv1_100(noise)\n", - " input = torch.cat((energy_trans, noise_trans), 1)\n", - " x = self.main_conv(input)\n", - " x = x.view(-1, 30, 30, 30)\n", - " return x\n", - "\n", - "\n", - "def weights_init(m):\n", - " classname = m.__class__.__name__\n", - " if classname.find('Conv') != -1:\n", - " nn.init.normal_(m.weight.data, 0.0, 0.02)\n", - " elif classname.find('LayerNorm') != -1:\n", - " nn.init.normal_(m.weight.data, 1.0, 0.02)\n", - " nn.init.constant_(m.bias.data, 0)\n", - " elif classname.find('Linear') != -1:\n", - " m.weight.data.normal_(0.0, 0.02)\n", - " m.bias.data.fill_(0)" + " z = torch.cat((noise, energy), 1)\n", + " z = z.view(-1, self.nz+1)\n", + " \n", + " # label embedding\n", + " x = F.leaky_relu(self.cond1(z), 0.2, inplace=True)\n", + " x = F.leaky_relu(self.cond2(x), 0.2, inplace=True)\n", + " x = F.leaky_relu(self.cond3(x), 0.2, inplace=True)\n", + " x = F.leaky_relu(self.cond4(x), 0.2, inplace=True)\n", + " \n", + " # split x to spread it between ECAL and HCAL generator branches\n", + " x_ECAL, x_HCAL = torch.tensor_split(x, 2)\n", + " x_ECAL = x_ECAL.view(-1, self.ngf, 5, 5, 5)\n", + " x_HCAL = x_HCAL.view(-1, self.ngf, 5, 5, 5)\n", + " \n", + " ########## ECAL Branch #########\n", + " # ECAL deconvolutions up to desired shape\n", + " x_ECAL = F.leaky_relu(self.ECALlndc1(self.ECALdeconv1(x_ECAL)), 0.2, inplace=True)\n", + " x_ECAL = F.leaky_relu(self.ECALlndc2(self.ECALdeconv2(x_ECAL)), 0.2, inplace=True)\n", + " x_ECAL = F.leaky_relu(self.ECALlndc3(self.ECALdeconv3(x_ECAL)), 0.2, inplace=True)\n", + " \n", + " # ECAL convolutions\n", + " x_ECAL = F.leaky_relu(self.ECALlnc1(self.ECALconv1(x_ECAL)), 0.2, inplace=True)\n", + " x_ECAL = F.leaky_relu(self.ECALlnc2(self.ECALconv2(x_ECAL)), 0.2, inplace=True)\n", + " x_ECAL = F.leaky_relu(self.ECALlnc3(self.ECALconv3(x_ECAL)), 0.2, inplace=True)\n", + " x_ECAL = F.relu(self.ECALconv4(x_ECAL), inplace=True)\n", + " \n", + " ######### HCAL Branch ##########\n", + " # HCAL deconvolutions up to desired shape\n", + " x_HCAL = F.leaky_relu(self.HCALlndc1(self.HCALdeconv1(x_HCAL)), 0.2, inplace=True)\n", + " x_HCAL = F.leaky_relu(self.HCALlndc2(self.HCALdeconv2(x_HCAL)), 0.2, inplace=True)\n", + " x_HCAL = F.leaky_relu(self.HCALlndc3(self.HCALdeconv3(x_HCAL)), 0.2, inplace=True)\n", + " \n", + " # HCAL convolutions\n", + " x_HCAL = F.leaky_relu(self.HCALlnc1(self.HCALconv1(x_HCAL)), 0.2, inplace=True)\n", + " x_HCAL = F.leaky_relu(self.HCALlnc2(self.HCALconv2(x_HCAL)), 0.2, inplace=True)\n", + " x_HCAL = F.leaky_relu(self.HCALlnc3(self.HCALconv3(x_HCAL)), 0.2, inplace=True)\n", + " x_HCAL = F.relu(self.HCALconv4(x_HCAL), inplace=True)\n", + " \n", + " #x_ECAL = x_ECAL.view(-1, 30, 30, 30)\n", + " #x_HCAL = x_HCAL.view(-1, 48, 30, 30)\n", + " \n", + " return x_ECAL, x_HCAL" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -117,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -128,33 +165,54 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "aG = DCGAN_G(16, LATENT).to(device)\n", + "aG = Combined_Generator(8, LATENT).to(device)\n", "\n", "with torch.no_grad():\n", " noisev = noise # totally freeze G, training D\n", - "\n", - "fake_data = aG(noisev, real_label).detach()\n", " " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10, 1, 30, 30, 30])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "fake_data.shape" + "fake_ecal, fake_hcal = aG(noisev, real_label)\n", + "fake_ecal.shape" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "15309618" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "sum(p.numel() for p in aG.parameters() if p.requires_grad)\n", "#aG" @@ -263,12 +321,95 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "class Global_Discriminator(nn.Module):\n", + " def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64):\n", + " super(Global_Discriminator, self).__init__() \n", + " self.ndf = ndf\n", + " self.isize_1 = isize_1\n", + " self.isize_2 = isize_2\n", + " self.nc = nc\n", + " self.size_embed = 16\n", + " self.conv1_bias = False\n", + "\n", + " \n", + " \n", + " # ECAL component of convolutions\n", + " # Designed for input 30*30*30\n", + " self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False)\n", + " self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29])\n", + " self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)\n", + " self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14])\n", + " self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)\n", + " # output [batch_size, ndf, 7, 7, 7]\n", + " \n", + " # HCAL component of convolutions\n", + " # Designed for input 48*30*30\n", + " self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=(5,0,0), bias=False)\n", + " self.ln_HCAL_1 = torch.nn.LayerNorm([29,29,29])\n", + " self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)\n", + " self.ln_HCAL_2 = torch.nn.LayerNorm([14,14,14])\n", + " self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)\n", + " # output [batch_size, ndf, 7, 7, 7]\n", + " \n", + " # alternative structure for 48*25*25 HCAL\n", + " #self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=0, bias=False)\n", + " #self.ln_HCAL_1 = torch.nn.LayerNorm([24,24,24])\n", + " #self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False)\n", + " #self.ln_HCAL_2 = torch.nn.LayerNorm([12,12,12])\n", + " #self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False)\n", + " \n", + "\n", + " self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) \n", + " self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, 64) \n", + " \n", + " self.econd_lin = torch.nn.Linear(1, 64) # label embedding\n", + "\n", + " self.fc1 = torch.nn.Linear(64*3, 128) # 3 components after cat\n", + " self.fc2 = torch.nn.Linear(128, 64)\n", + " self.fc3 = torch.nn.Linear(64, 1)\n", + "\n", + "\n", + " def forward(self, img_ECAL, img_HCAL, E_true):\n", + " batch_size = img_ECAL.size(0)\n", + " # input: img_ECAL = [batch_size, 1, 30, 30, 30]\n", + " # img_HCAL = [batch_size, 1, 48, 30, 30] \n", + " \n", + " # ECAL \n", + " x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2)\n", + " x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2)\n", + " x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2)\n", + " x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7)\n", + " x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2)\n", + " \n", + " # HCAL\n", + " x_HCAL = F.leaky_relu(self.ln_HCAL_1(self.conv_HCAL_1(img_HCAL)), 0.2)\n", + " x_HCAL = F.leaky_relu(self.ln_HCAL_2(self.conv_HCAL_2(x_HCAL)), 0.2)\n", + " x_HCAL = F.leaky_relu(self.conv_HCAL_3(x_HCAL), 0.2)\n", + " x_HCAL = x_HCAL.view(-1, self.ndf*7*7*7)\n", + " x_HCAL = F.leaky_relu(self.conv_lin_HCAL(x_HCAL), 0.2) \n", + " \n", + " x_E = F.leaky_relu(self.econd_lin(E_true), 0.2)\n", + " \n", + " xa = torch.cat((x_ECAL, x_HCAL, x_E), 1)\n", + " \n", + " xa = F.leaky_relu(self.fc1(xa), 0.2)\n", + " xa = F.leaky_relu(self.fc2(xa), 0.2)\n", + " xa = self.fc3(xa)\n", + " \n", + " return xa ### flattens" + ] + }, + { + "cell_type": "code", + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "from criticRes import *\n", - "#import criticRes" + "aD = Global_Discriminator(ndf=128)" ] }, { @@ -277,7 +418,7 @@ "metadata": {}, "outputs": [], "source": [ - "aD = generate_model(34).to(device)\n", + "\n", "for batch_idx, (data, energy) in enumerate(train_loader):\n", " #print (data.shape, energy.shape)\n", " data = data.to(device).unsqueeze(1) \n", @@ -290,9 +431,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "8122869" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "sum(p.numel() for p in aD.parameters() if p.requires_grad)\n" ] @@ -408,7 +560,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/models/combined_generator_split.py b/models/combined_generator_split.py new file mode 100644 index 0000000000000000000000000000000000000000..2acac208656627fe461fbce61544e54c71aea2bf --- /dev/null +++ b/models/combined_generator_split.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.nn.functional as F +from torch.autograd import Variable +from torch.utils import data + +class Combined_Generator(nn.Module): + """ + combined generator for WGAN that generates both ECAL and HCAL components of shower, with a branching structure + + """ + + def __init__(self, ngf, nz): + super(Combined_Generator, self).__init__() + + self.ngf = ngf + self.nz = nz + + # input: cat(z, E) + + # fully connected layers for label embedding + self.cond1 = torch.nn.Linear(self.nz+1, int(self.nz*1.5), bias=True) # +1 for both the energy label + self.cond2 = torch.nn.Linear(int(self.nz*1.5), ngf*200, bias=True) + self.cond3 = torch.nn.Linear(ngf*200, 500*ngf, bias=True) + self.cond4 = torch.nn.Linear(ngf*500, 250*ngf, bias=True) + + ####### ECAL Branch ############# + # ECAL transpose convolutions to desired shape + # input [5, 5, 5] + self.ECALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,5,5), stride=(2,2,2), padding=0, bias=False) + self.ECALlndc1 = torch.nn.LayerNorm([13,13,13]) + self.ECALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(4,4,4), stride=(2,2,2), padding=0, bias=False) + self.ECALlndc2 = torch.nn.LayerNorm([28,28,28]) + self.ECALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(3,3,3), stride=(1,1,1), padding=0, bias=False) + self.ECALlndc3 = torch.nn.LayerNorm([30,30,30]) + + # ECAL conv layers + self.ECALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + self.ECALlnc1 = torch.nn.LayerNorm([30,30,30]) + self.ECALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + self.ECALlnc2 = torch.nn.LayerNorm([30,30,30]) + self.ECALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + self.ECALlnc3 = torch.nn.LayerNorm([30,30,30]) + self.ECALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + + ####### HCAL Branch ############ + # HCAL transpose convolutions to deisred shape + # input [5, 5, 5] + self.HCALdeconv1 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(6,5,5), stride=(2,2,2), padding=0, bias=False) + self.HCALlndc1 = torch.nn.LayerNorm([14,13,13]) + self.HCALdeconv2 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,4,4), stride=(3,2,2), padding=0, bias=False) + self.HCALlndc2 = torch.nn.LayerNorm([44,28,28]) + self.HCALdeconv3 = torch.nn.ConvTranspose3d(ngf, ngf, kernel_size=(5,3,3), stride=(1,1,1), padding=0, bias=False) + self.HCALlndc3 = torch.nn.LayerNorm([48,30,30]) + + # HCAL conv layers + self.HCALconv1 = torch.nn.Conv3d(ngf, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + self.HCALlnc1 = torch.nn.LayerNorm([48,30,30]) + self.HCALconv2 = torch.nn.Conv3d(ngf*2, ngf*4, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + self.HCALlnc2 = torch.nn.LayerNorm([48,30,30]) + self.HCALconv3 = torch.nn.Conv3d(ngf*4, ngf*2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + self.HCALlnc3 = torch.nn.LayerNorm([48,30,30]) + self.HCALconv4 = torch.nn.Conv3d(ngf*2, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1), bias=False) + + def forward(self, noise, energy): + z = torch.cat((noise, energy), 1) + z = z.view(-1, self.nz+1) + + # label embedding + x = F.leaky_relu(self.cond1(z), 0.2, inplace=True) + x = F.leaky_relu(self.cond2(x), 0.2, inplace=True) + x = F.leaky_relu(self.cond3(x), 0.2, inplace=True) + x = F.leaky_relu(self.cond4(x), 0.2, inplace=True) + + # split x to spread it between ECAL and HCAL generator branches + x_ECAL, x_HCAL = torch.tensor_split(x, 2) + x_ECAL = x_ECAL.view(-1, self.ngf, 5, 5, 5) + x_HCAL = x_HCAL.view(-1, self.ngf, 5, 5, 5) + + ########## ECAL Branch ######### + # ECAL deconvolutions up to desired shape + x_ECAL = F.leaky_relu(self.ECALlndc1(self.ECALdeconv1(x_ECAL)), 0.2, inplace=True) + x_ECAL = F.leaky_relu(self.ECALlndc2(self.ECALdeconv2(x_ECAL)), 0.2, inplace=True) + x_ECAL = F.leaky_relu(self.ECALlndc3(self.ECALdeconv3(x_ECAL)), 0.2, inplace=True) + + # ECAL convolutions + x_ECAL = F.leaky_relu(self.ECALlnc1(self.ECALconv1(x_ECAL)), 0.2, inplace=True) + x_ECAL = F.leaky_relu(self.ECALlnc2(self.ECALconv2(x_ECAL)), 0.2, inplace=True) + x_ECAL = F.leaky_relu(self.ECALlnc3(self.ECALconv3(x_ECAL)), 0.2, inplace=True) + x_ECAL = F.relu(self.ECALconv4(x_ECAL), inplace=True) + + ######### HCAL Branch ########## + # HCAL deconvolutions up to desired shape + x_HCAL = F.leaky_relu(self.HCALlndc1(self.HCALdeconv1(x_HCAL)), 0.2, inplace=True) + x_HCAL = F.leaky_relu(self.HCALlndc2(self.HCALdeconv2(x_HCAL)), 0.2, inplace=True) + x_HCAL = F.leaky_relu(self.HCALlndc3(self.HCALdeconv3(x_HCAL)), 0.2, inplace=True) + + # HCAL convolutions + x_HCAL = F.leaky_relu(self.HCALlnc1(self.HCALconv1(x_HCAL)), 0.2, inplace=True) + x_HCAL = F.leaky_relu(self.HCALlnc2(self.HCALconv2(x_HCAL)), 0.2, inplace=True) + x_HCAL = F.leaky_relu(self.HCALlnc3(self.HCALconv3(x_HCAL)), 0.2, inplace=True) + x_HCAL = F.relu(self.HCALconv4(x_HCAL), inplace=True) + + #x_ECAL = x_ECAL.view(-1, 30, 30, 30) + #x_HCAL = x_HCAL.view(-1, 48, 30, 30) + + return x_ECAL, x_HCAL + \ No newline at end of file diff --git a/models/global_disc.py b/models/global_disc.py new file mode 100644 index 0000000000000000000000000000000000000000..7f370f5d57dd594cb2e8f900fe018ef7942807d0 --- /dev/null +++ b/models/global_disc.py @@ -0,0 +1,86 @@ +import numpy as np +import argparse +import torch +import torch.utils.data +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image +from torch import autograd + +class Global_Discriminator(nn.Module): + def __init__(self, isize_1=30, isize_2=48, nc=2, ndf=64): + super(Global_Discriminator, self).__init__() + self.ndf = ndf + self.isize_1 = isize_1 + self.isize_2 = isize_2 + self.nc = nc + self.size_embed = 16 + self.conv1_bias = False + + + + # ECAL component of convolutions + # Designed for input 30*30*30 + self.conv_ECAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=(2,2,2), stride=(1,1,1), padding=0, bias=False) + self.ln_ECAL_1 = torch.nn.LayerNorm([29,29,29]) + self.conv_ECAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False) + self.ln_ECAL_2 = torch.nn.LayerNorm([14,14,14]) + self.conv_ECAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) + # output [batch_size, ndf, 7, 7, 7] + + # HCAL component of convolutions + # Designed for input 48*30*30 + self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=(5,0,0), bias=False) + self.ln_HCAL_1 = torch.nn.LayerNorm([29,29,29]) + self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False) + self.ln_HCAL_2 = torch.nn.LayerNorm([14,14,14]) + self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) + # output [batch_size, ndf, 7, 7, 7] + + # alternative structure for 48*25*25 HCAL + #self.conv_HCAL_1 = torch.nn.Conv3d(1, ndf, kernel_size=2, stride=(2,1,1), padding=0, bias=False) + #self.ln_HCAL_1 = torch.nn.LayerNorm([24,24,24]) + #self.conv_HCAL_2 = torch.nn.Conv3d(ndf, ndf, kernel_size=2, stride=(2,2,2), padding=0, bias=False) + #self.ln_HCAL_2 = torch.nn.LayerNorm([12,12,12]) + #self.conv_HCAL_3 = torch.nn.Conv3d(ndf, ndf, kernel_size=4, stride=(2,2,2), padding=(1,1,1), bias=False) + + + self.conv_lin_ECAL = torch.nn.Linear(7*7*7*ndf, 64) + self.conv_lin_HCAL = torch.nn.Linear(7*7*7*ndf, 64) + + self.econd_lin = torch.nn.Linear(1, 64) # label embedding + + self.fc1 = torch.nn.Linear(64*3, 128) # 3 components after cat + self.fc2 = torch.nn.Linear(128, 64) + self.fc3 = torch.nn.Linear(64, 1) + + + def forward(self, img_ECAL, img_HCAL, E_true): + batch_size = img_ECAL.size(0) + # input: img_ECAL = [batch_size, 1, 30, 30, 30] + # img_HCAL = [batch_size, 1, 48, 30, 30] + + # ECAL + x_ECAL = F.leaky_relu(self.ln_ECAL_1(self.conv_ECAL_1(img_ECAL)), 0.2) + x_ECAL = F.leaky_relu(self.ln_ECAL_2(self.conv_ECAL_2(x_ECAL)), 0.2) + x_ECAL = F.leaky_relu(self.conv_ECAL_3(x_ECAL), 0.2) + x_ECAL = x_ECAL.view(-1, self.ndf*7*7*7) + x_ECAL = F.leaky_relu(self.conv_lin_ECAL(x_ECAL), 0.2) + + # HCAL + x_HCAL = F.leaky_relu(self.ln_HCAL_1(self.conv_HCAL_1(img_HCAL)), 0.2) + x_HCAL = F.leaky_relu(self.ln_HCAL_2(self.conv_HCAL_2(x_HCAL)), 0.2) + x_HCAL = F.leaky_relu(self.conv_HCAL_3(x_HCAL), 0.2) + x_HCAL = x_HCAL.view(-1, self.ndf*7*7*7) + x_HCAL = F.leaky_relu(self.conv_lin_HCAL(x_HCAL), 0.2) + + x_E = F.leaky_relu(self.econd_lin(E_true), 0.2) + + xa = torch.cat((x_ECAL, x_HCAL, x_E), 1) + + xa = F.leaky_relu(self.fc1(xa), 0.2) + xa = F.leaky_relu(self.fc2(xa), 0.2) + xa = self.fc3(xa) + + return xa ### flattens \ No newline at end of file diff --git a/pytorch_job_wganSingleGen_ncc.yaml b/pytorch_job_wganSingleGen_ncc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b980cec96d582d2027f729e672cf0caaded7cfa3 --- /dev/null +++ b/pytorch_job_wganSingleGen_ncc.yaml @@ -0,0 +1,76 @@ +apiVersion: "kubeflow.org/v1" +kind: "PyTorchJob" +metadata: + name: "pytorch-dist-wganhcal-nccl" +spec: + pytorchReplicaSpecs: + Master: + replicas: 1 + restartPolicy: OnFailure + template: + metadata: + labels: + mount-kerberos-secret: "true" + mount-eos: "true" + mount-nvidia-driver: "true" + annotations: + sidecar.istio.io/inject: "false" + spec: + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: magnum.openstack.org/nodegroup + operator: In + values: + - v100 + - v100s + containers: + - name: pytorch + image: gitlab-registry.cern.ch/eneren/pytorchjob:SingleGen + imagePullPolicy: Always + env: + - name: PYTHONUNBUFFERED + value: "1" + command: [sh, -c] + args: + - python -u wganSingleGen.py --backend nccl --epochs 50 --exp wganSingleGenV1 --batch-size 64 --ncrit 4 + resources: + limits: + nvidia.com/gpu: 1 + Worker: + replicas: 4 + restartPolicy: OnFailure + template: + metadata: + labels: + mount-kerberos-secret: "true" + mount-eos: "true" + mount-nvidia-driver: "true" + annotations: + sidecar.istio.io/inject: "false" + spec: + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: magnum.openstack.org/nodegroup + operator: In + values: + - v100 + - v100s + containers: + - name: pytorch + image: gitlab-registry.cern.ch/eneren/pytorchjob:SingleGen + imagePullPolicy: Always + env: + - name: PYTHONUNBUFFERED + value: "1" + command: [sh, -c] + args: + - python -u wganSingleGen.py --backend nccl --epochs 50 --exp wganSingleGenV1 --batch-size 64 --ncrit 4 + resources: + limits: + nvidia.com/gpu: 1 diff --git a/wganSingleGen.py b/wganSingleGen.py new file mode 100644 index 0000000000000000000000000000000000000000..1ece4d0cce35c0e75c722d2c632008b2b2aa8e47 --- /dev/null +++ b/wganSingleGen.py @@ -0,0 +1,344 @@ +from comet_ml import Experiment +import argparse +import os, sys +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +from torch import autograd +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +from torch.autograd import Variable + +from API_keys import api_key + +torch.autograd.set_detect_anomaly(True) + +os.environ['MKL_THREADING_LAYER'] = 'GNU' + +sys.path.append('/opt/regressor/src') + + + +#from models.combined_generator import Combined_Generator +from models.combined_generator_split import Combined_Generator +from models.global_disc import Global_Discriminator +from models.data_loaderFull import HDF5Dataset + + +def calc_gradient_penalty_ECAL_HCAL(netD, real_ecal, real_hcal, fake_ecal, fake_hcal, real_label, BATCH_SIZE, device, layer, layer_hcal, xsize, ysize): + + alphaE = torch.rand(BATCH_SIZE, 1) + alphaE = alphaE.expand(BATCH_SIZE, int(real_ecal.nelement()/BATCH_SIZE)).contiguous() + alphaE = alphaE.view(BATCH_SIZE, 1, layer, xsize, ysize) + alphaE = alphaE.to(device) + + + alphaH = torch.rand(BATCH_SIZE, 1) + alphaH = alphaH.expand(BATCH_SIZE, int(real_hcal.nelement()/BATCH_SIZE)).contiguous() + alphaH = alphaH.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + alphaH = alphaH.to(device) + + fake_hcal = fake_hcal.view(BATCH_SIZE, 1, layer_hcal, xsize, ysize) + fake_ecal = fake_ecal.view(BATCH_SIZE, 1, layer, xsize, ysize) + + interpolatesHCAL = alphaH * real_hcal.detach() + ((1 - alphaH) * fake_hcal.detach()) + interpolatesECAL = alphaE * real_ecal.detach() + ((1 - alphaE) * fake_ecal.detach()) + + + interpolatesHCAL = interpolatesHCAL.to(device) + interpolatesHCAL.requires_grad_(True) + + interpolatesECAL = interpolatesECAL.to(device) + interpolatesECAL.requires_grad_(True) + + disc_interpolates = netD(interpolatesECAL.float(), interpolatesHCAL.float(), real_label.float()) + + gradients = autograd.grad(outputs=disc_interpolates, inputs=[interpolatesECAL, interpolatesHCAL], + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + return gradient_penalty + +def train(args, netD_E_H, netG_E_H, device, train_loader, optimizer_D_E_H, optimizer_G_E_H, epoch, experiment): + Tensor = torch.cuda.FloatTensor + + for batch_idx, (dataE, dataH, energy) in enumerate(train_loader): + ## GLOBAL CRITIC TRAINING + netD_E_H.train() + netG_E_H.eval() + + # zero out critic gradients + optimizer_D_E_H.zero_grad() + + ## Get Real data + real_dataECAL = dataE.to(device).unsqueeze(1).float() + real_dataHCAL = dataH.to(device).unsqueeze(1).float() + label = energy.to(device).float() + + ## Generate fake ECAL and HCAL + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=False) + fake_ecal, fake_hcal = netG_E_H(z, label.view(-1, 1, 1, 1, 1)) + fake_ecal = fake_ecal.detach() + fake_hcal = fake_hcal.detach() + + ## Global Critic forward pass on Real + disc_real_E_H = netD_E_H(real_dataECAL, real_dataHCAL, label) + + ## Calculate Gradient Penalty Term + gradient_penalty_E_H = calc_gradient_penalty_ECAL_HCAL(netD_E_H, real_dataECAL, real_dataHCAL, fake_ecal, fake_hcal, label, args.batch_size, device, layer=30, layer_hcal=48, xsize=30, ysize=30) + + ## Global Critic forward pass on fake data + disc_fake_E_H = netD_E_H(fake_ecal, fake_hcal, label) + + ## wasserstein-1 distace for critic + w_dist_E_H = torch.mean(disc_fake_E_H) - torch.mean(disc_real_E_H) + + ## final disc cost + disc_cost_E_H = w_dist_E_H + args.lambd * gradient_penalty_E_H + + disc_cost_E_H.backward() + optimizer_D_E_H.step() + + ## GENERATOR TRAINING + ## training generator every ncrit + if (batch_idx % args.ncrit == 0) and (batch_idx != 0): + netD_E_H.eval() + netG_E_H.train() + + # zero out generator gradients + optimizer_G_E_H.zero_grad() + + ## Generate fake ECAL and HCAL + z = Variable(Tensor(np.random.uniform(-1, 1, (args.batch_size, args.nz, 1, 1, 1))), requires_grad=True) + fake_ecal, fake_hcal = netG_E_H(z, label.view(-1, 1, 1, 1, 1)) + + + ## Loss function for ECAL generator + gen_E_H_cost = netD_E_H(fake_ecal, fake_hcal, label) + g_E_H_cost = -torch.mean(gen_E_H_cost) + g_E_H_cost.backward() + optimizer_G_E_H. step() + + if batch_idx % args.log_interval == 0 : + print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), g_E_H_cost.item())) + + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_Gen_E_H", g_E_H_cost, step=niter) + + print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), disc_cost_E_H.item())) + experiment.log_metric("L_crit_H_E", disc_cost_E_H, step=niter) + experiment.log_metric("gradient_pen_E_H", gradient_penalty_E_H, step=niter) + experiment.log_metric("Wasserstein Dist E H", w_dist_E_H, step=niter) + experiment.log_metric("Critic Score E H(Real)", torch.mean(disc_real_E_H), step=niter) + experiment.log_metric("Critic Score E H (Fake)", torch.mean(disc_fake_E_H), step=niter) + + +def parse_args(): + parser = argparse.ArgumentParser(description='WGAN training on hadron showers') + parser.add_argument('--batch-size', type=int, default=64*4, metavar='N', + help='input batch size for training (default: 64*4)') + + parser.add_argument('--nz', type=int, default=100, metavar='N', + help='latent space for generator (default: 100)') + + parser.add_argument('--lambd', type=int, default=15, metavar='N', + help='weight of gradient penalty (default: 15)') + + parser.add_argument('--ndf', type=int, default=128, metavar='N', + help='n-feature of critic (default: 32)') + + parser.add_argument('--ngf', type=int, default=8, metavar='N', # 8 + help='n-feature of generator (default: 8)') + + parser.add_argument('--ncrit', type=int, default=10, metavar='N', + help='critic updates before generator one (default: 10)') + + parser.add_argument('--epochs', type=int, default=100, metavar='N', + help='number of epochs to train (default: 100)') + + parser.add_argument('--nworkers', type=int, default=1, metavar='N', + help='number of epochs to train (default: 1)') + + parser.add_argument('--lrCrit_E_H', type=float, default=0.00001, metavar='LR', + help='learning rate Critic_H_E (default: 0.00001)') + + parser.add_argument('--chpt_base', type=str, default='/eos/user/e/eneren/experiments/', + help='continue training from a saved model') + + parser.add_argument('--lrGen_E_H', type=float, default=0.0001, metavar='LR', + help='learning rate Generator_H_E (default: 0.0001)') + + parser.add_argument('--chpt', action='store_true', default=False, + help='continue training from a saved model') + + parser.add_argument('--exp', type=str, default='dist_wgan', + help='name of the experiment') + + parser.add_argument('--chpt_eph', type=int, default=1, + help='continue checkpoint epoch') + + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status') + + if dist.is_available(): + parser.add_argument('--backend', type=str, help='Distributed backend', + choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI], + default=dist.Backend.GLOO) + + parser.add_argument('--local_rank', type=int, default=0) + + args = parser.parse_args() + + + args.local_rank = int(os.environ.get('LOCAL_RANK', args.local_rank)) + args.rank = int(os.environ.get('RANK')) + args.world_size = int(os.environ.get('WORLD_SIZE')) + + + # postprocess args + args.device = 'cuda:{}'.format(args.local_rank) # PytorchJob/launch.py + args.batch_size = max(args.batch_size, + args.world_size * 2) # min valid batchsize + return args + + +def run(args): + # Training settings + + use_cuda = not args.no_cuda and torch.cuda.is_available() + if use_cuda: + print('Using CUDA') + + # set up COMET ML experiment for logging + experiment = Experiment(api_key=api_key, + project_name="ecal-hcal-shower", workspace="engineren", auto_output_logging="simple") + experiment.add_tag(args.exp) + + experiment.log_parameters( + { + "batch_size" : args.batch_size, + "latent": args.nz, + "lambda": args.lambd, + "ncrit" : args.ncrit, + "ngf": args.ngf, + "ndf": args.ndf + } + ) + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + if args.world_size > 1: + print('Using distributed PyTorch with {} backend'.format(args.backend)) + dist.init_process_group(backend=args.backend) + + print('[init] == local rank: {}, global rank: {}, world size: {} =='.format(args.local_rank, args.rank, args.world_size)) + + + print ("loading data") + + dataset = HDF5Dataset('/eos/user/e/eneren/scratch/50GeV75k.hdf5', transform=None, train_size=75000) + + + + if args.world_size > 1: + sampler = DistributedSampler(dataset, shuffle=True) + train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.nworkers, drop_last=True, pin_memory=False) + else: + train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.nworkers, shuffle=True, drop_last=True, pin_memory=False) + + + ## Global critic + Crit_E_H = Global_Discriminator(args.ndf) + + ## Combined generator + Gen_E_H = Combined_Generator(args.ngf, args.nz) + + + Crit_E_H = Crit_E_H.to(device) + Gen_E_H = Gen_E_H.to(device) + + print('Critic trainable params:', sum(p.numel() for p in Crit_E_H.parameters() if p.requires_grad)) + print('Generator trainable params:', sum(p.numel() for p in Gen_E_H.parameters() if p.requires_grad)) + + if args.world_size > 1: + Distributor = nn.parallel.DistributedDataParallel if use_cuda \ + else nn.parallel.DistributedDataParallelCPU + Crit_E_H = Distributor(Crit_E_H, device_ids=[args.local_rank], output_device=args.local_rank ) + Gen_E_H = Distributor(Gen_E_H, device_ids=[args.local_rank], output_device=args.local_rank ) + + else: + Crit_E_H = nn.parallel.DataParallel(Crit_E_H) + Gen_E_H = nn.parallel.DataParallel(Gen_E_H) + + + # instantiate opimizers + optimizerG_E_H = optim.Adam(Gen_E_H.parameters(), lr=args.lrGen_E_H, betas=(0.5, 0.9)) + optimizerD_E_H = optim.Adam(Crit_E_H.parameters(), lr=args.lrCrit_E_H, betas=(0.5, 0.9)) + + + + # load from checkpoint if desired + if (args.chpt): + critic_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_criticE_"+ str(args.chpt_eph) + ".pt") + gen_E_H_checkpoint = torch.load(args.chpt_base + args.exp + "_generatorE_H_"+ str(args.chpt_eph) + ".pt") + + Crit_E_H.load_state_dict(critic_E_H_checkpoint['model_state_dict']) + optimizerD_E_H.load_state_dict(critic_E_H_checkpoint['optimizer_state_dict']) + + Gen_E_H.load_state_dict(gen_E_H_checkpoint['model_state_dict']) + optimizerG_E_H.load_state_dict(gen_H_E_checkpoint['optimizer_state_dict']) + + eph = gen_E_H_checkpoint['epoch'] + + else: + eph = 0 + print ("init models") + + experiment.set_model_graph(str(Crit_E_H), overwrite=False) + experiment.set_model_graph(str(Gen_E_H), overwrite=False) + + print('starting training...') + + for epoch in range(1, args.epochs + 1): + epoch += eph + + train(args, Crit_E_H, Gen_E_H, device, train_loader, optimizerD_E_H, optimizerG_E_H, epoch, experiment) + + # saving to checkpoints + g_E_H_path = args.chpt_base + args.exp + "_generator_E_H_"+ str(epoch) + ".pt" + c_E_H_path = args.chpt_base + args.exp + "_critic_E_H_"+ str(epoch) + ".pt" + + torch.save({ + 'epoch': epoch, + 'model_state_dict': Gen_E_H.state_dict(), + 'optimizer_state_dict': optimizerG_E_H.state_dict() + }, g_E_H_path) + + torch.save({ + 'epoch': epoch, + 'model_state_dict': Crit_E_H.state_dict(), + 'optimizer_state_dict': optimizerD_E_H.state_dict() + }, c_E_H_path) + + print('end training') + +def main(): + args = parse_args() + run(args) + +if __name__ == '__main__': + main() \ No newline at end of file