Skip to content
Snippets Groups Projects
Commit c8e18ae8 authored by Bertrand Martin Dit Latour's avatar Bertrand Martin Dit Latour Committed by Tadej Novak
Browse files

DerivationFrameworkTau, tauRec, tauRecTools: optimise muon-tau removal sequence

DerivationFrameworkTau, tauRec, tauRecTools: optimise muon-tau removal sequence

Hello,

This MR is addressing the CPU increase in DAOD_PHYS/LITE reported in ATLASG-2712 coming from the recent addition of GNTau ID.
The main changes are in TauAODRunnerAlg, which first run a tool to remove muon tracks and clusters associated with tau candidate, then reruns most of the tau reconstruction with muon-free inputs.
Now, in the muon-tau removal, if no muon track nor cluster is found near the tau, we discard the tau candidate by effectively removing it from the container. This prevents afterburner tools like GNN tau ID from running over the full container (which so far includes irrelevant tau candidates removed later on by a thinning algorithm), thereby saving CPU. I've checked that the TauJets_MuonRM_TauIDDecorKernel CPU time is reduced, it's no longer visible in the SPOT test summary. The DAOD output is unchanged (checked over 1000 events).

Adding the urgent flag in case it would sill arrive in time for the imminent DAOD bulk prod.

Cheers,
Bertrand
parent 0fe32fe0
No related branches found
No related tags found
29 merge requests!78241Draft: FPGATrackSim: GenScan code refactor,!78236Draft: Switching Streams https://its.cern.ch/jira/browse/ATR-27417,!78056AFP monitoring: new synchronization and cleaning,!78041AFP monitoring: new synchronization and cleaning,!77990Updating TRT chip masks for L1TRT trigger simulation - ATR-28372,!77733Draft: add new HLT NN JVT, augmented with additional tracking information,!77731Draft: Updates to ZDC reconstruction,!77728Draft: updates to ZDC reconstruction,!77522Draft: sTGC Pad Trigger Emulator,!76725ZdcNtuple: Fix cppcheck warning.,!76611L1CaloFEXByteStream: Fix out-of-bounds array accesses.,!76475Punchthrough AF3 implementation in FastG4,!76474Punchthrough AF3 implementation in FastG4,!76343Draft: MooTrackBuilder: Recalibrate NSW hits in refine method,!75729New implementation of ZDC nonlinear FADC correction.,!75703Draft: Update to HI han config for HLT jets,!75184Draft: Update file heavyions_run.config,!74430Draft: Fixing upper bound for Delayed Jet Triggers,!73963Changing the path of the histograms to "Expert" area,!73875updating ID ART reference plots,!73874AtlasCLHEP_RandomGenerators: Fix cppcheck warnings.,!73449Add muon detectors to DarkJetPEBTLA partial event building,!73343Draft: [TrigEgamma] Add photon ringer chains on bootstrap mechanism,!72336Fixed TRT calibration crash,!72176Draft: Improving L1TopoOnline chain that now gets no-empty plots. Activating it by default,!72012Draft: Separate JiveXMLConfig.py into Config files,!71876Fix MET trigger name in MissingETMonitoring,!71820Draft: Adding new TLA End-Of-Fill (EOF) chains and removing obsolete DIPZ chains,!71478DerivationFrameworkTau, tauRec, tauRecTools: optimise muon-tau removal sequence
Showing
with 107 additions and 114 deletions
# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
# PhysCommonThinningConfig
# Contains the configuration for the thinning for PHYS(LITE)
......@@ -6,7 +6,7 @@
from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
from AthenaConfiguration.ComponentFactory import CompFactory
def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs):
def PhysCommonThinningCfg(flags, StreamName = "StreamDAOD_PHYS", **kwargs):
"""Configure the common augmentation"""
acc = ComponentAccumulator()
......@@ -21,7 +21,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
if "TrackParticleThinningToolName" in kwargs:
tp_thinning_expression = "InDetTrackParticles.DFCommonTightPrimary && abs(DFCommonInDetTrackZ0AtPV)*sin(InDetTrackParticles.theta) < 3.0*mm && InDetTrackParticles.pt > 10*GeV"
acc.merge(TrackParticleThinningCfg(
ConfigFlags,
flags,
name = kwargs['TrackParticleThinningToolName'],
StreamName = StreamName,
SelectionString = tp_thinning_expression,
......@@ -30,7 +30,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
# Include inner detector tracks associated with muons
if "MuonTPThinningToolName" in kwargs:
acc.merge(MuonTrackParticleThinningCfg(
ConfigFlags,
flags,
name = kwargs['MuonTPThinningToolName'],
StreamName = StreamName,
MuonKey = "Muons",
......@@ -38,8 +38,8 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
# Tau-related containers: taus, tau tracks and associated ID tracks, neutral PFOs, secondary vertices
if "TauJetThinningToolName" in kwargs:
tau_thinning_expression = "TauJets.pt >= 13*GeV"
acc.merge(TauThinningCfg(ConfigFlags,
tau_thinning_expression = f"TauJets.pt >= {flags.Tau.MinPtDAOD}"
acc.merge(TauThinningCfg(flags,
name = kwargs['TauJetThinningToolName'],
StreamName = StreamName,
Taus = "TauJets",
......@@ -51,8 +51,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
if "TauJets_MuonRMThinningToolName" in kwargs:
tau_murm_thinning_expression = tau_thinning_expression.replace('TauJets', 'TauJets_MuonRM')
tau_murm_thinning_expression += " && TauJets_MuonRM.ModifiedInAOD"
acc.merge(TauThinningCfg(ConfigFlags,
acc.merge(TauThinningCfg(flags,
name = kwargs['TauJets_MuonRMThinningToolName'],
StreamName = StreamName,
Taus = "TauJets_MuonRM",
......@@ -64,7 +63,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
if "TauJets_EleRMThinningToolName" in kwargs:
tau_erm_thinning_expression = tau_thinning_expression.replace('TauJets', 'TauJets_EleRM')
acc.merge(TauThinningCfg(ConfigFlags,
acc.merge(TauThinningCfg(flags,
name = kwargs['TauJets_EleRMThinningToolName'],
StreamName = StreamName,
Taus = "TauJets_EleRM",
......@@ -77,7 +76,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
# ID tracks associated with high-pt di-tau
if "DiTauTPThinningToolName" in kwargs:
acc.merge(DiTauTrackParticleThinningCfg(
ConfigFlags,
flags,
name = kwargs['DiTauTPThinningToolName'],
StreamName = StreamName,
DiTauKey = "DiTauJets",
......@@ -86,7 +85,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
## Low-pt di-tau thinning
if "DiTauLowPtThinningToolName" in kwargs:
acc.merge(GenericObjectThinningCfg(
ConfigFlags,
flags,
name = kwargs['DiTauLowPtThinningToolName'],
StreamName = StreamName,
ContainerName = "DiTauJetsLowPt",
......@@ -95,7 +94,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
# ID tracks associated with low-pt ditau
if "DiTauLowPtTPThinningToolName" in kwargs:
acc.merge(DiTauTrackParticleThinningCfg(
ConfigFlags,
flags,
name = kwargs['DiTauLowPtTPThinningToolName'],
StreamName = StreamName,
DiTauKey = "DiTauJetsLowPt",
......@@ -105,7 +104,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
# keep calo clusters around electrons
if "ElectronCaloClusterThinningToolName" in kwargs:
acc.merge(CaloClusterThinningCfg(
ConfigFlags,
flags,
name = kwargs['ElectronCaloClusterThinningToolName'],
StreamName = StreamName,
SGKey = "AnalysisElectrons",
......@@ -115,7 +114,7 @@ def PhysCommonThinningCfg(ConfigFlags, StreamName = "StreamDAOD_PHYS", **kwargs)
# keep calo clusters around photons
if "PhotonCaloClusterThinningToolName" in kwargs:
acc.merge(CaloClusterThinningCfg(
ConfigFlags,
flags,
name = kwargs['PhotonCaloClusterThinningToolName'],
StreamName = StreamName,
SGKey = "AnalysisPhotons",
......
# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
from AthenaConfiguration.ComponentFactory import CompFactory
def AddTauAugmentationCfg(ConfigFlags, **kwargs):
def AddTauAugmentationCfg(flags, **kwargs):
prefix = kwargs["prefix"]
kwargs.setdefault("doVeryLoose", False)
......@@ -14,7 +14,7 @@ def AddTauAugmentationCfg(ConfigFlags, **kwargs):
acc = ComponentAccumulator()
# tau selection relies on RNN electron veto, we must decorate the fixed eveto WPs before applying tau selection
acc.merge(AddTauIDDecorationCfg(ConfigFlags, TauContainerName="TauJets"))
acc.merge(AddTauIDDecorationCfg(flags, TauContainerName="TauJets"))
from DerivationFrameworkTools.DerivationFrameworkToolsConfig import AsgSelectionToolWrapperCfg
from TauAnalysisTools.TauAnalysisToolsConfig import TauSelectionToolCfg
......@@ -22,12 +22,12 @@ def AddTauAugmentationCfg(ConfigFlags, **kwargs):
TauAugmentationTools = []
if kwargs["doVeryLoose"]:
TauSelectorVeryLoose = acc.popToolsAndMerge(TauSelectionToolCfg(ConfigFlags,
TauSelectorVeryLoose = acc.popToolsAndMerge(TauSelectionToolCfg(flags,
name = 'TauSelectorVeryLoose',
ConfigPath = 'TauAnalysisAlgorithms/tau_selection_veryloose.conf'))
acc.addPublicTool(TauSelectorVeryLoose)
TauVeryLooseWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(ConfigFlags,
TauVeryLooseWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(flags,
name = "TauVeryLooseWrapper",
AsgSelectionTool = TauSelectorVeryLoose,
StoreGateEntryName = "DFTauVeryLoose",
......@@ -35,12 +35,12 @@ def AddTauAugmentationCfg(ConfigFlags, **kwargs):
TauAugmentationTools.append(TauVeryLooseWrapper)
if kwargs["doLoose"]:
TauSelectorLoose = acc.popToolsAndMerge(TauSelectionToolCfg(ConfigFlags,
TauSelectorLoose = acc.popToolsAndMerge(TauSelectionToolCfg(flags,
name = 'TauSelectorLoose',
ConfigPath = 'TauAnalysisAlgorithms/tau_selection_loose.conf'))
acc.addPublicTool(TauSelectorLoose)
TauLooseWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(ConfigFlags,
TauLooseWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(flags,
name = "TauLooseWrapper",
AsgSelectionTool = TauSelectorLoose,
StoreGateEntryName = "DFTauLoose",
......@@ -48,12 +48,12 @@ def AddTauAugmentationCfg(ConfigFlags, **kwargs):
TauAugmentationTools.append(TauLooseWrapper)
if kwargs["doMedium"]:
TauSelectorMedium = acc.popToolsAndMerge(TauSelectionToolCfg(ConfigFlags,
TauSelectorMedium = acc.popToolsAndMerge(TauSelectionToolCfg(flags,
name = 'TauSelectorMedium',
ConfigPath = 'TauAnalysisAlgorithms/tau_selection_medium.conf'))
acc.addPublicTool(TauSelectorMedium)
TauMediumWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(ConfigFlags,
TauMediumWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(flags,
name = "TauMediumWrapper",
AsgSelectionTool = TauSelectorMedium,
StoreGateEntryName = "DFTauMedium",
......@@ -61,12 +61,12 @@ def AddTauAugmentationCfg(ConfigFlags, **kwargs):
TauAugmentationTools.append(TauMediumWrapper)
if kwargs["doTight"]:
TauSelectorTight = acc.popToolsAndMerge(TauSelectionToolCfg(ConfigFlags,
TauSelectorTight = acc.popToolsAndMerge(TauSelectionToolCfg(flags,
name = 'TauSelectorTight',
ConfigPath = 'TauAnalysisAlgorithms/tau_selection_tight.conf'))
acc.addPublicTool(TauSelectorTight)
TauTightWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(ConfigFlags,
TauTightWrapper = acc.getPrimaryAndMerge(AsgSelectionToolWrapperCfg(flags,
name = "TauTightWrapper",
AsgSelectionTool = TauSelectorTight,
StoreGateEntryName = "DFTauTight",
......@@ -81,7 +81,7 @@ def AddTauAugmentationCfg(ConfigFlags, **kwargs):
# Low pT di-taus
def AddDiTauLowPtCfg(ConfigFlags, **kwargs):
def AddDiTauLowPtCfg(flags, **kwargs):
"""Configure the low-pt di-tau building"""
acc = ComponentAccumulator()
......@@ -90,10 +90,10 @@ def AddDiTauLowPtCfg(ConfigFlags, **kwargs):
from JetRecConfig.StandardLargeRJets import AntiKt10LCTopo
from JetRecConfig.StandardJetConstits import stdConstitDic as cst
AntiKt10EMPFlow = AntiKt10LCTopo.clone(inputdef = cst.GPFlow)
acc.merge(JetRecCfg(ConfigFlags,AntiKt10EMPFlow))
acc.merge(JetRecCfg(flags,AntiKt10EMPFlow))
from DiTauRec.DiTauBuilderConfig import DiTauBuilderLowPtCfg
acc.merge(DiTauBuilderLowPtCfg(ConfigFlags, name="DiTauLowPtBuilder"))
acc.merge(DiTauBuilderLowPtCfg(flags, name="DiTauLowPtBuilder"))
return acc
......@@ -103,7 +103,7 @@ def AddTauIDDecorationCfg(flags, **kwargs):
kwargs.setdefault("evetoFix", True)
kwargs.setdefault("DeepSetID", True)
kwargs.setdefault("GNNTauID", True)
kwargs.setdefault("GNNTauID", True)
kwargs.setdefault("TauContainerName", "TauJets")
kwargs.setdefault("prefix", kwargs['TauContainerName'])
......@@ -115,8 +115,11 @@ def AddTauIDDecorationCfg(flags, **kwargs):
if kwargs['evetoFix']:
tools.append( acc.popToolsAndMerge(tauTools.TauWPDecoratorEleRNNFixCfg(flags)) )
if kwargs['DeepSetID']:
# vertex-corrected clusters must be rebuilt for tau ID
if kwargs['DeepSetID'] or kwargs['GNNTauID']:
tools.append( acc.popToolsAndMerge(tauTools.TauVertexedClusterDecoratorCfg(flags)) )
if kwargs['DeepSetID']:
# R22 DeepSet tau ID tune with track RNN scores
tools.append( acc.popToolsAndMerge(tauTools.TauJetDeepSetEvaluatorCfg(flags, version="v1")) )
tools.append( acc.popToolsAndMerge(tauTools.TauWPDecoratorJetDeepSetCfg(flags, version="v1")) )
......
......@@ -21,7 +21,7 @@ namespace DerivationFramework {
// parse the properties of TauWPDecorator tools
for (const auto& tool : m_tauIDTools) {
if ((tool->type() != "TauWPDecorator" )) continue;
if (tool->type() != "TauWPDecorator") continue;
// check whether we must compute eVeto WPs, as this requires the recalculation of a variable
BooleanProperty useAbsEta("UseAbsEta", false);
......@@ -86,23 +86,29 @@ namespace DerivationFramework {
}
const xAOD::VertexContainer* vtxContainer = vtxReadHandle.cptr();
const xAOD::Vertex* pVtx = nullptr;
float sumpt_PV0 = 0., sumpt2_PV0 = 0.;
// Check that PV container exists and is non-empty, find the PV if possible
if(vtxContainer != nullptr && vtxContainer->size()>0) {
ATH_MSG_DEBUG("Found vtx container for decorating taus!");
if (vtxContainer != nullptr && !vtxContainer->empty()) {
auto itrVtx = std::find_if(vtxContainer->begin(), vtxContainer->end(),
[](const xAOD::Vertex* vtx) {
return vtx->vertexType() == xAOD::VxType::PriVtx;
});
pVtx = (itrVtx == vtxContainer->end() ? 0 : *itrVtx);
if(!pVtx){
ATH_MSG_WARNING("No PV found, using the first element instead!");
[](const xAOD::Vertex* vtx) {
return vtx->vertexType() == xAOD::VxType::PriVtx;
});
pVtx = (itrVtx == vtxContainer->end() ? nullptr : *itrVtx);
if (pVtx == nullptr){
ATH_MSG_DEBUG("No PV found, using the first element instead!");
pVtx = vtxContainer->at(0);
}
for (const ElementLink<xAOD::TrackParticleContainer>& trk : pVtx->trackParticleLinks()) {
sumpt_PV0 += (*trk)->pt();
sumpt2_PV0 += std::pow((*trk)->pt(), 2.);
}
}
//Create accessors
static const SG::AuxElement::Decorator<float> acc_trackWidth("trackWidth");
static const SG::AuxElement::Accessor<float> acc_absEtaLead("ABS_ETA_LEAD_TRACK");
static const SG::AuxElement::Accessor<float> acc_dz0_TV_PV0("dz0_TV_PV0");
static const SG::AuxElement::Accessor<float> acc_log_sumpt_TV("log_sumpt_TV");
static const SG::AuxElement::Accessor<float> acc_log_sumpt2_TV("log_sumpt2_TV");
......@@ -110,12 +116,10 @@ namespace DerivationFramework {
static const SG::AuxElement::Accessor<float> acc_log_sumpt2_PV0("log_sumpt2_PV0");
for (const auto tau : *tauContainer) {
float tauTrackBasedWidth = 0;
// equivalent to
// tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged)
float tauTrackBasedWidth = 0.;
// equivalent to tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged)
std::vector<const xAOD::TauTrack *> tauTracks = tau->tracks();
for (const xAOD::TauTrack *trk : tau->tracks(
xAOD::TauJetParameters::TauTrackFlag::classifiedIsolation)) {
for (const xAOD::TauTrack *trk : tau->tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedIsolation)) {
tauTracks.push_back(trk);
}
double sumWeightedDR = 0.;
......@@ -125,7 +129,7 @@ namespace DerivationFramework {
sumWeightedDR += deltaR * track->pt();
ptSum += track->pt();
}
if (ptSum > 0) {
if (ptSum > 0.) {
tauTrackBasedWidth = sumWeightedDR / ptSum;
}
......@@ -134,22 +138,16 @@ namespace DerivationFramework {
// create shallow copy
auto shallowCopy = xAOD::shallowCopyContainer (*tauContainer);
static const SG::AuxElement::Accessor<float> acc_absEtaLead("ABS_ETA_LEAD_TRACK");
for (auto tau : *shallowCopy.first) {
//Add in the TV/PV0 vertex variables needed for some calculators in TauGNNUtils.cxx (for GNTau)
float dz0_TV_PV0 = -999., sumpt_TV = 0., sumpt2_TV = 0., sumpt_PV0 = 0., sumpt2_PV0 = 0.;
if(pVtx!=nullptr) {
float dz0_TV_PV0 = -999., sumpt_TV = 0., sumpt2_TV = 0.;
if (pVtx!=nullptr) {
dz0_TV_PV0 = tau->vertex()->z() - pVtx->z();
for (const ElementLink<xAOD::TrackParticleContainer>& trk : pVtx->trackParticleLinks()) {
sumpt_PV0 += (*trk)->pt();
sumpt2_PV0 += pow((*trk)->pt(), 2.);
}
for (const ElementLink<xAOD::TrackParticleContainer>& trk : tau->vertex()->trackParticleLinks()) {
sumpt_TV += (*trk)->pt();
sumpt2_TV += pow((*trk)->pt(), 2.);
sumpt2_TV += std::pow((*trk)->pt(), 2.);
}
}
acc_dz0_TV_PV0(*tau) = dz0_TV_PV0;
......@@ -157,7 +155,6 @@ namespace DerivationFramework {
acc_log_sumpt2_TV(*tau) = (sumpt2_TV>0.) ? std::log(sumpt2_TV) : 0.;
acc_log_sumpt_PV0(*tau) = (sumpt_PV0>0.) ? std::log(sumpt_PV0) : 0.;
acc_log_sumpt2_PV0(*tau) = (sumpt2_PV0>0.) ? std::log(sumpt2_PV0) : 0.;
//End of vertex variable addition block
// ABS_ETA_LEAD_TRACK is removed from the AOD content and must be redecorated when computing eVeto WPs
// note: this redecoration is not robust against charged track thinning, but charged tracks should never be thinned
......
# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
import unittest
from AthenaConfiguration.AthConfigFlags import AthConfigFlags
......@@ -48,6 +48,7 @@ def createTauConfigFlags():
tau_cfg.addFlag("Tau.MvaTESConfig", "MvaTES_R23.root")
tau_cfg.addFlag("Tau.MinPt0p", 9.25*Units.GeV)
tau_cfg.addFlag("Tau.MinPt", 6.75*Units.GeV)
tau_cfg.addFlag("Tau.MinPtDAOD", 13*Units.GeV)
tau_cfg.addFlag("Tau.TauJetRNNConfig", ["tauid_rnn_1p_R22_v1.json", "tauid_rnn_2p_R22_v1.json", "tauid_rnn_3p_R22_v1.json"])
tau_cfg.addFlag("Tau.TauJetRNNWPConfig", ["tauid_rnnWP_1p_R22_v0.root", "tauid_rnnWP_2p_R22_v0.root", "tauid_rnnWP_3p_R22_v0.root"])
tau_cfg.addFlag("Tau.TauEleRNNConfig", ["taueveto_rnn_config_1P_r22.json", "taueveto_rnn_config_3P_r22.json"])
......
......@@ -854,6 +854,7 @@ def TauGNNEvaluatorCfg(flags):
MaxTracks = 30,
MaxClusters = 20,
MaxClusterDR = 15.0,
MinTauPt = flags.Tau.MinPtDAOD,
VertexCorrection = True,
DecorateTracks = False,
InputLayerScalar = "tau_vars",
......
/*
Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
*/
#include "TauAODRunnerAlg.h"
......@@ -66,6 +66,7 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const {
xAOD::TauJetContainer *newTauCon = outputTauHandle.ptr();
static const SG::AuxElement::Accessor<ElementLink<xAOD::TauJetContainer>> acc_ori_tau_link("originalTauJet");
static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD");
for (const xAOD::TauJet *tau : *pTauContainer) {
// deep copy the tau container
......@@ -88,6 +89,21 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const {
linkToTauTrack.toContainedElement(*newTauTrkCon, newTauTrk);
newTau->addTauTrackLink(linkToTauTrack);
}
// 'ModifiedInAOD' will be overriden by modification tools for relevant candidates
acc_modified(*newTau) = static_cast<char>(false);
StatusCode sc;
for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) {
ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name());
sc = tool->execute(*newTau);
if (sc.isFailure()) break;
}
// if tau candidate was not modified, remove it from container, track cleanup performed by thinning algorithm downstream
if (!acc_modified(*newTau)) {
newTauCon->pop_back();
}
}
// Read the CaloClusterContainer
......@@ -123,49 +139,34 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const {
ATH_CHECK(vertOutHandle.record(std::make_unique<xAOD::VertexContainer>(), std::make_unique<xAOD::VertexAuxContainer>()));
xAOD::VertexContainer* pSecVtxContainer = vertOutHandle.ptr();
int n_tau_modified = 0;
static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD");
for (xAOD::TauJet *pTau : *newTauCon) {
// Loop stops when Failure indicated by one of the tools
StatusCode sc;
//add a identifier of if the tau is modifed by the mod tools
acc_modified(*pTau) = static_cast<char>(false);
// iterate over the copy
for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) {
for (const ToolHandle<ITauToolBase> &tool : m_officialTools) {
ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name());
sc = tool->execute(*pTau);
if (tool->type() == "TauPi0ClusterCreator")
sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer);
else if (tool->type() == "TauVertexVariables")
sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer);
else if (tool->type() == "TauPi0ClusterScaler")
sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer);
else if (tool->type() == "TauPi0ScoreCalculator")
sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
else if (tool->type() == "TauPi0Selector")
sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
else if (tool->type() == "PanTau::PanTauProcessor")
sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer);
else if (tool->type() == "tauRecTools::TauTrackRNNClassifier")
sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon);
else
sc = tool->execute(*pTau);
if (sc.isFailure()) break;
}
if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked modification tools.");
// if tau is not modified by the above tools, never mind running the tools afterward
if (static_cast<bool>(isTauModified(pTau))) {
n_tau_modified++;
for (const ToolHandle<ITauToolBase> &tool : m_officialTools) {
ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name());
if (tool->type() == "TauPi0ClusterCreator")
sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer);
else if (tool->type() == "TauVertexVariables")
sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer);
else if (tool->type() == "TauPi0ClusterScaler")
sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer);
else if (tool->type() == "TauPi0ScoreCalculator")
sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
else if (tool->type() == "TauPi0Selector")
sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
else if (tool->type() == "PanTau::PanTauProcessor")
sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer);
else if (tool->type() == "tauRecTools::TauTrackRNNClassifier")
sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon);
else
sc = tool->execute(*pTau);
if (sc.isFailure()) break;
}
if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools.");
}
if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools.");
}
ATH_MSG_VERBOSE("The tau candidate container has been modified by the rest of the tools");
ATH_MSG_DEBUG(n_tau_modified << " / " << pTauContainer->size() <<" taus were modified");
ATH_MSG_DEBUG(newTauCon->size() << " / " << pTauContainer->size() <<" taus were modified");
return StatusCode::SUCCESS;
}
......
......@@ -4,15 +4,12 @@
#include "tauRecTools/TauGNN.h"
#include "FlavorTagDiscriminants/OnnxUtil.h"
#include "lwtnn/parse_json.hh"
#include "PathResolver/PathResolver.h"
#include <algorithm>
#include <fstream>
#include "lwtnn/LightweightGraph.hh"
#include "lwtnn/Exceptions.hh"
//#include "lwtnn/parse_json.hh"
#include "tauRecTools/TauGNNUtils.h"
TauGNN::TauGNN(const std::string &nnFile, const Config &config):
......@@ -106,8 +103,8 @@ std::tuple<
std::map<std::string, std::vector<char>>,
std::map<std::string, std::vector<float>> >
TauGNN::compute(const xAOD::TauJet &tau,
const std::vector<const xAOD::TauTrack *> &tracks,
const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
const std::vector<const xAOD::TauTrack *> &tracks,
const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
InputMap scalarInputs;
InputSequenceMap vectorInputs;
std::map<std::string, Inputs> gnn_input;
......
......@@ -25,6 +25,7 @@ TauGNNEvaluator::TauGNNEvaluator(const std::string &name):
declareProperty("VertexCorrection", m_doVertexCorrection = true);
declareProperty("DecorateTracks", m_decorateTracks = false);
declareProperty("TrackClassification", m_doTrackClassification = true);
declareProperty("MinTauPt", m_minTauPt = 0.);
// Naming conventions for the network weight files:
declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars");
......@@ -82,8 +83,9 @@ StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const {
output(tau) = -1111.0f;
out_ptau(tau) = -1111.0f;
out_pjet(tau) = -1111.0f;
//Skip execution for low-pT taus to save resources
if(tau.pt()<13000) {
if (tau.pt() < m_minTauPt) {
return StatusCode::SUCCESS;
}
......@@ -129,7 +131,7 @@ StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<cons
std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin();
while(it != tracks.end()) {
if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
it = tracks.erase(it);
it = tracks.erase(it);
}
else {
++it;
......@@ -151,8 +153,7 @@ StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xA
TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters();
for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) {
for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) {
TLorentzVector clusterP4 = vertexedCluster.p4();
if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
......
/*
Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
*/
#include "tauRecTools/TauGNNUtils.h"
......@@ -32,6 +32,7 @@ bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
const std::vector<const xAOD::TauTrack *> &tracks,
std::vector<double> &out) const {
out.clear();
out.reserve(tracks.size());
// Retrieve calculator function
TrackCalc func = nullptr;
......@@ -57,6 +58,7 @@ bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
std::vector<double> &out) const {
out.clear();
out.reserve(clusters.size());
// Retrieve calculator function
ClusterCalc func = nullptr;
......
......@@ -16,11 +16,6 @@
#include <string>
#include <map>
// Forward declaration
namespace lwt {
class LightweightGraph;
}
namespace TauGNNUtils {
class GNNVarCalc;
}
......@@ -74,10 +69,6 @@ public:
return m_var_calc.get();
}
explicit operator bool() const {
return static_cast<bool>(m_graph);
}
//Make the output config transparent to external tools
FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config;
......@@ -92,7 +83,6 @@ private:
private:
const Config m_config;
std::unique_ptr<const lwt::LightweightGraph> m_graph;
// Names of the input variables
std::vector<std::string> m_scalar_inputs;
......
......@@ -51,6 +51,7 @@ private:
std::size_t m_max_tracks;
std::size_t m_max_clusters;
float m_max_cluster_dr;
float m_minTauPt;
bool m_doVertexCorrection;
bool m_doTrackClassification;
bool m_decorateTracks;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment