From c2b47b9bf87cf7e62b7524da1a3f2b66a7071ca8 Mon Sep 17 00:00:00 2001
From: Bertrand Martin Dit Latour <bertrand.martindl@cern.ch>
Date: Tue, 23 Feb 2021 18:07:53 +0000
Subject: [PATCH] tauRecTools: skip LRTs in tau track classification and
 downstream tools (ATLTAU-1772)

Hello,

In R22, we will have a dedicated reconstruction at xAOD level targeting Long Lived Particles decaying to taus, which will use Large Radius Tracks within 0.4 of the tau axis, and we'll have dedicated LRT tunes for tau algorithms (tau ID, ...).
For the R22 reprocessing already, we plan to associated LRTs with taus, so that LRTs can be retrieved via tau->tracks(xAOD::TauJetParameters::LargeRadiusTrack), which will facilite a lot the navigation at xAOD level.
However, we don't want the LRTs to participate in the standard tau reconstruction, besides the tau-track association.
This MR implements a mechanism to skip Large Radius Tracks in the tau track classifier and all downstream tools, even if LRTs have been associated with taus in TauTrackFinder.

The modification is made in the tau track classifier. If the tool is instructed to ignore LRTs (via tauRecFlags), it will filter out the LRTs possibly associated with the tau, decorate these with default RNN track classification scores, and proceed to classify the non-LRTs. I also fixed a problem with an incorrectly set xAOD::TauJetParameters::unclassified tau track flag.

The RNN tau ID tool now explicitly ignores the tracks that are not classified, so it is not affected by the possible presence of LRTs associated with taus in TauTrackFinder.
I have checked that the RNN ID score is unchanged when we associate LRTs with taus but don't classify them.
More validation will follow.

In TauAlgorithmsHolder, deprecated tools were moved to the end of the file, and will be cleaned up at the next occasion.

Cheers,
Bertrand
---
 .../tauRec/python/TauAlgorithmsHolder.py      | 346 +++++++++---------
 Reconstruction/tauRec/python/tauRecFlags.py   |  14 +-
 .../tauRecTools/Root/TauJetRNNEvaluator.cxx   |  19 +-
 .../Root/TauTrackRNNClassifier.cxx            | 118 +++---
 .../tauRecTools/TauJetRNNEvaluator.h          |   1 +
 .../tauRecTools/TauTrackRNNClassifier.h       |  10 +-
 .../python/TrigTauAlgorithmsHolder.py         |  31 +-
 7 files changed, 294 insertions(+), 245 deletions(-)

diff --git a/Reconstruction/tauRec/python/TauAlgorithmsHolder.py b/Reconstruction/tauRec/python/TauAlgorithmsHolder.py
index d36f039942a..500b1bc470d 100644
--- a/Reconstruction/tauRec/python/TauAlgorithmsHolder.py
+++ b/Reconstruction/tauRec/python/TauAlgorithmsHolder.py
@@ -254,21 +254,6 @@ def getTauCommonCalcVars():
     cached_instances[_name] = TauCommonCalcVars    
     return TauCommonCalcVars
 
-
-#########################################################################
-# Tau Test
-def getTauTestDump():
-    _name = sPrefix + 'TauTestDump'
-    
-    if _name in cached_instances:
-        return cached_instances[_name]
-    
-    from tauRecTools.tauRecToolsConf import TauTestDump
-    TauTestDump = TauTestDump(name = _name)
-    
-    cached_instances[_name] = TauTestDump
-    return TauTestDump
-
 #########################################################################
 # Tau Vertex Variables
 def getTauVertexVariables():
@@ -594,7 +579,7 @@ def getTauTrackFinder(removeDuplicateTracks=True):
                                     tauParticleCache = getParticleCache(),
                                     removeDuplicateCoreTracks = removeDuplicateTracks,
                                     Key_trackPartInputContainer = _DefaultTrackContainer,
-                                    Key_LargeD0TrackInputContainer = _DefaultLargeD0TrackContainer if tauFlags.useLargeD0Tracks else "",
+                                    Key_LargeD0TrackInputContainer = _DefaultLargeD0TrackContainer if tauFlags.associateLRT() else "",
                                     TrackToVertexIPEstimator = getTauTrackToVertexIPEstimator(),
                                     #maxDeltaZ0wrtLeadTrk = 2, #in mm
                                     #removeTracksOutsideZ0wrtLeadTrk = True
@@ -614,7 +599,7 @@ def getTauClusterFinder():
     from JetRec.JetRecFlags import jetFlags
 
     doJetVertexCorrection = False
-    if tauFlags.isStandalone:
+    if tauFlags.isStandalone():
         doJetVertexCorrection = True
     if jetFlags.useVertices() and jetFlags.useTracks():
         doJetVertexCorrection = True
@@ -658,72 +643,8 @@ def getTauCombinedTES():
     cached_instances[_name] = TauCombinedTES
     return TauCombinedTES
     
-#########################################################################
-def getTauTrackClassifier():
-    _name = sPrefix + 'TauTrackClassifier'
-    
-    if _name in cached_instances:
-        return cached_instances[_name]
-    
-    from AthenaCommon.AppMgr import ToolSvc
-    from tauRecTools.tauRecToolsConf import tauRecTools__TauTrackClassifier as TauTrackClassifier
-    from tauRecTools.tauRecToolsConf import tauRecTools__TrackMVABDT as TrackMVABDT
 
