diff --git a/wgan_ECAL_HCAL_2crit.py b/wgan_ECAL_HCAL_2crit.py index 11113d4fb0b75afa89d2f31630c0af540f49e3f1..57f94655ac8760d65a34e9f8bd1a0fc816eee565 100644 --- a/wgan_ECAL_HCAL_2crit.py +++ b/wgan_ECAL_HCAL_2crit.py @@ -443,8 +443,14 @@ def run(args): else: eph = 0 gen_E_checkpoint = torch.load("/eos/user/e/eneren/experiments/wganv1_generator_694.pt", map_location=torch.device('cuda')) + critic_E_checkpoint = torch.load("/eos/user/e/eneren/experiments/wganv1_critic_694.pt", map_location=torch.device('cuda')) + mGenE.load_state_dict(gen_E_checkpoint['model_state_dict']) optimizerG_E.load_state_dict(gen_E_checkpoint['optimizer_state_dict']) + + mCritE.load_state_dict(critic_E_checkpoint['model_state_dict']) + optimizerD_E.load_state_dict(critic_E_checkpoint['optimizer_state_dict']) + print ("init models")