diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h index 9033f475eb45ca3ab0ca92da1214a9bacf60b3b3..87641570482e06e05ddf590d77076db8ad2f211e 100644 --- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h +++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h @@ -12,6 +12,7 @@ #include "StoreGate/ReadHandleKey.h" #include "StoreGate/WriteDecorHandleKeyArray.h" #include "xAODTau/TauJetContainer.h" +#include "xAODTracking/VertexContainer.h" #include <string> #include <vector> @@ -32,6 +33,7 @@ namespace DerivationFramework { private: SG::ReadHandleKey<xAOD::TauJetContainer> m_tauContainerKey { this, "TauContainerName", "TauJets", "Input tau container key" }; + SG::ReadHandleKey<xAOD::VertexContainer> m_vtxContainerKey { this, "VertexContainerName", "PrimaryVertices", "Input PV container key" }; SG::WriteDecorHandleKeyArray<xAOD::TauJetContainer> m_decorKeys{ this, "DecorationKeys", {}, "List of decorations added to the tau"}; ToolHandleArray<TauRecToolBase> m_tauIDTools { this, "TauIDTools", {}, "" }; diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py index 10873c78b1ad38086574a987e33a1135b172787d..c6a3495fd9bee888576c4bfd31b244ab14ec18ec 100644 --- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py +++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py @@ -103,6 +103,7 @@ def AddTauIDDecorationCfg(flags, **kwargs): kwargs.setdefault("evetoFix", True) kwargs.setdefault("DeepSetID", True) + kwargs.setdefault("GNNTauID", True) kwargs.setdefault("TauContainerName", "TauJets") kwargs.setdefault("prefix", kwargs['TauContainerName']) @@ -123,6 +124,11 @@ def AddTauIDDecorationCfg(flags, **kwargs): tools.append( acc.popToolsAndMerge(tauTools.TauJetDeepSetEvaluatorCfg(flags, version="v2")) ) tools.append( acc.popToolsAndMerge(tauTools.TauWPDecoratorJetDeepSetCfg(flags, version="v2")) ) + if kwargs['GNNTauID']: + # Add in GNTau! + tools.append( acc.popToolsAndMerge(tauTools.TauGNNEvaluatorCfg(flags)) ) + tools.append( acc.popToolsAndMerge(tauTools.TauWPDecoratorGNNCfg(flags)) ) + if tools: for tool in tools: acc.addPublicTool(tool) diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py index a7067535a37e47e013e8c0a8a1dadae420fcff57..05602ed2898a3b5fa785e9f5ace959018baf1429 100644 --- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py +++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py @@ -4,7 +4,7 @@ TauJetsCPContent = [ "TauJets", - "TauJetsAux.pt.eta.phi.m.ptFinalCalib.etaFinalCalib.ptTauEnergyScale.etaTauEnergyScale.charge.nChargedTracks.nIsolatedTracks.nAllTracks.isTauFlags.PanTau_DecayMode.NNDecayMode.NNDecayModeProb_1p0n.NNDecayModeProb_1p1n.NNDecayModeProb_1pXn.NNDecayModeProb_3p0n.RNNJetScore.RNNJetScoreSigTrans.JetDeepSetScore.JetDeepSetScoreTrans.JetDeepSetVeryLoose.JetDeepSetLoose.JetDeepSetMedium.JetDeepSetTight.JetDeepSetScore_v2.JetDeepSetScoreTrans_v2.JetDeepSetVeryLoose_v2.JetDeepSetLoose_v2.JetDeepSetMedium_v2.JetDeepSetTight_v2.RNNEleScore.RNNEleScoreSigTrans_v1.EleRNNLoose_v1.EleRNNMedium_v1.EleRNNTight_v1.tauTrackLinks.vertexLink.secondaryVertexLink.neutralPFOLinks.pi0PFOLinks.truthParticleLink.truthJetLink.trackWidth.centFrac.etOverPtLeadTrk.innerTrkAvgDist.absipSigLeadTrk.SumPtTrkFrac.EMPOverTrkSysP.ptRatioEflowApprox.mEflowApprox.dRmax.trFlightPathSig.massTrkSys.leadTrackProbNNorHT.EMFracFixed.etHotShotWinOverPtLeadTrk.hadLeakFracFixed.PSFrac.ClustersMeanCenterLambda.ClustersMeanFirstEngDens.ClustersMeanPresamplerFrac", + "TauJetsAux.pt.eta.phi.m.ptFinalCalib.etaFinalCalib.ptTauEnergyScale.etaTauEnergyScale.charge.nChargedTracks.nIsolatedTracks.nAllTracks.isTauFlags.PanTau_DecayMode.NNDecayMode.NNDecayModeProb_1p0n.NNDecayModeProb_1p1n.NNDecayModeProb_1pXn.NNDecayModeProb_3p0n.RNNJetScore.RNNJetScoreSigTrans.JetDeepSetScore.JetDeepSetScoreTrans.JetDeepSetVeryLoose.JetDeepSetLoose.JetDeepSetMedium.JetDeepSetTight.JetDeepSetScore_v2.JetDeepSetScoreTrans_v2.JetDeepSetVeryLoose_v2.JetDeepSetLoose_v2.JetDeepSetMedium_v2.JetDeepSetTight_v2.RNNEleScore.RNNEleScoreSigTrans_v1.EleRNNLoose_v1.EleRNNMedium_v1.EleRNNTight_v1.tauTrackLinks.vertexLink.secondaryVertexLink.neutralPFOLinks.pi0PFOLinks.truthParticleLink.truthJetLink.trackWidth.centFrac.etOverPtLeadTrk.innerTrkAvgDist.absipSigLeadTrk.SumPtTrkFrac.EMPOverTrkSysP.ptRatioEflowApprox.mEflowApprox.dRmax.trFlightPathSig.massTrkSys.leadTrackProbNNorHT.EMFracFixed.etHotShotWinOverPtLeadTrk.hadLeakFracFixed.PSFrac.ClustersMeanCenterLambda.ClustersMeanFirstEngDens.ClustersMeanPresamplerFrac.GNTauScore.GNTauScoreSigTrans_v0.GNTauVL_v0.GNTauL_v0.GNTauM_v0.GNTauT_v0", "TauTracks", "TauTracksAux.pt.eta.phi.flagSet.trackLinks.rnn_chargedScore.rnn_isolationScore.rnn_conversionScore", "InDetTrackParticles", diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx index ee95314f5f9a86879b35760ffd088eea7a625c12..4f218cf68678ca47ad3cfec87897d5a3af90d229 100644 --- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx +++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx @@ -21,7 +21,7 @@ namespace DerivationFramework { // parse the properties of TauWPDecorator tools for (const auto& tool : m_tauIDTools) { - if (tool->type() != "TauWPDecorator") continue; + if ((tool->type() != "TauWPDecorator" )) continue; // check whether we must compute eVeto WPs, as this requires the recalculation of a variable BooleanProperty useAbsEta("UseAbsEta", false); @@ -57,6 +57,7 @@ namespace DerivationFramework { // initialize read/write handle keys ATH_CHECK( m_tauContainerKey.initialize() ); + ATH_CHECK( m_vtxContainerKey.initialize() ); ATH_CHECK( m_decorKeys.initialize() ); return StatusCode::SUCCESS; @@ -77,30 +78,59 @@ namespace DerivationFramework { } const xAOD::TauJetContainer* tauContainer = tauJetsReadHandle.cptr(); - static const SG::AuxElement::Decorator<float> acc_trackWidth("trackWidth"); - - for (const auto tau : *tauContainer) { - float tauTrackBasedWidth = 0; - // equivalent to - // tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged) - std::vector<const xAOD::TauTrack *> tauTracks = tau->tracks(); - for (const xAOD::TauTrack *trk : tau->tracks( - xAOD::TauJetParameters::TauTrackFlag::classifiedIsolation)) { - tauTracks.push_back(trk); - } - double sumWeightedDR = 0.; - double ptSum = 0.; - for (const xAOD::TauTrack *track : tauTracks) { - double deltaR = tau->p4().DeltaR(track->p4()); - sumWeightedDR += deltaR * track->pt(); - ptSum += track->pt(); + // retrieve PrimaryVertices container + SG::ReadHandle<xAOD::VertexContainer> vtxReadHandle(m_vtxContainerKey); + if (!vtxReadHandle.isValid()) { + ATH_MSG_ERROR ("Could not retrieve VertexContainer with key " << vtxReadHandle.key()); + return StatusCode::FAILURE; } - if (ptSum > 0) { - tauTrackBasedWidth = sumWeightedDR / ptSum; + const xAOD::VertexContainer* vtxContainer = vtxReadHandle.cptr(); + const xAOD::Vertex* pVtx = nullptr; + + // Check that PV container exists and is non-empty, find the PV if possible + if(vtxContainer != nullptr && vtxContainer->size()>0) { + ATH_MSG_DEBUG("Found vtx container for decorating taus!"); + auto itrVtx = std::find_if(vtxContainer->begin(), vtxContainer->end(), + [](const xAOD::Vertex* vtx) { + return vtx->vertexType() == xAOD::VxType::PriVtx; + }); + pVtx = (itrVtx == vtxContainer->end() ? 0 : *itrVtx); + if(!pVtx){ + ATH_MSG_WARNING("No PV found, using the first element instead!"); + pVtx = vtxContainer->at(0); + } } + + //Create accessors + static const SG::AuxElement::Decorator<float> acc_trackWidth("trackWidth"); + static const SG::AuxElement::Accessor<float> acc_dz0_TV_PV0("dz0_TV_PV0"); + static const SG::AuxElement::Accessor<float> acc_log_sumpt_TV("log_sumpt_TV"); + static const SG::AuxElement::Accessor<float> acc_log_sumpt2_TV("log_sumpt2_TV"); + static const SG::AuxElement::Accessor<float> acc_log_sumpt_PV0("log_sumpt_PV0"); + static const SG::AuxElement::Accessor<float> acc_log_sumpt2_PV0("log_sumpt2_PV0"); + + for (const auto tau : *tauContainer) { + float tauTrackBasedWidth = 0; + // equivalent to + // tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged) + std::vector<const xAOD::TauTrack *> tauTracks = tau->tracks(); + for (const xAOD::TauTrack *trk : tau->tracks( + xAOD::TauJetParameters::TauTrackFlag::classifiedIsolation)) { + tauTracks.push_back(trk); + } + double sumWeightedDR = 0.; + double ptSum = 0.; + for (const xAOD::TauTrack *track : tauTracks) { + double deltaR = tau->p4().DeltaR(track->p4()); + sumWeightedDR += deltaR * track->pt(); + ptSum += track->pt(); + } + if (ptSum > 0) { + tauTrackBasedWidth = sumWeightedDR / ptSum; + } - acc_trackWidth(*tau) = tauTrackBasedWidth; - } + acc_trackWidth(*tau) = tauTrackBasedWidth; + } // create shallow copy auto shallowCopy = xAOD::shallowCopyContainer (*tauContainer); @@ -108,6 +138,26 @@ namespace DerivationFramework { static const SG::AuxElement::Accessor<float> acc_absEtaLead("ABS_ETA_LEAD_TRACK"); for (auto tau : *shallowCopy.first) { + + //Add in the TV/PV0 vertex variables needed for some calculators in TauGNNUtils.cxx (for GNTau) + float dz0_TV_PV0 = -999., sumpt_TV = 0., sumpt2_TV = 0., sumpt_PV0 = 0., sumpt2_PV0 = 0.; + if(pVtx!=nullptr) { + dz0_TV_PV0 = tau->vertex()->z() - pVtx->z(); + for (const ElementLink<xAOD::TrackParticleContainer>& trk : pVtx->trackParticleLinks()) { + sumpt_PV0 += (*trk)->pt(); + sumpt2_PV0 += pow((*trk)->pt(), 2.); + } + for (const ElementLink<xAOD::TrackParticleContainer>& trk : tau->vertex()->trackParticleLinks()) { + sumpt_TV += (*trk)->pt(); + sumpt2_TV += pow((*trk)->pt(), 2.); + } + } + acc_dz0_TV_PV0(*tau) = dz0_TV_PV0; + acc_log_sumpt_TV(*tau) = (sumpt_TV>0.) ? std::log(sumpt_TV) : 0.; + acc_log_sumpt2_TV(*tau) = (sumpt2_TV>0.) ? std::log(sumpt2_TV) : 0.; + acc_log_sumpt_PV0(*tau) = (sumpt_PV0>0.) ? std::log(sumpt_PV0) : 0.; + acc_log_sumpt2_PV0(*tau) = (sumpt2_PV0>0.) ? std::log(sumpt2_PV0) : 0.; + //End of vertex variable addition block // ABS_ETA_LEAD_TRACK is removed from the AOD content and must be redecorated when computing eVeto WPs // note: this redecoration is not robust against charged track thinning, but charged tracks should never be thinned diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py index f4c7ff650c38985a3f7eb8936ba7d2879c78156b..d17807f2bf9667ee81c67d1a18046181b2430c6d 100644 --- a/Reconstruction/tauRec/python/TauConfigFlags.py +++ b/Reconstruction/tauRec/python/TauConfigFlags.py @@ -60,6 +60,9 @@ def createTauConfigFlags(): # R22 DeepSet tau ID tune without track RNN scores, for now define a second set of flags, but ultimately we'll choose one and drop the other tau_cfg.addFlag("Tau.TauJetDeepSetConfig_v2", ["tauid_1p_R22_dpst_noTrackScore.json", "tauid_2p_R22_dpst_noTrackScore.json", "tauid_3p_R22_dpst_noTrackScore.json"]) tau_cfg.addFlag("Tau.TauJetDeepSetWP_v2", ["model_1p_R22_dpst_noTrackScore.root", "model_2p_R22_dpst_noTrackScore.root", "model_3p_R22_dpst_noTrackScore.root"]) + # GNTau ID tune file (need to add another version for noAux) + tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_noAux_simplified.onnx"]) + tau_cfg.addFlag("Tau.TauGNNWP_v0", ["GNTauNA_flat_model_1p.root", "GNTauNA_flat_model_2p.root", "GNTauNA_flat_model_3p.root"]) # PanTau config flags from PanTauAlgs.PanTauConfigFlags import createPanTauConfigFlags diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py index 1a9b19d3f770dda57a20555413f8fe7924be9ec7..338a1ad7a67942987cf2913c9f355d28c9324c74 100644 --- a/Reconstruction/tauRec/python/TauToolHolder.py +++ b/Reconstruction/tauRec/python/TauToolHolder.py @@ -858,6 +858,54 @@ def TauWPDecoratorJetDeepSetCfg(flags, version=None): result.setPrivateTools(myTauWPDecorator) return result +def TauGNNEvaluatorCfg(flags): + result = ComponentAccumulator() + _name = flags.Tau.ActiveConfig.prefix + 'TauGNN' + + TauGNNEvaluator = CompFactory.getComp("TauGNNEvaluator") + GNNConf = flags.Tau.TauGNNConfig + myTauGNNEvaluator = TauGNNEvaluator(name = _name, + NetworkFile = GNNConf[0], + OutputVarname = "GNTauScore", + OutputPTau = "GNTauProbTau", + OutputPJet = "GNTauProbJet", + MaxTracks = 30, + MaxClusters = 20, + MaxClusterDR = 15.0, + VertexCorrection = True, + DecorateTracks = False, + InputLayerScalar = "tau_vars", + InputLayerTracks = "track_vars", + InputLayerClusters = "cluster_vars", + NodeNameTau="GN2TauNoAux_pb", + NodeNameJet="GN2TauNoAux_pu") + + result.setPrivateTools(myTauGNNEvaluator) + return result + +def TauWPDecoratorGNNCfg(flags): + result = ComponentAccumulator() + _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN' + + TauWPDecorator = CompFactory.getComp("TauWPDecorator") + WPConf = flags.Tau.TauGNNWP_v0 + decorWPNames = ["GNTauVL_v0", "GNTauL_v0", "GNTauM_v0", "GNTauT_v0"] + scoreName = "GNTauScore" + newScoreName = "GNTauScoreSigTrans_v0" + myTauWPDecorator = TauWPDecorator(name=_name, + flatteningFile1Prong = WPConf[0], + flatteningFile2Prong = WPConf[1], + flatteningFile3Prong = WPConf[2], + DecorWPNames = decorWPNames, + DecorWPCutEffs1P = [0.95, 0.85, 0.75, 0.60], + DecorWPCutEffs2P = [0.95, 0.75, 0.60, 0.45], + DecorWPCutEffs3P = [0.95, 0.75, 0.60, 0.45], + ScoreName = scoreName, + NewScoreName = newScoreName, + DefineWPs = True) + result.setPrivateTools(myTauWPDecorator) + return result + def TauEleRNNEvaluatorCfg(flags): result = ComponentAccumulator() _name = flags.Tau.ActiveConfig.prefix + 'TauEleRNN' diff --git a/Reconstruction/tauRecTools/CMakeLists.txt b/Reconstruction/tauRecTools/CMakeLists.txt index 40a34801c23a244adb2a9be46ffa39c086df3fe2..87b3c12ef094a880e7a73652436f1f74dc33a39f 100644 --- a/Reconstruction/tauRecTools/CMakeLists.txt +++ b/Reconstruction/tauRecTools/CMakeLists.txt @@ -8,6 +8,7 @@ find_package( Boost ) find_package( Eigen ) find_package( ROOT COMPONENTS Core Tree Hist RIO ) find_package( lwtnn ) +find_package( onnxruntime ) # Optional dependencies. set( extra_public_libs ) @@ -28,12 +29,12 @@ atlas_add_library( tauRecToolsLib tauRecTools/*.h Root/*.cxx tauRecTools/lwtnn/*.h Root/lwtnn/*.cxx PUBLIC_HEADERS tauRecTools INCLUDE_DIRS ${Boost_INCLUDE_DIRS} ${EIGEN_INCLUDE_DIRS} - ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} + ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${ONNXRUNTIME_INCLUDE_DIRS} LINK_LIBRARIES ${Boost_LIBRARIES} ${EIGEN_LIBRARIES} ${LWTNN_LIBRARIES} - ${ROOT_LIBRARIES} + ${ROOT_LIBRARIES} ${ONNXRUNTIME_LIBRARIES} CxxUtils AthLinks AsgMessagingLib AsgDataHandlesLib AsgTools xAODCaloEvent xAODEventInfo xAODJet xAODParticleEvent xAODPFlow xAODTau xAODTracking xAODEventShape - MVAUtils ${extra_public_libs} + MVAUtils FlavorTagDiscriminants ${extra_public_libs} PRIVATE_LINK_LIBRARIES CaloGeoHelpers FourMomUtils PathResolver ) atlas_add_dictionary( tauRecToolsDict @@ -46,5 +47,5 @@ if( NOT XAOD_STANDALONE ) INCLUDE_DIRS ${Boost_INCLUDE_DIRS} LINK_LIBRARIES ${Boost_LIBRARIES} AsgTools AsgMessagingLib xAODBase xAODCaloEvent xAODJet xAODPFlow xAODTau FourMomUtils tauRecToolsLib PFlowUtilsLib - ${extra_private_libs} ) + FlavorTagDiscriminants ${extra_private_libs} ) endif() diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx new file mode 100644 index 0000000000000000000000000000000000000000..65f13c4c2f6bdde1fd3fd0901579612eb2e44bb4 --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -0,0 +1,206 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNN.h" +#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "PathResolver/PathResolver.h" + +#include <algorithm> +#include <fstream> + +#include "lwtnn/LightweightGraph.hh" +#include "lwtnn/Exceptions.hh" +//#include "lwtnn/parse_json.hh" + +#include "tauRecTools/TauGNNUtils.h" + +TauGNN::TauGNN(const std::string &nnFile, const Config &config): + asg::AsgMessaging("TauGNN"), + m_onnxUtil(nullptr) + { + //==================================================// + // This part is ported from FTagDiscriminant GNN.cxx// + //==================================================// + + m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile); + + // get the configuration of the model outputs + FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + + //Let's see the output! + for (const auto& out_node: gnn_output_config) { + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name); + } + + //Get model config (for inputs) + auto lwtnn_config = m_onnxUtil->getLwtConfig(); + + //===================================================// + // This part is ported from tauRecTools TauJetRNN.cxx// + //===================================================// + + // Search for input layer names specified in 'config' + auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_scalar; + }; + auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_tracks; + }; + auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) { + return in_node.name == config.input_layer_clusters; + }; + + auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(), + lwtnn_config.inputs.cend(), + node_is_scalar); + + auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(), + lwtnn_config.input_sequences.cend(), + node_is_track); + + auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(), + lwtnn_config.input_sequences.cend(), + node_is_cluster); + + // Check which input layers were found + auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend(); + auto has_track_node = track_node != lwtnn_config.input_sequences.cend(); + auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend(); + if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!"); + if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!"); + if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!"); + + // Fill the variable names of each input layer into the corresponding vector + if (has_scalar_node) { + for (const auto &in : scalar_node->variables) { + std::string name = in.name; + m_scalarCalc_inputs.push_back(name); + } + } + + if (has_track_node) { + for (const auto &in : track_node->variables) { + std::string name = in.name; + m_trackCalc_inputs.push_back(name); + } + } + + if (has_cluster_node) { + for (const auto &in : cluster_node->variables) { + std::string name = in.name; + m_clusterCalc_inputs.push_back(name); + } + } + // Load the variable calculator + m_var_calc = TauGNNUtils::get_calculator(m_scalarCalc_inputs, m_trackCalc_inputs, m_clusterCalc_inputs); + ATH_MSG_INFO("TauGNN object initialized successfully!"); +} + +TauGNN::~TauGNN() {} + +std::tuple< + std::map<std::string, float>, + std::map<std::string, std::vector<char>>, + std::map<std::string, std::vector<float>> > +TauGNN::compute(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { + InputMap scalarInputs; + InputSequenceMap vectorInputs; + std::map<std::string, input_pair> gnn_input; + ATH_MSG_DEBUG("Starting compute..."); + //Prepare input variables + if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) { + ATH_MSG_FATAL("Failed calculateInputVariables"); + throw StatusCode::FAILURE; + } + + // Add TauJet-level features to the input + std::vector<float> tau_feats; + for (const auto &varname : m_scalarCalc_inputs) { + tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname])); + } + std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())}; + input_pair tau_info (tau_feats, tau_feats_dim); + gnn_input.insert({"tau_vars", tau_info}); + + //Add track-level features to the input + std::vector<float> trk_feats; + int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size()); + int num_node_vars=static_cast<int>(m_trackCalc_inputs.size()); + trk_feats.resize(num_nodes * num_node_vars); + int var_idx=0; + for (const auto &varname : m_trackCalc_inputs) { + for (int node_idx=0; node_idx<num_nodes; node_idx++){ + trk_feats.at(node_idx*num_node_vars + var_idx) + = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx)); + } + var_idx++; + } + std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars}; + input_pair trk_info (trk_feats, trk_feats_dim); + gnn_input.insert({"track_vars", trk_info}); + + //Add cluster-level features to the input + std::vector<float> cls_feats; + num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size()); + num_node_vars=static_cast<int>(m_clusterCalc_inputs.size()); + cls_feats.resize(num_nodes * num_node_vars); + var_idx=0; + for (const auto &varname : m_clusterCalc_inputs) { + for (int node_idx=0; node_idx<num_nodes; node_idx++){ + cls_feats.at(node_idx*num_node_vars + var_idx) + = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx)); + } + var_idx++; + } + std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars}; + input_pair cls_info (cls_feats, cls_feats_dim); + gnn_input.insert({"cluster_vars", cls_info}); + + //RUN THE INFERENCE!!! + ATH_MSG_DEBUG("Prepared inputs, running inference..."); + auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + ATH_MSG_DEBUG("Finished compute!"); + return std::make_tuple(out_f, out_vc, out_vf); +} + +bool TauGNN::calculateInputVariables(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::map<std::string, std::map<std::string, double>>& scalarInputs, + std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const { + scalarInputs.clear(); + vectorInputs.clear(); + // Populate input (sequence) map with input variables + for (const auto &varname : m_scalarCalc_inputs) { + if (!m_var_calc->compute(varname, tau, + scalarInputs[m_config.input_layer_scalar][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + + for (const auto &varname : m_trackCalc_inputs) { + if (!m_var_calc->compute(varname, tau, tracks, + vectorInputs[m_config.input_layer_tracks][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + + for (const auto &varname : m_clusterCalc_inputs) { + if (!m_var_calc->compute(varname, tau, clusters, + vectorInputs[m_config.input_layer_clusters][varname])) { + ATH_MSG_WARNING("Error computing '" << varname + << "' returning default"); + return false; + } + } + return true; +} diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx new file mode 100644 index 0000000000000000000000000000000000000000..0be8819d79ec52275dfb7c251872439bec964221 --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx @@ -0,0 +1,175 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNNEvaluator.h" +#include "tauRecTools/TauGNN.h" +#include "tauRecTools/HelperFunctions.h" + +#include "PathResolver/PathResolver.h" + +#include <algorithm> + + +TauGNNEvaluator::TauGNNEvaluator(const std::string &name): + TauRecToolBase(name), + m_net(nullptr){ + + declareProperty("NetworkFile", m_weightfile = ""); + declareProperty("OutputVarname", m_output_varname = "GNTauScore"); + declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau"); + declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet"); + declareProperty("MaxTracks", m_max_tracks = 30); + declareProperty("MaxClusters", m_max_clusters = 20); + declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f); + declareProperty("VertexCorrection", m_doVertexCorrection = true); + declareProperty("DecorateTracks", m_decorateTracks = false); + declareProperty("TrackClassification", m_doTrackClassification = true); + + // Naming conventions for the network weight files: + declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars"); + declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars"); + declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars"); + declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb"); + declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu"); + } + +TauGNNEvaluator::~TauGNNEvaluator() {} + +StatusCode TauGNNEvaluator::initialize() { + ATH_MSG_INFO("Initializing TauGNNEvaluator"); + + std::string weightfile(""); + + // Use PathResolver to search for the weight files + if (!m_weightfile.empty()) { + weightfile = find_file(m_weightfile); + if (weightfile.empty()) { + ATH_MSG_ERROR("Could not find network weights: " << m_weightfile); + return StatusCode::FAILURE; + } else { + ATH_MSG_INFO("Using network config: " << weightfile); + } + } + + // Set the layer and node names in the weight file + TauGNN::Config config; + config.input_layer_scalar = m_input_layer_scalar; + config.input_layer_tracks = m_input_layer_tracks; + config.input_layer_clusters = m_input_layer_clusters; + config.output_node_tau = m_outnode_tau; + config.output_node_jet = m_outnode_jet; + + // Load the weights and create the network + if (!weightfile.empty()) { + m_net = std::make_unique<TauGNN>(weightfile, config); + if (!m_net) { + ATH_MSG_ERROR("No network configured."); + return StatusCode::FAILURE; + } + } + + return StatusCode::SUCCESS; +} + +StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const { + // Output variable Decorators + const SG::AuxElement::Accessor<float> output(m_output_varname); + const SG::AuxElement::Accessor<float> out_ptau(m_output_ptau); + const SG::AuxElement::Accessor<float> out_pjet(m_output_pjet); + const SG::AuxElement::Decorator<char> out_trkclass("GNTau_TrackClass"); + // Set default score and overwrite later + output(tau) = -1111.0f; + out_ptau(tau) = -1111.0f; + out_pjet(tau) = -1111.0f; + //Skip execution for low-pT taus to save resources + if(tau.pt()<13000) { + return StatusCode::SUCCESS; + } + + // Get input objects + std::vector<const xAOD::TauTrack *> tracks; + ATH_CHECK(get_tracks(tau, tracks)); + std::vector<xAOD::CaloVertexedTopoCluster> clusters; + ATH_CHECK(get_clusters(tau, clusters)); + + // Truncate tracks + int numTracksMax = std::min(m_max_tracks, tracks.size()); + std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax); + // Evaluate networks + if (m_net) { + auto [out_f, out_vc, out_vf] = m_net->compute(tau, trackVec, clusters); + output(tau)=std::log10(1/(1-out_f.at(m_outnode_tau))); + out_ptau(tau)=out_f.at(m_outnode_tau); + out_pjet(tau)=out_f.at(m_outnode_jet); + if (m_decorateTracks){ + for(unsigned int i=0;i<tracks.size();i++){ + if(i<out_vc.at("track_class").size()){out_trkclass(*tracks.at(i))=out_vc.at("track_class").at(i);} + else{out_trkclass(*tracks.at(i))='9';} //Dummy value for tracks outside range of out_vc + } + } + } + + return StatusCode::SUCCESS; +} + +const TauGNN* TauGNNEvaluator::get_gnn() const { + return m_net.get(); +} + + +StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const { + std::vector<const xAOD::TauTrack*> tracks = tau.allTracks(); + + // Skip unclassified tracks: + // - the track is a LRT and classifyLRT = false + // - the track is not among the MaxNtracks highest-pt tracks in the track classifier + // - track classification is not run (trigger) + if(m_doTrackClassification) { + std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin(); + while(it != tracks.end()) { + if((*it)->flag(xAOD::TauJetParameters::unclassified)) { + it = tracks.erase(it); + } + else { + ++it; + } + } + } + + // Sort by descending pt + auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) { + return lhs->pt() > rhs->pt(); + }; + std::sort(tracks.begin(), tracks.end(), cmp_pt); + out = std::move(tracks); + + return StatusCode::SUCCESS; +} + +StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const { + + TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection); + + std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters(); + for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) { + TLorentzVector clusterP4 = vertexedCluster.p4(); + if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue; + + clusters.push_back(vertexedCluster); + } + + // Sort by descending et + auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs, + const xAOD::CaloVertexedTopoCluster& rhs) { + return lhs.p4().Et() > rhs.p4().Et(); + }; + std::sort(clusters.begin(), clusters.end(), et_cmp); + + // Truncate clusters + if (clusters.size() > m_max_clusters) { + clusters.resize(m_max_clusters, clusters[0]); + } + + return StatusCode::SUCCESS; +} diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx new file mode 100644 index 0000000000000000000000000000000000000000..8da2fe5ea6ec2056d79a99a00b21908155379630 --- /dev/null +++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx @@ -0,0 +1,914 @@ +/* + Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/TauGNNUtils.h" +#include "tauRecTools/HelperFunctions.h" +#include <algorithm> +#include <iostream> +#define GeV 1000 + +namespace TauGNNUtils { + +GNNVarCalc::GNNVarCalc() : asg::AsgMessaging("TauGNNUtils::GNNVarCalc") { +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + double &out) const { + // Retrieve calculator function + ScalarCalc func = nullptr; + try { + func = m_scalar_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variable + return func(tau, out); +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + std::vector<double> &out) const { + out.clear(); + + // Retrieve calculator function + TrackCalc func = nullptr; + try { + func = m_track_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variables for selected tracks + bool success = true; + double value; + for (const auto *const trk : tracks) { + success = success && func(tau, *trk, value); + out.push_back(value); + } + + return success; +} + +bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::vector<double> &out) const { + out.clear(); + + // Retrieve calculator function + ClusterCalc func = nullptr; + try { + func = m_cluster_map.at(name); + } catch (const std::out_of_range &e) { + ATH_MSG_ERROR("Variable '" << name << "' not defined"); + throw; + } + + // Calculate variables for selected clusters + bool success = true; + double value; + for (const xAOD::CaloVertexedTopoCluster& cluster : clusters) { + success = success && func(tau, cluster, value); + out.push_back(value); + } + + return success; +} + +void GNNVarCalc::insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars) { + if (std::find(scalar_vars.begin(), scalar_vars.end(), name) == scalar_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_scalar_map[name] = func; +} + +void GNNVarCalc::insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars) { + if (std::find(track_vars.begin(), track_vars.end(), name) == track_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_track_map[name] = func; +} + +void GNNVarCalc::insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars) { + if (std::find(cluster_vars.begin(), cluster_vars.end(), name) == cluster_vars.end()) { + return; + } + if (!func) { + throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert"); + } + m_cluster_map[name] = func; +} + +std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars, + const std::vector<std::string>& track_vars, + const std::vector<std::string>& cluster_vars) { + auto calc = std::make_unique<GNNVarCalc>(); + + // Scalar variable calculator functions + calc->insert("absEta", Variables::absEta, scalar_vars); + calc->insert("isolFrac", Variables::isolFrac, scalar_vars); + calc->insert("centFrac", Variables::centFrac, scalar_vars); + calc->insert("etOverPtLeadTrk", Variables::etOverPtLeadTrk, scalar_vars); + calc->insert("innerTrkAvgDist", Variables::innerTrkAvgDist, scalar_vars); + calc->insert("absipSigLeadTrk", Variables::absipSigLeadTrk, scalar_vars); + calc->insert("SumPtTrkFrac", Variables::SumPtTrkFrac, scalar_vars); + calc->insert("sumEMCellEtOverLeadTrkPt", Variables::sumEMCellEtOverLeadTrkPt, scalar_vars); + calc->insert("EMPOverTrkSysP", Variables::EMPOverTrkSysP, scalar_vars); + calc->insert("ptRatioEflowApprox", Variables::ptRatioEflowApprox, scalar_vars); + calc->insert("mEflowApprox", Variables::mEflowApprox, scalar_vars); + calc->insert("dRmax", Variables::dRmax, scalar_vars); + calc->insert("trFlightPathSig", Variables::trFlightPathSig, scalar_vars); + calc->insert("massTrkSys", Variables::massTrkSys, scalar_vars); + calc->insert("pt", Variables::pt, scalar_vars); + calc->insert("pt_tau_log", Variables::pt_tau_log, scalar_vars); + calc->insert("ptDetectorAxis", Variables::ptDetectorAxis, scalar_vars); + calc->insert("ptIntermediateAxis", Variables::ptIntermediateAxis, scalar_vars); + //---added for the eVeto + calc->insert("ptJetSeed_log", Variables::ptJetSeed_log, scalar_vars); + calc->insert("absleadTrackEta", Variables::absleadTrackEta, scalar_vars); + calc->insert("leadTrackDeltaEta", Variables::leadTrackDeltaEta, scalar_vars); + calc->insert("leadTrackDeltaPhi", Variables::leadTrackDeltaPhi, scalar_vars); + calc->insert("leadTrackProbNNorHT", Variables::leadTrackProbNNorHT, scalar_vars); + calc->insert("EMFracFixed", Variables::EMFracFixed, scalar_vars); + calc->insert("etHotShotWinOverPtLeadTrk", Variables::etHotShotWinOverPtLeadTrk, scalar_vars); + calc->insert("hadLeakFracFixed", Variables::hadLeakFracFixed, scalar_vars); + calc->insert("PSFrac", Variables::PSFrac, scalar_vars); + calc->insert("ClustersMeanCenterLambda", Variables::ClustersMeanCenterLambda, scalar_vars); + calc->insert("ClustersMeanFirstEngDens", Variables::ClustersMeanFirstEngDens, scalar_vars); + calc->insert("ClustersMeanPresamplerFrac", Variables::ClustersMeanPresamplerFrac, scalar_vars); + + // Track variable calculator functions + calc->insert("pt_log", Variables::Track::pt_log, track_vars); + calc->insert("trackPt", Variables::Track::trackPt, track_vars); + calc->insert("trackEta", Variables::Track::trackEta, track_vars); + calc->insert("trackPhi", Variables::Track::trackPhi, track_vars); + calc->insert("pt_tau_log", Variables::Track::pt_tau_log, track_vars); + calc->insert("pt_jetseed_log", Variables::Track::pt_jetseed_log, track_vars); + calc->insert("d0_abs_log", Variables::Track::d0_abs_log, track_vars); + calc->insert("z0sinThetaTJVA_abs_log", Variables::Track::z0sinThetaTJVA_abs_log, track_vars); + calc->insert("z0sinthetaTJVA", Variables::Track::z0sinthetaTJVA, track_vars); + calc->insert("z0sinthetaSigTJVA", Variables::Track::z0sinthetaSigTJVA, track_vars); + calc->insert("d0TJVA", Variables::Track::d0TJVA, track_vars); + calc->insert("d0SigTJVA", Variables::Track::d0SigTJVA, track_vars); + calc->insert("dEta", Variables::Track::dEta, track_vars); + calc->insert("dPhi", Variables::Track::dPhi, track_vars); + calc->insert("nInnermostPixelHits", Variables::Track::nInnermostPixelHits, track_vars); + calc->insert("nPixelHits", Variables::Track::nPixelHits, track_vars); + calc->insert("nSCTHits", Variables::Track::nSCTHits, track_vars); + calc->insert("nIBLHitsAndExp", Variables::Track::nIBLHitsAndExp, track_vars); + calc->insert("nPixelHitsPlusDeadSensors", Variables::Track::nPixelHitsPlusDeadSensors, track_vars); + calc->insert("nSCTHitsPlusDeadSensors", Variables::Track::nSCTHitsPlusDeadSensors, track_vars); + calc->insert("eProbabilityHT", Variables::Track::eProbabilityHT, track_vars); + calc->insert("eProbabilityNN", Variables::Track::eProbabilityNN, track_vars); + calc->insert("eProbabilityNNorHT", Variables::Track::eProbabilityNNorHT, track_vars); + calc->insert("chargedScoreRNN", Variables::Track::chargedScoreRNN, track_vars); + calc->insert("isolationScoreRNN", Variables::Track::isolationScoreRNN, track_vars); + calc->insert("conversionScoreRNN", Variables::Track::conversionScoreRNN, track_vars); + calc->insert("fakeScoreRNN", Variables::Track::fakeScoreRNN, track_vars); + //Extension - variables for GNTau + calc->insert("numberOfInnermostPixelLayerHits", Variables::Track::numberOfInnermostPixelLayerHits, track_vars); + calc->insert("numberOfPixelHits", Variables::Track::numberOfPixelHits, track_vars); + calc->insert("numberOfPixelSharedHits", Variables::Track::numberOfPixelSharedHits, track_vars); + calc->insert("numberOfPixelDeadSensors", Variables::Track::numberOfPixelDeadSensors, track_vars); + calc->insert("numberOfSCTHits", Variables::Track::numberOfSCTHits, track_vars); + calc->insert("numberOfSCTSharedHits", Variables::Track::numberOfSCTSharedHits, track_vars); + calc->insert("numberOfSCTDeadSensors", Variables::Track::numberOfSCTDeadSensors, track_vars); + calc->insert("numberOfTRTHighThresholdHits", Variables::Track::numberOfTRTHighThresholdHits, track_vars); + calc->insert("numberOfTRTHits", Variables::Track::numberOfTRTHits, track_vars); + calc->insert("nSiHits", Variables::Track::nSiHits, track_vars); + calc->insert("expectInnermostPixelLayerHit", Variables::Track::expectInnermostPixelLayerHit, track_vars); + calc->insert("expectNextToInnermostPixelLayerHit", Variables::Track::expectNextToInnermostPixelLayerHit, track_vars); + calc->insert("numberOfContribPixelLayers", Variables::Track::numberOfContribPixelLayers, track_vars); + calc->insert("numberOfPixelHoles", Variables::Track::numberOfPixelHoles, track_vars); + calc->insert("d0_old", Variables::Track::d0_old, track_vars); + calc->insert("qOverP", Variables::Track::qOverP, track_vars); + calc->insert("theta", Variables::Track::theta, track_vars); + calc->insert("z0TJVA", Variables::Track::z0TJVA, track_vars); + calc->insert("charge", Variables::Track::charge, track_vars); + calc->insert("dz0_TV_PV0", Variables::Track::dz0_TV_PV0, track_vars); + calc->insert("log_sumpt_TV", Variables::Track::log_sumpt_TV, track_vars); + calc->insert("log_sumpt2_TV", Variables::Track::log_sumpt2_TV, track_vars); + calc->insert("log_sumpt_PV0", Variables::Track::log_sumpt_PV0, track_vars); + calc->insert("log_sumpt2_PV0", Variables::Track::log_sumpt2_PV0, track_vars); + + // Cluster variable calculator functions + calc->insert("et_log", Variables::Cluster::et_log, cluster_vars); + calc->insert("pt_tau_log", Variables::Cluster::pt_tau_log, cluster_vars); + calc->insert("pt_jetseed_log", Variables::Cluster::pt_jetseed_log, cluster_vars); + calc->insert("dEta", Variables::Cluster::dEta, cluster_vars); + calc->insert("dPhi", Variables::Cluster::dPhi, cluster_vars); + calc->insert("SECOND_R", Variables::Cluster::SECOND_R, cluster_vars); + calc->insert("SECOND_LAMBDA", Variables::Cluster::SECOND_LAMBDA, cluster_vars); + calc->insert("CENTER_LAMBDA", Variables::Cluster::CENTER_LAMBDA, cluster_vars); + //---added for the eVeto + calc->insert("SECOND_LAMBDAOverClustersMeanSecondLambda", Variables::Cluster::SECOND_LAMBDAOverClustersMeanSecondLambda, cluster_vars); + calc->insert("CENTER_LAMBDAOverClustersMeanCenterLambda", Variables::Cluster::CENTER_LAMBDAOverClustersMeanCenterLambda, cluster_vars); + calc->insert("FirstEngDensOverClustersMeanFirstEngDens" , Variables::Cluster::FirstEngDensOverClustersMeanFirstEngDens, cluster_vars); + + //Extension - Variables for GNTau + calc->insert("e", Variables::Cluster::e, cluster_vars); + calc->insert("et", Variables::Cluster::et, cluster_vars); + calc->insert("FIRST_ENG_DENS", Variables::Cluster::FIRST_ENG_DENS, cluster_vars); + calc->insert("EM_PROBABILITY", Variables::Cluster::EM_PROBABILITY, cluster_vars); + calc->insert("CENTER_MAG", Variables::Cluster::CENTER_MAG, cluster_vars); + return calc; +} + + +namespace Variables { +using TauDetail = xAOD::TauJetParameters::Detail; + +bool absEta(const xAOD::TauJet &tau, double &out) { + out = std::abs(tau.eta()); + return true; +} + +bool centFrac(const xAOD::TauJet &tau, double &out) { + float centFrac; + const auto success = tau.detail(TauDetail::centFrac, centFrac); + //out = std::min(centFrac, 1.0f); + out = centFrac; + return success; +} + +bool isolFrac(const xAOD::TauJet &tau, double &out) { + float isolFrac; + const auto success = tau.detail(TauDetail::isolFrac, isolFrac); + //out = std::min(isolFrac, 1.0f); + out = isolFrac; + return success; +} + +bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out) { + float etOverPtLeadTrk; + const auto success = tau.detail(TauDetail::etOverPtLeadTrk, etOverPtLeadTrk); + out = etOverPtLeadTrk; + return success; +} + +bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out) { + float innerTrkAvgDist; + const auto success = tau.detail(TauDetail::innerTrkAvgDist, innerTrkAvgDist); + out = innerTrkAvgDist; + return success; +} + +bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out) { + float ipSigLeadTrk = (tau.nTracks()>0) ? tau.track(0)->d0SigTJVA() : 0.; + //out = std::min(std::abs(ipSigLeadTrk), 30.0f); + out = std::abs(ipSigLeadTrk); + return true; +} + +bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out) { + float sumEMCellEtOverLeadTrkPt; + const auto success = tau.detail(TauDetail::sumEMCellEtOverLeadTrkPt, sumEMCellEtOverLeadTrkPt); + out = sumEMCellEtOverLeadTrkPt; + return success; +} + +bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out) { + float SumPtTrkFrac; + const auto success = tau.detail(TauDetail::SumPtTrkFrac, SumPtTrkFrac); + out = SumPtTrkFrac; + return success; +} + +bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out) { + float EMPOverTrkSysP; + const auto success = tau.detail(TauDetail::EMPOverTrkSysP, EMPOverTrkSysP); + out = EMPOverTrkSysP; + return success; +} + +bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out) { + float ptRatioEflowApprox; + const auto success = tau.detail(TauDetail::ptRatioEflowApprox, ptRatioEflowApprox); + //out = std::min(ptRatioEflowApprox, 4.0f); + out = ptRatioEflowApprox; + return success; +} + +bool mEflowApprox(const xAOD::TauJet &tau, double &out) { + float mEflowApprox; + const auto success = tau.detail(TauDetail::mEflowApprox, mEflowApprox); + out = mEflowApprox; + return success; +} + +bool dRmax(const xAOD::TauJet &tau, double &out) { + float dRmax; + const auto success = tau.detail(TauDetail::dRmax, dRmax); + out = dRmax; + return success; +} + +bool trFlightPathSig(const xAOD::TauJet &tau, double &out) { + float trFlightPathSig; + const auto success = tau.detail(TauDetail::trFlightPathSig, trFlightPathSig); + out = trFlightPathSig; + return success; +} + +bool massTrkSys(const xAOD::TauJet &tau, double &out) { + float massTrkSys; + const auto success = tau.detail(TauDetail::massTrkSys, massTrkSys); + out = massTrkSys; + return success; +} + +bool pt(const xAOD::TauJet &tau, double &out) { + out = tau.pt(); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, double &out) { + out = std::log10(std::max(tau.pt() / GeV, 1e-6)); + return true; +} + +bool ptDetectorAxis(const xAOD::TauJet &tau, double &out) { + out = tau.ptDetectorAxis(); + return true; +} + +bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out) { + out = tau.ptIntermediateAxis(); + return true; +} + +bool ptJetSeed_log(const xAOD::TauJet &tau, double &out) { + out = std::log10(std::max(tau.ptJetSeed(), 1e-3)); + return true; +} + +bool absleadTrackEta(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absEtaLeadTrack("ABS_ETA_LEAD_TRACK"); + float absEtaLeadTrack = acc_absEtaLeadTrack(tau); + out = std::max(0.f, absEtaLeadTrack); + return true; +} + +bool leadTrackDeltaEta(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absDeltaEta("TAU_ABSDELTAETA"); + float absDeltaEta = acc_absDeltaEta(tau); + out = std::max(0.f, absDeltaEta); + return true; +} + +bool leadTrackDeltaPhi(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_absDeltaPhi("TAU_ABSDELTAPHI"); + float absDeltaPhi = acc_absDeltaPhi(tau); + out = std::max(0.f, absDeltaPhi); + return true; +} + +bool leadTrackProbNNorHT(const xAOD::TauJet &tau, double &out){ + auto tracks = tau.allTracks(); + + // Sort tracks in descending pt order + if (!tracks.empty()) { + auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) { + return lhs->pt() > rhs->pt(); + }; + std::sort(tracks.begin(), tracks.end(), cmp_pt); + + const xAOD::TauTrack* tauLeadTrack = tracks.at(0); + const xAOD::TrackParticle* xTrackParticle = tauLeadTrack->track(); + float eProbabilityHT = xTrackParticle->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle); + out = (tauLeadTrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT; + } + else { + out = 0.; + } + return true; +} + +bool EMFracFixed(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_emFracFixed("EMFracFixed"); + float emFracFixed = acc_emFracFixed(tau); + out = std::max(emFracFixed, 0.0f); + return true; +} + +bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_etHotShotWinOverPtLeadTrk("etHotShotWinOverPtLeadTrk"); + float etHotShotWinOverPtLeadTrk = acc_etHotShotWinOverPtLeadTrk(tau); + out = std::max(etHotShotWinOverPtLeadTrk, 1e-6f); + return true; +} + +bool hadLeakFracFixed(const xAOD::TauJet &tau, double &out){ + static const SG::AuxElement::ConstAccessor<float> acc_hadLeakFracFixed("hadLeakFracFixed"); + float hadLeakFracFixed = acc_hadLeakFracFixed(tau); + out = std::max(0.f, hadLeakFracFixed); + return true; +} + +bool PSFrac(const xAOD::TauJet &tau, double &out){ + float PSFrac; + const auto success = tau.detail(TauDetail::PSSFraction, PSFrac); + out = std::max(0.f,PSFrac); + return success; +} + +bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, double &out){ + float ClustersMeanCenterLambda; + const auto success = tau.detail(TauDetail::ClustersMeanCenterLambda, ClustersMeanCenterLambda); + out = std::max(0.f, ClustersMeanCenterLambda); + return success; +} + +bool ClustersMeanEMProbability(const xAOD::TauJet &tau, double &out){ + float ClustersMeanEMProbability; + const auto success = tau.detail(TauDetail::ClustersMeanEMProbability, ClustersMeanEMProbability); + out = std::max(0.f, ClustersMeanEMProbability); + return success; +} + +bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, double &out){ + float ClustersMeanFirstEngDens; + const auto success = tau.detail(TauDetail::ClustersMeanFirstEngDens, ClustersMeanFirstEngDens); + out = std::max(-10.f, ClustersMeanFirstEngDens); + return success; +} + +bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out){ + float ClustersMeanPresamplerFrac; + const auto success = tau.detail(TauDetail::ClustersMeanPresamplerFrac, ClustersMeanPresamplerFrac); + out = std::max(0.f, ClustersMeanPresamplerFrac); + return success; +} + +bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, double &out){ + float ClustersMeanSecondLambda; + const auto success = tau.detail(TauDetail::ClustersMeanSecondLambda, ClustersMeanSecondLambda); + out = std::max(0.f, ClustersMeanSecondLambda); + return success; +} + +namespace Track { + +bool pt_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(track.pt()); + return true; +} + +bool trackPt(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.pt(); + return true; +} + +bool trackEta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.eta(); + return true; +} + +bool trackPhi(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.phi(); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) { + out = std::log10(std::max(tau.pt(), 1e-6)); + return true; +} + +bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) { + out = std::log10(tau.ptJetSeed()); + return true; +} + +bool d0_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(std::abs(track.d0TJVA()) + 1e-6); + return true; +} + +bool z0sinThetaTJVA_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = std::log10(std::abs(track.z0sinthetaTJVA()) + 1e-6); + return true; +} + +bool z0sinthetaTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.z0sinthetaTJVA(); + return true; +} + +bool z0sinthetaSigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.z0sinthetaSigTJVA(); + return true; +} + +bool d0TJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.d0TJVA(); + return true; +} + +bool d0SigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.d0SigTJVA(); + return true; +} + +bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + out = track.eta() - tau.eta(); + return true; +} + +bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) { + out = track.p4().DeltaPhi(tau.p4()); + return true; +} + +bool nInnermostPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t inner_pixel_hits; + const auto success = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits); + out = inner_pixel_hits; + return success; +} + +bool nPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pixel_hits; + const auto success = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits); + out = pixel_hits; + return success; +} + +bool nSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t sct_hits; + const auto success = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits); + out = sct_hits; + return success; +} + +// same as in tau track classification for trigger +bool nIBLHitsAndExp(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t inner_pixel_hits, inner_pixel_exp; + const auto success1 = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits); + const auto success2 = track.track()->summaryValue(inner_pixel_exp, xAOD::expectInnermostPixelLayerHit); + out = inner_pixel_exp ? inner_pixel_hits : 1.; + return success1 && success2; +} + +bool nPixelHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pixel_hits, pixel_dead; + const auto success1 = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits); + const auto success2 = track.track()->summaryValue(pixel_dead, xAOD::numberOfPixelDeadSensors); + out = pixel_hits + pixel_dead; + return success1 && success2; +} + +bool nSCTHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t sct_hits, sct_dead; + const auto success1 = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits); + const auto success2 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors); + out = sct_hits + sct_dead; + return success1 && success2; +} + +bool eProbabilityHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + float eProbabilityHT; + const auto success = track.track()->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + out = eProbabilityHT; + return success; +} + +bool eProbabilityNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + out = acc_eProbabilityNN(track); + return true; +} + +bool eProbabilityNNorHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + auto atrack = track.track(); + float eProbabilityHT = atrack->summaryValue(eProbabilityHT, xAOD::eProbabilityHT); + static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN"); + float eProbabilityNN = acc_eProbabilityNN(*atrack); + out = (atrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT; + return true; +} + +bool chargedScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_chargedScoreRNN("rnn_chargedScore"); + out = acc_chargedScoreRNN(track); + return true; +} + +bool isolationScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_isolationScoreRNN("rnn_isolationScore"); + out = acc_isolationScoreRNN(track); + return true; +} + +bool conversionScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_conversionScoreRNN("rnn_conversionScore"); + out = acc_conversionScoreRNN(track); + return true; +} + +bool fakeScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_fakeScoreRNN("rnn_fakeScore"); + out = acc_fakeScoreRNN(track); + return true; +} + +//Extension - variables for GNTau +bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfInnermostPixelLayerHits); + out = trk_val; + return success; +} + +bool numberOfPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHits); + out = trk_val; + return success; +} + +bool numberOfPixelSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelSharedHits); + out = trk_val; + return success; +} + +bool numberOfPixelDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelDeadSensors); + out = trk_val; + return success; +} + +bool numberOfSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTHits); + out = trk_val; + return success; +} + +bool numberOfSCTSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTSharedHits); + out = trk_val; + return success; +} + +bool numberOfSCTDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTDeadSensors); + out = trk_val; + return success; +} + +bool numberOfTRTHighThresholdHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHighThresholdHits); + out = trk_val; + return success; +} + +bool numberOfTRTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHits); + out = trk_val; + return success; +} + +bool nSiHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t pix_hit = 0;uint8_t pix_dead = 0;uint8_t sct_hit = 0;uint8_t sct_dead = 0; + const auto success1 = track.track()->summaryValue(pix_hit, xAOD::numberOfPixelHits); + const auto success2 = track.track()->summaryValue(pix_dead, xAOD::numberOfPixelDeadSensors); + const auto success3 = track.track()->summaryValue(sct_hit, xAOD::numberOfSCTHits); + const auto success4 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors); + out = pix_hit + pix_dead + sct_hit + sct_dead; + return success1 && success2 && success3 && success4; +} + +bool expectInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::expectInnermostPixelLayerHit); + out = trk_val; + return success; +} + +bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::expectNextToInnermostPixelLayerHit); + out = trk_val; + return success; +} + +bool numberOfContribPixelLayers(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfContribPixelLayers); + out = trk_val; + return success; +} + +bool numberOfPixelHoles(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + uint8_t trk_val = 0; + const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHoles); + out = trk_val; + return success; +} + +bool d0_old(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->d0(); + //out = trk_val; + return true; +} + +bool qOverP(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->qOverP(); + return true; +} + +bool theta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->theta(); + return true; +} + +bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out) { + out = track.track()->z0() + track.track()->vz() - tau.vertex()->z(); + return true; +} + +bool charge(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) { + out = track.track()->charge(); + return true; +} + +bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out = 0.; + static const SG::AuxElement::ConstAccessor<float> acc_dz0TVPV0("dz0_TV_PV0"); + if (tau.isAvailable<float>("dz0_TV_PV0")){out = acc_dz0TVPV0(tau);} + return true; +} + +bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumptTV("log_sumpt_TV"); + if (tau.isAvailable<float>("log_sumpt_TV")){out=acc_logsumptTV(tau);} + return true; +} + +bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2TV("log_sumpt2_TV"); + if (tau.isAvailable<float>("log_sumpt2_TV")){out=acc_logsumpt2TV(tau);} + return true; +} + +bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumptPV0("log_sumpt_PV0"); + if (tau.isAvailable<float>("log_sumpt_PV0")){out=acc_logsumptPV0(tau);} + return true; +} + +bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) { + out=0.; + static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2PV0("log_sumpt2_PV0"); + if (tau.isAvailable<float>("log_sumpt2_PV0")){out=acc_logsumpt2PV0(tau);} + return true; +} + +} // namespace Track + + +namespace Cluster { +using MomentType = xAOD::CaloCluster::MomentType; + +bool et_log(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = std::log10(cluster.p4().Et()); + return true; +} + +bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) { + out = std::log10(std::max(tau.pt(), 1e-6)); + return true; +} + +bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) { + out = std::log10(tau.ptJetSeed()); + return true; +} + +bool dEta(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.eta() - tau.eta(); + return true; +} + +bool dPhi(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().DeltaPhi(tau.p4()); + return true; +} + +bool SECOND_R(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_R, out); + return success; +} + +bool SECOND_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, out); + return success; +} + +bool CENTER_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, out); + return success; +} + +bool SECOND_LAMBDAOverClustersMeanSecondLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanSecondLambda("ClustersMeanSecondLambda"); + float ClustersMeanSecondLambda = acc_ClustersMeanSecondLambda(tau); + double secondLambda(0); + const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, secondLambda); + out = (ClustersMeanSecondLambda != 0.) ? secondLambda/ClustersMeanSecondLambda : 0.; + return success; +} + +bool CENTER_LAMBDAOverClustersMeanCenterLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanCenterLambda("ClustersMeanCenterLambda"); + float ClustersMeanCenterLambda = acc_ClustersMeanCenterLambda(tau); + double centerLambda(0); + const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, centerLambda); + if (ClustersMeanCenterLambda == 0.){ + out = 250.; + }else { + out = centerLambda/ClustersMeanCenterLambda; + } + + out = std::min(out, 250.); + + return success; +} + + +bool FirstEngDensOverClustersMeanFirstEngDens(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + // the ClustersMeanFirstEngDens is the log10 of the energy weighted average of the First_ENG_DENS + // divided by ETot to make it dimension-less, + // so we need to evaluate the difference of log10(clusterFirstEngDens/clusterTotalEnergy) and the ClustersMeanFirstEngDens + double clusterFirstEngDens = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens); + if (clusterFirstEngDens < 1e-6) clusterFirstEngDens = 1e-6; + + static const SG::AuxElement::ConstAccessor<float> acc_ClusterTotalEnergy("ClusterTotalEnergy"); + float clusterTotalEnergy = acc_ClusterTotalEnergy(tau); + if (clusterTotalEnergy < 1e-6) clusterTotalEnergy = 1e-6; + + static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanFirstEngDens("ClustersMeanFirstEngDens"); + float clustersMeanFirstEngDens = acc_ClustersMeanFirstEngDens(tau); + + out = std::log10(clusterFirstEngDens/clusterTotalEnergy) - clustersMeanFirstEngDens; + + return status; +} + +//Extension - Variables for GNTau +bool e(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().E(); + return true; +} + +bool et(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + out = cluster.p4().Et(); + return true; +} + +bool FIRST_ENG_DENS(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterFirstEngDens = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens); + out = clusterFirstEngDens; + return status; +} + +bool EM_PROBABILITY(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterEMprob = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::EM_PROBABILITY, clusterEMprob); + out = clusterEMprob; + return status; +} + +bool CENTER_MAG(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) { + double clusterCenterMag = 0.0; + bool status = cluster.clust().retrieveMoment(MomentType::CENTER_MAG, clusterCenterMag); + out = clusterCenterMag; + return status; +} + +} // namespace Cluster +} // namespace Variables +} // namespace TauGNNUtils diff --git a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx index d4d5db31542ab15ccd7487f6dacb303eb48da1b3..0e78bf7d357e50f2de792c78d4ffb30d5469dab2 100644 --- a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx +++ b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx @@ -25,6 +25,7 @@ #include "tauRecTools/TauWPDecorator.h" #include "tauRecTools/TauIDVarCalculator.h" #include "tauRecTools/TauJetRNNEvaluator.h" +#include "tauRecTools/TauGNNEvaluator.h" #include "tauRecTools/TauDecayModeNNClassifier.h" #include "tauRecTools/TauVertexedClusterDecorator.h" #include "tauRecTools/TauAODSelector.h" @@ -59,6 +60,7 @@ DECLARE_COMPONENT( TauPi0Selector ) DECLARE_COMPONENT( TauWPDecorator ) DECLARE_COMPONENT( TauIDVarCalculator ) DECLARE_COMPONENT( TauJetRNNEvaluator ) +DECLARE_COMPONENT( TauGNNEvaluator ) DECLARE_COMPONENT( TauDecayModeNNClassifier ) DECLARE_COMPONENT( TauVertexedClusterDecorator ) DECLARE_COMPONENT( TauAODSelector ) diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h new file mode 100644 index 0000000000000000000000000000000000000000..f2e39c02cbbc1444547c103d3b6dbcf8c059615d --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h @@ -0,0 +1,110 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNN_H +#define TAURECTOOLS_TAUGNN_H + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" + +#include "AsgMessaging/AsgMessaging.h" + +#include "FlavorTagDiscriminants/OnnxUtil.h" + +#include <memory> +#include <string> +#include <map> + +// Forward declaration +namespace lwt { + class LightweightGraph; +} + +namespace TauGNNUtils { + class GNNVarCalc; +} + +namespace FlavorTagDiscriminants{ + class OnnxUtil; +} + +/** + * @brief Wrapper around ONNXUtil to compute the output score of a model + * + * Configures the network and computes the network outputs given the input + * objects. Retrieval of input variables is handled internally. + * + * @author N.M. Tamir + * + */ +class TauGNN : public asg::AsgMessaging { +public: + // Configuration of the weight file structure + struct Config { + std::string input_layer_scalar; + std::string input_layer_tracks; + std::string input_layer_clusters; + std::string output_node_tau; + std::string output_node_jet; + }; + std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil; +public: + TauGNN(const std::string &nnFile, const Config &config); + ~TauGNN(); + + // Output the OnnxUtil tuple + std::tuple< + std::map<std::string, float>, + std::map<std::string, std::vector<char>>, + std::map<std::string, std::vector<float>> > + compute(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const; + + // Compute all input variables and store them in the maps that are passed by reference + bool calculateInputVariables(const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::map<std::string, std::map<std::string, double>>& scalarInputs, + std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const; + + // Getter for the variable calculator + const TauGNNUtils::GNNVarCalc* variable_calculator() const { + return m_var_calc.get(); + } + + explicit operator bool() const { + return static_cast<bool>(m_graph); + } + + //Make the output config transparent to external tools + FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config; + +private: + using input_pair = FlavorTagDiscriminants::input_pair; + // Abbreviations for lwtnn + using VariableMap = std::map<std::string, double>; + using VectorMap = std::map<std::string, std::vector<double>>; + + using InputMap = std::map<std::string, VariableMap>; + using InputSequenceMap = std::map<std::string, VectorMap>; + +private: + const Config m_config; + std::unique_ptr<const lwt::LightweightGraph> m_graph; + + // Names of the input variables + std::vector<std::string> m_scalar_inputs; + std::vector<std::string> m_track_inputs; + std::vector<std::string> m_cluster_inputs; + // Names passed to the variable calculator + std::vector<std::string> m_scalarCalc_inputs; + std::vector<std::string> m_trackCalc_inputs; + std::vector<std::string> m_clusterCalc_inputs; + + // Variable calculator to calculate input variables on the fly + std::unique_ptr<TauGNNUtils::GNNVarCalc> m_var_calc; +}; + +#endif // TAURECTOOLS_TAUGNN_H diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h new file mode 100644 index 0000000000000000000000000000000000000000..4e5cebf53ad5e99907784d96586788ed55ceec28 --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h @@ -0,0 +1,69 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNNEVALUATOR_H +#define TAURECTOOLS_TAUGNNEVALUATOR_H + +#include "tauRecTools/TauRecToolBase.h" + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" + +#include <memory> + +class TauGNN; + +/** + * @brief Tool to calculate tau identification score from .onnx inputs + * + * The network configuration is supplied in .onnx format. + * Currently runs on a prongness-inclusive model + * Based off of TauJetRNNEvaluator.h format! + * @author N.M. Tamir + * + */ +class TauGNNEvaluator : public TauRecToolBase { +public: + ASG_TOOL_CLASS2(TauGNNEvaluator, TauRecToolBase, ITauToolBase) + + TauGNNEvaluator(const std::string &name = "TauGNNEvaluator"); + virtual ~TauGNNEvaluator(); + + virtual StatusCode initialize() override; + virtual StatusCode execute(xAOD::TauJet &tau) const override; + // Getter for the underlying RNN implementation + const TauGNN* get_gnn() const; + + // Selects tracks to be used as input to the network + StatusCode get_tracks(const xAOD::TauJet &tau, + std::vector<const xAOD::TauTrack *> &out) const; + + // Selects clusters to be used as input to the network + StatusCode get_clusters(const xAOD::TauJet &tau, + std::vector<xAOD::CaloVertexedTopoCluster> &out) const; + +private: + std::string m_output_varname; + std::string m_output_ptau; + std::string m_output_pjet; + std::string m_weightfile; + std::size_t m_max_tracks; + std::size_t m_max_clusters; + float m_max_cluster_dr; + bool m_doVertexCorrection; + bool m_doTrackClassification; + bool m_decorateTracks; + + // Configuration of the network file + std::string m_input_layer_scalar; + std::string m_input_layer_tracks; + std::string m_input_layer_clusters; + std::string m_outnode_tau; + std::string m_outnode_jet; + + // Wrappers for lwtnn + std::unique_ptr<TauGNN> m_net; //! +}; + +#endif // TAURECTOOLS_TAUGNNEVALUATOR_H diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..28a983dd5e8a84ca17a2b61fa1b72428fc504232 --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h @@ -0,0 +1,311 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_TAUGNNUTILS_H +#define TAURECTOOLS_TAUGNNUTILS_H + +#include "xAODTau/TauJet.h" +#include "xAODCaloEvent/CaloVertexedTopoCluster.h" +#include "xAODEventInfo/EventInfo.h" +#include "xAODTracking/VertexContainer.h" +#include "AsgTools/AsgTool.h" +#include "AsgMessaging/AsgMessaging.h" +#include <unordered_map> + + +namespace TauGNNUtils { + +/** + * @brief Tool to calculate input variables for the GNN-based tau identification + * + * Used to calculate input variables for (onnx)GNN-based tau identification on + * the fly by providing a mapping between variable names (strings) and + * functions to calculate these variables. + * + * @author C. Deutsch + * @author W. Davey + * @author N.M. Tamir + * + */ +class GNNVarCalc : public asg::AsgMessaging { +public: + // Pointers to calculator functions + using ScalarCalc = bool (*)(const xAOD::TauJet &, double &); + + using TrackCalc = bool (*)(const xAOD::TauJet &, const xAOD::TauTrack &, + double &); + + using ClusterCalc = bool (*)(const xAOD::TauJet &, + const xAOD::CaloVertexedTopoCluster &, double &); + +public: + GNNVarCalc(); + ~GNNVarCalc() = default; + + // Methods to compute the output (vector) based on the variable name + + // Computes high-level ID variables + bool compute(const std::string &name, const xAOD::TauJet &tau, double &out) const; + + // Computes track variables + bool compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<const xAOD::TauTrack *> &tracks, + std::vector<double> &out) const; + + // Computes cluster variables + bool compute(const std::string &name, const xAOD::TauJet &tau, + const std::vector<xAOD::CaloVertexedTopoCluster> &clusters, + std::vector<double> &out) const; + + // Methods to insert calculator functions into the lookup table + void insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars); + void insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars); + void insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars); + +private: + // Lookup tables + std::unordered_map<std::string, ScalarCalc> m_scalar_map; + std::unordered_map<std::string, TrackCalc> m_track_map; + std::unordered_map<std::string, ClusterCalc> m_cluster_map; +}; + +// Factory function to create a variable calculator populated with default +// variables +std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars, + const std::vector<std::string>& track_vars, + const std::vector<std::string>& cluster_vars); + + +namespace Variables { + +// Functions to calculate (scalar) input variables +// Returns a status code indicating success +bool absEta(const xAOD::TauJet &tau, double &out); + +bool centFrac(const xAOD::TauJet &tau, double &out); + +bool isolFrac(const xAOD::TauJet &tau, double &out); + +bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out); + +bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out); + +bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out); + +bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out); + +bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out); + +bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out); + +bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out); + +bool mEflowApprox(const xAOD::TauJet &tau, double &out); + +bool dRmax(const xAOD::TauJet &tau, double &out); + +bool trFlightPathSig(const xAOD::TauJet &tau, double &out); + +bool massTrkSys(const xAOD::TauJet &tau, double &out); + +bool pt(const xAOD::TauJet &tau, double &out); + +bool pt_tau_log(const xAOD::TauJet &tau, double &out); + +bool ptDetectorAxis(const xAOD::TauJet &tau, double &out); + +bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out); + +//functions to calculate input variables needed for the eVeto RNN +bool ptJetSeed_log (const xAOD::TauJet &tau, double &out); +bool absleadTrackEta (const xAOD::TauJet &tau, double &out); +bool leadTrackDeltaEta (const xAOD::TauJet &tau, double &out); +bool leadTrackDeltaPhi (const xAOD::TauJet &tau, double &out); +bool leadTrackProbNNorHT (const xAOD::TauJet &tau, double &out); +bool EMFracFixed (const xAOD::TauJet &tau, double &out); +bool etHotShotWinOverPtLeadTrk (const xAOD::TauJet &tau, double &out); +bool hadLeakFracFixed (const xAOD::TauJet &tau, double &out); +bool PSFrac (const xAOD::TauJet &tau, double &out); +bool ClustersMeanCenterLambda (const xAOD::TauJet &tau, double &out); +bool ClustersMeanEMProbability (const xAOD::TauJet &tau, double &out); +bool ClustersMeanFirstEngDens (const xAOD::TauJet &tau, double &out); +bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out); +bool ClustersMeanSecondLambda (const xAOD::TauJet &tau, double &out); +bool EMPOverTrkSysP (const xAOD::TauJet &tau, double &out); + + +namespace Track { + +// Functions to calculate input variables for each track +// Returns a status code indicating success + +bool pt_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool trackPt( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool trackEta( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool trackPhi( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool pt_tau_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool pt_jetseed_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool d0_abs_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool z0sinThetaTJVA_abs_log( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool z0sinthetaTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool z0sinthetaSigTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool d0TJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool d0SigTJVA( + const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +bool dEta( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool dPhi( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nInnermostPixelHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nPixelHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nSCTHits( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +// trigger variants +bool nIBLHitsAndExp ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nPixelHitsPlusDeadSensors ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool nSCTHitsPlusDeadSensors ( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityHT( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool eProbabilityNNorHT( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool chargedScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool isolationScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool conversionScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +bool fakeScoreRNN( + const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out); + +//Extension - variables for GNTau +bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfSCTDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfTRTHighThresholdHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfTRTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool nSiHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool expectInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfContribPixelLayers(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool numberOfPixelHoles(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool d0_old(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool qOverP(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool theta(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool charge(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); +bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out); + +} // namespace Track + + +namespace Cluster { + +// Functions to calculate input variables for each cluster +// Returns a status code indicating success + +bool et_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool pt_tau_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool pt_jetseed_log( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool dEta( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool dPhi( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_R( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_LAMBDA( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_LAMBDA( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool SECOND_LAMBDAOverClustersMeanSecondLambda( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_LAMBDAOverClustersMeanCenterLambda( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool FirstEngDensOverClustersMeanFirstEngDens( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +//Extension - Variables for GNTau +bool e( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool et( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool FIRST_ENG_DENS( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool EM_PROBABILITY( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); + +bool CENTER_MAG( + const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out); +} // namespace Cluster +} // namespace Variables +} // namespace TauJetGNNUtils + +#endif // TAURECTOOLS_TAUGNNUTILS_H diff --git a/Tools/WorkflowTestRunner/python/References.py b/Tools/WorkflowTestRunner/python/References.py index 8aa92ba7fa1b9c35fcad19e43224854eca676ce8..4dd63adb04203f730449d850cc0bb8a609862577 100644 --- a/Tools/WorkflowTestRunner/python/References.py +++ b/Tools/WorkflowTestRunner/python/References.py @@ -29,14 +29,14 @@ references_map = { "q452": "v9", "q454": "v15", # Derivations - "data_PHYS_Run2": "v20", + "data_PHYS_Run2": "v21", "data_PHYSLITE_Run2": "v2", - "data_PHYS_Run3": "v19", + "data_PHYS_Run3": "v20", "data_PHYSLITE_Run3": "v2", - "mc_PHYS_Run2": "v24", + "mc_PHYS_Run2": "v25", "mc_PHYSLITE_Run2": "v3", - "mc_PHYS_Run3": "v25", + "mc_PHYS_Run3": "v26", "mc_PHYSLITE_Run3": "v3", - "af3_PHYS_Run3": "v6", + "af3_PHYS_Run3": "v7", "af3_PHYSLITE_Run3": "v3", }