-    import PyUtils.RootUtils as ru
-    ROOT = ru.import_root()
-    import cppyy
-    cppyy.load_library('libxAODTau_cDict.so')
-
-    # =========================================================================
-    _BDT_TTCT_ITFT_0 = TrackMVABDT(name = _name + "_0",
-                                   #InputWeightsPath = "TMVAClassification_BDT.weights.root",
-                                   #Threshold      = -0.005,
-                                   InputWeightsPath = tauFlags.tauRecMVATrackClassificationConfig()[0][0],
-                                   Threshold = tauFlags.tauRecMVATrackClassificationConfig()[0][1],
-                                   ExpectedFlag   = ROOT.xAOD.TauJetParameters.TauTrackFlag.unclassified, 
-                                   SignalType     = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedCharged, 
-                                   BackgroundType = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedIsolation,
-                                   calibFolder = tauFlags.tauRecToolsCVMFSPath(), 
-                                   )
-    ToolSvc += _BDT_TTCT_ITFT_0
-    cached_instances[_BDT_TTCT_ITFT_0.name] = _BDT_TTCT_ITFT_0
-    
-    _BDT_TTCT_ITFT_0_0 = TrackMVABDT(name = _name + "_0_0",
-                                     #InputWeightsPath = "TMVAClassification_BDT_0.weights.root",
-                                     #Threshold      = -0.0074,
-                                     InputWeightsPath = tauFlags.tauRecMVATrackClassificationConfig()[1][0],
-                                     Threshold = tauFlags.tauRecMVATrackClassificationConfig()[1][1],
-                                     ExpectedFlag   = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedCharged,
-                                     SignalType     = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedCharged,
-                                     BackgroundType = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedConversion,
-                                     calibFolder = tauFlags.tauRecToolsCVMFSPath(),
-                                     )
-    ToolSvc += _BDT_TTCT_ITFT_0_0
-    cached_instances[_BDT_TTCT_ITFT_0_0.name] = _BDT_TTCT_ITFT_0_0
-    
-    _BDT_TTCT_ITFT_0_1 = TrackMVABDT(name = _name + "_0_1",
-                                     #InputWeightsPath = "TMVAClassification_BDT_1.weights.root",
-                                     #Threshold      = 0.0005,
-                                     InputWeightsPath = tauFlags.tauRecMVATrackClassificationConfig()[2][0],
-                                     Threshold = tauFlags.tauRecMVATrackClassificationConfig()[2][1],
-                                     ExpectedFlag   = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedIsolation, 
-                                     SignalType     = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedIsolation, 
-                                     BackgroundType = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedFake,
-                                     calibFolder = tauFlags.tauRecToolsCVMFSPath(),
-                                     )
-    ToolSvc += _BDT_TTCT_ITFT_0_1
-    cached_instances[_BDT_TTCT_ITFT_0_1.name] = _BDT_TTCT_ITFT_0_1
-
-    # create tool alg
-    myTauTrackClassifier = TauTrackClassifier( name = _name,
-                                               Classifiers = [_BDT_TTCT_ITFT_0, _BDT_TTCT_ITFT_0_0, _BDT_TTCT_ITFT_0_1] )
-    #ToolSvc += TauTrackClassifier #only add to tool service sub tools to your tool, the main tool will be added via TauRecConfigured
-    cached_instances[_name] = myTauTrackClassifier 
-
-    return myTauTrackClassifier
-
-########################################################################                                                                                                             
-#            
+########################################################################
 def getTauTrackRNNClassifier():
     _name = sPrefix + 'TauTrackRNNClassifier'
     
@@ -737,100 +658,26 @@ def getTauTrackRNNClassifier():
     import cppyy
     cppyy.load_library('libxAODTau_cDict')
 
-    _RNN= TrackRNN(name = _name + "_0",
-                   InputWeightsPath = tauFlags.tauRecRNNTrackClassificationConfig()[0],
-                   calibFolder = tauFlags.tauRecToolsCVMFSPath(), 
+    _RNN = TrackRNN(name = _name + "_0",
+                    InputWeightsPath = tauFlags.tauRecRNNTrackClassificationConfig()[0],
+                    calibFolder = tauFlags.tauRecToolsCVMFSPath(), 
                    )
-
     ToolSvc += _RNN
     cached_instances[_RNN.name] = _RNN
-    
+
+    _classifyLRT = True
+    if tauFlags.associateLRT() and not tauFlags.classifyLRT():
+        _classifyLRT = False
+
     # create tool alg
     myTauTrackClassifier = TauTrackRNNClassifier( name = _name,
-                                               Classifiers = [_RNN] )
-    cached_instances[_name] = myTauTrackClassifier 
+                                                  Classifiers = [_RNN],
+                                                  classifyLRT = _classifyLRT )
 
+    cached_instances[_name] = myTauTrackClassifier 
     return myTauTrackClassifier
 
-########################################################################                                                                                                             
-#
-def getTauWPDecoratorJetRNN():
-    import PyUtils.RootUtils as ru
-    ROOT = ru.import_root()
-    import cppyy
-    cppyy.load_library('libxAODTau_cDict')
-
-    _name = sPrefix + 'TauWPDecoratorJetRNN'
-    from tauRecTools.tauRecToolsConf import TauWPDecorator
-    myTauWPDecorator = TauWPDecorator( name=_name,
-                                       flatteningFile1Prong = "rnnid_mc16d_flat_1p.root",
-                                       flatteningFile3Prong = "rnnid_mc16d_flat_3p.root",
-                                       CutEnumVals =
-                                       [ ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigVeryLoose, ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigLoose,
-                                         ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigMedium, ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigTight ],
-                                       SigEff1P = [0.95, 0.85, 0.75, 0.60],
-                                       SigEff3P = [0.95, 0.75, 0.60, 0.45],
-                                       ScoreName = "RNNJetScore",
-                                       NewScoreName = "RNNJetScoreSigTrans",
-                                       DefineWPs = True,
-                                       )
-    cached_instances[_name] = myTauWPDecorator
-    return myTauWPDecorator
-
-
-#                                                                                                                                                                                  
-def getTauWPDecoratorJetBDT():
-    import PyUtils.RootUtils as ru
-    ROOT = ru.import_root()
-    import cppyy
-    cppyy.load_library('libxAODTau_cDict')
-
-    _name = sPrefix + 'TauWPDecoratorJetBDT'
-    from tauRecTools.tauRecToolsConf import TauWPDecorator
-    myTauWPDecorator = TauWPDecorator( name=_name,
-                                       flatteningFile1Prong = "FlatJetBDT1Pv2.root", #update
-                                       flatteningFile3Prong = "FlatJetBDT3Pv2.root", #update
-                                       CutEnumVals = 
-                                       [ ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigVeryLoose, ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigLoose,
-                                         ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigMedium, ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigTight ],
-                                       SigEff1P = [0.95, 0.85, 0.75, 0.60],
-                                       SigEff3P = [0.95, 0.75, 0.60, 0.45],
-                                       ScoreName = "BDTJetScore",
-                                       NewScoreName = "BDTJetScoreSigTrans",
-                                       DefineWPs = True,
-                                       )
-    cached_instances[_name] = myTauWPDecorator
-    return myTauWPDecorator
-
-
-# 
-def getTauWPDecoratorEleBDT():
-    import PyUtils.RootUtils as ru
-    ROOT = ru.import_root()
-    import cppyy
-    cppyy.load_library('libxAODTau_cDict')
-
-    _name = sPrefix + 'TauWPDecoratorEleBDT'
-    from tauRecTools.tauRecToolsConf import TauWPDecorator
-    TauScoreFlatteningTool = TauWPDecorator( name=_name,
-                                             flatteningFile1Prong = "EleBDTFlat1P.root",#update
-                                             flatteningFile3Prong = "EleBDTFlat3P.root",#update                                             
-                                             UseEleBDT = True ,
-                                             ScoreName = "BDTEleScore",
-                                             NewScoreName = "BDTEleScoreSigTrans", #dynamic
-                                             DefineWPs = True,
-                                             CutEnumVals = 
-                                             [ROOT.xAOD.TauJetParameters.IsTauFlag.EleBDTLoose, 
-                                              ROOT.xAOD.TauJetParameters.IsTauFlag.EleBDTMedium, 
-                                              ROOT.xAOD.TauJetParameters.IsTauFlag.EleBDTTight],
-                                             SigEff1P = [0.95, 0.85, 0.75],
-                                             SigEff3P = [0.95, 0.85, 0.75],
-                                             ) 
-    cached_instances[_name] = TauScoreFlatteningTool
-    return TauScoreFlatteningTool
-
 
-#
 def getTauJetRNNEvaluator():
     _name = sPrefix + 'TauJetRNN'
     from tauRecTools.tauRecToolsConf import TauJetRNNEvaluator
@@ -853,18 +700,29 @@ def getTauJetRNNEvaluator():
     return myTauJetRNNEvaluator
 
 
-def getTauJetBDTEvaluator(_n, weightsFile="", minNTracks=0, maxNTracks=10000, outputVarName="BDTJetScore", minAbsTrackEta=-1, maxAbsTrackEta=-1):
-    _name = sPrefix + _n
-    from tauRecTools.tauRecToolsConf import TauJetBDTEvaluator
-    myTauJetBDTEvaluator = TauJetBDTEvaluator(name=_name,
-                                              weightsFile=weightsFile, #update config?
-                                              minNTracks=minNTracks,
-                                              maxNTracks=maxNTracks,
-                                              minAbsTrackEta=minAbsTrackEta,
-                                              maxAbsTrackEta=maxAbsTrackEta,
-                                              outputVarName=outputVarName)
-    cached_instances[_name] = myTauJetBDTEvaluator
-    return myTauJetBDTEvaluator
+def getTauWPDecoratorJetRNN():
+    import PyUtils.RootUtils as ru
+    ROOT = ru.import_root()
+    import cppyy
+    cppyy.load_library('libxAODTau_cDict')
+
+    _name = sPrefix + 'TauWPDecoratorJetRNN'
+    from tauRecTools.tauRecToolsConf import TauWPDecorator
+    myTauWPDecorator = TauWPDecorator( name=_name,
+                                       flatteningFile1Prong = "rnnid_mc16d_flat_1p.root",
+                                       flatteningFile3Prong = "rnnid_mc16d_flat_3p.root",
+                                       CutEnumVals =
+                                       [ ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigVeryLoose, ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigLoose,
+                                         ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigMedium, ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigTight ],
+                                       SigEff1P = [0.95, 0.85, 0.75, 0.60],
+                                       SigEff3P = [0.95, 0.75, 0.60, 0.45],
+                                       ScoreName = "RNNJetScore",
+                                       NewScoreName = "RNNJetScoreSigTrans",
+                                       DefineWPs = True,
+                                       )
+    cached_instances[_name] = myTauWPDecorator
+    return myTauWPDecorator
+
 
 def getTauIDVarCalculator():
     _name = sPrefix + 'TauIDVarCalculator'
@@ -979,3 +837,133 @@ def getTVATool():
     cached_instances[_name] = TVATool
     return TVATool
 
+
+
+# deprecated in R22
+
+def getTauTrackClassifier():
+    _name = sPrefix + 'TauTrackClassifier'
+    
+    if _name in cached_instances:
+        return cached_instances[_name]
+    
+    from AthenaCommon.AppMgr import ToolSvc
+    from tauRecTools.tauRecToolsConf import tauRecTools__TauTrackClassifier as TauTrackClassifier
+    from tauRecTools.tauRecToolsConf import tauRecTools__TrackMVABDT as TrackMVABDT
+
+    import PyUtils.RootUtils as ru
+    ROOT = ru.import_root()
+    import cppyy
+    cppyy.load_library('libxAODTau_cDict.so')
+
+    # =========================================================================
+    _BDT_TTCT_ITFT_0 = TrackMVABDT(name = _name + "_0",
+                                   #InputWeightsPath = "TMVAClassification_BDT.weights.root",
+                                   #Threshold      = -0.005,
+                                   InputWeightsPath = tauFlags.tauRecMVATrackClassificationConfig()[0][0],
+                                   Threshold = tauFlags.tauRecMVATrackClassificationConfig()[0][1],
+                                   ExpectedFlag   = ROOT.xAOD.TauJetParameters.TauTrackFlag.unclassified, 
+                                   SignalType     = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedCharged, 
+                                   BackgroundType = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedIsolation,
+                                   calibFolder = tauFlags.tauRecToolsCVMFSPath(), 
+                                   )
+    ToolSvc += _BDT_TTCT_ITFT_0
+    cached_instances[_BDT_TTCT_ITFT_0.name] = _BDT_TTCT_ITFT_0
+    
+    _BDT_TTCT_ITFT_0_0 = TrackMVABDT(name = _name + "_0_0",
+                                     #InputWeightsPath = "TMVAClassification_BDT_0.weights.root",
+                                     #Threshold      = -0.0074,
+                                     InputWeightsPath = tauFlags.tauRecMVATrackClassificationConfig()[1][0],
+                                     Threshold = tauFlags.tauRecMVATrackClassificationConfig()[1][1],
+                                     ExpectedFlag   = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedCharged,
+                                     SignalType     = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedCharged,
+                                     BackgroundType = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedConversion,
+                                     calibFolder = tauFlags.tauRecToolsCVMFSPath(),
+                                     )
+    ToolSvc += _BDT_TTCT_ITFT_0_0
+    cached_instances[_BDT_TTCT_ITFT_0_0.name] = _BDT_TTCT_ITFT_0_0
+    
+    _BDT_TTCT_ITFT_0_1 = TrackMVABDT(name = _name + "_0_1",
+                                     #InputWeightsPath = "TMVAClassification_BDT_1.weights.root",
+                                     #Threshold      = 0.0005,
+                                     InputWeightsPath = tauFlags.tauRecMVATrackClassificationConfig()[2][0],
+                                     Threshold = tauFlags.tauRecMVATrackClassificationConfig()[2][1],
+                                     ExpectedFlag   = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedIsolation, 
+                                     SignalType     = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedIsolation, 
+                                     BackgroundType = ROOT.xAOD.TauJetParameters.TauTrackFlag.classifiedFake,
+                                     calibFolder = tauFlags.tauRecToolsCVMFSPath(),
+                                     )
+    ToolSvc += _BDT_TTCT_ITFT_0_1
+    cached_instances[_BDT_TTCT_ITFT_0_1.name] = _BDT_TTCT_ITFT_0_1
+
+    # create tool alg
+    myTauTrackClassifier = TauTrackClassifier( name = _name,
+                                               Classifiers = [_BDT_TTCT_ITFT_0, _BDT_TTCT_ITFT_0_0, _BDT_TTCT_ITFT_0_1] )
+    #ToolSvc += TauTrackClassifier #only add to tool service sub tools to your tool, the main tool will be added via TauRecConfigured
+    cached_instances[_name] = myTauTrackClassifier 
+
+    return myTauTrackClassifier
+
+
+def getTauJetBDTEvaluator(_n, weightsFile="", minNTracks=0, maxNTracks=10000, outputVarName="BDTJetScore", minAbsTrackEta=-1, maxAbsTrackEta=-1):
+    _name = sPrefix + _n
+    from tauRecTools.tauRecToolsConf import TauJetBDTEvaluator
+    myTauJetBDTEvaluator = TauJetBDTEvaluator(name=_name,
+                                              weightsFile=weightsFile,
+                                              minNTracks=minNTracks,
+                                              maxNTracks=maxNTracks,
+                                              minAbsTrackEta=minAbsTrackEta,
+                                              maxAbsTrackEta=maxAbsTrackEta,
+                                              outputVarName=outputVarName)
+    cached_instances[_name] = myTauJetBDTEvaluator
+    return myTauJetBDTEvaluator
+
+
+def getTauWPDecoratorJetBDT():
+    import PyUtils.RootUtils as ru
+    ROOT = ru.import_root()
+    import cppyy
+    cppyy.load_library('libxAODTau_cDict')
+
+    _name = sPrefix + 'TauWPDecoratorJetBDT'
+    from tauRecTools.tauRecToolsConf import TauWPDecorator
+    myTauWPDecorator = TauWPDecorator( name=_name,
+                                       flatteningFile1Prong = "FlatJetBDT1Pv2.root",
+                                       flatteningFile3Prong = "FlatJetBDT3Pv2.root",
+                                       CutEnumVals = 
+                                       [ ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigVeryLoose, ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigLoose,
+                                         ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigMedium, ROOT.xAOD.TauJetParameters.IsTauFlag.JetBDTSigTight ],
+                                       SigEff1P = [0.95, 0.85, 0.75, 0.60],
+                                       SigEff3P = [0.95, 0.75, 0.60, 0.45],
+                                       ScoreName = "BDTJetScore",
+                                       NewScoreName = "BDTJetScoreSigTrans",
+                                       DefineWPs = True,
+                                       )
+    cached_instances[_name] = myTauWPDecorator
+    return myTauWPDecorator
+
+
+def getTauWPDecoratorEleBDT():
+    import PyUtils.RootUtils as ru
+    ROOT = ru.import_root()
+    import cppyy
+    cppyy.load_library('libxAODTau_cDict')
+
+    _name = sPrefix + 'TauWPDecoratorEleBDT'
+    from tauRecTools.tauRecToolsConf import TauWPDecorator
+    TauScoreFlatteningTool = TauWPDecorator( name=_name,
+                                             flatteningFile1Prong = "EleBDTFlat1P.root",
+                                             flatteningFile3Prong = "EleBDTFlat3P.root",
+                                             UseEleBDT = True ,
+                                             ScoreName = "BDTEleScore",
+                                             NewScoreName = "BDTEleScoreSigTrans",
+                                             DefineWPs = True,
+                                             CutEnumVals = 
+                                             [ROOT.xAOD.TauJetParameters.IsTauFlag.EleBDTLoose, 
+                                              ROOT.xAOD.TauJetParameters.IsTauFlag.EleBDTMedium, 
+                                              ROOT.xAOD.TauJetParameters.IsTauFlag.EleBDTTight],
+                                             SigEff1P = [0.95, 0.85, 0.75],
+                                             SigEff3P = [0.95, 0.85, 0.75],
+                                             ) 
+    cached_instances[_name] = TauScoreFlatteningTool
+    return TauScoreFlatteningTool
diff --git a/Reconstruction/tauRec/python/tauRecFlags.py b/Reconstruction/tauRec/python/tauRecFlags.py
index cf7bbe7b8d2..79acd883490 100644
--- a/Reconstruction/tauRec/python/tauRecFlags.py
+++ b/Reconstruction/tauRec/python/tauRecFlags.py
@@ -58,9 +58,15 @@ class doTJVA(JobProperty):
     allowedTypes=['bool']
     StoredValue=True
 
-class useLargeD0Tracks(JobProperty):
-    """ Use LRT tracks in tau track finding """
-    statusOn=False
+class associateLRT(JobProperty):
+    """ associate Large Radius Tracks with tau in TauTrackFinder """
+    statusOn=True
+    allowedTypes=['bool']
+    StoredValue=False
+
+class classifyLRT(JobProperty):
+    """ classify Large Radius Tracks in tau track classifier """
+    statusOn=True
     allowedTypes=['bool']
     StoredValue=False
 
@@ -256,7 +262,7 @@ class tauRecFlags(JobPropertyContainer):
 jobproperties.add_Container(tauRecFlags)
 
 # I want always the following flags in the Rec container  
-_list_tau=[Enabled,doTauRec,isStandalone,tauRecSeedJetCollection,tauRecToolsCVMFSPath,doTJVA,useLargeD0Tracks,removeDuplicateCoreTracks,tauRecMVATrackClassification,tauRecRNNTrackClassification,tauRecMVATrackClassificationConfig,tauRecRNNTrackClassificationConfig,tauRecDecayModeNNClassifierConfig,tauRecCalibrateLCConfig,tauRecMvaTESConfig,tauRecCombinedTESConfig,tauRecTauJetRNNConfig,tauRecTauEleRNNConfig,tauRecSeedMinPt,tauRecSeedMaxEta,tauRecMinPt,tauRecMaxNTracks,tauRecToolsDevToolList,tauRecToolsDevToolListProcessor,doRunTauDiscriminant,doPanTau,doPi0,pi0EtCuts,pi0MVACuts_1prong,pi0MVACuts_mprong,shotPtCut_1Photon,shotPtCut_2Photons,useOldVertexFitterAPI]
+_list_tau=[Enabled,doTauRec,isStandalone,tauRecSeedJetCollection,tauRecToolsCVMFSPath,doTJVA,associateLRT,classifyLRT,removeDuplicateCoreTracks,tauRecMVATrackClassification,tauRecRNNTrackClassification,tauRecMVATrackClassificationConfig,tauRecRNNTrackClassificationConfig,tauRecDecayModeNNClassifierConfig,tauRecCalibrateLCConfig,tauRecMvaTESConfig,tauRecCombinedTESConfig,tauRecTauJetRNNConfig,tauRecTauEleRNNConfig,tauRecSeedMinPt,tauRecSeedMaxEta,tauRecMinPt,tauRecMaxNTracks,tauRecToolsDevToolList,tauRecToolsDevToolListProcessor,doRunTauDiscriminant,doPanTau,doPi0,pi0EtCuts,pi0MVACuts_1prong,pi0MVACuts_mprong,shotPtCut_1Photon,shotPtCut_2Photons,useOldVertexFitterAPI]
 for j in _list_tau: 
     jobproperties.tauRecFlags.add_JobProperty(j)
 del _list_tau
diff --git a/Reconstruction/tauRecTools/Root/TauJetRNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauJetRNNEvaluator.cxx
index 097803c5ecb..886729e2371 100644
--- a/Reconstruction/tauRecTools/Root/TauJetRNNEvaluator.cxx
+++ b/Reconstruction/tauRecTools/Root/TauJetRNNEvaluator.cxx
@@ -27,6 +27,7 @@ TauJetRNNEvaluator::TauJetRNNEvaluator(const std::string &name):
   declareProperty("MaxClusters", m_max_clusters = 6);
   declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f);
   declareProperty("VertexCorrection", m_doVertexCorrection = true);
+  declareProperty("TrackClassification", m_doTrackClassification = true);
 
   // Naming conventions for the network weight files:
   declareProperty("InputLayerScalar", m_input_layer_scalar = "scalar");
@@ -182,7 +183,23 @@ const TauJetRNN* TauJetRNNEvaluator::get_rnn_3p() const {
 }
 
 StatusCode TauJetRNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
-  auto tracks = tau.allTracks();
+  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
+
+  // Skip unclassified tracks:
+  // - the track is a LRT and classifyLRT = false
+  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
+  // - track classification is not run (trigger)
+  if(m_doTrackClassification) {
+    std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin();
+    while(it != tracks.end()) {
+      if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
+	it = tracks.erase(it);
+      }
+      else {
+	++it;
+      }
+    }
+  }
 
   // Sort by descending pt
   auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
diff --git a/Reconstruction/tauRecTools/Root/TauTrackRNNClassifier.cxx b/Reconstruction/tauRecTools/Root/TauTrackRNNClassifier.cxx
index 9d23e5432ee..00f2c9eb150 100644
--- a/Reconstruction/tauRecTools/Root/TauTrackRNNClassifier.cxx
+++ b/Reconstruction/tauRecTools/Root/TauTrackRNNClassifier.cxx
@@ -23,8 +23,9 @@ using namespace tauRecTools;
 //==============================================================================
 
 //______________________________________________________________________________
-TauTrackRNNClassifier::TauTrackRNNClassifier(const std::string& sName)
-  : TauRecToolBase(sName) {
+TauTrackRNNClassifier::TauTrackRNNClassifier(const std::string& name)
+  : TauRecToolBase(name) {
+  declareProperty("classifyLRT", m_classifyLRT = true);
 }
 
 //______________________________________________________________________________
@@ -35,11 +36,9 @@ TauTrackRNNClassifier::~TauTrackRNNClassifier()
 //______________________________________________________________________________
 StatusCode TauTrackRNNClassifier::initialize()
 {
-  ATH_MSG_DEBUG("intialize classifiers");
-
-  for (auto cClassifier : m_vClassifier){
-    ATH_MSG_INFO("TauTrackRNNClassifier tool : " << cClassifier );
-    ATH_CHECK(cClassifier.retrieve());
+  for (auto classifier : m_vClassifier){
+    ATH_MSG_INFO("Intialize TauTrackRNNClassifier tool : " << classifier );
+    ATH_CHECK(classifier.retrieve());
   }
  
   return StatusCode::SUCCESS;
@@ -50,21 +49,41 @@ StatusCode TauTrackRNNClassifier::executeTrackClassifier(xAOD::TauJet& xTau, xAO
 
   std::vector<xAOD::TauTrack*> vTracks = xAOD::TauHelpers::allTauTracksNonConst(&xTau, &tauTrackCon);
 
-  for (xAOD::TauTrack* xTrack : vTracks)
-    {
-      // reset all track flags and set status to unclassified
-      xTrack->setFlag(xAOD::TauJetParameters::classifiedCharged, false);
-      xTrack->setFlag(xAOD::TauJetParameters::classifiedConversion, false);
-      xTrack->setFlag(xAOD::TauJetParameters::classifiedIsolation, false);
-      xTrack->setFlag(xAOD::TauJetParameters::classifiedFake, false);
-      xTrack->setFlag(xAOD::TauJetParameters::unclassified, true);
+  for (xAOD::TauTrack* xTrack : vTracks) {
+    // reset all track flags and set status to unclassified
+    xTrack->setFlag(xAOD::TauJetParameters::classifiedCharged, false);
+    xTrack->setFlag(xAOD::TauJetParameters::classifiedConversion, false);
+    xTrack->setFlag(xAOD::TauJetParameters::classifiedIsolation, false);
+    xTrack->setFlag(xAOD::TauJetParameters::classifiedFake, false);
+    xTrack->setFlag(xAOD::TauJetParameters::unclassified, true);
+  }
+
+  // don't classify LRTs even if LRTs were associated with taus in TauTrackFinder
+  if(!m_classifyLRT) {
+    std::vector<xAOD::TauTrack*> vLRTs;
+    std::vector<xAOD::TauTrack*>::iterator it = vTracks.begin(); 
+    while(it != vTracks.end()) {      
+      if((*it)->flag(xAOD::TauJetParameters::LargeRadiusTrack)) {	
+	vLRTs.push_back(*it);
+        it = vTracks.erase(it);
+      }
+      else {
+	++it;
+      }
     }
-  
-  for (auto cClassifier : m_vClassifier) {
-    ATH_CHECK(cClassifier->classifyTracks(vTracks, xTau));
+
+    // decorate LRTs with default RNN scores
+    for (auto classifier : m_vClassifier) {
+      ATH_CHECK(classifier->classifyTracks(vLRTs, xTau, true));
+    }
+  }
+
+  // classify tracks
+  for (auto classifier : m_vClassifier) {
+    ATH_CHECK(classifier->classifyTracks(vTracks, xTau));
   }
 
-  std::vector< ElementLink< xAOD::TauTrackContainer > > &tauTrackLinks(xTau.allTauTrackLinksNonConst());
+  std::vector< ElementLink< xAOD::TauTrackContainer > >& tauTrackLinks(xTau.allTauTrackLinksNonConst());
   std::sort(tauTrackLinks.begin(), tauTrackLinks.end(), sortTracks);
   float charge=0.0;
   for( const xAOD::TauTrack* trk : xTau.tracks(xAOD::TauJetParameters::classifiedCharged) ){
@@ -82,13 +101,15 @@ StatusCode TauTrackRNNClassifier::executeTrackClassifier(xAOD::TauJet& xTau, xAO
 
   //set modifiedIsolationTrack
   for (xAOD::TauTrack* xTrack : vTracks) {
-    if( not xTrack->flag(xAOD::TauJetParameters::classifiedCharged) and 
-	xTrack->flag(xAOD::TauJetParameters::passTrkSelector) ) xTrack->setFlag(xAOD::TauJetParameters::modifiedIsolationTrack, true);
-    else xTrack->setFlag(xAOD::TauJetParameters::modifiedIsolationTrack, false);
+    if( not xTrack->flag(xAOD::TauJetParameters::classifiedCharged) and xTrack->flag(xAOD::TauJetParameters::passTrkSelector) ) {
+      xTrack->setFlag(xAOD::TauJetParameters::modifiedIsolationTrack, true);
+    }
+    else {
+      xTrack->setFlag(xAOD::TauJetParameters::modifiedIsolationTrack, false);
+    }
   }
   xTau.setDetail(xAOD::TauJetParameters::nModifiedIsolationTracks, (int) xTau.nTracks(xAOD::TauJetParameters::modifiedIsolationTrack));
 
-
   return StatusCode::SUCCESS;
 }
 
@@ -97,13 +118,13 @@ StatusCode TauTrackRNNClassifier::executeTrackClassifier(xAOD::TauJet& xTau, xAO
 //==============================================================================
 
 //______________________________________________________________________________
-TrackRNN::TrackRNN(const std::string& sName)
-  : TauRecToolBase(sName)
-  , m_sInputWeightsPath("")
+TrackRNN::TrackRNN(const std::string& name)
+  : TauRecToolBase(name)
+  , m_inputWeightsPath("")
 {
   // for conversion compatibility cast nTracks 
   int nMaxNtracks = 0;
-  declareProperty( "InputWeightsPath", m_sInputWeightsPath );
+  declareProperty( "InputWeightsPath", m_inputWeightsPath );
   declareProperty( "MaxNtracks",  nMaxNtracks);
   m_nMaxNtracks = (unsigned int)nMaxNtracks;
 }
@@ -122,18 +143,30 @@ StatusCode TrackRNN::initialize()
 }
 
 //______________________________________________________________________________
-StatusCode TrackRNN::classifyTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD::TauJet& xTau) const
+StatusCode TrackRNN::classifyTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD::TauJet& xTau, bool skipTracks) const
 {
-  if(vTracks.size() == 0)
+  if(vTracks.size() == 0) {
     return StatusCode::SUCCESS;
-
-  std::sort(vTracks.begin(), vTracks.end(), [](const xAOD::TauTrack * a, const xAOD::TauTrack * b) {return a->pt() > b->pt();});
+  }
 
   static const SG::AuxElement::Accessor<float> idScoreCharged("rnn_chargedScore");
   static const SG::AuxElement::Accessor<float> idScoreIso("rnn_isolationScore");
   static const SG::AuxElement::Accessor<float> idScoreConv("rnn_conversionScore");
   static const SG::AuxElement::Accessor<float> idScoreFake("rnn_fakeScore");
 
+  // don't classify tracks, set default decorations
+  if(skipTracks) {
+    for(xAOD::TauTrack* track : vTracks) {
+      idScoreCharged(*track) = 0.;
+      idScoreConv(*track) = 0.;
+      idScoreIso(*track) = 0.;
+      idScoreFake(*track) = 0.;
+    }
+    return StatusCode::SUCCESS;
+  }
+
+  std::sort(vTracks.begin(), vTracks.end(), [](const xAOD::TauTrack * a, const xAOD::TauTrack * b) {return a->pt() > b->pt();});
+
   VectorMap valueMap;
   ATH_CHECK(calulateVars(vTracks, xTau, valueMap));
 
@@ -167,11 +200,15 @@ StatusCode TrackRNN::classifyTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD:
     idScoreIso(*vTracks[i]) = vClassProb[2];
     idScoreFake(*vTracks[i]) = vClassProb[3];
 
-    int iMaxIndex = 3; // for safty reasons set this to FT to circumvent bias
+    int iMaxIndex = 3; // for safety reasons set this to FT to circumvent bias
     for (unsigned int j = 0; j < vClassProb.size(); ++j){
       if(vClassProb[j] > vClassProb[iMaxIndex]) iMaxIndex = j;
     }
 
+    if(iMaxIndex < 4) {
+      vTracks[i]->setFlag(xAOD::TauJetParameters::unclassified, false);
+    }
+
     if(iMaxIndex == 3){
       vTracks[i]->setFlag(xAOD::TauJetParameters::classifiedFake, true);
     }else if(iMaxIndex == 0){
@@ -180,8 +217,6 @@ StatusCode TrackRNN::classifyTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD:
       vTracks[i]->setFlag(xAOD::TauJetParameters::classifiedConversion, true);
     }else if(iMaxIndex == 2){
       vTracks[i]->setFlag(xAOD::TauJetParameters::classifiedIsolation, true);
-    }else if(iMaxIndex == 4){
-      vTracks[i]->setFlag(xAOD::TauJetParameters::unclassified, true);
     }
   }
   
@@ -198,10 +233,10 @@ StatusCode TrackRNN::classifyTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD:
 //______________________________________________________________________________
 StatusCode TrackRNN::addWeightsFile()
 {
-  std::string sInputWeightsPath = find_file(m_sInputWeightsPath);
-  ATH_MSG_DEBUG("InputWeightsPath: " << sInputWeightsPath);
+  std::string inputWeightsPath = find_file(m_inputWeightsPath);
+  ATH_MSG_DEBUG("InputWeightsPath: " << inputWeightsPath);
 
-  std::ifstream nn_config_istream(sInputWeightsPath);
+  std::ifstream nn_config_istream(inputWeightsPath);
   
   lwtDev::GraphConfig NNconfig = lwtDev::parse_json_graph(nn_config_istream);
   
@@ -221,8 +256,9 @@ StatusCode TrackRNN::calulateVars(const std::vector<xAOD::TauTrack*>& vTracks, c
   // initialize map with values
   valueMap.clear();
   unsigned int n_timeSteps = vTracks.size();
-  if(m_nMaxNtracks > 0 && n_timeSteps>m_nMaxNtracks)
+  if(m_nMaxNtracks > 0 && n_timeSteps > m_nMaxNtracks) {
     n_timeSteps = m_nMaxNtracks;
+  }
 
   valueMap["log(trackPt)"] = std::vector<double>(n_timeSteps);
   valueMap["log(jetSeedPt)"] = std::vector<double>(n_timeSteps);
@@ -270,9 +306,6 @@ StatusCode TrackRNN::calulateVars(const std::vector<xAOD::TauTrack*>& vTracks, c
       uint8_t iTracksNSCTDeadSensors = 0; ATH_CHECK( xTrackParticle->summaryValue(iTracksNSCTDeadSensors, xAOD::numberOfSCTDeadSensors) );
       uint8_t iTracksNTRTHighThresholdHits = 0; ATH_CHECK( xTrackParticle->summaryValue( iTracksNTRTHighThresholdHits, xAOD::numberOfTRTHighThresholdHits) );
       uint8_t iTracksNTRTHits = 0; ATH_CHECK( xTrackParticle->summaryValue( iTracksNTRTHits, xAOD::numberOfTRTHits) );
-      //uint8_t iNumberOfContribPixelLayers = 0; ATH_CHECK( xTrackParticle->summaryValue(iNumberOfContribPixelLayers, xAOD::numberOfContribPixelLayers) );
-      //uint8_t iNumberOfPixelHoles = 0; ATH_CHECK( xTrackParticle->summaryValue(iNumberOfPixelHoles, xAOD::numberOfPixelHoles) );
-      //uint8_t iNumberOfSCTHoles = 0; ATH_CHECK( xTrackParticle->summaryValue(iNumberOfSCTHoles, xAOD::numberOfSCTHoles) );
 
       float fTracksEProbabilityHT; ATH_CHECK( xTrackParticle->summaryValue( fTracksEProbabilityHT, xAOD::eProbabilityHT) );
   
@@ -296,8 +329,9 @@ StatusCode TrackRNN::calulateVars(const std::vector<xAOD::TauTrack*>& vTracks, c
       valueMap["charge"][i] = fTrackCharge;
 
       ++i;
-      if(m_nMaxNtracks > 0 && i >= m_nMaxNtracks)
+      if(m_nMaxNtracks > 0 && i >= m_nMaxNtracks) {
 	break;
+      }
     }
 
   return StatusCode::SUCCESS;
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauJetRNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauJetRNNEvaluator.h
index 3275d6aa099..a4e13523e0f 100644
--- a/Reconstruction/tauRecTools/tauRecTools/TauJetRNNEvaluator.h
+++ b/Reconstruction/tauRecTools/tauRecTools/TauJetRNNEvaluator.h
@@ -57,6 +57,7 @@ private:
     std::size_t m_max_clusters;
     float m_max_cluster_dr;
     bool m_doVertexCorrection;
+    bool m_doTrackClassification;
 
     // Configuration of the weight file
     std::string m_input_layer_scalar;
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauTrackRNNClassifier.h b/Reconstruction/tauRecTools/tauRecTools/TauTrackRNNClassifier.h
index 82c8d0c7f7b..ea2152c22c5 100644
--- a/Reconstruction/tauRecTools/tauRecTools/TauTrackRNNClassifier.h
+++ b/Reconstruction/tauRecTools/tauRecTools/TauTrackRNNClassifier.h
@@ -49,7 +49,7 @@ public:
 
   ASG_TOOL_CLASS2( TauTrackRNNClassifier, TauRecToolBase, ITauToolBase )
 
-  TauTrackRNNClassifier(const std::string& sName="TauTrackRNNClassifier");
+  TauTrackRNNClassifier(const std::string& name="TauTrackRNNClassifier");
   ~TauTrackRNNClassifier();
 
   // retrieve all track classifier sub tools
@@ -59,6 +59,8 @@ public:
 
  private:
   ToolHandleArray<TrackRNN> m_vClassifier {this, "Classifiers", {}};
+  bool m_classifyLRT;
+
 }; // class TauTrackRNNClassifier
   
 //______________________________________________________________________________
@@ -72,7 +74,7 @@ class TrackRNN
   
   public:
   
-  TrackRNN(const std::string& sName);
+  TrackRNN(const std::string& name);
   ~TrackRNN();
 
   // configure the MVA object and build a general map to store variables
@@ -82,7 +84,7 @@ class TrackRNN
   
   // executes MVA object to get the BDT score, makes the decision and resets
   // classification flags
-  StatusCode classifyTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD::TauJet& xTau) const;
+  StatusCode classifyTracks(std::vector<xAOD::TauTrack*>& vTracks, xAOD::TauJet& xTau, bool skipTracks=false) const;
   
 private:
   // set BDT input variables in the corresponding map entries
@@ -99,7 +101,7 @@ private:
   
 private:
   // configurable variables
-  std::string m_sInputWeightsPath; 
+  std::string m_inputWeightsPath; 
   unsigned int m_nMaxNtracks;
 
 private:
diff --git a/Trigger/TrigAlgorithms/TrigTauRec/python/TrigTauAlgorithmsHolder.py b/Trigger/TrigAlgorithms/TrigTauRec/python/TrigTauAlgorithmsHolder.py
index 33662b5b7ee..e68a0738582 100644
--- a/Trigger/TrigAlgorithms/TrigTauRec/python/TrigTauAlgorithmsHolder.py
+++ b/Trigger/TrigAlgorithms/TrigTauRec/python/TrigTauAlgorithmsHolder.py
@@ -720,7 +720,7 @@ def getTauJetBDTEvaluator(suffix="TauJetBDT", weightsFile="", calibFolder="", mi
 ########################################################################
 # TauJetRNNEvaluator
 def getTauJetRNNEvaluator(NetworkFile0P="", NetworkFile1P="", NetworkFile3P="", OutputVarname="RNNJetScore", 
-                          MaxTracks=10, MaxClusters=6, MaxClusterDR=1.0, 
+                          MaxTracks=10, MaxClusters=6, MaxClusterDR=1.0, TrackClassification=False,
                           InputLayerScalar="scalar", InputLayerTracks="tracks", InputLayerClusters="clusters", 
                           OutputLayer="rnnid_output", OutputNode="sig_prob"):
 
@@ -731,20 +731,21 @@ def getTauJetRNNEvaluator(NetworkFile0P="", NetworkFile1P="", NetworkFile3P="",
 
     from AthenaCommon.AppMgr import ToolSvc
     from tauRecTools.tauRecToolsConf import TauJetRNNEvaluator
-    TauJetRNNEvaluator = TauJetRNNEvaluator(name=_name,
-                                      NetworkFile0P=NetworkFile0P,
-                                      NetworkFile1P=NetworkFile1P,
-                                      NetworkFile3P=NetworkFile3P,
-                                      OutputVarname=OutputVarname,
-                                      MaxTracks=MaxTracks,
-                                      MaxClusters=MaxClusters,
-                                      MaxClusterDR=MaxClusterDR,
-                                      VertexCorrection=doVertexCorrection,
-                                      InputLayerScalar=InputLayerScalar,
-                                      InputLayerTracks=InputLayerTracks,
-                                      InputLayerClusters=InputLayerClusters,
-                                      OutputLayer=OutputLayer,
-                                      OutputNode=OutputNode)
+    TauJetRNNEvaluator = TauJetRNNEvaluator(name = _name,
+                                            NetworkFile0P = NetworkFile0P,
+                                            NetworkFile1P = NetworkFile1P,
+                                            NetworkFile3P = NetworkFile3P,
+                                            OutputVarname = OutputVarname,
+                                            MaxTracks = MaxTracks,
+                                            MaxClusters = MaxClusters,
+                                            MaxClusterDR = MaxClusterDR,
+                                            VertexCorrection = doVertexCorrection,
+                                            TrackClassification = TrackClassification,
+                                            InputLayerScalar = InputLayerScalar,
+                                            InputLayerTracks = InputLayerTracks,
+                                            InputLayerClusters = InputLayerClusters,
+                                            OutputLayer = OutputLayer,
+                                            OutputNode = OutputNode)
 
     ToolSvc += TauJetRNNEvaluator
     cached_instances[_name] = TauJetRNNEvaluator
-- 
GitLab