From cacdc0ca3f86c1e93e72a0532e18c146b5d5082c Mon Sep 17 00:00:00 2001 From: Julien Maurer Date: Tue, 17 May 2022 14:19:30 +0200 Subject: [PATCH] Merge branch 'ReadLayersFromJSONsNNbasedExtrapWeights' into '21.0' Retrieve relevant layers from networks used to predict the longitudinal hit position and fix GAN code using those networks See merge request atlas/athena!51075 (cherry picked from commit 699dc10fb66571ece90dd5e4d72e0a90a7207729) 58aa031d No hardcoded layers & support NN using all parts 33fdfa32 Remove temporary lines and pdgID for pions 1574c721 Improve handling of errors and fix for GAN b3b49f35 Merge tag 'nightly/21.0/2022-04-16T2123' into ReadLayersFromJSONsNNbasedExtrapWeights --- .../src/TFCSEnergyAndHitGAN.cxx | 15 +++++++-------- .../src/TFCSPredictExtrapWeights.cxx | 16 +++++++--------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGAN.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGAN.cxx index 4b10b666c96..362a6b30533 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGAN.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGAN.cxx @@ -288,6 +288,13 @@ bool TFCSEnergyAndHitGAN::fillEnergy(TFCSSimulationState& simulstate, const TFCS } } + for(unsigned int ichain=m_bin_start.back();ichainGetName()); + if (simulate_and_retry(chain()[ichain], simulstate, truth, extrapol) != FCSSuccess) { + return FCSFatal; + } + } + vox = 0; for (auto element : binsInLayers){ int layer = element.first; @@ -506,14 +513,6 @@ FCSReturnCode TFCSEnergyAndHitGAN::simulate(TFCSSimulationState& simulstate,cons return FCSSuccess; } - - for(unsigned int ichain=m_bin_start.back();ichainGetName()); - if (simulate_and_retry(chain()[ichain], simulstate, truth, extrapol) != FCSSuccess) { - return FCSFatal; - } - } - return FCSSuccess; } diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx index 63127f6836f..ab700b587f6 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx @@ -136,6 +136,7 @@ bool TFCSPredictExtrapWeights::getNormInputs(std::string etaBin, std::string Fas inputTXT.close(); } else { ATH_MSG_ERROR(" Unable to open file "); + return false; } return true; @@ -183,12 +184,9 @@ FCSReturnCode TFCSPredictExtrapWeights::simulate(TFCSSimulationState& simulstate // Get predicted extrapolation weights auto outputs = m_nn->compute(inputVariables); - std::vector layers = {0,1,2,3,12}; - if(is_match_pdgid(211) || is_match_pdgid(-211)){ // charged pion - layers.push_back(13); - layers.push_back(14); - } - for(int ilayer : layers){ // loop over layers and decorate simulstate with corresponding predicted extrapolation weight + for (unsigned int i = 0; i < m_relevantLayers->size(); i++) { + int ilayer = m_relevantLayers->at(i); + ATH_MSG_DEBUG("TFCSPredictExtrapWeights::simulate: layer: " << ilayer << " weight: " << outputs["extrapWeight_"+std::to_string(ilayer)]); simulstate.setAuxInfo(ilayer,outputs["extrapWeight_"+std::to_string(ilayer)]); } return FCSSuccess; @@ -206,8 +204,8 @@ FCSReturnCode TFCSPredictExtrapWeights::simulate_hit(Hit& hit, TFCSSimulationSta if(simulstate.hasAuxInfo(cs)){ extrapWeight = simulstate.getAuxInfo(cs); } else{ // missing AuxInfo - simulate(simulstate, truth, extrapol); // decorate simulstate with extrapolation weights - extrapWeight = simulstate.getAuxInfo(cs); // retrieve corresponding extrapolation weight + ATH_MSG_FATAL("Simulstate is not decorated with extrapolation weights"); + return FCSFatal; } double r = (1.-extrapWeight)*extrapol->r(cs, SUBPOS_ENT) + extrapWeight*extrapol->r(cs, SUBPOS_EXT); @@ -255,7 +253,7 @@ bool TFCSPredictExtrapWeights::initializeNetwork(int pid, std::string etaBin, st sin << input.rdbuf(); input.close(); auto config = lwt::parse_json(sin); - m_nn = new lwt::LightweightNeuralNetwork(config.inputs, config.layers, config.outputs); + m_nn = new lwt::LightweightNeuralNetwork(config.inputs, config.layers, config.outputs); if(m_nn==nullptr){ ATH_MSG_ERROR("Could not create LightWeightNeuralNetwork from " << inputFileName ); return false; -- GitLab