From 037662f0bae7b7280c55931164cba3d6cd66c362 Mon Sep 17 00:00:00 2001 From: Jean Yves Beaucamp <jean.yves.beaucamp@cern.ch> Date: Tue, 5 Nov 2024 01:40:25 +0100 Subject: [PATCH 1/4] Implementation of ONNX model score computation and decoration for Taus (GNTau) Implement Bertand's suggested solution to vertex variables calculation - which also only has to run once per tau! --- .../tauRec/python/TauConfigFlags.py | 3 + Reconstruction/tauRec/python/TauToolHolder.py | 48 + Reconstruction/tauRecTools/CMakeLists.txt | 9 +- Reconstruction/tauRecTools/Root/TauGNN.cxx | 206 ++++ .../tauRecTools/Root/TauGNNEvaluator.cxx | 175 ++++ .../tauRecTools/Root/TauGNNUtils.cxx | 914 ++++++++++++++++++ .../src/components/tauRecTools_entries.cxx | 2 + .../tauRecTools/tauRecTools/TauGNN.h | 110 +++ .../tauRecTools/tauRecTools/TauGNNEvaluator.h | 69 ++ .../tauRecTools/tauRecTools/TauGNNUtils.h | 311 ++++++ 10 files changed, 1843 insertions(+), 4 deletions(-) create mode 100644 Reconstruction/tauRecTools/Root/TauGNN.cxx create mode 100644 Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx create mode 100644 Reconstruction/tauRecTools/Root/TauGNNUtils.cxx create mode 100644 Reconstruction/tauRecTools/tauRecTools/TauGNN.h create mode 100644 Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h create mode 100644 Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py index f4c7ff650c38..d17807f2bf96 100644 --- a/Reconstruction/tauRec/python/TauConfigFlags.py +++ b/Reconstruction/tauRec/python/TauConfigFlags.py @@ -60,6 +60,9 @@ def createTauConfigFlags(): # R22 DeepSet tau ID tune without track RNN scores, for now define a second set of flags, but ultimately we'll choose one and drop the other tau_cfg.addFlag("Tau.TauJetDeepSetConfig_v2", ["tauid_1p_R22_dpst_noTrackScore.json", "tauid_2p_R22_dpst_noTrackScore.json", "tauid_3p_R22_dpst_noTrackScore.json"]) tau_cfg.addFlag("Tau.TauJetDeepSetWP_v2", ["model_1p_R22_dpst_noTrackScore.root", "model_2p_R22_dpst_noTrackScore.root", "model_3p_R22_dpst_noTrackScore.root"]) + # GNTau ID tune file (need to add another version for noAux) + tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_noAux_simplified.onnx"]) + tau_cfg.addFlag("Tau.TauGNNWP_v0", ["GNTauNA_flat_model_1p.root", "GNTauNA_flat_model_2p.root", "GNTauNA_flat_model_3p.root"]) # PanTau config flags from PanTauAlgs.PanTauConfigFlags import createPanTauConfigFlags diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py index 679fb4c86f9f..35aea18ff2c9 100644 --- a/Reconstruction/tauRec/python/TauToolHolder.py +++ b/Reconstruction/tauRec/python/TauToolHolder.py @@ -851,6 +851,54 @@ def TauWPDecoratorJetDeepSetCfg(flags, version=None): result.setPrivateTools(myTauWPDecorator) return result +def TauGNNEvaluatorCfg(flags): + result = ComponentAccumulator() + _name = flags.Tau.ActiveConfig.prefix + 'TauGNN' + + TauGNNEvaluator = CompFactory.getComp("TauGNNEvaluator") + GNNConf = flags.Tau.TauGNNConfig + myTauGNNEvaluator = TauGNNEvaluator(name = _name, + NetworkFile = GNNConf[0], + OutputVarname = "GNTauScore", + OutputPTau = "GNTauProbTau", + OutputPJet = "GNTauProbJet", + MaxTracks = 30, + MaxClusters = 20, + MaxClusterDR = 15.0, + VertexCorrection = True, + DecorateTracks = False, + InputLayerScalar = "tau_vars", + InputLayerTracks = "track_vars", + InputLayerClusters = "cluster_vars", + NodeNameTau="GN2TauNoAux_pb", + NodeNameJet="GN2TauNoAux_pu") + + result.setPrivateTools(myTauGNNEvaluator) + return result + +def TauWPDecoratorGNNCfg(flags): + result = ComponentAccumulator() + _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN' + + TauWPDecorator = CompFactory.getComp("TauWPDecorator") + WPConf = flags.Tau.TauGNNWP_v0 + decorWPNames = ["GNTauVL_v0", "GNTauL_v0", "GNTauM_v0", "GNTauT_v0"] + scoreName = "GNTauScore" + newScoreName = "GNTauScoreSigTrans_v0" + myTauWPDecorator = TauWPDecorator(name=_name, + flatteningFile1Prong = WPConf[0], + flatteningFile2Prong = WPConf[1], + flatteningFile3Prong = WPConf[2], + DecorWPNames = decorWPNames, + DecorWPCutEffs1P = [0.95, 0.85, 0.75, 0.60], + DecorWPCutEffs2P = [0.95, 0.75, 0.60, 0.45], + DecorWPCutEffs3P = [0.95, 0.75, 0.60, 0.45], + ScoreName = scoreName, + NewScoreName = newScoreName, + DefineWPs = True) + result.setPrivateTools(myTauWPDecorator) + return result + def TauEleRNNEvaluatorCfg(flags): result = ComponentAccumulator() _name = flags.Tau.ActiveConfig.prefix + 'TauEleRNN' diff --git a/Reconstruction/tauRecTools/CMakeLists.txt b/Reconstruction/tauRecTools/CMakeLists.txt index 40a34801c23a..87b3c12ef094 100644 --- a/Reconstruction/tauRecTools/CMakeLists.txt +++ b/Reconstruction/tauRecTools/CMakeLists.txt @@ -8,6 +8,7 @@ find_package( Boost ) find_package( Eigen ) find_package( ROOT COMPONENTS Core Tree Hist RIO ) find_package( lwtnn ) +find_package( onnxruntime ) # Optional dependencies. set( extra_public_libs ) @@ -28,12 +29,12 @@ atlas_add_library( tauRecToolsLib tauRecTools/*.h Root/*.cxx tauRecTools/lwtnn/*.h Root/lwtnn/*.cxx PUBLIC_HEADERS tauRecTools INCLUDE_DIRS ${Boost_INCLUDE_DIRS} ${EIGEN_INCLUDE_DIRS} - ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} + ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${ONNXRUNTIME_INCLUDE_DIRS} LINK_LIBRARIES ${Boost_LIBRARIES} ${EIGEN_LIBRARIES} ${LWTNN_LIBRARIES} - ${ROOT_LIBRARIES} + ${ROOT_LIBRARIES} ${ONNXRUNTIME_LIBRARIES} CxxUtils AthLinks AsgMessagingLib AsgDataHandlesLib AsgTools xAODCaloEvent xAODEventInfo xAODJet xAODParticleEvent xAODPFlow xAODTau xAODTracking xAODEventShape - MVAUtils ${extra_public_libs} + MVAUtils FlavorTagDiscriminants ${extra_public_libs} PRIVATE_LINK_LIBRARIES CaloGeoHelpers FourMomUtils PathResolver ) atlas_add_dictionary( tauRecToolsDict @@ -46,5 +47,5 @@ if( NOT XAOD_STANDALONE ) INCLUDE_DIRS ${Boost_INCLUDE_DIRS} LINK_LIBRARIES ${Boost_LIBRARIES} AsgTools AsgMessagingLib xAODBase xAODCaloEvent xAODJet xAODPFlow xAODTau FourMomUtils tauRecToolsLib PFlowUtilsLib - ${extra_private_libs} ) + FlavorTagDiscriminants ${extra_private_libs} ) endif() diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx new file mode 100644 index 000000000000..65f13c4c2f6b --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -0,0 +1,206 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNN.h" +#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "PathResolver/PathResolver.h" + +#include <algorithm> +#include <fstream> + +#include "lwtnn/LightweightGraph.hh" +#include "lwtnn/Exceptions.hh" +//#include "lwtnn/parse_json.hh" + +#include "tauRecTools/TauGNNUtils.h" + +TauGNN::TauGNN(const std::string &nnFile, const Config &config): + asg::AsgMessaging("TauGNN"), + m_onnxUtil(nullptr) + { + //==================================================// + // This part is ported from FTagDiscriminant GNN.cxx// + //==================================================// + + m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile); + + // get the configuration of the model outputs + FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + + //Let's see the output! + for (const auto& out_node: gnn_output_config) { + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name); + } + + //Get model config (for inputs) + auto lwtnn_config = m_onnxUtil->getLwtConfig(); + + //===================================================// + // This part is ported from tauRecTools TauJetRNN.cxx// + //===================================================// + + // Search for input layer names specified in 'config' + auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_scalar; + }; + auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_tracks; + }; + auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_clusters; + }; + + auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(), + lwtnn_config.inputs.cend(), + node_is_scalar); + + auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(), + lwtnn_config.input_sequences.cend(), + node_is_track); + + auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(), + lwtnn_config.input_sequences.cend(), + node_is_cluster); + + // Check which input layers were found + auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend(); + auto has_track_node = track_node != lwtnn_config.input_sequences.cend(); + auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend(); + if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!"); + if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!"); + if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!"); + + // Fill the variable names of each input layer into the corresponding vector + if (has_scalar_node) { + for (const auto &in : scalar_node->variables) { + std::string name = in.name; + m_scalarCalc_inputs.push_back(name); + } + } + + if (has_track_node) { + for (const auto &in : track_node->variables) { + std::string name = in.name; + m_trackCalc_inputs.push_back(name); + } + } + + if (has_cluster_node) { + for (const auto &in : cluster_node->variables) { + std::string name = in.name; + m_clusterCalc_inputs.push_back(name); + } + } + // Load the variable calculator + m_var_calc = TauGNNUtils::get_calculator(m_scalarCalc_inputs, m_trackCalc_inputs, m_clusterCalc_inputs); + ATH_MSG_INFO("TauGNN object initialized successfully!"); +} + +TauGNN::~TauGNN() {} + +std::tuple< + std::map<std::string, float>, + std::map<std::string, std::vector<char>>, + std::map<std::string, std::vector<float>> > +TauGNN::compute(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { + InputMap scalarInputs; + InputSequenceMap vectorInputs; + std::map<std::string, input_pair> gnn_input; + ATH_MSG_DEBUG("Starting compute..."); + //Prepare input variables + if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) { + ATH_MSG_FATAL("Failed calculateInputVariables"); + throw StatusCode::FAILURE; + } + + // Add TauJet-level features to the input + std::vector<float> tau_feats; + for (const auto &varname : m_scalarCalc_inputs) { + tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname])); + } + std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())}; + input_pair tau_info (tau_feats, tau_feats_dim); + gnn_input.insert({"tau_vars", tau_info}); + + //Add track-level features to the input + std::vector<float> trk_feats; + int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size()); + int num_node_vars=static_cast<int>(m_trackCalc_inputs.size()); + trk_feats.resize(num_nodes * num_node_vars); + int var_idx=0; + for (const auto &varname : m_trackCalc_inputs) { + for (int node_idx=0; node_idx<num_nodes; node_idx++){ + trk_feats.at(node_idx*num_node_vars + var_idx) + = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx)); + } + var_idx++; + } + std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars}; + input_pair trk_info (trk_feats, trk_feats_dim); + gnn_input.insert({"track_vars", trk_info}); + + //Add cluster-level features to the input + std::vector<float> cls_feats; + num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size()); + num_node_vars=static_cast<int>(m_clusterCalc_inputs.size()); + cls_feats.resize(num_nodes * num_node_vars); + var_idx=0; + for (const auto &varname : m_clusterCalc_inputs) { + for (int node_idx=0; node_idx<num_nodes; node_idx++){ + cls_feats.at(node_idx*num_node_vars + var_idx) + = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx)); + } + var_idx++; + } + std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars}; + input_pair cls_info (cls_feats, cls_feats_dim); + gnn_input.insert({"cluster_vars", cls_info}); + + //RUN THE INFERENCE!!! + ATH_MSG_DEBUG("Prepared inputs, running inference..."); + auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + ATH_MSG_DEBUG("Finished compute!"); + return std::make_tuple(out_f, out_vc, out_vf); +} + +bool TauGNN::calculateInputVariables(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::map<std::string, std::map<std::string, double>>& scalarInputs, + std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const { + scalarInputs.clear(); + vectorInputs.clear(); + // Populate input (sequence) map with input variables + for (const auto &varname : m_scalarCalc_inputs) { + if (!m_var_calc->compute(varname, tau, + scalarInputs[m_config.input_layer_scalar][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + + for (const auto &varname : m_trackCalc_inputs) { + if (!m_var_calc->compute(varname, tau, tracks, + vectorInputs[m_config.input_layer_tracks][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + + for (const auto &varname : m_clusterCalc_inputs) { + if (!m_var_calc->compute(varname, tau, clusters, + vectorInputs[m_config.input_layer_clusters][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + return true; +} diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx new file mode 100644 index 000000000000..0be8819d79ec --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx @@ -0,0 +1,175 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNNEvaluator.h" +#include "tauRecTools/TauGNN.h" +#include "tauRecTools/HelperFunctions.h" + +#include "PathResolver/PathResolver.h" + +#include <algorithm> + + +TauGNNEvaluator::TauGNNEvaluator(const std::string &name): + TauRecToolBase(name), + m_net(nullptr){ + + declareProperty("NetworkFile", m_weightfile = ""); + declareProperty("OutputVarname", m_output_varname = "GNTauScore"); + declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau"); + declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet"); + declareProperty("MaxTracks", m_max_tracks = 30); + declareProperty("MaxClusters", m_max_clusters = 20); + declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f); + declareProperty("VertexCorrection", m_doVertexCorrection = true); + declareProperty("DecorateTracks", m_decorateTracks = false); + declareProperty("TrackClassification", m_doTrackClassification = true); + + // Naming conventions for the network weight files: + declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars"); + declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars"); + declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars"); + declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb"); + declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu"); + } + +TauGNNEvaluator::~TauGNNEvaluator() {} + +StatusCode TauGNNEvaluator::initialize() { + ATH_MSG_INFO("Initializing TauGNNEvaluator"); + + std::string weightfile(""); + + // Use PathResolver to search for the weight files + if (!m_weightfile.empty()) { + weightfile = find_file(m_weightfile); + if (weightfile.empty()) { + ATH_MSG_ERROR("Could not find network weights: " << m_weightfile); + return StatusCode::FAILURE; + } else { + ATH_MSG_INFO("Using network config: " << weightfile); + } + } + + // Set the layer and node names in the weight file + TauGNN::Config config; + config.input_layer_scalar = m_input_layer_scalar; + config.input_layer_tracks = m_input_layer_tracks; + config.input_layer_clusters = m_input_layer_clusters; + config.output_node_tau = m_outnode_tau; + config.output_node_jet = m_outnode_jet; + + // Load the weights and create the network + if (!weightfile.empty()) { + m_net = std::make_unique<TauGNN>(weightfile, config); + if (!m_net) { + ATH_MSG_ERROR("No network configured."); + return StatusCode::FAILURE; + } + } + + return StatusCode::SUCCESS; +} + +StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const { + // Output variable Decorators + const SG::AuxElement::Accessor<float> output(m_output_varname); + const SG::AuxElement::Accessor<float> out_ptau(m_output_ptau); + const SG::AuxElement::Accessor<float> out_pjet(m_output_pjet); + const SG::AuxElement::Decorator<char> out_trkclass("GNTau_TrackClass"); + // Set default score and overwrite later + output(tau) = -1111.0f; + out_ptau(tau) = -1111.0f; + out_pjet(tau) = -1111.0f; + //Skip execution for low-pT taus to save resources + if(tau.pt()<13000) { + return StatusCode::SUCCESS; + } + + // Get input objects + std::vector<const xAOD::TauTrack *> tracks; + ATH_CHECK(get_tracks(tau, tracks)); + std::vector<xAOD::CaloVertexedTopoCluster> clusters; + ATH_CHECK(get_clusters(tau, clusters)); + + // Truncate tracks + int numTracksMax = std::min(m_max_tracks, tracks.size()); + std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax); + // Evaluate networks + if (m_net) { + auto [out_f, out_vc, out_vf] = m_net->compute(tau, trackVec, clusters); + output(tau)=std::log10(1/(1-out_f.at(m_outnode_tau))); + out_ptau(tau)=out_f.at(m_outnode_tau); + out_pjet(tau)=out_f.at(m_outnode_jet); + if (m_decorateTracks){ + for(unsigned int i=0;i<tracks.size();i++){ + if(i<out_vc.at("track_class").size()){out_trkclass(*tracks.at(i))=out_vc.at("track_class").at(i);} + else{out_trkclass(*tracks.at(i))='9';} //Dummy value for tracks outside range of out_vc + } + } + } + + return StatusCode::SUCCESS; +} + +const TauGNN* TauGNNEvaluator::get_gnn() const { + return m_net.get(); +} + + +StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const { + std::vector<const xAOD::TauTrack*> tracks = tau.allTracks(); + + // Skip unclassified tracks: + // - the track is a LRT and classifyLRT = false + // - the track is not among the MaxNtracks highest-pt tracks in the track classifier + // - track classification is not run (trigger) + if(m_doTrackClassification) { + std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin(); + while(it != tracks.end()) { + if((*it)->flag(xAOD::TauJetParameters::unclassified)) { + it = tracks.erase(it); + } + else { + ++it; + } + } + } + + // Sort by descending pt + auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) { + return lhs->pt() > rhs->pt(); + }; + std::sort(tracks.begin(), tracks.end(), cmp_pt); + out = std::move(tracks); + + return StatusCode::SUCCESS; +} + +StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { + + TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection); + + std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters(); + for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) { + TLorentzVector clusterP4 = vertexedCluster.p4(); + if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue; + + clusters.push_back(vertexedCluster); + } + + // Sort by descending et + auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs, + const xAOD::CaloVertexedTopoCluster& rhs) { + return lhs.p4().Et() > rhs.p4().Et(); + }; + std::sort(clusters.begin(), clusters.end(), et_cmp); + + // Truncate clusters + if (clusters.size() > m_max_clusters) { + clusters.resize(m_max_clusters, clusters[0]); + } + + return StatusCode::SUCCESS; +} diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx new file mode 100644 index 000000000000..8da2fe5ea6ec --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx @@ -0,0 +1,914 @@ +/* + Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNNUtils.h" +#include "tauRecTools/HelperFunctions.h" +#include <algorithm> +#include <iostream> +#define GeV 1000 + +namespace TauGNNUtils { + +GNNVarCalc::GNNVarCalc() : asg::AsgMessaging("TauGNNUtils::GNNVarCalc") { +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + double &out) const { + // Retrieve calculator function + ScalarCalc func = nullptr; + try { + func = m_scalar_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variable + return func(tau, out); +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + std::vector<double> &out) const { + out.clear(); + + // Retrieve calculator function + TrackCalc func = nullptr; + try { + func = m_track_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variables for selected tracks + bool success = true; + double value; + for (const auto *const trk : tracks) { + success = success && func(tau, *trk, value); + out.push_back(value); + } + + return success; +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::vector<double> &out) const { + out.clear(); + + // Retrieve calculator function + ClusterCalc func = nullptr; + try { + func = m_cluster_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variables for selected clusters + bool success = true; + double value; + for (const xAOD::CaloVertexedTopoCluster& cluster : clusters) { + success = success && func(tau, cluster, value); + out.push_back(value); + } + + return success; +} + +void GNNVarCalc::insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars) { + if (std::find(scalar_vars.begin(), scalar_vars.end(), name) == scalar_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_scalar_map[name] = func; +} + +void GNNVarCalc::insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars) { + if (std::find(track_vars.begin(), track_vars.end(), name) == track_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_track_map[name] = func; +} + +void GNNVarCalc::insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars) { + if (std::find(cluster_vars.begin(), cluster_vars.end(), name) == cluster_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_cluster_map[name] = func; +} + +std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars, + const std::vector<std::string>& track_vars, + const std::vector<std::string>& cluster_vars) { + auto calc = std::make_unique<GNNVarCalc>(); + + // Scalar variable calculator functions + calc->insert("absEta", Variables::absEta, scalar_vars); + calc->insert("isolFrac", Variables::isolFrac, scalar_vars); + calc->insert("centFrac", Variables::centFrac, scalar_vars); + calc->insert("etOverPtLeadTrk", Variables::etOverPtLeadTrk, scalar_vars); + calc->insert("innerTrkAvgDist", Variables::innerTrkAvgDist, scalar_vars); + calc->insert("absipSigLeadTrk", Variables::absipSigLeadTrk, scalar_vars); + calc->insert("SumPtTrkFrac", Variables::SumPtTrkFrac, scalar_vars); + calc->insert("sumEMCellEtOverLeadTrkPt", Variables::sumEMCellEtOverLeadTrkPt, scalar_vars); + calc->insert("EMPOverTrkSysP", Variables::EMPOverTrkSysP, scalar_vars); + calc->insert("ptRatioEflowApprox", Variables::ptRatioEflowApprox, scalar_vars); + calc->insert("mEflowApprox", Variables::mEflowApprox, scalar_vars); + calc->insert("dRmax", Variables::dRmax, scalar_vars); + calc->insert("trFlightPathSig", Variables::trFlightPathSig, scalar_vars); + calc->insert("massTrkSys", Variables::massTrkSys, scalar_vars); + calc->insert("pt", Variables::pt, scalar_vars); + calc->insert("pt_tau_log", Variables::pt_tau_log, scalar_vars); + calc->insert("ptDetectorAxis", Variables::ptDetectorAxis, scalar_vars); + calc->insert("ptIntermediateAxis", Variables::ptIntermediateAxis, scalar_vars); + //---added for the eVeto + calc->insert("ptJetSeed_log", Variables::ptJetSeed_log, scalar_vars); + calc->insert("absleadTrackEta", Variables::absleadTrackEta, scalar_vars); + calc->insert("leadTrackDeltaEta", Variables::leadTrackDeltaEta, scalar_vars); + calc->insert("leadTrackDeltaPhi", Variables::leadTrackDeltaPhi, scalar_vars); + calc->insert("leadTrackProbNNorHT", Variables::leadTrackProbNNorHT, scalar_vars); + calc->insert("EMFracFixed", Variables::EMFracFixed, scalar_vars); + calc->insert("etHotShotWinOverPtLeadTrk", Variables::etHotShotWinOverPtLeadTrk, scalar_vars); + calc->insert("hadLeakFracFixed", Variables::hadLeakFracFixed, scalar_vars); + calc->insert("PSFrac", Variables::PSFrac, scalar_vars); + calc->insert("ClustersMeanCenterLambda", Variables::ClustersMeanCenterLambda, scalar_vars); + calc->insert("ClustersMeanFirstEngDens", Variables::ClustersMeanFirstEngDens, scalar_vars); + calc->insert("ClustersMeanPresamplerFrac", Variables::ClustersMeanPresamplerFrac, scalar_vars); + + // Track variable calculator functions + calc->insert("pt_log", Variables::Track::pt_log, track_vars); + calc->insert("trackPt", Variables::Track::trackPt, track_vars); + calc->insert("trackEta", Variables::Track::trackEta, track_vars); + calc->insert("trackPhi", Variables::Track::trackPhi, track_vars); + calc->insert("pt_tau_log", Variables::Track::pt_tau_log, track_vars); + calc->insert("pt_jetseed_log", Variables::Track::pt_jetseed_log, track_vars); + calc->insert("d0_abs_log", Variables::Track::d0_abs_log, track_vars); + calc->insert("z0sinThetaTJVA_abs_log", Variables::Track::z0sinThetaTJVA_abs_log, track_vars); + calc->insert("z0sinthetaTJVA", Variables::Track::z0sinthetaTJVA, track_vars); + calc->insert("z0sinthetaSigTJVA", Variables::Track::z0sinthetaSigTJVA, track_vars); + calc->insert("d0TJVA", Variables::Track::d0TJVA, track_vars); + calc->insert("d0SigTJVA", Variables::Track::d0SigTJVA, track_vars); + calc->insert("dEta", Variables::Track::dEta, track_vars); + calc->insert("dPhi", Variables::Track::dPhi, track_vars); + calc->insert("nInnermostPixelHits", Variables::Track::nInnermostPixelHits, track_vars); + calc->insert("nPixelHits", Variables::Track::nPixelHits, track_vars); + calc->insert("nSCTHits", Variables::Track::nSCTHits, track_vars); + calc->insert("nIBLHitsAndExp", Variables::Track::nIBLHitsAndExp, track_vars); + calc->insert("nPixelHitsPlusDeadSensors", Variables::Track::nPixelHitsPlusDeadSensors, track_vars); + calc->insert("nSCTHitsPlusDeadSensors", Variables::Track::nSCTHitsPlusDeadSensors, track_vars); + calc->insert("eProbabilityHT", Variables::Track::eProbabilityHT, track_vars); + calc->insert("eProbabilityNN", Variables::Track::eProbabilityNN, track_vars); + calc->insert("eProbabilityNNorHT", Variables::Track::eProbabilityNNorHT, track_vars); + calc->insert("chargedScoreRNN", Variables::Track::chargedScoreRNN, track_vars); + calc->insert("isolationScoreRNN", Variables::Track::isolationScoreRNN, track_vars); + calc->insert("conversionScoreRNN", Variables::Track::conversionScoreRNN, track_vars); + calc->insert("fakeScoreRNN", Variables::Track::fakeScoreRNN, track_vars); + //Extension - variables for GNTau + calc->insert("numberOfInnermostPixelLayerHits", Variables::Track::numberOfInnermostPixelLayerHits, track_vars); + calc->insert("numberOfPixelHits", Variables::Track::numberOfPixelHits, track_vars); + calc->insert("numberOfPixelSharedHits", Variables::Track::numberOfPixelSharedHits, track_vars); + calc->insert("numberOfPixelDeadSensors", Variables::Track::numberOfPixelDeadSensors, track_vars); + calc->insert("numberOfSCTHits", Variables::Track::numberOfSCTHits, track_vars); + calc->insert("numberOfSCTSharedHits", Variables::Track::numberOfSCTSharedHits, track_vars); + calc->insert("numberOfSCTDeadSensors", Variables::Track::numberOfSCTDeadSensors, track_vars); + calc->insert("numberOfTRTHighThresholdHits", Variables::Track::numberOfTRTHighThresholdHits, track_vars); + calc->insert("numberOfTRTHits", Variables::Track::numberOfTRTHits, track_vars); + calc->insert("nSiHits", Variables::Track::nSiHits, track_vars); + calc->insert("expectInnermostPixelLayerHit", Variables::Track::expectInnermostPixelLayerHit, track_vars); + calc->insert("expectNextToInnermostPixelLayerHit", Variables::Track::expectNextToInnermostPixelLayerHit, track_vars); + calc->insert("numberOfContribPixelLayers", Variables::Track::numberOfContribPixelLayers, track_vars); + calc->insert("numberOfPixelHoles", Variables::Track::numberOfPixelHoles, track_vars); + calc->insert("d0_old", Variables::Track::d0_old, track_vars); + calc->insert("qOverP", Variables::Track::qOverP, track_vars); + calc->insert("theta", Variables::Track::theta, track_vars); + calc->insert("z0TJVA", Variables::Track::z0TJVA, track_vars); + calc->insert("charge", Variables::Track::charge, track_vars); + calc->insert("dz0_TV_PV0", Variables::Track::dz0_TV_PV0, track_vars); + calc->insert("log_sumpt_TV", Variables::Track::log_sumpt_TV, track_vars); + calc->insert("log_sumpt2_TV", Variables::Track::log_sumpt2_TV, track_vars); + calc->insert("log_sumpt_PV0", Variables::Track::log_sumpt_PV0, track_vars); + calc->insert("log_sumpt2_PV0", Variables::Track::log_sumpt2_PV0, track_vars); + + // Cluster variable calculator functions + calc->insert("et_log", Variables::Cluster::et_log, cluster_vars); + calc->insert("pt_tau_log", Variables::Cluster::pt_tau_log, cluster_vars); + calc->insert("pt_jetseed_log", Variables::Cluster::pt_jetseed_log, cluster_vars); + calc->insert("dEta", Variables::Cluster::dEta, cluster_vars); + calc->insert("dPhi", Variables::Cluster::dPhi, cluster_vars); + calc->insert("SECOND_R", Variables::Cluster::SECOND_R, cluster_vars); + calc->insert("SECOND_LAMBDA", Variables::Cluster::SECOND_LAMBDA, cluster_vars); + calc->insert("CENTER_LAMBDA", Variables::Cluster::CENTER_LAMBDA, cluster_vars); + //---added for the eVeto + calc->insert("SECOND_LAMBDAOverClustersMeanSecondLambda", Variables::Cluster::SECOND_LAMBDAOverClustersMeanSecondLambda, cluster_vars); + calc->insert("CENTER_LAMBDAOverClustersMeanCenterLambda", Variables::Cluster::CENTER_LAMBDAOverClustersMeanCenterLambda, cluster_vars); + calc->insert("FirstEngDensOverClustersMeanFirstEngDens" , Variables::Cluster::FirstEngDensOverClustersMeanFirstEngDens, cluster_vars); + + //Extension - Variables for GNTau + calc->insert("e", Variables::Cluster::e, cluster_vars); + calc->insert("et", Variables::Cluster::et, cluster_vars); + calc->insert("FIRST_ENG_DENS", Variables::Cluster::FIRST_ENG_DENS, cluster_vars); + calc->insert("EM_PROBABILITY", Variables::Cluster::EM_PROBABILITY, cluster_vars); + calc->insert("CENTER_MAG", Variables::Cluster::CENTER_MAG, cluster_vars); + return calc; +} + + +namespace Variables { +using TauDetail = xAOD::TauJetParameters::Detail; + +bool absEta(const xAOD::TauJet &tau, double &out) { + out = std::abs(tau.eta()); + return true; +} + +bool centFrac(const xAOD::TauJet &tau, double &out) { + float centFrac; + const auto success = tau.detail(TauDetail::centFrac, centFrac); + //out = std::min(centFrac, 1.0f); + out = centFrac; + return success; +} + +bool isolFrac(const xAOD::TauJet &tau, double &out) { + float isolFrac; + const auto success = tau.detail(TauDetail::isolFrac, isolFrac); + //out = std::min(isolFrac, 1.0f); + out = isolFrac; + return success; +} + +bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out) { + float etOverPtLeadTrk; + const auto success = tau.detail(TauDetail::etOverPtLeadTrk, etOverPtLeadTrk); + out = etOverPtLeadTrk; + return success; +} + +bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out) { + float innerTrkAvgDist; + const auto success = tau.detail(TauDetail::innerTrkAvgDist, innerTrkAvgDist); + out = innerTrkAvgDist; + return success; +} + +bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out) { + float ipSigLeadTrk = (tau.nTracks()>0) ? tau.track(0)->d0SigTJVA() : 0.; + //out = std::min(std::abs(ipSigLeadTrk), 30.0f); + out = std::abs(ipSigLeadTrk); + return true; +} + +bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out) { + float sumEMCellEtOverLeadTrkPt; + const auto success = tau.detail(TauDetail::sumEMCellEtOverLeadTrkPt, sumEMCellEtOverLeadTrkPt); + out = sumEMCellEtOverLeadTrkPt; + return success; +} + +bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out) { + float SumPtTrkFrac; + const auto success = tau.detail(TauDetail::SumPtTrkFrac, SumPtTrkFrac); + out = SumPtTrkFrac; + return success; +} + +bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out) { + float EMPOverTrkSysP; + const auto success = tau.detail(TauDetail::EMPOverTrkSysP, EMPOverTrkSysP); + out = EMPOverTrkSysP; + return success; +} + +bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out) { + float ptRatioEflowApprox; + const auto success = tau.detail(TauDetail::ptRatioEflowApprox, ptRatioEflowApprox); + //out = std::min(ptRatioEflowApprox, 4.0f); + out = ptRatioEflowApprox; + return success; +} + +bool mEflowApprox(const xAOD::TauJet &tau, double &out) { + float mEflowApprox; + const auto success = tau.detail(TauDetail::mEflowApprox, mEflowApprox); + out = mEflowApprox; + return success; +} + +bool dRmax(const xAOD::TauJet &tau, double &out) { + float dRmax; + const auto success = tau.detail(TauDetail::dRmax, dRmax); + out = dRmax; + return success; +} + +bool trFlightPathSig(const xAOD::TauJet &tau, double &out) { + float trFlightPathSig; + const auto success = tau.detail(TauDetail::trFlightPathSig, trFlightPathSig); + out = trFlightPathSig; + return success; +} + +bool massTrkSys(const xAOD::TauJet &tau, double &out) { + float massTrkSys; + const auto success = tau.detail(TauDetail::massTrkSys, massTrkSys); + out = massTrkSys; + return success; +} + +bool pt(const xAOD::TauJet &tau, double &out) { + out = tau.pt(); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, double &out) { + out = std::log10(std::max(tau.pt() / GeV, 1e-6)); + return true; +} + +bool ptDetectorAxis(const xAOD::TauJet &tau, double &out) { + out = tau.ptDetectorAxis(); + return true; +} + +bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out) { + out = tau.ptIntermediateAxis(); + return true; +} + +bool ptJetSeed_log(const xAOD::TauJet &tau, double &out) { + out = std::log10(std::max(tau.ptJetSeed(), 1e-3)); + return true; +} + +bool absleadTrackEta(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absEtaLeadTrack("ABS_ETA_LEAD_TRACK"); + float absEtaLeadTrack = acc_absEtaLeadTrack(tau); + out = std::max(0.f, absEtaLeadTrack); + return true; +} + +bool leadTrackDeltaEta(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absDeltaEta("TAU_ABSDELTAETA"); + float absDeltaEta = acc_absDeltaEta(tau); + out = std::max(0.f, absDeltaEta); + return true; +} + +bool leadTrackDeltaPhi(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absDeltaPhi("TAU_ABSDELTAPHI"); + float absDeltaPhi = acc_absDeltaPhi(tau); + out = std::max(0.f, absDeltaPhi); + return true; +} + +bool leadTrackProbNNorHT(const xAOD::TauJet &tau, double &out){ + auto tracks = tau.allTracks(); + + // Sort tracks in descending pt order + if (!tracks.empty()) { + auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) { + return lhs->pt() > rhs->pt(); + }; + std::sort(tracks.begin(), tracks.end(), cmp_pt); + + const xAOD::TauTrack* tauLeadTrack = tracks.at(0); + const xAOD::TrackParticle* xTrackParticle = tauLeadTrack->track(); + float eProbabilityHT = xTrackParticle->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle); + out = (tauLeadTrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT; + } + else { + out = 0.; + } + return true; +} + +bool EMFracFixed(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_emFracFixed("EMFracFixed"); + float emFracFixed = acc_emFracFixed(tau); + out = std::max(emFracFixed, 0.0f); + return true; +} + +bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_etHotShotWinOverPtLeadTrk("etHotShotWinOverPtLeadTrk"); + float etHotShotWinOverPtLeadTrk = acc_etHotShotWinOverPtLeadTrk(tau); + out = std::max(etHotShotWinOverPtLeadTrk, 1e-6f); + return true; +} + +bool hadLeakFracFixed(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_hadLeakFracFixed("hadLeakFracFixed"); + float hadLeakFracFixed = acc_hadLeakFracFixed(tau); + out = std::max(0.f, hadLeakFracFixed); + return true; +} + +bool PSFrac(const xAOD::TauJet &tau, double &out){ + float PSFrac; + const auto success = tau.detail(TauDetail::PSSFraction, PSFrac); + out = std::max(0.f,PSFrac); + return success; +} + +bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, double &out){ + float ClustersMeanCenterLambda; + const auto success = tau.detail(TauDetail::ClustersMeanCenterLambda, ClustersMeanCenterLambda); + out = std::max(0.f, ClustersMeanCenterLambda); + return success; +} + +bool ClustersMeanEMProbability(const xAOD::TauJet &tau, double &out){ + float ClustersMeanEMProbability; + const auto success = tau.detail(TauDetail::ClustersMeanEMProbability, ClustersMeanEMProbability); + out = std::max(0.f, ClustersMeanEMProbability); + return success; +} + +bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, double &out){ + float ClustersMeanFirstEngDens; + const auto success = tau.detail(TauDetail::ClustersMeanFirstEngDens, ClustersMeanFirstEngDens); + out = std::max(-10.f, ClustersMeanFirstEngDens); + return success; +} + +bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out){ + float ClustersMeanPresamplerFrac; + const auto success = tau.detail(TauDetail::ClustersMeanPresamplerFrac, ClustersMeanPresamplerFrac); + out = std::max(0.f, ClustersMeanPresamplerFrac); + return success; +} + +bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, double &out){ + float ClustersMeanSecondLambda; + const auto success = tau.detail(TauDetail::ClustersMeanSecondLambda, ClustersMeanSecondLambda); + out = std::max(0.f, ClustersMeanSecondLambda); + return success; +} + +namespace Track { + +bool pt_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(track.pt()); + return true; +} + +bool trackPt(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.pt(); + return true; +} + +bool trackEta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.eta(); + return true; +} + +bool trackPhi(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.phi(); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) { + out = std::log10(std::max(tau.pt(), 1e-6)); + return true; +} + +bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) { + out = std::log10(tau.ptJetSeed()); + return true; +} + +bool d0_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(std::abs(track.d0TJVA()) + 1e-6); + return true; +} + +bool z0sinThetaTJVA_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(std::abs(track.z0sinthetaTJVA()) + 1e-6); + return true; +} + +bool z0sinthetaTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.z0sinthetaTJVA(); + return true; +} + +bool z0sinthetaSigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.z0sinthetaSigTJVA(); + return true; +} + +bool d0TJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.d0TJVA(); + return true; +} + +bool d0SigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.d0SigTJVA(); + return true; +} + +bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + out = track.eta() - tau.eta(); + return true; +} + +bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + out = track.p4().DeltaPhi(tau.p4()); + return true; +} + +bool nInnermostPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t inner_pixel_hits; + const auto success = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits); + out = inner_pixel_hits; + return success; +} + +bool nPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pixel_hits; + const auto success = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits); + out = pixel_hits; + return success; +} + +bool nSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t sct_hits; + const auto success = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits); + out = sct_hits; + return success; +} + +// same as in tau track classification for trigger +bool nIBLHitsAndExp(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t inner_pixel_hits, inner_pixel_exp; + const auto success1 = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits); + const auto success2 = track.track()->summaryValue(inner_pixel_exp, xAOD::expectInnermostPixelLayerHit); + out = inner_pixel_exp ? inner_pixel_hits : 1.; + return success1 && success2; +} + +bool nPixelHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pixel_hits, pixel_dead; + const auto success1 = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits); + const auto success2 = track.track()->summaryValue(pixel_dead, xAOD::numberOfPixelDeadSensors); + out = pixel_hits + pixel_dead; + return success1 && success2; +} + +bool nSCTHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t sct_hits, sct_dead; + const auto success1 = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits); + const auto success2 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors); + out = sct_hits + sct_dead; + return success1 && success2; +} + +bool eProbabilityHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + float eProbabilityHT; + const auto success = track.track()->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + out = eProbabilityHT; + return success; +} + +bool eProbabilityNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + out = acc_eProbabilityNN(track); + return true; +} + +bool eProbabilityNNorHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + auto atrack = track.track(); + float eProbabilityHT = atrack->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + float eProbabilityNN = acc_eProbabilityNN(*atrack); + out = (atrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT; + return true; +} + +bool chargedScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_chargedScoreRNN("rnn_chargedScore"); + out = acc_chargedScoreRNN(track); + return true; +} + +bool isolationScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_isolationScoreRNN("rnn_isolationScore"); + out = acc_isolationScoreRNN(track); + return true; +} + +bool conversionScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_conversionScoreRNN("rnn_conversionScore"); + out = acc_conversionScoreRNN(track); + return true; +} + +bool fakeScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_fakeScoreRNN("rnn_fakeScore"); + out = acc_fakeScoreRNN(track); + return true; +} + +//Extension - variables for GNTau +bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfInnermostPixelLayerHits); + out = trk_val; + return success; +} + +bool numberOfPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHits); + out = trk_val; + return success; +} + +bool numberOfPixelSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelSharedHits); + out = trk_val; + return success; +} + +bool numberOfPixelDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelDeadSensors); + out = trk_val; + return success; +} + +bool numberOfSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTHits); + out = trk_val; + return success; +} + +bool numberOfSCTSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTSharedHits); + out = trk_val; + return success; +} + +bool numberOfSCTDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTDeadSensors); + out = trk_val; + return success; +} + +bool numberOfTRTHighThresholdHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHighThresholdHits); + out = trk_val; + return success; +} + +bool numberOfTRTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHits); + out = trk_val; + return success; +} + +bool nSiHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pix_hit = 0;uint8_t pix_dead = 0;uint8_t sct_hit = 0;uint8_t sct_dead = 0; + const auto success1 = track.track()->summaryValue(pix_hit, xAOD::numberOfPixelHits); + const auto success2 = track.track()->summaryValue(pix_dead, xAOD::numberOfPixelDeadSensors); + const auto success3 = track.track()->summaryValue(sct_hit, xAOD::numberOfSCTHits); + const auto success4 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors); + out = pix_hit + pix_dead + sct_hit + sct_dead; + return success1 && success2 && success3 && success4; +} + +bool expectInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::expectInnermostPixelLayerHit); + out = trk_val; + return success; +} + +bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::expectNextToInnermostPixelLayerHit); + out = trk_val; + return success; +} + +bool numberOfContribPixelLayers(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfContribPixelLayers); + out = trk_val; + return success; +} + +bool numberOfPixelHoles(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHoles); + out = trk_val; + return success; +} + +bool d0_old(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->d0(); + //out = trk_val; + return true; +} + +bool qOverP(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->qOverP(); + return true; +} + +bool theta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->theta(); + return true; +} + +bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out) { + out = track.track()->z0() + track.track()->vz() - tau.vertex()->z(); + return true; +} + +bool charge(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->charge(); + return true; +} + +bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out = 0.; + static const SG::AuxElement::ConstAccessor<float> acc_dz0TVPV0("dz0_TV_PV0"); + if (tau.isAvailable<float>("dz0_TV_PV0")){out = acc_dz0TVPV0(tau);} + return true; +} + +bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumptTV("log_sumpt_TV"); + if (tau.isAvailable<float>("log_sumpt_TV")){out=acc_logsumptTV(tau);} + return true; +} + +bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2TV("log_sumpt2_TV"); + if (tau.isAvailable<float>("log_sumpt2_TV")){out=acc_logsumpt2TV(tau);} + return true; +} + +bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumptPV0("log_sumpt_PV0"); + if (tau.isAvailable<float>("log_sumpt_PV0")){out=acc_logsumptPV0(tau);} + return true; +} + +bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2PV0("log_sumpt2_PV0"); + if (tau.isAvailable<float>("log_sumpt2_PV0")){out=acc_logsumpt2PV0(tau);} + return true; +} + +} // namespace Track + + +namespace Cluster { +using MomentType = xAOD::CaloCluster::MomentType; + +bool et_log(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = std::log10(cluster.p4().Et()); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) { + out = std::log10(std::max(tau.pt(), 1e-6)); + return true; +} + +bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) { + out = std::log10(tau.ptJetSeed()); + return true; +} + +bool dEta(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.eta() - tau.eta(); + return true; +} + +bool dPhi(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().DeltaPhi(tau.p4()); + return true; +} + +bool SECOND_R(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_R, out); + return success; +} + +bool SECOND_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, out); + return success; +} + +bool CENTER_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, out); + return success; +} + +bool SECOND_LAMBDAOverClustersMeanSecondLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanSecondLambda("ClustersMeanSecondLambda"); + float ClustersMeanSecondLambda = acc_ClustersMeanSecondLambda(tau); + double secondLambda(0); + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, secondLambda); + out = (ClustersMeanSecondLambda != 0.) ? secondLambda/ClustersMeanSecondLambda : 0.; + return success; +} + +bool CENTER_LAMBDAOverClustersMeanCenterLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanCenterLambda("ClustersMeanCenterLambda"); + float ClustersMeanCenterLambda = acc_ClustersMeanCenterLambda(tau); + double centerLambda(0); + const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, centerLambda); + if (ClustersMeanCenterLambda == 0.){ + out = 250.; + }else { + out = centerLambda/ClustersMeanCenterLambda; + } + + out = std::min(out, 250.); + + return success; +} + + +bool FirstEngDensOverClustersMeanFirstEngDens(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + // the ClustersMeanFirstEngDens is the log10 of the energy weighted average of the First_ENG_DENS + // divided by ETot to make it dimension-less, + // so we need to evaluate the difference of log10(clusterFirstEngDens/clusterTotalEnergy) and the ClustersMeanFirstEngDens + double clusterFirstEngDens = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens); + if (clusterFirstEngDens < 1e-6) clusterFirstEngDens = 1e-6; + + static const SG::AuxElement::ConstAccessor<float> acc_ClusterTotalEnergy("ClusterTotalEnergy"); + float clusterTotalEnergy = acc_ClusterTotalEnergy(tau); + if (clusterTotalEnergy < 1e-6) clusterTotalEnergy = 1e-6; + + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanFirstEngDens("ClustersMeanFirstEngDens"); + float clustersMeanFirstEngDens = acc_ClustersMeanFirstEngDens(tau); + + out = std::log10(clusterFirstEngDens/clusterTotalEnergy) - clustersMeanFirstEngDens; + + return status; +} + +//Extension - Variables for GNTau +bool e(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().E(); + return true; +} + +bool et(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().Et(); + return true; +} + +bool FIRST_ENG_DENS(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterFirstEngDens = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens); + out = clusterFirstEngDens; + return status; +} + +bool EM_PROBABILITY(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterEMprob = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::EM_PROBABILITY, clusterEMprob); + out = clusterEMprob; + return status; +} + +bool CENTER_MAG(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterCenterMag = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::CENTER_MAG, clusterCenterMag); + out = clusterCenterMag; + return status; +} + +} // namespace Cluster +} // namespace Variables +} // namespace TauGNNUtils diff --git a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx index d4d5db31542a..0e78bf7d357e 100644 --- a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx +++ b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx @@ -25,6 +25,7 @@ #include "tauRecTools/TauWPDecorator.h" #include "tauRecTools/TauIDVarCalculator.h" #include "tauRecTools/TauJetRNNEvaluator.h" +#include "tauRecTools/TauGNNEvaluator.h" #include "tauRecTools/TauDecayModeNNClassifier.h" #include "tauRecTools/TauVertexedClusterDecorator.h" #include "tauRecTools/TauAODSelector.h" @@ -59,6 +60,7 @@ DECLARE_COMPONENT( TauPi0Selector ) DECLARE_COMPONENT( TauWPDecorator ) DECLARE_COMPONENT( TauIDVarCalculator ) DECLARE_COMPONENT( TauJetRNNEvaluator ) +DECLARE_COMPONENT( TauGNNEvaluator ) DECLARE_COMPONENT( TauDecayModeNNClassifier ) DECLARE_COMPONENT( TauVertexedClusterDecorator ) DECLARE_COMPONENT( TauAODSelector ) diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h new file mode 100644 index 000000000000..f2e39c02cbbc --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h @@ -0,0 +1,110 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNN_H +#define TAURECTOOLS_TAUGNN_H + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" + +#include "AsgMessaging/AsgMessaging.h" + +#include "FlavorTagDiscriminants/OnnxUtil.h" + +#include <memory> +#include <string> +#include <map> + +// Forward declaration +namespace lwt { + class LightweightGraph; +} + +namespace TauGNNUtils { + class GNNVarCalc; +} + +namespace FlavorTagDiscriminants{ + class OnnxUtil; +} + +/** + * @brief Wrapper around ONNXUtil to compute the output score of a model + * + * Configures the network and computes the network outputs given the input + * objects. Retrieval of input variables is handled internally. + * + * @author N.M. Tamir + * + */ +class TauGNN : public asg::AsgMessaging { +public: + // Configuration of the weight file structure + struct Config { + std::string input_layer_scalar; + std::string input_layer_tracks; + std::string input_layer_clusters; + std::string output_node_tau; + std::string output_node_jet; + }; + std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil; +public: + TauGNN(const std::string &nnFile, const Config &config); + ~TauGNN(); + + // Output the OnnxUtil tuple + std::tuple< + std::map<std::string, float>, + std::map<std::string, std::vector<char>>, + std::map<std::string, std::vector<float>> > + compute(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const; + + // Compute all input variables and store them in the maps that are passed by reference + bool calculateInputVariables(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::map<std::string, std::map<std::string, double>>& scalarInputs, + std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const; + + // Getter for the variable calculator + const TauGNNUtils::GNNVarCalc* variable_calculator() const { + return m_var_calc.get(); + } + + explicit operator bool() const { + return static_cast<bool>(m_graph); + } + + //Make the output config transparent to external tools + FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config; + +private: + using input_pair = FlavorTagDiscriminants::input_pair; + // Abbreviations for lwtnn + using VariableMap = std::map<std::string, double>; + using VectorMap = std::map<std::string, std::vector<double>>; + + using InputMap = std::map<std::string, VariableMap>; + using InputSequenceMap = std::map<std::string, VectorMap>; + +private: + const Config m_config; + std::unique_ptr<const lwt::LightweightGraph> m_graph; + + // Names of the input variables + std::vector<std::string> m_scalar_inputs; + std::vector<std::string> m_track_inputs; + std::vector<std::string> m_cluster_inputs; + // Names passed to the variable calculator + std::vector<std::string> m_scalarCalc_inputs; + std::vector<std::string> m_trackCalc_inputs; + std::vector<std::string> m_clusterCalc_inputs; + + // Variable calculator to calculate input variables on the fly + std::unique_ptr<TauGNNUtils::GNNVarCalc> m_var_calc; +}; + +#endif // TAURECTOOLS_TAUGNN_H diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h new file mode 100644 index 000000000000..4e5cebf53ad5 --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h @@ -0,0 +1,69 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNNEVALUATOR_H +#define TAURECTOOLS_TAUGNNEVALUATOR_H + +#include "tauRecTools/TauRecToolBase.h" + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" + +#include <memory> + +class TauGNN; + +/** + * @brief Tool to calculate tau identification score from .onnx inputs + * + * The network configuration is supplied in .onnx format. + * Currently runs on a prongness-inclusive model + * Based off of TauJetRNNEvaluator.h format! + * @author N.M. Tamir + * + */ +class TauGNNEvaluator : public TauRecToolBase { +public: + ASG_TOOL_CLASS2(TauGNNEvaluator, TauRecToolBase, ITauToolBase) + + TauGNNEvaluator(const std::string &name = "TauGNNEvaluator"); + virtual ~TauGNNEvaluator(); + + virtual StatusCode initialize() override; + virtual StatusCode execute(xAOD::TauJet &tau) const override; + // Getter for the underlying RNN implementation + const TauGNN* get_gnn() const; + + // Selects tracks to be used as input to the network + StatusCode get_tracks(const xAOD::TauJet &tau, + std::vector<const xAOD::TauTrack *> &out) const; + + // Selects clusters to be used as input to the network + StatusCode get_clusters(const xAOD::TauJet &tau, + std::vector<xAOD::CaloVertexedTopoCluster> &out) const; + +private: + std::string m_output_varname; + std::string m_output_ptau; + std::string m_output_pjet; + std::string m_weightfile; + std::size_t m_max_tracks; + std::size_t m_max_clusters; + float m_max_cluster_dr; + bool m_doVertexCorrection; + bool m_doTrackClassification; + bool m_decorateTracks; + + // Configuration of the network file + std::string m_input_layer_scalar; + std::string m_input_layer_tracks; + std::string m_input_layer_clusters; + std::string m_outnode_tau; + std::string m_outnode_jet; + + // Wrappers for lwtnn + std::unique_ptr<TauGNN> m_net; //! +}; + +#endif // TAURECTOOLS_TAUGNNEVALUATOR_H diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h new file mode 100644 index 000000000000..28a983dd5e8a --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h @@ -0,0 +1,311 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNNUTILS_H +#define TAURECTOOLS_TAUGNNUTILS_H + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" +#include "xAODEventInfo/EventInfo.h" +#include "xAODTracking/VertexContainer.h" +#include "AsgTools/AsgTool.h" +#include "AsgMessaging/AsgMessaging.h" +#include <unordered_map> + + +namespace TauGNNUtils { + +/** + * @brief Tool to calculate input variables for the GNN-based tau identification + * + * Used to calculate input variables for (onnx)GNN-based tau identification on + * the fly by providing a mapping between variable names (strings) and + * functions to calculate these variables. + * + * @author C. Deutsch + * @author W. Davey + * @author N.M. Tamir + * + */ +class GNNVarCalc : public asg::AsgMessaging { +public: + // Pointers to calculator functions + using ScalarCalc = bool (*)(const xAOD::TauJet &, double &); + + using TrackCalc = bool (*)(const xAOD::TauJet &, const xAOD::TauTrack &, + double &); + + using ClusterCalc = bool (*)(const xAOD::TauJet &, + const xAOD::CaloVertexedTopoCluster &, double &); + +public: + GNNVarCalc(); + ~GNNVarCalc() = default; + + // Methods to compute the output (vector) based on the variable name + + // Computes high-level ID variables + bool compute(const std::string &name, const xAOD::TauJet &tau, double &out) const; + + // Computes track variables + bool compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + std::vector<double> &out) const; + + // Computes cluster variables + bool compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::vector<double> &out) const; + + // Methods to insert calculator functions into the lookup table + void insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars); + void insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars); + void insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars); + +private: + // Lookup tables + std::unordered_map<std::string, ScalarCalc> m_scalar_map; + std::unordered_map<std::string, TrackCalc> m_track_map; + std::unordered_map<std::string, ClusterCalc> m_cluster_map; +}; + +// Factory function to create a variable calculator populated with default +// variables +std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars, + const std::vector<std::string>& track_vars, + const std::vector<std::string>& cluster_vars); + + +namespace Variables { + +// Functions to calculate (scalar) input variables +// Returns a status code indicating success +bool absEta(const xAOD::TauJet &tau, double &out); + +bool centFrac(const xAOD::TauJet &tau, double &out); + +bool isolFrac(const xAOD::TauJet &tau, double &out); + +bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out); + +bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out); + +bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out); + +bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out); + +bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out); + +bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out); + +bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out); + +bool mEflowApprox(const xAOD::TauJet &tau, double &out); + +bool dRmax(const xAOD::TauJet &tau, double &out); + +bool trFlightPathSig(const xAOD::TauJet &tau, double &out); + +bool massTrkSys(const xAOD::TauJet &tau, double &out); + +bool pt(const xAOD::TauJet &tau, double &out); + +bool pt_tau_log(const xAOD::TauJet &tau, double &out); + +bool ptDetectorAxis(const xAOD::TauJet &tau, double &out); + +bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out); + +//functions to calculate input variables needed for the eVeto RNN +bool ptJetSeed_log (const xAOD::TauJet &tau, double &out); +bool absleadTrackEta (const xAOD::TauJet &tau, double &out); +bool leadTrackDeltaEta (const xAOD::TauJet &tau, double &out); +bool leadTrackDeltaPhi (const xAOD::TauJet &tau, double &out); +bool leadTrackProbNNorHT (const xAOD::TauJet &tau, double &out); +bool EMFracFixed (const xAOD::TauJet &tau, double &out); +bool etHotShotWinOverPtLeadTrk (const xAOD::TauJet &tau, double &out); +bool hadLeakFracFixed (const xAOD::TauJet &tau, double &out); +bool PSFrac (const xAOD::TauJet &tau, double &out); +bool ClustersMeanCenterLambda (const xAOD::TauJet &tau, double &out); +bool ClustersMeanEMProbability (const xAOD::TauJet &tau, double &out); +bool ClustersMeanFirstEngDens (const xAOD::TauJet &tau, double &out); +bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out); +bool ClustersMeanSecondLambda (const xAOD::TauJet &tau, double &out); +bool EMPOverTrkSysP (const xAOD::TauJet &tau, double &out); + + +namespace Track { + +// Functions to calculate input variables for each track +// Returns a status code indicating success + +bool pt_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool trackPt( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool trackEta( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool trackPhi( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool pt_tau_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool pt_jetseed_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool d0_abs_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool z0sinThetaTJVA_abs_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool z0sinthetaTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool z0sinthetaSigTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool d0TJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool d0SigTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool dEta( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool dPhi( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nInnermostPixelHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nPixelHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nSCTHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +// trigger variants +bool nIBLHitsAndExp ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nPixelHitsPlusDeadSensors ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nSCTHitsPlusDeadSensors ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityHT( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityNNorHT( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool chargedScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool isolationScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool conversionScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool fakeScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +//Extension - variables for GNTau +bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfTRTHighThresholdHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfTRTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool nSiHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool expectInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfContribPixelLayers(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelHoles(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool d0_old(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool qOverP(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool theta(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool charge(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +} // namespace Track + + +namespace Cluster { + +// Functions to calculate input variables for each cluster +// Returns a status code indicating success + +bool et_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool pt_tau_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool pt_jetseed_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool dEta( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool dPhi( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_R( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_LAMBDA( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_LAMBDA( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_LAMBDAOverClustersMeanSecondLambda( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_LAMBDAOverClustersMeanCenterLambda( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool FirstEngDensOverClustersMeanFirstEngDens( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +//Extension - Variables for GNTau +bool e( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool et( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool FIRST_ENG_DENS( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool EM_PROBABILITY( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_MAG( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); +} // namespace Cluster +} // namespace Variables +} // namespace TauJetGNNUtils + +#endif // TAURECTOOLS_TAUGNNUTILS_H -- GitLab From 83e7173098d0758143fa8f982d7a566bb499f99a Mon Sep 17 00:00:00 2001 From: Jean Yves Beaucamp <jean.yves.beaucamp@cern.ch> Date: Tue, 5 Nov 2024 01:41:42 +0100 Subject: [PATCH 2/4] DerivationFrameworkTau, tauRec, tauRecTools: optimise muon-tau removal sequence DerivationFrameworkTau, tauRec, tauRecTools: optimise muon-tau removal sequence Hello, This MR is addressing the CPU increase in DAOD_PHYS/LITE reported in ATLASG-2712 coming from the recent addition of GNTau ID. The main changes are in TauAODRunnerAlg, which first run a tool to remove muon tracks and clusters associated with tau candidate, then reruns most of the tau reconstruction with muon-free inputs. Now, in the muon-tau removal, if no muon track nor cluster is found near the tau, we discard the tau candidate by effectively removing it from the container. This prevents afterburner tools like GNN tau ID from running over the full container (which so far includes irrelevant tau candidates removed later on by a thinning algorithm), thereby saving CPU. I've checked that the TauJets_MuonRM_TauIDDecorKernel CPU time is reduced, it's no longer visible in the SPOT test summary. The DAOD output is unchanged (checked over 1000 events). Adding the urgent flag in case it would sill arrive in time for the imminent DAOD bulk prod. Cheers, Bertrand --- .../tauRec/python/TauConfigFlags.py | 3 +- Reconstruction/tauRec/python/TauToolHolder.py | 1 + Reconstruction/tauRec/src/TauAODRunnerAlg.cxx | 75 ++++++++++--------- Reconstruction/tauRecTools/Root/TauGNN.cxx | 17 ++--- .../tauRecTools/Root/TauGNNEvaluator.cxx | 9 ++- .../tauRecTools/Root/TauGNNUtils.cxx | 4 +- .../tauRecTools/tauRecTools/TauGNN.h | 12 +-- .../tauRecTools/tauRecTools/TauGNNEvaluator.h | 1 + 8 files changed, 58 insertions(+), 64 deletions(-) diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py index d17807f2bf96..a0b3743ef4c6 100644 --- a/Reconstruction/tauRec/python/TauConfigFlags.py +++ b/Reconstruction/tauRec/python/TauConfigFlags.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration import unittest from AthenaConfiguration.AthConfigFlags import AthConfigFlags @@ -48,6 +48,7 @@ def createTauConfigFlags(): tau_cfg.addFlag("Tau.MvaTESConfig", "MvaTES_R23.root") tau_cfg.addFlag("Tau.MinPt0p", 9.25*Units.GeV) tau_cfg.addFlag("Tau.MinPt", 6.75*Units.GeV) + tau_cfg.addFlag("Tau.MinPtDAOD", 13*Units.GeV) tau_cfg.addFlag("Tau.TauJetRNNConfig", ["tauid_rnn_1p_R22_v1.json", "tauid_rnn_2p_R22_v1.json", "tauid_rnn_3p_R22_v1.json"]) tau_cfg.addFlag("Tau.TauJetRNNWPConfig", ["tauid_rnnWP_1p_R22_v0.root", "tauid_rnnWP_2p_R22_v0.root", "tauid_rnnWP_3p_R22_v0.root"]) tau_cfg.addFlag("Tau.TauEleRNNConfig", ["taueveto_rnn_config_1P_r22.json", "taueveto_rnn_config_3P_r22.json"]) diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py index 35aea18ff2c9..fd96be80c82e 100644 --- a/Reconstruction/tauRec/python/TauToolHolder.py +++ b/Reconstruction/tauRec/python/TauToolHolder.py @@ -865,6 +865,7 @@ def TauGNNEvaluatorCfg(flags): MaxTracks = 30, MaxClusters = 20, MaxClusterDR = 15.0, + MinTauPt = flags.Tau.MinPtDAOD, VertexCorrection = True, DecorateTracks = False, InputLayerScalar = "tau_vars", diff --git a/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx b/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx index 02ed3802ab71..5a9a2d01c0e6 100644 --- a/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx +++ b/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #include "TauAODRunnerAlg.h" @@ -66,6 +66,7 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const { xAOD::TauJetContainer *newTauCon = outputTauHandle.ptr(); static const SG::AuxElement::Accessor<ElementLink<xAOD::TauJetContainer>> acc_ori_tau_link("originalTauJet"); + static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD"); for (const xAOD::TauJet *tau : *pTauContainer) { // deep copy the tau container @@ -88,6 +89,21 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const { linkToTauTrack.toContainedElement(*newTauTrkCon, newTauTrk); newTau->addTauTrackLink(linkToTauTrack); } + + // 'ModifiedInAOD' will be overriden by modification tools for relevant candidates + acc_modified(*newTau) = static_cast<char>(false); + + StatusCode sc; + for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) { + ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name()); + sc = tool->execute(*newTau); + if (sc.isFailure()) break; + } + + // if tau candidate was not modified, remove it from container, track cleanup performed by thinning algorithm downstream + if (!acc_modified(*newTau)) { + newTauCon->pop_back(); + } } // Read the CaloClusterContainer @@ -123,49 +139,34 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const { ATH_CHECK(vertOutHandle.record(std::make_unique<xAOD::VertexContainer>(), std::make_unique<xAOD::VertexAuxContainer>())); xAOD::VertexContainer* pSecVtxContainer = vertOutHandle.ptr(); - int n_tau_modified = 0; - static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD"); - for (xAOD::TauJet *pTau : *newTauCon) { - // Loop stops when Failure indicated by one of the tools StatusCode sc; - //add a identifier of if the tau is modifed by the mod tools - acc_modified(*pTau) = static_cast<char>(false); - // iterate over the copy - for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) { + for (const ToolHandle<ITauToolBase> &tool : m_officialTools) { ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name()); - sc = tool->execute(*pTau); + if (tool->type() == "TauPi0ClusterCreator") + sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer); + else if (tool->type() == "TauVertexVariables") + sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer); + else if (tool->type() == "TauPi0ClusterScaler") + sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer); + else if (tool->type() == "TauPi0ScoreCalculator") + sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); + else if (tool->type() == "TauPi0Selector") + sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); + else if (tool->type() == "PanTau::PanTauProcessor") + sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer); + else if (tool->type() == "tauRecTools::TauTrackRNNClassifier") + sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon); + else + sc = tool->execute(*pTau); if (sc.isFailure()) break; } - if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked modification tools."); - // if tau is not modified by the above tools, never mind running the tools afterward - if (static_cast<bool>(isTauModified(pTau))) { - n_tau_modified++; - for (const ToolHandle<ITauToolBase> &tool : m_officialTools) { - ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name()); - if (tool->type() == "TauPi0ClusterCreator") - sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer); - else if (tool->type() == "TauVertexVariables") - sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer); - else if (tool->type() == "TauPi0ClusterScaler") - sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer); - else if (tool->type() == "TauPi0ScoreCalculator") - sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); - else if (tool->type() == "TauPi0Selector") - sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); - else if (tool->type() == "PanTau::PanTauProcessor") - sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer); - else if (tool->type() == "tauRecTools::TauTrackRNNClassifier") - sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon); - else - sc = tool->execute(*pTau); - if (sc.isFailure()) break; - } - if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools."); - } + if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools."); } + ATH_MSG_VERBOSE("The tau candidate container has been modified by the rest of the tools"); - ATH_MSG_DEBUG(n_tau_modified << " / " << pTauContainer->size() <<" taus were modified"); + ATH_MSG_DEBUG(newTauCon->size() << " / " << pTauContainer->size() <<" taus were modified"); + return StatusCode::SUCCESS; } diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx index 65f13c4c2f6b..971372e865b0 100644 --- a/Reconstruction/tauRecTools/Root/TauGNN.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -4,15 +4,12 @@ #include "tauRecTools/TauGNN.h" #include "FlavorTagDiscriminants/OnnxUtil.h" +#include "lwtnn/parse_json.hh" #include "PathResolver/PathResolver.h" #include <algorithm> #include <fstream> -#include "lwtnn/LightweightGraph.hh" -#include "lwtnn/Exceptions.hh" -//#include "lwtnn/parse_json.hh" - #include "tauRecTools/TauGNNUtils.h" TauGNN::TauGNN(const std::string &nnFile, const Config &config): @@ -106,11 +103,11 @@ std::tuple< std::map<std::string, std::vector<char>>, std::map<std::string, std::vector<float>> > TauGNN::compute(const xAOD::TauJet &tau, - const std::vector<const xAOD::TauTrack *> &tracks, - const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { InputMap scalarInputs; InputSequenceMap vectorInputs; - std::map<std::string, input_pair> gnn_input; + std::map<std::string, Inputs> gnn_input; ATH_MSG_DEBUG("Starting compute..."); //Prepare input variables if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) { @@ -124,7 +121,7 @@ TauGNN::compute(const xAOD::TauJet &tau, tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname])); } std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())}; - input_pair tau_info (tau_feats, tau_feats_dim); + Inputs tau_info (tau_feats, tau_feats_dim); gnn_input.insert({"tau_vars", tau_info}); //Add track-level features to the input @@ -141,7 +138,7 @@ TauGNN::compute(const xAOD::TauJet &tau, var_idx++; } std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars}; - input_pair trk_info (trk_feats, trk_feats_dim); + Inputs trk_info (trk_feats, trk_feats_dim); gnn_input.insert({"track_vars", trk_info}); //Add cluster-level features to the input @@ -158,7 +155,7 @@ TauGNN::compute(const xAOD::TauJet &tau, var_idx++; } std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars}; - input_pair cls_info (cls_feats, cls_feats_dim); + Inputs cls_info (cls_feats, cls_feats_dim); gnn_input.insert({"cluster_vars", cls_info}); //RUN THE INFERENCE!!! diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx index 0be8819d79ec..a1346d67f22b 100644 --- a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx @@ -25,6 +25,7 @@ TauGNNEvaluator::TauGNNEvaluator(const std::string &name): declareProperty("VertexCorrection", m_doVertexCorrection = true); declareProperty("DecorateTracks", m_decorateTracks = false); declareProperty("TrackClassification", m_doTrackClassification = true); + declareProperty("MinTauPt", m_minTauPt = 0.); // Naming conventions for the network weight files: declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars"); @@ -82,8 +83,9 @@ StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const { output(tau) = -1111.0f; out_ptau(tau) = -1111.0f; out_pjet(tau) = -1111.0f; + //Skip execution for low-pT taus to save resources - if(tau.pt()<13000) { + if (tau.pt() < m_minTauPt) { return StatusCode::SUCCESS; } @@ -129,7 +131,7 @@ StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<cons std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin(); while(it != tracks.end()) { if((*it)->flag(xAOD::TauJetParameters::unclassified)) { - it = tracks.erase(it); + it = tracks.erase(it); } else { ++it; @@ -151,8 +153,7 @@ StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xA TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection); - std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters(); - for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) { + for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) { TLorentzVector clusterP4 = vertexedCluster.p4(); if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue; diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx index 8da2fe5ea6ec..a372e0c019bf 100644 --- a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #include "tauRecTools/TauGNNUtils.h" @@ -32,6 +32,7 @@ bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, const std::vector<const xAOD::TauTrack *> &tracks, std::vector<double> &out) const { out.clear(); + out.reserve(tracks.size()); // Retrieve calculator function TrackCalc func = nullptr; @@ -57,6 +58,7 @@ bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, std::vector<double> &out) const { out.clear(); + out.reserve(clusters.size()); // Retrieve calculator function ClusterCalc func = nullptr; diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h index f2e39c02cbbc..e318762111a0 100644 --- a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h @@ -16,11 +16,6 @@ #include <string> #include <map> -// Forward declaration -namespace lwt { - class LightweightGraph; -} - namespace TauGNNUtils { class GNNVarCalc; } @@ -74,15 +69,11 @@ public: return m_var_calc.get(); } - explicit operator bool() const { - return static_cast<bool>(m_graph); - } - //Make the output config transparent to external tools FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config; private: - using input_pair = FlavorTagDiscriminants::input_pair; + using Inputs = FlavorTagDiscriminants::Inputs; // Abbreviations for lwtnn using VariableMap = std::map<std::string, double>; using VectorMap = std::map<std::string, std::vector<double>>; @@ -92,7 +83,6 @@ private: private: const Config m_config; - std::unique_ptr<const lwt::LightweightGraph> m_graph; // Names of the input variables std::vector<std::string> m_scalar_inputs; diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h index 4e5cebf53ad5..baa5d54a181b 100644 --- a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h @@ -51,6 +51,7 @@ private: std::size_t m_max_tracks; std::size_t m_max_clusters; float m_max_cluster_dr; + float m_minTauPt; bool m_doVertexCorrection; bool m_doTrackClassification; bool m_decorateTracks; -- GitLab From 405dd5f288330dc4acaaab433f3418ba8fa29b08 Mon Sep 17 00:00:00 2001 From: Jean Yves Beaucamp <jean.yves.beaucamp@cern.ch> Date: Tue, 5 Nov 2024 01:45:15 +0100 Subject: [PATCH 3/4] Small GNTau update Small GNTau update --- .../tauRec/python/TauConfigFlags.py | 20 +++++++++-- Reconstruction/tauRec/python/TauToolHolder.py | 33 ++++++++---------- Reconstruction/tauRecTools/Root/TauGNN.cxx | 4 +-- .../tauRecTools/Root/TauGNNEvaluator.cxx | 9 +++-- .../tauRecTools/Root/TauGNNUtils.cxx | 14 ++++++++ .../Root/lwtnn/LightweightGraph.cxx | 4 +-- .../tauRecTools/Root/lwtnn/Stack.cxx | 34 +++++++++---------- .../tauRecTools/tauRecTools/TauGNNEvaluator.h | 4 +-- .../tauRecTools/tauRecTools/TauGNNUtils.h | 6 ++++ .../tauRecTools/lwtnn/LightweightGraph.h | 4 +-- .../tauRecTools/tauRecTools/lwtnn/Stack.h | 26 +++++++------- 11 files changed, 96 insertions(+), 62 deletions(-) diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py index a0b3743ef4c6..ea91293697d6 100644 --- a/Reconstruction/tauRec/python/TauConfigFlags.py +++ b/Reconstruction/tauRec/python/TauConfigFlags.py @@ -62,8 +62,24 @@ def createTauConfigFlags(): tau_cfg.addFlag("Tau.TauJetDeepSetConfig_v2", ["tauid_1p_R22_dpst_noTrackScore.json", "tauid_2p_R22_dpst_noTrackScore.json", "tauid_3p_R22_dpst_noTrackScore.json"]) tau_cfg.addFlag("Tau.TauJetDeepSetWP_v2", ["model_1p_R22_dpst_noTrackScore.root", "model_2p_R22_dpst_noTrackScore.root", "model_3p_R22_dpst_noTrackScore.root"]) # GNTau ID tune file (need to add another version for noAux) - tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_noAux_simplified.onnx"]) - tau_cfg.addFlag("Tau.TauGNNWP_v0", ["GNTauNA_flat_model_1p.root", "GNTauNA_flat_model_2p.root", "GNTauNA_flat_model_3p.root"]) + tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_pruned_MC23.onnx","GNTau_trunc_MC23.onnx"]) + tau_cfg.addFlag("Tau.TauGNNWP", + [ + ["GNTauNAprune_flat_model_1p.root", "GNTauNAprune_flat_model_2p.root", "GNTauNAprune_flat_model_3p.root"], + ["GNTauNAtrunc_flat_model_1p.root", "GNTauNAtrunc_flat_model_2p.root", "GNTauNAtrunc_flat_model_3p.root"] + ]) + tau_cfg.addFlag("Tau.GNTauScoreName", ["GNTauScore_v0prune","GNTauScore_v1trunc"]) + tau_cfg.addFlag("Tau.GNTauTransScoreName", ["GNTauScoreSigTrans_v0prune","GNTauScoreSigTrans_v1trunc"]) + tau_cfg.addFlag("Tau.GNTauMaxTracks", [30,10]) + tau_cfg.addFlag("Tau.GNTauMaxClusters", [20,6]) + tau_cfg.addFlag("Tau.GNTauNodeNameTau", "GN2TauNoAux_pb") + tau_cfg.addFlag("Tau.GNTauNodeNameJet", "GN2TauNoAux_pu") + tau_cfg.addFlag("Tau.GNTauDecorWPNames", + [ + ["GNTauVL_v0prune", "GNTauL_v0prune", "GNTauM_v0prune", "GNTauT_v0prune"], + ["GNTauVL_v1trunc", "GNTauL_v1trunc", "GNTauM_v1trunc", "GNTauT_v1trunc"] + ]) + # PanTau config flags from PanTauAlgs.PanTauConfigFlags import createPanTauConfigFlags diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py index fd96be80c82e..b4bafbba684b 100644 --- a/Reconstruction/tauRec/python/TauToolHolder.py +++ b/Reconstruction/tauRec/python/TauToolHolder.py @@ -851,19 +851,19 @@ def TauWPDecoratorJetDeepSetCfg(flags, version=None): result.setPrivateTools(myTauWPDecorator) return result -def TauGNNEvaluatorCfg(flags): +def TauGNNEvaluatorCfg(flags, version=0): result = ComponentAccumulator() - _name = flags.Tau.ActiveConfig.prefix + 'TauGNN' + _name = flags.Tau.ActiveConfig.prefix + 'TauGNN_v' + str(version) TauGNNEvaluator = CompFactory.getComp("TauGNNEvaluator") - GNNConf = flags.Tau.TauGNNConfig + GNNConf = flags.Tau.TauGNNConfig[version] myTauGNNEvaluator = TauGNNEvaluator(name = _name, - NetworkFile = GNNConf[0], - OutputVarname = "GNTauScore", + NetworkFile = GNNConf, + OutputVarname = flags.Tau.GNTauScoreName[version], OutputPTau = "GNTauProbTau", OutputPJet = "GNTauProbJet", - MaxTracks = 30, - MaxClusters = 20, + MaxTracks = flags.Tau.GNTauMaxTracks[version], + MaxClusters = flags.Tau.GNTauMaxClusters[version], MaxClusterDR = 15.0, MinTauPt = flags.Tau.MinPtDAOD, VertexCorrection = True, @@ -871,31 +871,28 @@ def TauGNNEvaluatorCfg(flags): InputLayerScalar = "tau_vars", InputLayerTracks = "track_vars", InputLayerClusters = "cluster_vars", - NodeNameTau="GN2TauNoAux_pb", - NodeNameJet="GN2TauNoAux_pu") + NodeNameTau=flags.Tau.GNTauNodeNameTau, + NodeNameJet=flags.Tau.GNTauNodeNameJet) result.setPrivateTools(myTauGNNEvaluator) return result -def TauWPDecoratorGNNCfg(flags): +def TauWPDecoratorGNNCfg(flags, version): result = ComponentAccumulator() - _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN' + _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN_v' + str(version) TauWPDecorator = CompFactory.getComp("TauWPDecorator") - WPConf = flags.Tau.TauGNNWP_v0 - decorWPNames = ["GNTauVL_v0", "GNTauL_v0", "GNTauM_v0", "GNTauT_v0"] - scoreName = "GNTauScore" - newScoreName = "GNTauScoreSigTrans_v0" + WPConf = flags.Tau.TauGNNWP[version] myTauWPDecorator = TauWPDecorator(name=_name, flatteningFile1Prong = WPConf[0], flatteningFile2Prong = WPConf[1], flatteningFile3Prong = WPConf[2], - DecorWPNames = decorWPNames, + DecorWPNames = flags.Tau.GNTauDecorWPNames[version], DecorWPCutEffs1P = [0.95, 0.85, 0.75, 0.60], DecorWPCutEffs2P = [0.95, 0.75, 0.60, 0.45], DecorWPCutEffs3P = [0.95, 0.75, 0.60, 0.45], - ScoreName = scoreName, - NewScoreName = newScoreName, + ScoreName = flags.Tau.GNTauScoreName[version], + NewScoreName = flags.Tau.GNTauTransScoreName[version], DefineWPs = True) result.setPrivateTools(myTauWPDecorator) return result diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx index 971372e865b0..08dfdc7fac4e 100644 --- a/Reconstruction/tauRecTools/Root/TauGNN.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -14,14 +14,12 @@ TauGNN::TauGNN(const std::string &nnFile, const Config &config): asg::AsgMessaging("TauGNN"), - m_onnxUtil(nullptr) + m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)) { //==================================================// // This part is ported from FTagDiscriminant GNN.cxx// //==================================================// - m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile); - // get the configuration of the model outputs FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx index a1346d67f22b..bc6b27d0b346 100644 --- a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx @@ -38,7 +38,7 @@ TauGNNEvaluator::TauGNNEvaluator(const std::string &name): TauGNNEvaluator::~TauGNNEvaluator() {} StatusCode TauGNNEvaluator::initialize() { - ATH_MSG_INFO("Initializing TauGNNEvaluator"); + ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks<<" tracks and "<<m_max_clusters<<" clusters..."); std::string weightfile(""); @@ -90,13 +90,16 @@ StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const { } // Get input objects + ATH_MSG_DEBUG("Fetching Tracks"); std::vector<const xAOD::TauTrack *> tracks; ATH_CHECK(get_tracks(tau, tracks)); + ATH_MSG_DEBUG("Fetching clusters"); std::vector<xAOD::CaloVertexedTopoCluster> clusters; ATH_CHECK(get_clusters(tau, clusters)); + ATH_MSG_DEBUG("Constituent fetching done..."); // Truncate tracks - int numTracksMax = std::min(m_max_tracks, tracks.size()); + int numTracksMax = std::min(m_max_tracks, static_cast<int>(tracks.size())); std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax); // Evaluate networks if (m_net) { @@ -168,7 +171,7 @@ StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xA std::sort(clusters.begin(), clusters.end(), et_cmp); // Truncate clusters - if (clusters.size() > m_max_clusters) { + if (static_cast<int>(clusters.size()) > m_max_clusters) { clusters.resize(m_max_clusters, clusters[0]); } diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx index a372e0c019bf..32b4bc2afb2b 100644 --- a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx @@ -162,7 +162,9 @@ std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scala calc->insert("d0TJVA", Variables::Track::d0TJVA, track_vars); calc->insert("d0SigTJVA", Variables::Track::d0SigTJVA, track_vars); calc->insert("dEta", Variables::Track::dEta, track_vars); + calc->insert("dEtaJetSeedAxis", Variables::Track::dEtaJetSeedAxis, track_vars); calc->insert("dPhi", Variables::Track::dPhi, track_vars); + calc->insert("dPhiJetSeedAxis", Variables::Track::dPhiJetSeedAxis, track_vars); calc->insert("nInnermostPixelHits", Variables::Track::nInnermostPixelHits, track_vars); calc->insert("nPixelHits", Variables::Track::nPixelHits, track_vars); calc->insert("nSCTHits", Variables::Track::nSCTHits, track_vars); @@ -527,11 +529,23 @@ bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { return true; } +bool dEtaJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed); + out = std::abs(tlvSeedJet.Eta() - track.eta()); + return true; +} + bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { out = track.p4().DeltaPhi(tau.p4()); return true; } +bool dPhiJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed); + out = tlvSeedJet.DeltaPhi(track.p4()); + return true; +} + bool nInnermostPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { uint8_t inner_pixel_hits; const auto success = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits); diff --git a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx index 94a4fd2a6c11..d6b83c9a7c56 100644 --- a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx +++ b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #include "tauRecTools/lwtnn/LightweightGraph.h" @@ -66,7 +66,7 @@ namespace lwtDev { typedef LightweightGraph::NodeMap NodeMap; LightweightGraph::LightweightGraph(const GraphConfig& config, - std::string default_output): + const std::string& default_output): m_graph(new Graph(config.nodes, config.layers)) { for (const auto& node: config.inputs) { diff --git a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx index ad895f07649f..fff78982eb2f 100644 --- a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx +++ b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #include "tauRecTools/lwtnn/Stack.h" @@ -424,7 +424,7 @@ namespace lwtDev { // __________________________________________________________________ // Recurrent layers - EmbeddingLayer::EmbeddingLayer(int var_row_index, MatrixXd W): + EmbeddingLayer::EmbeddingLayer(int var_row_index, const MatrixXd & W): m_var_row_index(var_row_index), m_W(W) { @@ -470,14 +470,14 @@ namespace lwtDev { // LSTM layer - LSTMLayer::LSTMLayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_i, MatrixXd U_i, VectorXd b_i, - MatrixXd W_f, MatrixXd U_f, VectorXd b_f, - MatrixXd W_o, MatrixXd U_o, VectorXd b_o, - MatrixXd W_c, MatrixXd U_c, VectorXd b_c, - bool go_backwards, - bool return_sequence): + LSTMLayer::LSTMLayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i, + const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f, + const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o, + const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c, + bool go_backwards, + bool return_sequence): m_W_i(W_i), m_U_i(U_i), m_b_i(b_i), @@ -547,11 +547,11 @@ namespace lwtDev { // GRU layer - GRULayer::GRULayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_z, MatrixXd U_z, VectorXd b_z, - MatrixXd W_r, MatrixXd U_r, VectorXd b_r, - MatrixXd W_h, MatrixXd U_h, VectorXd b_h): + GRULayer::GRULayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_z, const MatrixXd & U_z, const VectorXd & b_z, + const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r, + const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h): m_W_z(W_z), m_U_z(U_z), m_b_z(b_z), @@ -621,8 +621,8 @@ namespace lwtDev { } MatrixXd BidirectionalLayer::scan( const MatrixXd& x) const{ - MatrixXd forward = m_forward_layer->scan(x); - MatrixXd backward = m_backward_layer->scan(x); + const MatrixXd & forward = m_forward_layer->scan(x); + const MatrixXd & backward = m_backward_layer->scan(x); MatrixXd backward_rev; if (m_return_sequence){ backward_rev = backward.rowwise().reverse(); diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h index baa5d54a181b..776a6a96bd70 100644 --- a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h @@ -48,8 +48,8 @@ private: std::string m_output_ptau; std::string m_output_pjet; std::string m_weightfile; - std::size_t m_max_tracks; - std::size_t m_max_clusters; + int m_max_tracks; + int m_max_clusters; float m_max_cluster_dr; float m_minTauPt; bool m_doVertexCorrection; diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h index 28a983dd5e8a..7e4913018524 100644 --- a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h @@ -179,9 +179,15 @@ bool d0SigTJVA( bool dEta( const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); +bool dEtaJetSeedAxis( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + bool dPhi( const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); +bool dPhiJetSeedAxis( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + bool nInnermostPixelHits( const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h index 802d17df2b02..16d1b29d7acd 100644 --- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h +++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #ifndef LIGHTWEIGHT_GRAPH_HH_TAURECTOOLS @@ -72,7 +72,7 @@ namespace lwtDev { // define a "default" output, so that calling "compute" with no // output specified doesn't lead to ambiguity. LightweightGraph(const GraphConfig& config, - std::string default_output = ""); + const std::string& default_output = ""); ~LightweightGraph(); LightweightGraph(LightweightGraph&) = delete; diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h index 69ea569fb0f3..84d55e46c9ad 100644 --- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h +++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #ifndef STACK_HH_TAURECTOOLS @@ -222,7 +222,7 @@ namespace lwtDev { class EmbeddingLayer : public IRecurrentLayer { public: - EmbeddingLayer(int var_row_index, MatrixXd W); + EmbeddingLayer(int var_row_index, const MatrixXd & W); virtual ~EmbeddingLayer() {}; virtual MatrixXd scan( const MatrixXd&) const override; @@ -236,12 +236,12 @@ namespace lwtDev { class LSTMLayer : public IRecurrentLayer { public: - LSTMLayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_i, MatrixXd U_i, VectorXd b_i, - MatrixXd W_f, MatrixXd U_f, VectorXd b_f, - MatrixXd W_o, MatrixXd U_o, VectorXd b_o, - MatrixXd W_c, MatrixXd U_c, VectorXd b_c, + LSTMLayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i, + const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f, + const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o, + const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c, bool go_backwards, bool return_sequence); @@ -277,11 +277,11 @@ namespace lwtDev { class GRULayer : public IRecurrentLayer { public: - GRULayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_z, MatrixXd U_z, VectorXd b_z, - MatrixXd W_r, MatrixXd U_r, VectorXd b_r, - MatrixXd W_h, MatrixXd U_h, VectorXd b_h); + GRULayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_z, const MatrixXd & U_z, const VectorXd & b_z, + const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r, + const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h); virtual ~GRULayer() {}; virtual MatrixXd scan( const MatrixXd&) const override; -- GitLab From 4e01e714ab98fabb3abc8edb0dd860234a0a1232 Mon Sep 17 00:00:00 2001 From: Jean Yves Beaucamp <jean.yves.beaucamp@cern.ch> Date: Tue, 5 Nov 2024 01:45:35 +0100 Subject: [PATCH 4/4] Fix for the TauGNN implementation Fix for the TauGNN implementation --- Reconstruction/tauRecTools/Root/TauGNN.cxx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx index 08dfdc7fac4e..a0603fa2957b 100644 --- a/Reconstruction/tauRecTools/Root/TauGNN.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -14,7 +14,8 @@ TauGNN::TauGNN(const std::string &nnFile, const Config &config): asg::AsgMessaging("TauGNN"), - m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)) + m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)), + m_config{config} { //==================================================// // This part is ported from FTagDiscriminant GNN.cxx// -- GitLab