From d389879de004fbbd0f8d85a0dfc43db8cf96d5ee Mon Sep 17 00:00:00 2001
From: Dan Guest <daniel.hay.guest@cern.ch>
Date: Fri, 2 Apr 2021 18:17:46 +0000
Subject: [PATCH] Restructure FtagRun3DerivationConfig slightly

The idea is to return a component accumulator for each jet collection, so that
we can merge all of them together. There are still some points that need to be
cleaned up a bit, e.g. the calib service is just scheduled once, not in the
accumulator. Also it crashes right now...
---
 .../python/FtagRun3DerivationConfig.py        | 140 +++++++++++-------
 .../share/PHYSVAL.py                          |   4 +-
 .../BTagging/python/HighLevelBTagAlgConfig.py |   2 +-
 3 files changed, 87 insertions(+), 59 deletions(-)

diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkFlavourTag/python/FtagRun3DerivationConfig.py b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkFlavourTag/python/FtagRun3DerivationConfig.py
index 82be2e3127dc..62b9af5c61f3 100644
--- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkFlavourTag/python/FtagRun3DerivationConfig.py
+++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkFlavourTag/python/FtagRun3DerivationConfig.py
@@ -5,20 +5,46 @@ from AthenaCommon.AthenaCommonFlags import jobproperties as jps
 
 from GaudiKernel.Configurable import WARNING
 
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaCommon.Configurable import Configurable
+from AthenaConfiguration.ComponentAccumulator import conf2toConfigurable
+from AthenaConfiguration.ComponentFactory import CompFactory
 
+# for backward compatability
+def FtagJetCollection(jetcol, seq, OutputLevel=WARNING):
+    FtagJetCollections([jetcol], seq, OutputLevel)
 
+# this should be able to tag a few collections
+def FtagJetCollections(jetcols, seq, OutputLevel=WARNING):
 
-def FtagJetCollection(jetcol, seq, OutputLevel=WARNING):
-    
+    Configurable.configurableRun3Behavior=1
+    from AthenaConfiguration.AllConfigFlags import ConfigFlags as cfgFlags
 
-    from AthenaCommon.AppMgr import athCondSeq
+    taggerlist = ['IP2D', 'IP3D', 'SV1', 'SoftMu']
 
-    from AthenaCommon.Configurable import Configurable
+    setupCondDb(cfgFlags, taggerlist)
+
+    acc = ComponentAccumulator()
+
+    if 'AntiKt4EMTopoJets' in jetcols:
+        acc.merge(RenameInputContainerEmTopoHacksCfg('oldAODVersion'))
+
+    for jetcol in jetcols:
+        acc.merge(getFtagComponent(cfgFlags, jetcol, taggerlist, OutputLevel))
+
+    Configurable.configurableRun3Behavior=0
+    algs = findAllAlgorithms(acc.getSequence("AthAlgSeq"))
+    for alg in algs:
+
+        seq += conf2toConfigurable(alg)
+
+    acc.wasMerged()
+
+
+# this returns a component accumulator, which is merged across jet
+# collections in FtagJetCollections above
+def getFtagComponent(cfgFlags, jetcol, taggerlist, OutputLevel=WARNING):
 
-    Configurable.configurableRun3Behavior=1
-    from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
-    from AthenaConfiguration.ComponentFactory import CompFactory
-    from AthenaConfiguration.ComponentAccumulator import conf2toConfigurable
     from BTagging.JetParticleAssociationAlgConfig import JetParticleAssociationAlgCfg
     from BTagging.JetBTaggingAlgConfig import JetBTaggingAlgCfg
     from BTagging.JetSecVertexingAlgConfig import JetSecVertexingAlgCfg
@@ -26,8 +52,6 @@ def FtagJetCollection(jetcol, seq, OutputLevel=WARNING):
     from BTagging.BTagTrackAugmenterAlgConfig import BTagTrackAugmenterAlgCfg
     from BTagging.BTagHighLevelAugmenterAlgConfig import BTagHighLevelAugmenterAlgCfg
     from BTagging.HighLevelBTagAlgConfig import HighLevelBTagAlgCfg
-    from AthenaConfiguration.AllConfigFlags import ConfigFlags as cfgFlags
-
 
     jetcol_name_without_Jets = jetcol.replace('Jets','')
     BTaggingCollection = cfgFlags.BTagging.OutputFiles.Prefix + jetcol_name_without_Jets
@@ -35,60 +59,35 @@ def FtagJetCollection(jetcol, seq, OutputLevel=WARNING):
     kwargs = {}
     kwargs['Release'] = '22'
 
-
     cfgFlags.Input.Files = jps.AthenaCommonFlags.FilesInput.get_Value()
 
-    
     acc = ComponentAccumulator()
 
-
-    taggerlist = ['IP2D', 'IP3D', 'SV1', 'SoftMu']
-
-
-    CalibrationChannelAliases = ["AntiKt4EMPFlow->AntiKt4EMPFlow,AntiKt4EMTopo,AntiKt4TopoEM,AntiKt4LCTopo"]
-
-    grades= cfgFlags.BTagging.Grades
-
-    RNNIPConfig = {'rnnip':''}
-
-    JetTagCalibCondAlg=CompFactory.Analysis.JetTagCalibCondAlg
-    jettagcalibcondalg = "JetTagCalibCondAlg"
-    readkeycalibpath = "/GLOBAL/BTagCalib/RUN12"
-    connSchema = "GLOBAL_OFL"
-    if not cfgFlags.Input.isMC:
-        readkeycalibpath = readkeycalibpath.replace("/GLOBAL/BTagCalib","/GLOBAL/Onl/BTagCalib")
-        connSchema = "GLOBAL"
-    histoskey = "JetTagCalibHistosKey"
-    from IOVDbSvc.CondDB import conddb
-
-    conddb.addFolder(connSchema, readkeycalibpath, className='CondAttrListCollection')
-    JetTagCalib = JetTagCalibCondAlg(jettagcalibcondalg, ReadKeyCalibPath=readkeycalibpath, HistosKey = histoskey, taggers = taggerlist,
-        channelAliases = CalibrationChannelAliases, IP2D_TrackGradePartitions = grades, RNNIP_NetworkConfig = RNNIPConfig)
-
-    athCondSeq+=conf2toConfigurable( JetTagCalib, indent="  " )
-    
     acc.merge(JetParticleAssociationAlgCfg(cfgFlags, jetcol_name_without_Jets, "InDetTrackParticles", 'BTagTrackToJetAssociator', **kwargs))
 
     SecVertexingAndAssociators = {'JetFitter':'BTagTrackToJetAssociator','SV1':'BTagTrackToJetAssociator'}
     for k, v in SecVertexingAndAssociators.items():
 
         acc.merge(JetSecVtxFindingAlgCfg(cfgFlags, jetcol_name_without_Jets, "PrimaryVertices", k, v))
-    
+
         acc.merge(JetSecVertexingAlgCfg(cfgFlags, BTaggingCollection, jetcol_name_without_Jets, "PrimaryVertices", k, v))
 
-    
+
     acc.merge( JetBTaggingAlgCfg(cfgFlags, BTaggingCollection = BTaggingCollection, JetCollection = jetcol_name_without_Jets, PrimaryVertexCollectionName="PrimaryVertices", TaggerList = taggerlist, SVandAssoc = SecVertexingAndAssociators) )
-    
 
 
     postTagDL2JetToTrainingMap={
         'AntiKt4EMPFlow': [
-        #'BTagging/201903/smt/antikt4empflow/network.json',
-        'BTagging/201903/rnnip/antikt4empflow/network.json',
-        'BTagging/201903/dl1r/antikt4empflow/network.json',
-        'BTagging/201903/dl1/antikt4empflow/network.json',
-        #'BTagging/201903/dl1rmu/antikt4empflow/network.json',
+            'BTagging/201903/rnnip/antikt4empflow/network.json',
+            'BTagging/201903/dl1r/antikt4empflow/network.json',
+            'BTagging/201903/dl1/antikt4empflow/network.json',
+        ],
+        'AntiKt4EMTopo': [
+            'BTagging/201903/rnnip/antikt4empflow/network.json',
+            'BTagging/201903/dl1r/antikt4empflow/network.json',
+            'BTagging/201903/dl1/antikt4empflow/network.json',
         ]
+
     }
 
     acc.merge(BTagTrackAugmenterAlgCfg(cfgFlags))
@@ -98,23 +97,52 @@ def FtagJetCollection(jetcol, seq, OutputLevel=WARNING):
     for jsonfile in postTagDL2JetToTrainingMap[jetcol_name_without_Jets]:
         acc.merge(HighLevelBTagAlgCfg(cfgFlags, BTaggingCollection=BTaggingCollection, TrackCollection='InDetTrackParticles', NNFile=jsonfile) )
 
+    return acc
 
-    Configurable.configurableRun3Behavior=0
-
-
-    algs = findAllAlgorithms(acc.getSequence("AthAlgSeq"))
-    
-    for alg in algs:
+# this probably only has to happen once
+def setupCondDb(cfgFlags, taggerlist):
+    from AthenaCommon.AppMgr import athCondSeq
+    CalibrationChannelAliases = ["AntiKt4EMPFlow->AntiKt4EMPFlow,AntiKt4EMTopo,AntiKt4TopoEM,AntiKt4LCTopo"]
+    grades= cfgFlags.BTagging.Grades
+    RNNIPConfig = {'rnnip':''}
 
-        seq += conf2toConfigurable(alg)
+    JetTagCalibCondAlg=CompFactory.Analysis.JetTagCalibCondAlg
+    jettagcalibcondalg = "JetTagCalibCondAlg"
+    readkeycalibpath = "/GLOBAL/BTagCalib/RUN12"
+    connSchema = "GLOBAL_OFL"
+    if not cfgFlags.Input.isMC:
+        readkeycalibpath = readkeycalibpath.replace("/GLOBAL/BTagCalib","/GLOBAL/Onl/BTagCalib")
+        connSchema = "GLOBAL"
+    histoskey = "JetTagCalibHistosKey"
+    from IOVDbSvc.CondDB import conddb
 
-    acc.wasMerged()
+    conddb.addFolder(connSchema, readkeycalibpath, className='CondAttrListCollection')
+    JetTagCalib = JetTagCalibCondAlg(jettagcalibcondalg, ReadKeyCalibPath=readkeycalibpath, HistosKey = histoskey, taggers = taggerlist,
+        channelAliases = CalibrationChannelAliases, IP2D_TrackGradePartitions = grades, RNNIP_NetworkConfig = RNNIPConfig)
 
-    
-    return
+    athCondSeq+=conf2toConfigurable( JetTagCalib, indent="  " )
 
 
 
+# Valerio's magic hacks for emtopo
+def RenameInputContainerEmTopoHacksCfg(suffix):
 
+    acc=ComponentAccumulator()
 
+    #Delete BTagging container read from input ESD
+    AddressRemappingSvc, ProxyProviderSvc=CompFactory.getComps("AddressRemappingSvc","ProxyProviderSvc",)
+    AddressRemappingSvc = AddressRemappingSvc("AddressRemappingSvc")
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::JetAuxContainer#AntiKt4EMTopoJets.BTagTrackToJetAssociator->AntiKt4EMTopoJets.BTagTrackToJetAssociator_' + suffix]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::JetAuxContainer#AntiKt4EMTopoJets.JFVtx->AntiKt4EMTopoJets.JFVtx_' + suffix]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::JetAuxContainer#AntiKt4EMTopoJets.SecVtx->AntiKt4EMTopoJets.SecVtx_' + suffix]
 
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::JetAuxContainer#AntiKt4EMTopoJets.btaggingLink->AntiKt4EMTopoJets.btaggingLink_' + suffix]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::BTaggingContainer#BTagging_AntiKt4EMTopo->BTagging_AntiKt4EMTopo_' + suffix]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::BTaggingAuxContainer#BTagging_AntiKt4EMTopoAux.->BTagging_AntiKt4EMTopo_' + suffix+"Aux."]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::VertexContainer#BTagging_AntiKt4EMTopoSecVtx->BTagging_AntiKt4EMTopoSecVtx_' + suffix]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::VertexAuxContainer#BTagging_AntiKt4EMTopoSecVtxAux.->BTagging_AntiKt4EMTopoSecVtx_' + suffix+"Aux."]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::BTagVertexContainer#BTagging_AntiKt4EMTopoJFVtx->BTagging_AntiKt4EMTopoJFVtx_' + suffix]
+    AddressRemappingSvc.TypeKeyRenameMaps += ['xAOD::BTagVertexAuxContainer#BTagging_AntiKt4EMTopoJFVtxAux.->BTagging_AntiKt4EMTopoJFVtx_' + suffix+"Aux."]
+    acc.addService(AddressRemappingSvc)
+    acc.addService(ProxyProviderSvc(ProviderNames = [ "AddressRemappingSvc" ]))
+    return acc
diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkPhysicsValidation/share/PHYSVAL.py b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkPhysicsValidation/share/PHYSVAL.py
index 1ed4ba6d8ccb..c69d2713927f 100644
--- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkPhysicsValidation/share/PHYSVAL.py
+++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkPhysicsValidation/share/PHYSVAL.py
@@ -161,8 +161,8 @@ SeqPHYSVAL += CfgMgr.DerivationFramework__DerivationKernel("PHYSVALKernel")
 # FLAVOUR TAGGING   
 #====================================================================
 
-from DerivationFrameworkFlavourTag.FtagRun3DerivationConfig import FtagJetCollection
-FtagJetCollection('AntiKt4EMPFlowJets',SeqPHYSVAL)
+from DerivationFrameworkFlavourTag.FtagRun3DerivationConfig import FtagJetCollections
+FtagJetCollections(['AntiKt4EMPFlowJets'],SeqPHYSVAL)
 
 
 #====================================================================
diff --git a/PhysicsAnalysis/JetTagging/JetTagAlgs/BTagging/python/HighLevelBTagAlgConfig.py b/PhysicsAnalysis/JetTagging/JetTagAlgs/BTagging/python/HighLevelBTagAlgConfig.py
index 2225616ba335..e7af664a77fd 100644
--- a/PhysicsAnalysis/JetTagging/JetTagAlgs/BTagging/python/HighLevelBTagAlgConfig.py
+++ b/PhysicsAnalysis/JetTagging/JetTagAlgs/BTagging/python/HighLevelBTagAlgConfig.py
@@ -46,7 +46,7 @@ def HighLevelBTagAlgCfg(ConfigFlags, BTaggingCollection, TrackCollection, NNFile
     options['BTaggingCollectionName'] = BTaggingCollection
     options['TrackContainer'] = TrackCollection
     options['JetDecorator'] = dl2
-    options['name'] = Name.lower()
+    options['name'] = '_'.join([Name.lower(), BTaggingCollection])
 
     # -- create the association algorithm
     acc.addEventAlgo(Analysis__HighLevelBTagAlg(**options))
-- 
GitLab