Skip to content
Snippets Groups Projects
Commit fb6a56d4 authored by Engin Eren's avatar Engin Eren
Browse files

re-arranging some stuff

parent 58086f18
No related branches found
No related tags found
1 merge request!43crit peter
Pipeline #4286584 passed
......@@ -180,10 +180,22 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
g_E_cost.backward()
optimizer_g_E.step()
if batch_idx % args.log_interval == 0 :
print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f}'.format(
epoch, batch_idx * len(dataH), len(train_loader.dataset),
100. * batch_idx / len(train_loader), g_E_cost.item()))
niter = epoch * len(train_loader) + batch_idx
experiment.log_metric("L_Gen_E", g_E_cost, step=niter)
## HCAL CRITIC TRAINING
aDH.train()
aGH.eval()
aGE.eval()
# zero out critic gradients
optimizer_d_H.zero_grad()
......@@ -239,6 +251,15 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
g_H_cost.backward()
optimizer_g_H.step()
if batch_idx % args.log_interval == 0 :
print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGH={:.4f}'.format(
epoch, batch_idx * len(dataH), len(train_loader.dataset),
100. * batch_idx / len(train_loader), g_H_cost.item()))
niter = epoch * len(train_loader) + batch_idx
experiment.log_metric("L_Gen_H", g_H_cost, step=niter)
## ECAL + HCAL CRITIC TRAINING
aD_H_E.train()
......@@ -279,7 +300,28 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
optimizer_d_H_E.step()
if batch_idx % args.log_interval == 0:
print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
epoch, batch_idx * len(dataH), len(train_loader.dataset),
100. * batch_idx / len(train_loader), disc_cost_H_E.item()))
niter = epoch * len(train_loader) + batch_idx
experiment.log_metric("L_crit_E", disc_cost_E, step=niter)
experiment.log_metric("L_crit_H", disc_cost_H, step=niter)
experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter)
experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter)
experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter)
experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter)
experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter)
experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter)
experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter)
experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter)
experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter)
experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter)
experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter)
experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter)
experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter)
## training generator per ncrit
if (batch_idx % args.ncrit == 0) and (batch_idx != 0):
## GENERATOR TRAINING
......@@ -308,37 +350,17 @@ def train(args, aDE, aDH, aD_H_E, aGE, aGH, device, train_loader, optimizer_d_E,
g_cost.backward()
optimizer_g_H_E.step()
if batch_idx % args.log_interval == 0:
print('Critic --> Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
epoch, batch_idx * len(dataH), len(train_loader.dataset),
100. * batch_idx / len(train_loader), disc_cost_H_E.item()))
niter = epoch * len(train_loader) + batch_idx
experiment.log_metric("L_crit_E", disc_cost_E, step=niter)
experiment.log_metric("L_crit_H", disc_cost_H, step=niter)
experiment.log_metric("L_crit_H_E", disc_cost_H_E, step=niter)
experiment.log_metric("gradient_pen_E", gradient_penalty_E, step=niter)
experiment.log_metric("gradient_pen_H", gradient_penalty_H, step=niter)
experiment.log_metric("gradient_pen_H_E", gradient_penalty_H_E, step=niter)
experiment.log_metric("Wasserstein Dist E", w_dist_E, step=niter)
experiment.log_metric("Wasserstein Dist H", w_dist_H, step=niter)
experiment.log_metric("Wasserstein Dist E H", w_dist_H_E, step=niter)
experiment.log_metric("Critic Score E (Real)", torch.mean(disc_real_E), step=niter)
experiment.log_metric("Critic Score E (Fake)", torch.mean(disc_fake_E), step=niter)
experiment.log_metric("Critic Score H (Real)", torch.mean(disc_real_H), step=niter)
experiment.log_metric("Critic Score H (Fake)", torch.mean(disc_fake_H), step=niter)
experiment.log_metric("Critic Score H E (Real)", torch.mean(disc_real_H_E), step=niter)
experiment.log_metric("Critic Score H E (Fake)", torch.mean(disc_fake_H_E), step=niter)
if batch_idx % args.log_interval == 0 :
print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\tlossGE={:.4f} lossGH={:.4f} lossGHE={:.4f}'.format(
if batch_idx % args.log_interval == 0 :
print('Generator --> Train Epoch: {} [{}/{} ({:.0f}%)]\t lossGHE={:.4f}'.format(
epoch, batch_idx * len(dataH), len(train_loader.dataset),
100. * batch_idx / len(train_loader), g_E_cost.item(), g_H_cost.item(), g_cost.item()))
100. * batch_idx / len(train_loader), g_cost.item()))
niter = epoch * len(train_loader) + batch_idx
experiment.log_metric("L_Gen_E", g_E_cost, step=niter)
experiment.log_metric("L_Gen_H", g_H_cost, step=niter)
experiment.log_metric("L_Gen_H_E", g_cost, step=niter)
def is_distributed():
return dist.is_available() and dist.is_initialized()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment