diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py index f4c7ff650c38985a3f7eb8936ba7d2879c78156b..ea91293697d628a6ce4fec9558ef4a641fbdcc7c 100644 --- a/Reconstruction/tauRec/python/TauConfigFlags.py +++ b/Reconstruction/tauRec/python/TauConfigFlags.py @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration import unittest from AthenaConfiguration.AthConfigFlags import AthConfigFlags @@ -48,6 +48,7 @@ def createTauConfigFlags(): tau_cfg.addFlag("Tau.MvaTESConfig", "MvaTES_R23.root") tau_cfg.addFlag("Tau.MinPt0p", 9.25*Units.GeV) tau_cfg.addFlag("Tau.MinPt", 6.75*Units.GeV) + tau_cfg.addFlag("Tau.MinPtDAOD", 13*Units.GeV) tau_cfg.addFlag("Tau.TauJetRNNConfig", ["tauid_rnn_1p_R22_v1.json", "tauid_rnn_2p_R22_v1.json", "tauid_rnn_3p_R22_v1.json"]) tau_cfg.addFlag("Tau.TauJetRNNWPConfig", ["tauid_rnnWP_1p_R22_v0.root", "tauid_rnnWP_2p_R22_v0.root", "tauid_rnnWP_3p_R22_v0.root"]) tau_cfg.addFlag("Tau.TauEleRNNConfig", ["taueveto_rnn_config_1P_r22.json", "taueveto_rnn_config_3P_r22.json"]) @@ -60,6 +61,25 @@ def createTauConfigFlags(): # R22 DeepSet tau ID tune without track RNN scores, for now define a second set of flags, but ultimately we'll choose one and drop the other tau_cfg.addFlag("Tau.TauJetDeepSetConfig_v2", ["tauid_1p_R22_dpst_noTrackScore.json", "tauid_2p_R22_dpst_noTrackScore.json", "tauid_3p_R22_dpst_noTrackScore.json"]) tau_cfg.addFlag("Tau.TauJetDeepSetWP_v2", ["model_1p_R22_dpst_noTrackScore.root", "model_2p_R22_dpst_noTrackScore.root", "model_3p_R22_dpst_noTrackScore.root"]) + # GNTau ID tune file (need to add another version for noAux) + tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_pruned_MC23.onnx","GNTau_trunc_MC23.onnx"]) + tau_cfg.addFlag("Tau.TauGNNWP", + [ + ["GNTauNAprune_flat_model_1p.root", "GNTauNAprune_flat_model_2p.root", "GNTauNAprune_flat_model_3p.root"], + ["GNTauNAtrunc_flat_model_1p.root", "GNTauNAtrunc_flat_model_2p.root", "GNTauNAtrunc_flat_model_3p.root"] + ]) + tau_cfg.addFlag("Tau.GNTauScoreName", ["GNTauScore_v0prune","GNTauScore_v1trunc"]) + tau_cfg.addFlag("Tau.GNTauTransScoreName", ["GNTauScoreSigTrans_v0prune","GNTauScoreSigTrans_v1trunc"]) + tau_cfg.addFlag("Tau.GNTauMaxTracks", [30,10]) + tau_cfg.addFlag("Tau.GNTauMaxClusters", [20,6]) + tau_cfg.addFlag("Tau.GNTauNodeNameTau", "GN2TauNoAux_pb") + tau_cfg.addFlag("Tau.GNTauNodeNameJet", "GN2TauNoAux_pu") + tau_cfg.addFlag("Tau.GNTauDecorWPNames", + [ + ["GNTauVL_v0prune", "GNTauL_v0prune", "GNTauM_v0prune", "GNTauT_v0prune"], + ["GNTauVL_v1trunc", "GNTauL_v1trunc", "GNTauM_v1trunc", "GNTauT_v1trunc"] + ]) + # PanTau config flags from PanTauAlgs.PanTauConfigFlags import createPanTauConfigFlags diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py index 679fb4c86f9fdb56745bf7b9211926cd539f3fb1..b4bafbba684bb6495bd89b21a93b7c5aac2009dc 100644 --- a/Reconstruction/tauRec/python/TauToolHolder.py +++ b/Reconstruction/tauRec/python/TauToolHolder.py @@ -851,6 +851,52 @@ def TauWPDecoratorJetDeepSetCfg(flags, version=None): result.setPrivateTools(myTauWPDecorator) return result +def TauGNNEvaluatorCfg(flags, version=0): + result = ComponentAccumulator() + _name = flags.Tau.ActiveConfig.prefix + 'TauGNN_v' + str(version) + + TauGNNEvaluator = CompFactory.getComp("TauGNNEvaluator") + GNNConf = flags.Tau.TauGNNConfig[version] + myTauGNNEvaluator = TauGNNEvaluator(name = _name, + NetworkFile = GNNConf, + OutputVarname = flags.Tau.GNTauScoreName[version], + OutputPTau = "GNTauProbTau", + OutputPJet = "GNTauProbJet", + MaxTracks = flags.Tau.GNTauMaxTracks[version], + MaxClusters = flags.Tau.GNTauMaxClusters[version], + MaxClusterDR = 15.0, + MinTauPt = flags.Tau.MinPtDAOD, + VertexCorrection = True, + DecorateTracks = False, + InputLayerScalar = "tau_vars", + InputLayerTracks = "track_vars", + InputLayerClusters = "cluster_vars", + NodeNameTau=flags.Tau.GNTauNodeNameTau, + NodeNameJet=flags.Tau.GNTauNodeNameJet) + + result.setPrivateTools(myTauGNNEvaluator) + return result + +def TauWPDecoratorGNNCfg(flags, version): + result = ComponentAccumulator() + _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN_v' + str(version) + + TauWPDecorator = CompFactory.getComp("TauWPDecorator") + WPConf = flags.Tau.TauGNNWP[version] + myTauWPDecorator = TauWPDecorator(name=_name, + flatteningFile1Prong = WPConf[0], + flatteningFile2Prong = WPConf[1], + flatteningFile3Prong = WPConf[2], + DecorWPNames = flags.Tau.GNTauDecorWPNames[version], + DecorWPCutEffs1P = [0.95, 0.85, 0.75, 0.60], + DecorWPCutEffs2P = [0.95, 0.75, 0.60, 0.45], + DecorWPCutEffs3P = [0.95, 0.75, 0.60, 0.45], + ScoreName = flags.Tau.GNTauScoreName[version], + NewScoreName = flags.Tau.GNTauTransScoreName[version], + DefineWPs = True) + result.setPrivateTools(myTauWPDecorator) + return result + def TauEleRNNEvaluatorCfg(flags): result = ComponentAccumulator() _name = flags.Tau.ActiveConfig.prefix + 'TauEleRNN' diff --git a/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx b/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx index 02ed3802ab71ebdf64260bf7fe551b1e03a3316a..5a9a2d01c0e6f769ae321782eb30a6332e15575a 100644 --- a/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx +++ b/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #include "TauAODRunnerAlg.h" @@ -66,6 +66,7 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const { xAOD::TauJetContainer *newTauCon = outputTauHandle.ptr(); static const SG::AuxElement::Accessor<ElementLink<xAOD::TauJetContainer>> acc_ori_tau_link("originalTauJet"); + static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD"); for (const xAOD::TauJet *tau : *pTauContainer) { // deep copy the tau container @@ -88,6 +89,21 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const { linkToTauTrack.toContainedElement(*newTauTrkCon, newTauTrk); newTau->addTauTrackLink(linkToTauTrack); } + + // 'ModifiedInAOD' will be overriden by modification tools for relevant candidates + acc_modified(*newTau) = static_cast<char>(false); + + StatusCode sc; + for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) { + ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name()); + sc = tool->execute(*newTau); + if (sc.isFailure()) break; + } + + // if tau candidate was not modified, remove it from container, track cleanup performed by thinning algorithm downstream + if (!acc_modified(*newTau)) { + newTauCon->pop_back(); + } } // Read the CaloClusterContainer @@ -123,49 +139,34 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const { ATH_CHECK(vertOutHandle.record(std::make_unique<xAOD::VertexContainer>(), std::make_unique<xAOD::VertexAuxContainer>())); xAOD::VertexContainer* pSecVtxContainer = vertOutHandle.ptr(); - int n_tau_modified = 0; - static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD"); - for (xAOD::TauJet *pTau : *newTauCon) { - // Loop stops when Failure indicated by one of the tools StatusCode sc; - //add a identifier of if the tau is modifed by the mod tools - acc_modified(*pTau) = static_cast<char>(false); - // iterate over the copy - for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) { + for (const ToolHandle<ITauToolBase> &tool : m_officialTools) { ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name()); - sc = tool->execute(*pTau); + if (tool->type() == "TauPi0ClusterCreator") + sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer); + else if (tool->type() == "TauVertexVariables") + sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer); + else if (tool->type() == "TauPi0ClusterScaler") + sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer); + else if (tool->type() == "TauPi0ScoreCalculator") + sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); + else if (tool->type() == "TauPi0Selector") + sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); + else if (tool->type() == "PanTau::PanTauProcessor") + sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer); + else if (tool->type() == "tauRecTools::TauTrackRNNClassifier") + sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon); + else + sc = tool->execute(*pTau); if (sc.isFailure()) break; } - if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked modification tools."); - // if tau is not modified by the above tools, never mind running the tools afterward - if (static_cast<bool>(isTauModified(pTau))) { - n_tau_modified++; - for (const ToolHandle<ITauToolBase> &tool : m_officialTools) { - ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name()); - if (tool->type() == "TauPi0ClusterCreator") - sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer); - else if (tool->type() == "TauVertexVariables") - sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer); - else if (tool->type() == "TauPi0ClusterScaler") - sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer); - else if (tool->type() == "TauPi0ScoreCalculator") - sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); - else if (tool->type() == "TauPi0Selector") - sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer); - else if (tool->type() == "PanTau::PanTauProcessor") - sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer); - else if (tool->type() == "tauRecTools::TauTrackRNNClassifier") - sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon); - else - sc = tool->execute(*pTau); - if (sc.isFailure()) break; - } - if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools."); - } + if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools."); } + ATH_MSG_VERBOSE("The tau candidate container has been modified by the rest of the tools"); - ATH_MSG_DEBUG(n_tau_modified << " / " << pTauContainer->size() <<" taus were modified"); + ATH_MSG_DEBUG(newTauCon->size() << " / " << pTauContainer->size() <<" taus were modified"); + return StatusCode::SUCCESS; } diff --git a/Reconstruction/tauRecTools/CMakeLists.txt b/Reconstruction/tauRecTools/CMakeLists.txt index 40a34801c23a244adb2a9be46ffa39c086df3fe2..87b3c12ef094a880e7a73652436f1f74dc33a39f 100644 --- a/Reconstruction/tauRecTools/CMakeLists.txt +++ b/Reconstruction/tauRecTools/CMakeLists.txt @@ -8,6 +8,7 @@ find_package( Boost ) find_package( Eigen ) find_package( ROOT COMPONENTS Core Tree Hist RIO ) find_package( lwtnn ) +find_package( onnxruntime ) # Optional dependencies. set( extra_public_libs ) @@ -28,12 +29,12 @@ atlas_add_library( tauRecToolsLib tauRecTools/*.h Root/*.cxx tauRecTools/lwtnn/*.h Root/lwtnn/*.cxx PUBLIC_HEADERS tauRecTools INCLUDE_DIRS ${Boost_INCLUDE_DIRS} ${EIGEN_INCLUDE_DIRS} - ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} + ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${ONNXRUNTIME_INCLUDE_DIRS} LINK_LIBRARIES ${Boost_LIBRARIES} ${EIGEN_LIBRARIES} ${LWTNN_LIBRARIES} - ${ROOT_LIBRARIES} + ${ROOT_LIBRARIES} ${ONNXRUNTIME_LIBRARIES} CxxUtils AthLinks AsgMessagingLib AsgDataHandlesLib AsgTools xAODCaloEvent xAODEventInfo xAODJet xAODParticleEvent xAODPFlow xAODTau xAODTracking xAODEventShape - MVAUtils ${extra_public_libs} + MVAUtils FlavorTagDiscriminants ${extra_public_libs} PRIVATE_LINK_LIBRARIES CaloGeoHelpers FourMomUtils PathResolver ) atlas_add_dictionary( tauRecToolsDict @@ -46,5 +47,5 @@ if( NOT XAOD_STANDALONE ) INCLUDE_DIRS ${Boost_INCLUDE_DIRS} LINK_LIBRARIES ${Boost_LIBRARIES} AsgTools AsgMessagingLib xAODBase xAODCaloEvent xAODJet xAODPFlow xAODTau FourMomUtils tauRecToolsLib PFlowUtilsLib - ${extra_private_libs} ) + FlavorTagDiscriminants ${extra_private_libs} ) endif() diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx new file mode 100644 index 0000000000000000000000000000000000000000..a0603fa2957bba17acef8230b14702c59b3fe3be --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -0,0 +1,202 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNN.h" +#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "lwtnn/parse_json.hh" +#include "PathResolver/PathResolver.h" + +#include <algorithm> +#include <fstream> + +#include "tauRecTools/TauGNNUtils.h" + +TauGNN::TauGNN(const std::string &nnFile, const Config &config): + asg::AsgMessaging("TauGNN"), + m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)), + m_config{config} + { + //==================================================// + // This part is ported from FTagDiscriminant GNN.cxx// + //==================================================// + + // get the configuration of the model outputs + FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + + //Let's see the output! + for (const auto& out_node: gnn_output_config) { + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name); + } + + //Get model config (for inputs) + auto lwtnn_config = m_onnxUtil->getLwtConfig(); + + //===================================================// + // This part is ported from tauRecTools TauJetRNN.cxx// + //===================================================// + + // Search for input layer names specified in 'config' + auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_scalar; + }; + auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_tracks; + }; + auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_clusters; + }; + + auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(), + lwtnn_config.inputs.cend(), + node_is_scalar); + + auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(), + lwtnn_config.input_sequences.cend(), + node_is_track); + + auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(), + lwtnn_config.input_sequences.cend(), + node_is_cluster); + + // Check which input layers were found + auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend(); + auto has_track_node = track_node != lwtnn_config.input_sequences.cend(); + auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend(); + if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!"); + if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!"); + if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!"); + + // Fill the variable names of each input layer into the corresponding vector + if (has_scalar_node) { + for (const auto &in : scalar_node->variables) { + std::string name = in.name; + m_scalarCalc_inputs.push_back(name); + } + } + + if (has_track_node) { + for (const auto &in : track_node->variables) { + std::string name = in.name; + m_trackCalc_inputs.push_back(name); + } + } + + if (has_cluster_node) { + for (const auto &in : cluster_node->variables) { + std::string name = in.name; + m_clusterCalc_inputs.push_back(name); + } + } + // Load the variable calculator + m_var_calc = TauGNNUtils::get_calculator(m_scalarCalc_inputs, m_trackCalc_inputs, m_clusterCalc_inputs); + ATH_MSG_INFO("TauGNN object initialized successfully!"); +} + +TauGNN::~TauGNN() {} + +std::tuple< + std::map<std::string, float>, + std::map<std::string, std::vector<char>>, + std::map<std::string, std::vector<float>> > +TauGNN::compute(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { + InputMap scalarInputs; + InputSequenceMap vectorInputs; + std::map<std::string, Inputs> gnn_input; + ATH_MSG_DEBUG("Starting compute..."); + //Prepare input variables + if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) { + ATH_MSG_FATAL("Failed calculateInputVariables"); + throw StatusCode::FAILURE; + } + + // Add TauJet-level features to the input + std::vector<float> tau_feats; + for (const auto &varname : m_scalarCalc_inputs) { + tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname])); + } + std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())}; + Inputs tau_info (tau_feats, tau_feats_dim); + gnn_input.insert({"tau_vars", tau_info}); + + //Add track-level features to the input + std::vector<float> trk_feats; + int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size()); + int num_node_vars=static_cast<int>(m_trackCalc_inputs.size()); + trk_feats.resize(num_nodes * num_node_vars); + int var_idx=0; + for (const auto &varname : m_trackCalc_inputs) { + for (int node_idx=0; node_idx<num_nodes; node_idx++){ + trk_feats.at(node_idx*num_node_vars + var_idx) + = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx)); + } + var_idx++; + } + std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars}; + Inputs trk_info (trk_feats, trk_feats_dim); + gnn_input.insert({"track_vars", trk_info}); + + //Add cluster-level features to the input + std::vector<float> cls_feats; + num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size()); + num_node_vars=static_cast<int>(m_clusterCalc_inputs.size()); + cls_feats.resize(num_nodes * num_node_vars); + var_idx=0; + for (const auto &varname : m_clusterCalc_inputs) { + for (int node_idx=0; node_idx<num_nodes; node_idx++){ + cls_feats.at(node_idx*num_node_vars + var_idx) + = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx)); + } + var_idx++; + } + std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars}; + Inputs cls_info (cls_feats, cls_feats_dim); + gnn_input.insert({"cluster_vars", cls_info}); + + //RUN THE INFERENCE!!! + ATH_MSG_DEBUG("Prepared inputs, running inference..."); + auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + ATH_MSG_DEBUG("Finished compute!"); + return std::make_tuple(out_f, out_vc, out_vf); +} + +bool TauGNN::calculateInputVariables(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::map<std::string, std::map<std::string, double>>& scalarInputs, + std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const { + scalarInputs.clear(); + vectorInputs.clear(); + // Populate input (sequence) map with input variables + for (const auto &varname : m_scalarCalc_inputs) { + if (!m_var_calc->compute(varname, tau, + scalarInputs[m_config.input_layer_scalar][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + + for (const auto &varname : m_trackCalc_inputs) { + if (!m_var_calc->compute(varname, tau, tracks, + vectorInputs[m_config.input_layer_tracks][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + + for (const auto &varname : m_clusterCalc_inputs) { + if (!m_var_calc->compute(varname, tau, clusters, + vectorInputs[m_config.input_layer_clusters][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + return true; +} diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx new file mode 100644 index 0000000000000000000000000000000000000000..bc6b27d0b346277438be6cf56e730589fac9be3f --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx @@ -0,0 +1,179 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNNEvaluator.h" +#include "tauRecTools/TauGNN.h" +#include "tauRecTools/HelperFunctions.h" + +#include "PathResolver/PathResolver.h" + +#include <algorithm> + + +TauGNNEvaluator::TauGNNEvaluator(const std::string &name): + TauRecToolBase(name), + m_net(nullptr){ + + declareProperty("NetworkFile", m_weightfile = ""); + declareProperty("OutputVarname", m_output_varname = "GNTauScore"); + declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau"); + declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet"); + declareProperty("MaxTracks", m_max_tracks = 30); + declareProperty("MaxClusters", m_max_clusters = 20); + declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f); + declareProperty("VertexCorrection", m_doVertexCorrection = true); + declareProperty("DecorateTracks", m_decorateTracks = false); + declareProperty("TrackClassification", m_doTrackClassification = true); + declareProperty("MinTauPt", m_minTauPt = 0.); + + // Naming conventions for the network weight files: + declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars"); + declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars"); + declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars"); + declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb"); + declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu"); + } + +TauGNNEvaluator::~TauGNNEvaluator() {} + +StatusCode TauGNNEvaluator::initialize() { + ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks<<" tracks and "<<m_max_clusters<<" clusters..."); + + std::string weightfile(""); + + // Use PathResolver to search for the weight files + if (!m_weightfile.empty()) { + weightfile = find_file(m_weightfile); + if (weightfile.empty()) { + ATH_MSG_ERROR("Could not find network weights: " << m_weightfile); + return StatusCode::FAILURE; + } else { + ATH_MSG_INFO("Using network config: " << weightfile); + } + } + + // Set the layer and node names in the weight file + TauGNN::Config config; + config.input_layer_scalar = m_input_layer_scalar; + config.input_layer_tracks = m_input_layer_tracks; + config.input_layer_clusters = m_input_layer_clusters; + config.output_node_tau = m_outnode_tau; + config.output_node_jet = m_outnode_jet; + + // Load the weights and create the network + if (!weightfile.empty()) { + m_net = std::make_unique<TauGNN>(weightfile, config); + if (!m_net) { + ATH_MSG_ERROR("No network configured."); + return StatusCode::FAILURE; + } + } + + return StatusCode::SUCCESS; +} + +StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const { + // Output variable Decorators + const SG::AuxElement::Accessor<float> output(m_output_varname); + const SG::AuxElement::Accessor<float> out_ptau(m_output_ptau); + const SG::AuxElement::Accessor<float> out_pjet(m_output_pjet); + const SG::AuxElement::Decorator<char> out_trkclass("GNTau_TrackClass"); + // Set default score and overwrite later + output(tau) = -1111.0f; + out_ptau(tau) = -1111.0f; + out_pjet(tau) = -1111.0f; + + //Skip execution for low-pT taus to save resources + if (tau.pt() < m_minTauPt) { + return StatusCode::SUCCESS; + } + + // Get input objects + ATH_MSG_DEBUG("Fetching Tracks"); + std::vector<const xAOD::TauTrack *> tracks; + ATH_CHECK(get_tracks(tau, tracks)); + ATH_MSG_DEBUG("Fetching clusters"); + std::vector<xAOD::CaloVertexedTopoCluster> clusters; + ATH_CHECK(get_clusters(tau, clusters)); + ATH_MSG_DEBUG("Constituent fetching done..."); + + // Truncate tracks + int numTracksMax = std::min(m_max_tracks, static_cast<int>(tracks.size())); + std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax); + // Evaluate networks + if (m_net) { + auto [out_f, out_vc, out_vf] = m_net->compute(tau, trackVec, clusters); + output(tau)=std::log10(1/(1-out_f.at(m_outnode_tau))); + out_ptau(tau)=out_f.at(m_outnode_tau); + out_pjet(tau)=out_f.at(m_outnode_jet); + if (m_decorateTracks){ + for(unsigned int i=0;i<tracks.size();i++){ + if(i<out_vc.at("track_class").size()){out_trkclass(*tracks.at(i))=out_vc.at("track_class").at(i);} + else{out_trkclass(*tracks.at(i))='9';} //Dummy value for tracks outside range of out_vc + } + } + } + + return StatusCode::SUCCESS; +} + +const TauGNN* TauGNNEvaluator::get_gnn() const { + return m_net.get(); +} + + +StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const { + std::vector<const xAOD::TauTrack*> tracks = tau.allTracks(); + + // Skip unclassified tracks: + // - the track is a LRT and classifyLRT = false + // - the track is not among the MaxNtracks highest-pt tracks in the track classifier + // - track classification is not run (trigger) + if(m_doTrackClassification) { + std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin(); + while(it != tracks.end()) { + if((*it)->flag(xAOD::TauJetParameters::unclassified)) { + it = tracks.erase(it); + } + else { + ++it; + } + } + } + + // Sort by descending pt + auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) { + return lhs->pt() > rhs->pt(); + }; + std::sort(tracks.begin(), tracks.end(), cmp_pt); + out = std::move(tracks); + + return StatusCode::SUCCESS; +} + +StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { + + TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection); + + for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) { + TLorentzVector clusterP4 = vertexedCluster.p4(); + if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue; + + clusters.push_back(vertexedCluster); + } + + // Sort by descending et + auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs, + const xAOD::CaloVertexedTopoCluster& rhs) { + return lhs.p4().Et() > rhs.p4().Et(); + }; + std::sort(clusters.begin(), clusters.end(), et_cmp); + + // Truncate clusters + if (static_cast<int>(clusters.size()) > m_max_clusters) { + clusters.resize(m_max_clusters, clusters[0]); + } + + return StatusCode::SUCCESS; +} diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx new file mode 100644 index 0000000000000000000000000000000000000000..32b4bc2afb2b815683f91c246ffc3ad276111fcc --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx @@ -0,0 +1,930 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNNUtils.h" +#include "tauRecTools/HelperFunctions.h" +#include <algorithm> +#include <iostream> +#define GeV 1000 + +namespace TauGNNUtils { + +GNNVarCalc::GNNVarCalc() : asg::AsgMessaging("TauGNNUtils::GNNVarCalc") { +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + double &out) const { + // Retrieve calculator function + ScalarCalc func = nullptr; + try { + func = m_scalar_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variable + return func(tau, out); +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + std::vector<double> &out) const { + out.clear(); + out.reserve(tracks.size()); + + // Retrieve calculator function + TrackCalc func = nullptr; + try { + func = m_track_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variables for selected tracks + bool success = true; + double value; + for (const auto *const trk : tracks) { + success = success && func(tau, *trk, value); + out.push_back(value); + } + + return success; +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::vector<double> &out) const { + out.clear(); + out.reserve(clusters.size()); + + // Retrieve calculator function + ClusterCalc func = nullptr; + try { + func = m_cluster_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variables for selected clusters + bool success = true; + double value; + for (const xAOD::CaloVertexedTopoCluster& cluster : clusters) { + success = success && func(tau, cluster, value); + out.push_back(value); + } + + return success; +} + +void GNNVarCalc::insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars) { + if (std::find(scalar_vars.begin(), scalar_vars.end(), name) == scalar_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_scalar_map[name] = func; +} + +void GNNVarCalc::insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars) { + if (std::find(track_vars.begin(), track_vars.end(), name) == track_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_track_map[name] = func; +} + +void GNNVarCalc::insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars) { + if (std::find(cluster_vars.begin(), cluster_vars.end(), name) == cluster_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_cluster_map[name] = func; +} + +std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars, + const std::vector<std::string>& track_vars, + const std::vector<std::string>& cluster_vars) { + auto calc = std::make_unique<GNNVarCalc>(); + + // Scalar variable calculator functions + calc->insert("absEta", Variables::absEta, scalar_vars); + calc->insert("isolFrac", Variables::isolFrac, scalar_vars); + calc->insert("centFrac", Variables::centFrac, scalar_vars); + calc->insert("etOverPtLeadTrk", Variables::etOverPtLeadTrk, scalar_vars); + calc->insert("innerTrkAvgDist", Variables::innerTrkAvgDist, scalar_vars); + calc->insert("absipSigLeadTrk", Variables::absipSigLeadTrk, scalar_vars); + calc->insert("SumPtTrkFrac", Variables::SumPtTrkFrac, scalar_vars); + calc->insert("sumEMCellEtOverLeadTrkPt", Variables::sumEMCellEtOverLeadTrkPt, scalar_vars); + calc->insert("EMPOverTrkSysP", Variables::EMPOverTrkSysP, scalar_vars); + calc->insert("ptRatioEflowApprox", Variables::ptRatioEflowApprox, scalar_vars); + calc->insert("mEflowApprox", Variables::mEflowApprox, scalar_vars); + calc->insert("dRmax", Variables::dRmax, scalar_vars); + calc->insert("trFlightPathSig", Variables::trFlightPathSig, scalar_vars); + calc->insert("massTrkSys", Variables::massTrkSys, scalar_vars); + calc->insert("pt", Variables::pt, scalar_vars); + calc->insert("pt_tau_log", Variables::pt_tau_log, scalar_vars); + calc->insert("ptDetectorAxis", Variables::ptDetectorAxis, scalar_vars); + calc->insert("ptIntermediateAxis", Variables::ptIntermediateAxis, scalar_vars); + //---added for the eVeto + calc->insert("ptJetSeed_log", Variables::ptJetSeed_log, scalar_vars); + calc->insert("absleadTrackEta", Variables::absleadTrackEta, scalar_vars); + calc->insert("leadTrackDeltaEta", Variables::leadTrackDeltaEta, scalar_vars); + calc->insert("leadTrackDeltaPhi", Variables::leadTrackDeltaPhi, scalar_vars); + calc->insert("leadTrackProbNNorHT", Variables::leadTrackProbNNorHT, scalar_vars); + calc->insert("EMFracFixed", Variables::EMFracFixed, scalar_vars); + calc->insert("etHotShotWinOverPtLeadTrk", Variables::etHotShotWinOverPtLeadTrk, scalar_vars); + calc->insert("hadLeakFracFixed", Variables::hadLeakFracFixed, scalar_vars); + calc->insert("PSFrac", Variables::PSFrac, scalar_vars); + calc->insert("ClustersMeanCenterLambda", Variables::ClustersMeanCenterLambda, scalar_vars); + calc->insert("ClustersMeanFirstEngDens", Variables::ClustersMeanFirstEngDens, scalar_vars); + calc->insert("ClustersMeanPresamplerFrac", Variables::ClustersMeanPresamplerFrac, scalar_vars); + + // Track variable calculator functions + calc->insert("pt_log", Variables::Track::pt_log, track_vars); + calc->insert("trackPt", Variables::Track::trackPt, track_vars); + calc->insert("trackEta", Variables::Track::trackEta, track_vars); + calc->insert("trackPhi", Variables::Track::trackPhi, track_vars); + calc->insert("pt_tau_log", Variables::Track::pt_tau_log, track_vars); + calc->insert("pt_jetseed_log", Variables::Track::pt_jetseed_log, track_vars); + calc->insert("d0_abs_log", Variables::Track::d0_abs_log, track_vars); + calc->insert("z0sinThetaTJVA_abs_log", Variables::Track::z0sinThetaTJVA_abs_log, track_vars); + calc->insert("z0sinthetaTJVA", Variables::Track::z0sinthetaTJVA, track_vars); + calc->insert("z0sinthetaSigTJVA", Variables::Track::z0sinthetaSigTJVA, track_vars); + calc->insert("d0TJVA", Variables::Track::d0TJVA, track_vars); + calc->insert("d0SigTJVA", Variables::Track::d0SigTJVA, track_vars); + calc->insert("dEta", Variables::Track::dEta, track_vars); + calc->insert("dEtaJetSeedAxis", Variables::Track::dEtaJetSeedAxis, track_vars); + calc->insert("dPhi", Variables::Track::dPhi, track_vars); + calc->insert("dPhiJetSeedAxis", Variables::Track::dPhiJetSeedAxis, track_vars); + calc->insert("nInnermostPixelHits", Variables::Track::nInnermostPixelHits, track_vars); + calc->insert("nPixelHits", Variables::Track::nPixelHits, track_vars); + calc->insert("nSCTHits", Variables::Track::nSCTHits, track_vars); + calc->insert("nIBLHitsAndExp", Variables::Track::nIBLHitsAndExp, track_vars); + calc->insert("nPixelHitsPlusDeadSensors", Variables::Track::nPixelHitsPlusDeadSensors, track_vars); + calc->insert("nSCTHitsPlusDeadSensors", Variables::Track::nSCTHitsPlusDeadSensors, track_vars); + calc->insert("eProbabilityHT", Variables::Track::eProbabilityHT, track_vars); + calc->insert("eProbabilityNN", Variables::Track::eProbabilityNN, track_vars); + calc->insert("eProbabilityNNorHT", Variables::Track::eProbabilityNNorHT, track_vars); + calc->insert("chargedScoreRNN", Variables::Track::chargedScoreRNN, track_vars); + calc->insert("isolationScoreRNN", Variables::Track::isolationScoreRNN, track_vars); + calc->insert("conversionScoreRNN", Variables::Track::conversionScoreRNN, track_vars); + calc->insert("fakeScoreRNN", Variables::Track::fakeScoreRNN, track_vars); + //Extension - variables for GNTau + calc->insert("numberOfInnermostPixelLayerHits", Variables::Track::numberOfInnermostPixelLayerHits, track_vars); + calc->insert("numberOfPixelHits", Variables::Track::numberOfPixelHits, track_vars); + calc->insert("numberOfPixelSharedHits", Variables::Track::numberOfPixelSharedHits, track_vars); + calc->insert("numberOfPixelDeadSensors", Variables::Track::numberOfPixelDeadSensors, track_vars); + calc->insert("numberOfSCTHits", Variables::Track::numberOfSCTHits, track_vars); + calc->insert("numberOfSCTSharedHits", Variables::Track::numberOfSCTSharedHits, track_vars); + calc->insert("numberOfSCTDeadSensors", Variables::Track::numberOfSCTDeadSensors, track_vars); + calc->insert("numberOfTRTHighThresholdHits", Variables::Track::numberOfTRTHighThresholdHits, track_vars); + calc->insert("numberOfTRTHits", Variables::Track::numberOfTRTHits, track_vars); + calc->insert("nSiHits", Variables::Track::nSiHits, track_vars); + calc->insert("expectInnermostPixelLayerHit", Variables::Track::expectInnermostPixelLayerHit, track_vars); + calc->insert("expectNextToInnermostPixelLayerHit", Variables::Track::expectNextToInnermostPixelLayerHit, track_vars); + calc->insert("numberOfContribPixelLayers", Variables::Track::numberOfContribPixelLayers, track_vars); + calc->insert("numberOfPixelHoles", Variables::Track::numberOfPixelHoles, track_vars); + calc->insert("d0_old", Variables::Track::d0_old, track_vars); + calc->insert("qOverP", Variables::Track::qOverP, track_vars); + calc->insert("theta", Variables::Track::theta, track_vars); + calc->insert("z0TJVA", Variables::Track::z0TJVA, track_vars); + calc->insert("charge", Variables::Track::charge, track_vars); + calc->insert("dz0_TV_PV0", Variables::Track::dz0_TV_PV0, track_vars); + calc->insert("log_sumpt_TV", Variables::Track::log_sumpt_TV, track_vars); + calc->insert("log_sumpt2_TV", Variables::Track::log_sumpt2_TV, track_vars); + calc->insert("log_sumpt_PV0", Variables::Track::log_sumpt_PV0, track_vars); + calc->insert("log_sumpt2_PV0", Variables::Track::log_sumpt2_PV0, track_vars); + + // Cluster variable calculator functions + calc->insert("et_log", Variables::Cluster::et_log, cluster_vars); + calc->insert("pt_tau_log", Variables::Cluster::pt_tau_log, cluster_vars); + calc->insert("pt_jetseed_log", Variables::Cluster::pt_jetseed_log, cluster_vars); + calc->insert("dEta", Variables::Cluster::dEta, cluster_vars); + calc->insert("dPhi", Variables::Cluster::dPhi, cluster_vars); + calc->insert("SECOND_R", Variables::Cluster::SECOND_R, cluster_vars); + calc->insert("SECOND_LAMBDA", Variables::Cluster::SECOND_LAMBDA, cluster_vars); + calc->insert("CENTER_LAMBDA", Variables::Cluster::CENTER_LAMBDA, cluster_vars); + //---added for the eVeto + calc->insert("SECOND_LAMBDAOverClustersMeanSecondLambda", Variables::Cluster::SECOND_LAMBDAOverClustersMeanSecondLambda, cluster_vars); + calc->insert("CENTER_LAMBDAOverClustersMeanCenterLambda", Variables::Cluster::CENTER_LAMBDAOverClustersMeanCenterLambda, cluster_vars); + calc->insert("FirstEngDensOverClustersMeanFirstEngDens" , Variables::Cluster::FirstEngDensOverClustersMeanFirstEngDens, cluster_vars); + + //Extension - Variables for GNTau + calc->insert("e", Variables::Cluster::e, cluster_vars); + calc->insert("et", Variables::Cluster::et, cluster_vars); + calc->insert("FIRST_ENG_DENS", Variables::Cluster::FIRST_ENG_DENS, cluster_vars); + calc->insert("EM_PROBABILITY", Variables::Cluster::EM_PROBABILITY, cluster_vars); + calc->insert("CENTER_MAG", Variables::Cluster::CENTER_MAG, cluster_vars); + return calc; +} + + +namespace Variables { +using TauDetail = xAOD::TauJetParameters::Detail; + +bool absEta(const xAOD::TauJet &tau, double &out) { + out = std::abs(tau.eta()); + return true; +} + +bool centFrac(const xAOD::TauJet &tau, double &out) { + float centFrac; + const auto success = tau.detail(TauDetail::centFrac, centFrac); + //out = std::min(centFrac, 1.0f); + out = centFrac; + return success; +} + +bool isolFrac(const xAOD::TauJet &tau, double &out) { + float isolFrac; + const auto success = tau.detail(TauDetail::isolFrac, isolFrac); + //out = std::min(isolFrac, 1.0f); + out = isolFrac; + return success; +} + +bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out) { + float etOverPtLeadTrk; + const auto success = tau.detail(TauDetail::etOverPtLeadTrk, etOverPtLeadTrk); + out = etOverPtLeadTrk; + return success; +} + +bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out) { + float innerTrkAvgDist; + const auto success = tau.detail(TauDetail::innerTrkAvgDist, innerTrkAvgDist); + out = innerTrkAvgDist; + return success; +} + +bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out) { + float ipSigLeadTrk = (tau.nTracks()>0) ? tau.track(0)->d0SigTJVA() : 0.; + //out = std::min(std::abs(ipSigLeadTrk), 30.0f); + out = std::abs(ipSigLeadTrk); + return true; +} + +bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out) { + float sumEMCellEtOverLeadTrkPt; + const auto success = tau.detail(TauDetail::sumEMCellEtOverLeadTrkPt, sumEMCellEtOverLeadTrkPt); + out = sumEMCellEtOverLeadTrkPt; + return success; +} + +bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out) { + float SumPtTrkFrac; + const auto success = tau.detail(TauDetail::SumPtTrkFrac, SumPtTrkFrac); + out = SumPtTrkFrac; + return success; +} + +bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out) { + float EMPOverTrkSysP; + const auto success = tau.detail(TauDetail::EMPOverTrkSysP, EMPOverTrkSysP); + out = EMPOverTrkSysP; + return success; +} + +bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out) { + float ptRatioEflowApprox; + const auto success = tau.detail(TauDetail::ptRatioEflowApprox, ptRatioEflowApprox); + //out = std::min(ptRatioEflowApprox, 4.0f); + out = ptRatioEflowApprox; + return success; +} + +bool mEflowApprox(const xAOD::TauJet &tau, double &out) { + float mEflowApprox; + const auto success = tau.detail(TauDetail::mEflowApprox, mEflowApprox); + out = mEflowApprox; + return success; +} + +bool dRmax(const xAOD::TauJet &tau, double &out) { + float dRmax; + const auto success = tau.detail(TauDetail::dRmax, dRmax); + out = dRmax; + return success; +} + +bool trFlightPathSig(const xAOD::TauJet &tau, double &out) { + float trFlightPathSig; + const auto success = tau.detail(TauDetail::trFlightPathSig, trFlightPathSig); + out = trFlightPathSig; + return success; +} + +bool massTrkSys(const xAOD::TauJet &tau, double &out) { + float massTrkSys; + const auto success = tau.detail(TauDetail::massTrkSys, massTrkSys); + out = massTrkSys; + return success; +} + +bool pt(const xAOD::TauJet &tau, double &out) { + out = tau.pt(); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, double &out) { + out = std::log10(std::max(tau.pt() / GeV, 1e-6)); + return true; +} + +bool ptDetectorAxis(const xAOD::TauJet &tau, double &out) { + out = tau.ptDetectorAxis(); + return true; +} + +bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out) { + out = tau.ptIntermediateAxis(); + return true; +} + +bool ptJetSeed_log(const xAOD::TauJet &tau, double &out) { + out = std::log10(std::max(tau.ptJetSeed(), 1e-3)); + return true; +} + +bool absleadTrackEta(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absEtaLeadTrack("ABS_ETA_LEAD_TRACK"); + float absEtaLeadTrack = acc_absEtaLeadTrack(tau); + out = std::max(0.f, absEtaLeadTrack); + return true; +} + +bool leadTrackDeltaEta(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absDeltaEta("TAU_ABSDELTAETA"); + float absDeltaEta = acc_absDeltaEta(tau); + out = std::max(0.f, absDeltaEta); + return true; +} + +bool leadTrackDeltaPhi(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absDeltaPhi("TAU_ABSDELTAPHI"); + float absDeltaPhi = acc_absDeltaPhi(tau); + out = std::max(0.f, absDeltaPhi); + return true; +} + +bool leadTrackProbNNorHT(const xAOD::TauJet &tau, double &out){ + auto tracks = tau.allTracks(); + + // Sort tracks in descending pt order + if (!tracks.empty()) { + auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) { + return lhs->pt() > rhs->pt(); + }; + std::sort(tracks.begin(), tracks.end(), cmp_pt); + + const xAOD::TauTrack* tauLeadTrack = tracks.at(0); + const xAOD::TrackParticle* xTrackParticle = tauLeadTrack->track(); + float eProbabilityHT = xTrackParticle->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle); + out = (tauLeadTrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT; + } + else { + out = 0.; + } + return true; +} + +bool EMFracFixed(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_emFracFixed("EMFracFixed"); + float emFracFixed = acc_emFracFixed(tau); + out = std::max(emFracFixed, 0.0f); + return true; +} + +bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_etHotShotWinOverPtLeadTrk("etHotShotWinOverPtLeadTrk"); + float etHotShotWinOverPtLeadTrk = acc_etHotShotWinOverPtLeadTrk(tau); + out = std::max(etHotShotWinOverPtLeadTrk, 1e-6f); + return true; +} + +bool hadLeakFracFixed(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_hadLeakFracFixed("hadLeakFracFixed"); + float hadLeakFracFixed = acc_hadLeakFracFixed(tau); + out = std::max(0.f, hadLeakFracFixed); + return true; +} + +bool PSFrac(const xAOD::TauJet &tau, double &out){ + float PSFrac; + const auto success = tau.detail(TauDetail::PSSFraction, PSFrac); + out = std::max(0.f,PSFrac); + return success; +} + +bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, double &out){ + float ClustersMeanCenterLambda; + const auto success = tau.detail(TauDetail::ClustersMeanCenterLambda, ClustersMeanCenterLambda); + out = std::max(0.f, ClustersMeanCenterLambda); + return success; +} + +bool ClustersMeanEMProbability(const xAOD::TauJet &tau, double &out){ + float ClustersMeanEMProbability; + const auto success = tau.detail(TauDetail::ClustersMeanEMProbability, ClustersMeanEMProbability); + out = std::max(0.f, ClustersMeanEMProbability); + return success; +} + +bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, double &out){ + float ClustersMeanFirstEngDens; + const auto success = tau.detail(TauDetail::ClustersMeanFirstEngDens, ClustersMeanFirstEngDens); + out = std::max(-10.f, ClustersMeanFirstEngDens); + return success; +} + +bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out){ + float ClustersMeanPresamplerFrac; + const auto success = tau.detail(TauDetail::ClustersMeanPresamplerFrac, ClustersMeanPresamplerFrac); + out = std::max(0.f, ClustersMeanPresamplerFrac); + return success; +} + +bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, double &out){ + float ClustersMeanSecondLambda; + const auto success = tau.detail(TauDetail::ClustersMeanSecondLambda, ClustersMeanSecondLambda); + out = std::max(0.f, ClustersMeanSecondLambda); + return success; +} + +namespace Track { + +bool pt_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(track.pt()); + return true; +} + +bool trackPt(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.pt(); + return true; +} + +bool trackEta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.eta(); + return true; +} + +bool trackPhi(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.phi(); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) { + out = std::log10(std::max(tau.pt(), 1e-6)); + return true; +} + +bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) { + out = std::log10(tau.ptJetSeed()); + return true; +} + +bool d0_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(std::abs(track.d0TJVA()) + 1e-6); + return true; +} + +bool z0sinThetaTJVA_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(std::abs(track.z0sinthetaTJVA()) + 1e-6); + return true; +} + +bool z0sinthetaTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.z0sinthetaTJVA(); + return true; +} + +bool z0sinthetaSigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.z0sinthetaSigTJVA(); + return true; +} + +bool d0TJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.d0TJVA(); + return true; +} + +bool d0SigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.d0SigTJVA(); + return true; +} + +bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + out = track.eta() - tau.eta(); + return true; +} + +bool dEtaJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed); + out = std::abs(tlvSeedJet.Eta() - track.eta()); + return true; +} + +bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + out = track.p4().DeltaPhi(tau.p4()); + return true; +} + +bool dPhiJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed); + out = tlvSeedJet.DeltaPhi(track.p4()); + return true; +} + +bool nInnermostPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t inner_pixel_hits; + const auto success = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits); + out = inner_pixel_hits; + return success; +} + +bool nPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pixel_hits; + const auto success = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits); + out = pixel_hits; + return success; +} + +bool nSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t sct_hits; + const auto success = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits); + out = sct_hits; + return success; +} + +// same as in tau track classification for trigger +bool nIBLHitsAndExp(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t inner_pixel_hits, inner_pixel_exp; + const auto success1 = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits); + const auto success2 = track.track()->summaryValue(inner_pixel_exp, xAOD::expectInnermostPixelLayerHit); + out = inner_pixel_exp ? inner_pixel_hits : 1.; + return success1 && success2; +} + +bool nPixelHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pixel_hits, pixel_dead; + const auto success1 = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits); + const auto success2 = track.track()->summaryValue(pixel_dead, xAOD::numberOfPixelDeadSensors); + out = pixel_hits + pixel_dead; + return success1 && success2; +} + +bool nSCTHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t sct_hits, sct_dead; + const auto success1 = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits); + const auto success2 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors); + out = sct_hits + sct_dead; + return success1 && success2; +} + +bool eProbabilityHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + float eProbabilityHT; + const auto success = track.track()->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + out = eProbabilityHT; + return success; +} + +bool eProbabilityNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + out = acc_eProbabilityNN(track); + return true; +} + +bool eProbabilityNNorHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + auto atrack = track.track(); + float eProbabilityHT = atrack->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + float eProbabilityNN = acc_eProbabilityNN(*atrack); + out = (atrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT; + return true; +} + +bool chargedScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_chargedScoreRNN("rnn_chargedScore"); + out = acc_chargedScoreRNN(track); + return true; +} + +bool isolationScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_isolationScoreRNN("rnn_isolationScore"); + out = acc_isolationScoreRNN(track); + return true; +} + +bool conversionScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_conversionScoreRNN("rnn_conversionScore"); + out = acc_conversionScoreRNN(track); + return true; +} + +bool fakeScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_fakeScoreRNN("rnn_fakeScore"); + out = acc_fakeScoreRNN(track); + return true; +} + +//Extension - variables for GNTau +bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfInnermostPixelLayerHits); + out = trk_val; + return success; +} + +bool numberOfPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHits); + out = trk_val; + return success; +} + +bool numberOfPixelSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelSharedHits); + out = trk_val; + return success; +} + +bool numberOfPixelDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelDeadSensors); + out = trk_val; + return success; +} + +bool numberOfSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTHits); + out = trk_val; + return success; +} + +bool numberOfSCTSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTSharedHits); + out = trk_val; + return success; +} + +bool numberOfSCTDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTDeadSensors); + out = trk_val; + return success; +} + +bool numberOfTRTHighThresholdHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHighThresholdHits); + out = trk_val; + return success; +} + +bool numberOfTRTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHits); + out = trk_val; + return success; +} + +bool nSiHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pix_hit = 0;uint8_t pix_dead = 0;uint8_t sct_hit = 0;uint8_t sct_dead = 0; + const auto success1 = track.track()->summaryValue(pix_hit, xAOD::numberOfPixelHits); + const auto success2 = track.track()->summaryValue(pix_dead, xAOD::numberOfPixelDeadSensors); + const auto success3 = track.track()->summaryValue(sct_hit, xAOD::numberOfSCTHits); + const auto success4 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors); + out = pix_hit + pix_dead + sct_hit + sct_dead; + return success1 && success2 && success3 && success4; +} + +bool expectInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::expectInnermostPixelLayerHit); + out = trk_val; + return success; +} + +bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::expectNextToInnermostPixelLayerHit); + out = trk_val; + return success; +} + +bool numberOfContribPixelLayers(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfContribPixelLayers); + out = trk_val; + return success; +} + +bool numberOfPixelHoles(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHoles); + out = trk_val; + return success; +} + +bool d0_old(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->d0(); + //out = trk_val; + return true; +} + +bool qOverP(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->qOverP(); + return true; +} + +bool theta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->theta(); + return true; +} + +bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out) { + out = track.track()->z0() + track.track()->vz() - tau.vertex()->z(); + return true; +} + +bool charge(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->charge(); + return true; +} + +bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out = 0.; + static const SG::AuxElement::ConstAccessor<float> acc_dz0TVPV0("dz0_TV_PV0"); + if (tau.isAvailable<float>("dz0_TV_PV0")){out = acc_dz0TVPV0(tau);} + return true; +} + +bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumptTV("log_sumpt_TV"); + if (tau.isAvailable<float>("log_sumpt_TV")){out=acc_logsumptTV(tau);} + return true; +} + +bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2TV("log_sumpt2_TV"); + if (tau.isAvailable<float>("log_sumpt2_TV")){out=acc_logsumpt2TV(tau);} + return true; +} + +bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumptPV0("log_sumpt_PV0"); + if (tau.isAvailable<float>("log_sumpt_PV0")){out=acc_logsumptPV0(tau);} + return true; +} + +bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2PV0("log_sumpt2_PV0"); + if (tau.isAvailable<float>("log_sumpt2_PV0")){out=acc_logsumpt2PV0(tau);} + return true; +} + +} // namespace Track + + +namespace Cluster { +using MomentType = xAOD::CaloCluster::MomentType; + +bool et_log(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = std::log10(cluster.p4().Et()); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) { + out = std::log10(std::max(tau.pt(), 1e-6)); + return true; +} + +bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) { + out = std::log10(tau.ptJetSeed()); + return true; +} + +bool dEta(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.eta() - tau.eta(); + return true; +} + +bool dPhi(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().DeltaPhi(tau.p4()); + return true; +} + +bool SECOND_R(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_R, out); + return success; +} + +bool SECOND_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, out); + return success; +} + +bool CENTER_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, out); + return success; +} + +bool SECOND_LAMBDAOverClustersMeanSecondLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanSecondLambda("ClustersMeanSecondLambda"); + float ClustersMeanSecondLambda = acc_ClustersMeanSecondLambda(tau); + double secondLambda(0); + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, secondLambda); + out = (ClustersMeanSecondLambda != 0.) ? secondLambda/ClustersMeanSecondLambda : 0.; + return success; +} + +bool CENTER_LAMBDAOverClustersMeanCenterLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanCenterLambda("ClustersMeanCenterLambda"); + float ClustersMeanCenterLambda = acc_ClustersMeanCenterLambda(tau); + double centerLambda(0); + const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, centerLambda); + if (ClustersMeanCenterLambda == 0.){ + out = 250.; + }else { + out = centerLambda/ClustersMeanCenterLambda; + } + + out = std::min(out, 250.); + + return success; +} + + +bool FirstEngDensOverClustersMeanFirstEngDens(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + // the ClustersMeanFirstEngDens is the log10 of the energy weighted average of the First_ENG_DENS + // divided by ETot to make it dimension-less, + // so we need to evaluate the difference of log10(clusterFirstEngDens/clusterTotalEnergy) and the ClustersMeanFirstEngDens + double clusterFirstEngDens = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens); + if (clusterFirstEngDens < 1e-6) clusterFirstEngDens = 1e-6; + + static const SG::AuxElement::ConstAccessor<float> acc_ClusterTotalEnergy("ClusterTotalEnergy"); + float clusterTotalEnergy = acc_ClusterTotalEnergy(tau); + if (clusterTotalEnergy < 1e-6) clusterTotalEnergy = 1e-6; + + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanFirstEngDens("ClustersMeanFirstEngDens"); + float clustersMeanFirstEngDens = acc_ClustersMeanFirstEngDens(tau); + + out = std::log10(clusterFirstEngDens/clusterTotalEnergy) - clustersMeanFirstEngDens; + + return status; +} + +//Extension - Variables for GNTau +bool e(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().E(); + return true; +} + +bool et(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().Et(); + return true; +} + +bool FIRST_ENG_DENS(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterFirstEngDens = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens); + out = clusterFirstEngDens; + return status; +} + +bool EM_PROBABILITY(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterEMprob = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::EM_PROBABILITY, clusterEMprob); + out = clusterEMprob; + return status; +} + +bool CENTER_MAG(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterCenterMag = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::CENTER_MAG, clusterCenterMag); + out = clusterCenterMag; + return status; +} + +} // namespace Cluster +} // namespace Variables +} // namespace TauGNNUtils diff --git a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx index 94a4fd2a6c110f946a77a00a5901c8682cf14f99..d6b83c9a7c56f9c3c0ba1e6dd9acbc114707e299 100644 --- a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx +++ b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #include "tauRecTools/lwtnn/LightweightGraph.h" @@ -66,7 +66,7 @@ namespace lwtDev { typedef LightweightGraph::NodeMap NodeMap; LightweightGraph::LightweightGraph(const GraphConfig& config, - std::string default_output): + const std::string& default_output): m_graph(new Graph(config.nodes, config.layers)) { for (const auto& node: config.inputs) { diff --git a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx index ad895f07649f9c1193633baa283bf5815a1cd5d0..fff78982eb2f275a44617fc2881e1f8b616169e9 100644 --- a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx +++ b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #include "tauRecTools/lwtnn/Stack.h" @@ -424,7 +424,7 @@ namespace lwtDev { // __________________________________________________________________ // Recurrent layers - EmbeddingLayer::EmbeddingLayer(int var_row_index, MatrixXd W): + EmbeddingLayer::EmbeddingLayer(int var_row_index, const MatrixXd & W): m_var_row_index(var_row_index), m_W(W) { @@ -470,14 +470,14 @@ namespace lwtDev { // LSTM layer - LSTMLayer::LSTMLayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_i, MatrixXd U_i, VectorXd b_i, - MatrixXd W_f, MatrixXd U_f, VectorXd b_f, - MatrixXd W_o, MatrixXd U_o, VectorXd b_o, - MatrixXd W_c, MatrixXd U_c, VectorXd b_c, - bool go_backwards, - bool return_sequence): + LSTMLayer::LSTMLayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i, + const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f, + const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o, + const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c, + bool go_backwards, + bool return_sequence): m_W_i(W_i), m_U_i(U_i), m_b_i(b_i), @@ -547,11 +547,11 @@ namespace lwtDev { // GRU layer - GRULayer::GRULayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_z, MatrixXd U_z, VectorXd b_z, - MatrixXd W_r, MatrixXd U_r, VectorXd b_r, - MatrixXd W_h, MatrixXd U_h, VectorXd b_h): + GRULayer::GRULayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_z, const MatrixXd & U_z, const VectorXd & b_z, + const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r, + const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h): m_W_z(W_z), m_U_z(U_z), m_b_z(b_z), @@ -621,8 +621,8 @@ namespace lwtDev { } MatrixXd BidirectionalLayer::scan( const MatrixXd& x) const{ - MatrixXd forward = m_forward_layer->scan(x); - MatrixXd backward = m_backward_layer->scan(x); + const MatrixXd & forward = m_forward_layer->scan(x); + const MatrixXd & backward = m_backward_layer->scan(x); MatrixXd backward_rev; if (m_return_sequence){ backward_rev = backward.rowwise().reverse(); diff --git a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx index d4d5db31542ab15ccd7487f6dacb303eb48da1b3..0e78bf7d357e50f2de792c78d4ffb30d5469dab2 100644 --- a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx +++ b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx @@ -25,6 +25,7 @@ #include "tauRecTools/TauWPDecorator.h" #include "tauRecTools/TauIDVarCalculator.h" #include "tauRecTools/TauJetRNNEvaluator.h" +#include "tauRecTools/TauGNNEvaluator.h" #include "tauRecTools/TauDecayModeNNClassifier.h" #include "tauRecTools/TauVertexedClusterDecorator.h" #include "tauRecTools/TauAODSelector.h" @@ -59,6 +60,7 @@ DECLARE_COMPONENT( TauPi0Selector ) DECLARE_COMPONENT( TauWPDecorator ) DECLARE_COMPONENT( TauIDVarCalculator ) DECLARE_COMPONENT( TauJetRNNEvaluator ) +DECLARE_COMPONENT( TauGNNEvaluator ) DECLARE_COMPONENT( TauDecayModeNNClassifier ) DECLARE_COMPONENT( TauVertexedClusterDecorator ) DECLARE_COMPONENT( TauAODSelector ) diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h new file mode 100644 index 0000000000000000000000000000000000000000..e318762111a0edc19e187a7fddab22032a6738a4 --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h @@ -0,0 +1,100 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNN_H +#define TAURECTOOLS_TAUGNN_H + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" + +#include "AsgMessaging/AsgMessaging.h" + +#include "FlavorTagDiscriminants/OnnxUtil.h" + +#include <memory> +#include <string> +#include <map> + +namespace TauGNNUtils { + class GNNVarCalc; +} + +namespace FlavorTagDiscriminants{ + class OnnxUtil; +} + +/** + * @brief Wrapper around ONNXUtil to compute the output score of a model + * + * Configures the network and computes the network outputs given the input + * objects. Retrieval of input variables is handled internally. + * + * @author N.M. Tamir + * + */ +class TauGNN : public asg::AsgMessaging { +public: + // Configuration of the weight file structure + struct Config { + std::string input_layer_scalar; + std::string input_layer_tracks; + std::string input_layer_clusters; + std::string output_node_tau; + std::string output_node_jet; + }; + std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil; +public: + TauGNN(const std::string &nnFile, const Config &config); + ~TauGNN(); + + // Output the OnnxUtil tuple + std::tuple< + std::map<std::string, float>, + std::map<std::string, std::vector<char>>, + std::map<std::string, std::vector<float>> > + compute(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const; + + // Compute all input variables and store them in the maps that are passed by reference + bool calculateInputVariables(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::map<std::string, std::map<std::string, double>>& scalarInputs, + std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const; + + // Getter for the variable calculator + const TauGNNUtils::GNNVarCalc* variable_calculator() const { + return m_var_calc.get(); + } + + //Make the output config transparent to external tools + FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config; + +private: + using Inputs = FlavorTagDiscriminants::Inputs; + // Abbreviations for lwtnn + using VariableMap = std::map<std::string, double>; + using VectorMap = std::map<std::string, std::vector<double>>; + + using InputMap = std::map<std::string, VariableMap>; + using InputSequenceMap = std::map<std::string, VectorMap>; + +private: + const Config m_config; + + // Names of the input variables + std::vector<std::string> m_scalar_inputs; + std::vector<std::string> m_track_inputs; + std::vector<std::string> m_cluster_inputs; + // Names passed to the variable calculator + std::vector<std::string> m_scalarCalc_inputs; + std::vector<std::string> m_trackCalc_inputs; + std::vector<std::string> m_clusterCalc_inputs; + + // Variable calculator to calculate input variables on the fly + std::unique_ptr<TauGNNUtils::GNNVarCalc> m_var_calc; +}; + +#endif // TAURECTOOLS_TAUGNN_H diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h new file mode 100644 index 0000000000000000000000000000000000000000..776a6a96bd700081c8b5e7df17f0eba8c7c98040 --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h @@ -0,0 +1,70 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNNEVALUATOR_H +#define TAURECTOOLS_TAUGNNEVALUATOR_H + +#include "tauRecTools/TauRecToolBase.h" + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" + +#include <memory> + +class TauGNN; + +/** + * @brief Tool to calculate tau identification score from .onnx inputs + * + * The network configuration is supplied in .onnx format. + * Currently runs on a prongness-inclusive model + * Based off of TauJetRNNEvaluator.h format! + * @author N.M. Tamir + * + */ +class TauGNNEvaluator : public TauRecToolBase { +public: + ASG_TOOL_CLASS2(TauGNNEvaluator, TauRecToolBase, ITauToolBase) + + TauGNNEvaluator(const std::string &name = "TauGNNEvaluator"); + virtual ~TauGNNEvaluator(); + + virtual StatusCode initialize() override; + virtual StatusCode execute(xAOD::TauJet &tau) const override; + // Getter for the underlying RNN implementation + const TauGNN* get_gnn() const; + + // Selects tracks to be used as input to the network + StatusCode get_tracks(const xAOD::TauJet &tau, + std::vector<const xAOD::TauTrack *> &out) const; + + // Selects clusters to be used as input to the network + StatusCode get_clusters(const xAOD::TauJet &tau, + std::vector<xAOD::CaloVertexedTopoCluster> &out) const; + +private: + std::string m_output_varname; + std::string m_output_ptau; + std::string m_output_pjet; + std::string m_weightfile; + int m_max_tracks; + int m_max_clusters; + float m_max_cluster_dr; + float m_minTauPt; + bool m_doVertexCorrection; + bool m_doTrackClassification; + bool m_decorateTracks; + + // Configuration of the network file + std::string m_input_layer_scalar; + std::string m_input_layer_tracks; + std::string m_input_layer_clusters; + std::string m_outnode_tau; + std::string m_outnode_jet; + + // Wrappers for lwtnn + std::unique_ptr<TauGNN> m_net; //! +}; + +#endif // TAURECTOOLS_TAUGNNEVALUATOR_H diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7e4913018524c332970af7c8c55d3be1f64c6e98 --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h @@ -0,0 +1,317 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNNUTILS_H +#define TAURECTOOLS_TAUGNNUTILS_H + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" +#include "xAODEventInfo/EventInfo.h" +#include "xAODTracking/VertexContainer.h" +#include "AsgTools/AsgTool.h" +#include "AsgMessaging/AsgMessaging.h" +#include <unordered_map> + + +namespace TauGNNUtils { + +/** + * @brief Tool to calculate input variables for the GNN-based tau identification + * + * Used to calculate input variables for (onnx)GNN-based tau identification on + * the fly by providing a mapping between variable names (strings) and + * functions to calculate these variables. + * + * @author C. Deutsch + * @author W. Davey + * @author N.M. Tamir + * + */ +class GNNVarCalc : public asg::AsgMessaging { +public: + // Pointers to calculator functions + using ScalarCalc = bool (*)(const xAOD::TauJet &, double &); + + using TrackCalc = bool (*)(const xAOD::TauJet &, const xAOD::TauTrack &, + double &); + + using ClusterCalc = bool (*)(const xAOD::TauJet &, + const xAOD::CaloVertexedTopoCluster &, double &); + +public: + GNNVarCalc(); + ~GNNVarCalc() = default; + + // Methods to compute the output (vector) based on the variable name + + // Computes high-level ID variables + bool compute(const std::string &name, const xAOD::TauJet &tau, double &out) const; + + // Computes track variables + bool compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + std::vector<double> &out) const; + + // Computes cluster variables + bool compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::vector<double> &out) const; + + // Methods to insert calculator functions into the lookup table + void insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars); + void insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars); + void insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars); + +private: + // Lookup tables + std::unordered_map<std::string, ScalarCalc> m_scalar_map; + std::unordered_map<std::string, TrackCalc> m_track_map; + std::unordered_map<std::string, ClusterCalc> m_cluster_map; +}; + +// Factory function to create a variable calculator populated with default +// variables +std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars, + const std::vector<std::string>& track_vars, + const std::vector<std::string>& cluster_vars); + + +namespace Variables { + +// Functions to calculate (scalar) input variables +// Returns a status code indicating success +bool absEta(const xAOD::TauJet &tau, double &out); + +bool centFrac(const xAOD::TauJet &tau, double &out); + +bool isolFrac(const xAOD::TauJet &tau, double &out); + +bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out); + +bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out); + +bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out); + +bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out); + +bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out); + +bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out); + +bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out); + +bool mEflowApprox(const xAOD::TauJet &tau, double &out); + +bool dRmax(const xAOD::TauJet &tau, double &out); + +bool trFlightPathSig(const xAOD::TauJet &tau, double &out); + +bool massTrkSys(const xAOD::TauJet &tau, double &out); + +bool pt(const xAOD::TauJet &tau, double &out); + +bool pt_tau_log(const xAOD::TauJet &tau, double &out); + +bool ptDetectorAxis(const xAOD::TauJet &tau, double &out); + +bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out); + +//functions to calculate input variables needed for the eVeto RNN +bool ptJetSeed_log (const xAOD::TauJet &tau, double &out); +bool absleadTrackEta (const xAOD::TauJet &tau, double &out); +bool leadTrackDeltaEta (const xAOD::TauJet &tau, double &out); +bool leadTrackDeltaPhi (const xAOD::TauJet &tau, double &out); +bool leadTrackProbNNorHT (const xAOD::TauJet &tau, double &out); +bool EMFracFixed (const xAOD::TauJet &tau, double &out); +bool etHotShotWinOverPtLeadTrk (const xAOD::TauJet &tau, double &out); +bool hadLeakFracFixed (const xAOD::TauJet &tau, double &out); +bool PSFrac (const xAOD::TauJet &tau, double &out); +bool ClustersMeanCenterLambda (const xAOD::TauJet &tau, double &out); +bool ClustersMeanEMProbability (const xAOD::TauJet &tau, double &out); +bool ClustersMeanFirstEngDens (const xAOD::TauJet &tau, double &out); +bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out); +bool ClustersMeanSecondLambda (const xAOD::TauJet &tau, double &out); +bool EMPOverTrkSysP (const xAOD::TauJet &tau, double &out); + + +namespace Track { + +// Functions to calculate input variables for each track +// Returns a status code indicating success + +bool pt_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool trackPt( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool trackEta( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool trackPhi( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool pt_tau_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool pt_jetseed_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool d0_abs_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool z0sinThetaTJVA_abs_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool z0sinthetaTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool z0sinthetaSigTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool d0TJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool d0SigTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool dEta( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool dEtaJetSeedAxis( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool dPhi( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool dPhiJetSeedAxis( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nInnermostPixelHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nPixelHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nSCTHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +// trigger variants +bool nIBLHitsAndExp ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nPixelHitsPlusDeadSensors ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nSCTHitsPlusDeadSensors ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityHT( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityNNorHT( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool chargedScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool isolationScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool conversionScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool fakeScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +//Extension - variables for GNTau +bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfTRTHighThresholdHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfTRTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool nSiHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool expectInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfContribPixelLayers(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelHoles(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool d0_old(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool qOverP(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool theta(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool charge(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +} // namespace Track + + +namespace Cluster { + +// Functions to calculate input variables for each cluster +// Returns a status code indicating success + +bool et_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool pt_tau_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool pt_jetseed_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool dEta( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool dPhi( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_R( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_LAMBDA( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_LAMBDA( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_LAMBDAOverClustersMeanSecondLambda( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_LAMBDAOverClustersMeanCenterLambda( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool FirstEngDensOverClustersMeanFirstEngDens( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +//Extension - Variables for GNTau +bool e( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool et( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool FIRST_ENG_DENS( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool EM_PROBABILITY( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_MAG( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); +} // namespace Cluster +} // namespace Variables +} // namespace TauJetGNNUtils + +#endif // TAURECTOOLS_TAUGNNUTILS_H diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h index 802d17df2b021d1cc6ca38a73bb9a6c8a340bdc2..16d1b29d7acd4b28009e00826a76b2b902052d20 100644 --- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h +++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #ifndef LIGHTWEIGHT_GRAPH_HH_TAURECTOOLS @@ -72,7 +72,7 @@ namespace lwtDev { // define a "default" output, so that calling "compute" with no // output specified doesn't lead to ambiguity. LightweightGraph(const GraphConfig& config, - std::string default_output = ""); + const std::string& default_output = ""); ~LightweightGraph(); LightweightGraph(LightweightGraph&) = delete; diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h index 69ea569fb0f3addeb7f7d4c7a5f3fe4bfd852036..84d55e46c9ad5f1a6875c595e6fb590e7ff2b08b 100644 --- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h +++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h @@ -1,5 +1,5 @@ /* - Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ #ifndef STACK_HH_TAURECTOOLS @@ -222,7 +222,7 @@ namespace lwtDev { class EmbeddingLayer : public IRecurrentLayer { public: - EmbeddingLayer(int var_row_index, MatrixXd W); + EmbeddingLayer(int var_row_index, const MatrixXd & W); virtual ~EmbeddingLayer() {}; virtual MatrixXd scan( const MatrixXd&) const override; @@ -236,12 +236,12 @@ namespace lwtDev { class LSTMLayer : public IRecurrentLayer { public: - LSTMLayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_i, MatrixXd U_i, VectorXd b_i, - MatrixXd W_f, MatrixXd U_f, VectorXd b_f, - MatrixXd W_o, MatrixXd U_o, VectorXd b_o, - MatrixXd W_c, MatrixXd U_c, VectorXd b_c, + LSTMLayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i, + const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f, + const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o, + const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c, bool go_backwards, bool return_sequence); @@ -277,11 +277,11 @@ namespace lwtDev { class GRULayer : public IRecurrentLayer { public: - GRULayer(ActivationConfig activation, - ActivationConfig inner_activation, - MatrixXd W_z, MatrixXd U_z, VectorXd b_z, - MatrixXd W_r, MatrixXd U_r, VectorXd b_r, - MatrixXd W_h, MatrixXd U_h, VectorXd b_h); + GRULayer(const ActivationConfig & activation, + const ActivationConfig & inner_activation, + const MatrixXd & W_z, const MatrixXd & U_z, const VectorXd & b_z, + const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r, + const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h); virtual ~GRULayer() {}; virtual MatrixXd scan( const MatrixXd&) const override;