diff --git a/wgan_ECAL_HCAL_3crit.py b/wgan_ECAL_HCAL_3crit.py index 0e6bc93fd6dae480faf7cd3e230a815d25549696..d3da95e137d3a0d29200b5337f510fa5539a33ae 100644 --- a/wgan_ECAL_HCAL_3crit.py +++ b/wgan_ECAL_HCAL_3crit.py @@ -180,10 +180,22 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, g_E_cost.backward() optimizer_g_E.step() + if batch_idx % args.log_interval == 0 : + print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), g_E_cost.item())) + + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_Gen_E", g_E_cost, step=niter) + + + ## HCAL CRITIC TRAINING aDH.train() aGH.eval() + aGE.eval() + # zero out critic gradients optimizer_d_H.zero_grad() @@ -239,6 +251,15 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, g_H_cost.backward() optimizer_g_H.step() + if batch_idx % args.log_interval == 0 : + print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGH={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), g_H_cost.item())) + + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_Gen_H", g_H_cost, step=niter) + + ## ECAL + HCAL CRITIC TRAINING aD_H_E.train() @@ -279,7 +300,28 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, optimizer_d_H_E.step() - + if batch_idx % args.log_interval == 0: + print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( + epoch, batch_idx * len(dataH), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), disc_cost_H_E.item())) + niter = epoch * len(train_loader) + batch_idx + experiment.log_metric("L_crit_E", disc_cost_E, step=niter) + experiment.log_metric("L_crit_H", disc_cost_H, step=niter) + experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter) + experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter) + experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter) + experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter) + experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter) + experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter) + experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter) + experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter) + experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter) + experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter) + experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter) + experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter) + experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter) + + ## training generator per ncrit if (batch_idx % args.ncrit == 0) and (batch_idx != 0): ## GENERATOR TRAINING @@ -308,37 +350,17 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E, g_cost.backward() optimizer_g_H_E.step() - if batch_idx % args.log_interval == 0: - print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format( - epoch, batch_idx * len(dataH), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), disc_cost_H_E.item())) - niter = epoch * len(train_loader) + batch_idx - experiment.log_metric("L_crit_E", disc_cost_E, step=niter) - experiment.log_metric("L_crit_H", disc_cost_H, step=niter) - experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter) - experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter) - experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter) - experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter) - experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter) - experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter) - experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter) - experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter) - experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter) - experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter) - experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter) - experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter) - experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter) - - if batch_idx % args.log_interval == 0 : - print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f} lossGHE={:.4f}'.format( + if batch_idx % args.log_interval == 0 : + print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\t lossGHE={:.4f}'.format( epoch, batch_idx * len(dataH), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), g_E_cost.item(), g_H_cost.item(), g_cost.item())) + 100. * batch_idx / len(train_loader), g_cost.item())) + niter = epoch * len(train_loader) + batch_idx - experiment.log_metric("L_Gen_E", g_E_cost, step=niter) - experiment.log_metric("L_Gen_H", g_H_cost, step=niter) experiment.log_metric("L_Gen_H_E", g_cost, step=niter) + + def is_distributed(): return dist.is_available() and dist.is_initialized()