From cb3907407a8a93877533b7cf354e8ef9442bda08 Mon Sep 17 00:00:00 2001
From: Engin Eren <engin.eren@desy.de>
Date: Fri, 8 Apr 2022 15:01:05 +0200
Subject: [PATCH] adding comet ml stuff

---
 Dockerfile |  2 +-
 wgan.py    | 35 ++++++++++++++++++++++++++++++-----
 2 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/Dockerfile b/Dockerfile
index 41e0a6f..dfcf4f9 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,7 +12,7 @@ RUN apt-get -qq update && \
 
 RUN mkdir -p /opt/regressor && \ 
     mkdir -p /opt/regressor/src/models \
-    && pip install h5py pyflakes && export MKL_SERVICE_FORCE_INTEL=1
+    && pip install h5py pyflakes comet_ml && export MKL_SERVICE_FORCE_INTEL=1
 
 WORKDIR /opt/regressor/src
 ADD regressor.py /opt/regressor/src/regressor.py
diff --git a/wgan.py b/wgan.py
index 19360a9..1273dc6 100644
--- a/wgan.py
+++ b/wgan.py
@@ -1,4 +1,5 @@
 from __future__ import print_function
+from comet_ml import Experiment
 import argparse
 import os, sys
 import numpy as np
@@ -53,7 +54,7 @@ def calc_gradient_penalty(netD, real_data, fake_data, real_label, BATCH_SIZE, de
     return gradient_penalty
 
     
-def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch):
+def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch, experiment):
     
     ### CRITIC TRAINING
     aD.train()
@@ -95,7 +96,12 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch):
             print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
                 epoch, batch_idx * len(data), len(train_loader.dataset),
                 100. * batch_idx / len(train_loader), disc_cost.item()))
-            #niter = epoch * len(train_loader) + batch_idx
+            niter = epoch * len(train_loader) + batch_idx
+            experiment.log_metric("L_crit", disc_cost, step=niter)
+            experiment.log_metric("gradient_pen", gradient_penalty, step=niter)
+            experiment.log_metric("Wasserstein Dist", w_dist, step=niter)
+            experiment.log_metric("Critic Score (Real)", torch.mean(disc_real), step=niter)
+            experiment.log_metric("Critic Score (Fake)", torch.mean(disc_fake), step=niter)
 
         
         ## training generator per ncrit 
@@ -136,7 +142,9 @@ def train(args, aD, aG, device, train_loader, optimizer_d, optimizer_g, epoch):
                 print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
                     epoch, batch_idx * len(data), len(train_loader.dataset),
                     100. * batch_idx / len(train_loader), g_cost.item()))
-
+                niter = epoch * len(train_loader) + batch_idx
+                experiment.log_metric("L_Gen", g_cost, step=niter)
+                #experiment.log_metric("L_Gen", torch.mean(gen_cost), step=niter)
 
 
 
@@ -223,7 +231,21 @@ def run(args):
     if use_cuda:
         print('Using CUDA')
 
-   
+    
+    experiment = Experiment(api_key="keGmeIz4GfKlQZlOP6cit4QOi",
+                        project_name="ecal-hcal-shower", workspace="engineren", auto_output_logging="simple")
+    experiment.add_tag(args.exp)
+
+    experiment.log_parameters(
+        {
+        "batch_size" : args.batch_size,
+        "latent": args.nz,
+        "lambda": args.lambd,
+        "ncrit" : args.ncrit,
+        "resN": args.dres,
+        "ngf": args.ngf
+        }
+    )
 
     torch.manual_seed(args.seed)
 
@@ -282,11 +304,14 @@ def run(args):
     #reg_checkpoint = torch.load(args.chpt_base + "dist_launch_sampler_regressor_2.pt")
     #mReg.load_state_dict(reg_checkpoint['model_state_dict'])
 
+    
+    experiment.set_model_graph(str(mGen), overwrite=False)
+
     print ("starting training...")
     for epoch in range(1, args.epochs + 1):
         epoch += eph
         train_loader.sampler.set_epoch(epoch)
-        train(args, mCrit, mGen, device, train_loader, optimizerD, optimizerG, epoch)
+        train(args, mCrit, mGen, device, train_loader, optimizerD, optimizerG, epoch, experiment)
         if args.rank == 0:
             gPATH = args.chpt_base + args.exp + "_generator_"+ str(epoch) + ".pt"
             cPATH = args.chpt_base + args.exp + "_critic_"+ str(epoch) + ".pt"
-- 
GitLab