From 3350b7f9dfa8dea7a58a54f2cb3c4f3a6b214dd8 Mon Sep 17 00:00:00 2001
From: Vakho Tsulaia <vakhtang.tsulaia@cern.ch>
Date: Fri, 17 Dec 2021 01:08:48 +0100
Subject: [PATCH] Implemented CA-based configuration for AthenaMP

Introduced AthenaMPConfig.py script with the CA-based configuration for AthenaMP.
Also introduced a bunch of MP-specific configuration flags, and a new flag
for setting MaxFilesOpen property of PoolSvc, which needs to be modified
when AthenaMP runs with the Shared Reader
---
 .../python/AllConfigFlags.py                  |  26 +++
 .../python/AthConfigFlags.py                  |   4 +
 .../python/MainServicesConfig.py              |   7 +
 Control/AthenaMP/python/AthenaMPConfig.py     | 217 ++++++++++++++++++
 .../AthenaPoolCnvSvc/python/PoolReadConfig.py |   2 +-
 Database/IOVDbSvc/python/IOVDbSvcConfig.py    |   2 +-
 6 files changed, 256 insertions(+), 2 deletions(-)
 create mode 100644 Control/AthenaMP/python/AthenaMPConfig.py

diff --git a/Control/AthenaConfiguration/python/AllConfigFlags.py b/Control/AthenaConfiguration/python/AllConfigFlags.py
index cd4e9ca8e4bb..1871b5fb00a4 100644
--- a/Control/AthenaConfiguration/python/AllConfigFlags.py
+++ b/Control/AthenaConfiguration/python/AllConfigFlags.py
@@ -67,6 +67,7 @@ def _createCfgFlags():
     acf.addFlag('Concurrency.NumProcs', 0)
     acf.addFlag('Concurrency.NumThreads', 0 )
     acf.addFlag('Concurrency.NumConcurrentEvents', lambda prevFlags : prevFlags.Concurrency.NumThreads)
+    acf.addFlag('Concurrency.DebugWorkers', False )
 
     acf.addFlag('Scheduler.CheckDependencies', True)
     acf.addFlag('Scheduler.ShowDataDeps', True)
@@ -74,6 +75,28 @@ def _createCfgFlags():
     acf.addFlag('Scheduler.ShowControlFlow', True)
     acf.addFlag('Scheduler.EnableVerboseViews', True)
 
+    acf.addFlag('MP.WorkerTopDir', 'athenaMP_workers')
+    acf.addFlag('MP.OutputReportFile', 'AthenaMPOutputs')
+    acf.addFlag('MP.Strategy', 'SharedQueue')
+    acf.addFlag('MP.CollectSubprocessLogs', False)
+    acf.addFlag('MP.PollingInterval', 100)
+    acf.addFlag('MP.EventsBeforeFork', 0)
+    acf.addFlag('MP.EventRangeChannel', 'EventService_EventRanges')
+    acf.addFlag('MP.EvtRangeScattererCaching', False)
+    acf.addFlag('MP.MemSamplingInterval', 0)
+    """ Size of event chunks in the shared queue
+        if chunk_size==-1, chunk size is set to auto_flush for files compressed with LZMA
+        if chunk_size==-2, chunk size is set to auto_flush for files compressed with LZMA or ZLIB
+        if chunk_size==-3, chunk size is set to auto_flush for files compressed with LZMA, ZLIB, or LZ4
+        if chunk_size<=-4, chunk size is set to auto_flush
+    """
+    acf.addFlag('MP.ChunkSize', -1)
+    acf.addFlag('MP.ReadEventOrders', False)
+    acf.addFlag('MP.EventOrdersFile', 'athenamp_eventorders.txt')
+    acf.addFlag('MP.UseSharedReader', False)
+    acf.addFlag('MP.UseSharedWriter', False)
+    acf.addFlag('MP.UseParallelCompression', True)
+
     acf.addFlag('Common.MsgSourceLength',50) #Length of the source-field in the format str of MessageSvc
     acf.addFlag('Common.isOnline', False ) #  Job runs in an online environment (access only to resources available at P1) # former global.isOnline
     acf.addFlag('Common.useOnlineLumi', lambda prevFlags : prevFlags.Common.isOnline ) #  Use online version of luminosity. ??? Should just use isOnline?
@@ -170,6 +193,9 @@ def _createCfgFlags():
     acf.addFlag("IOVDb.RunToTimestampDict", lambda prevFlags: getRunToTimestampDict())
     acf.addFlag("IOVDb.DBConnection", lambda prevFlags : "sqlite://;schema=mycool.db;dbname=" + prevFlags.IOVDb.DatabaseInstance)
 
+#PoolSvc Flags:
+    acf.addFlag("PoolSvc.MaxFilesOpen", lambda prevFlags : 2 if prevFlags.MP.UseSharedReader else 0)
+
 
     def __bfield():
         from MagFieldConfig.BFieldConfigFlags import createBFieldConfigFlags
diff --git a/Control/AthenaConfiguration/python/AthConfigFlags.py b/Control/AthenaConfiguration/python/AthConfigFlags.py
index 619b2963d378..fd89e8104f1e 100644
--- a/Control/AthenaConfiguration/python/AthConfigFlags.py
+++ b/Control/AthenaConfiguration/python/AthConfigFlags.py
@@ -404,6 +404,7 @@ class AthConfigFlags(object):
         parser.add_argument("-l", "--loglevel", default=None, help="logging level (ALL, VERBOSE, DEBUG,INFO, WARNING, ERROR, or FATAL")
         parser.add_argument("--configOnly", type=str, default=None, help="Stop after configuration phase (may not be respected by all diver scripts)")
         parser.add_argument("--threads", type=int, default=0, help="Run with given number of threads")
+        parser.add_argument("--nprocs", type=int, default=0, help="Run AthenaMP with given number of worker processes")
 
         return parser
 
@@ -447,6 +448,9 @@ class AthConfigFlags(object):
         if args.threads:
             self.Concurrency.NumThreads = args.threads
 
+        if args.nprocs:
+            self.Concurrency.NumProcs = args.nprocs
+
         #All remaining arguments are assumed to be key=value pairs to set arbitrary flags:
 
 
diff --git a/Control/AthenaConfiguration/python/MainServicesConfig.py b/Control/AthenaConfiguration/python/MainServicesConfig.py
index 0ad0b8a6e89a..4874247fa706 100644
--- a/Control/AthenaConfiguration/python/MainServicesConfig.py
+++ b/Control/AthenaConfiguration/python/MainServicesConfig.py
@@ -29,6 +29,9 @@ def MainServicesCfg(cfgFlags, LoopMgr='AthenaEventLoopMgr'):
             raise Exception("Requested Concurrency.NumThreads>0 and Concurrency.NumConcurrentEvents==0, which will not process events!")
         LoopMgr = "AthenaHiveEventLoopMgr"
 
+    if cfgFlags.Concurrency.NumProcs>0:
+        LoopMgr = "AthMpEvtLoopMgr"
+
     ########################################################################
     # Core components needed for serial and threaded jobs
     cfg=MainServicesMiniCfg(loopMgr=LoopMgr, masterSequence='AthMasterSeq')
@@ -93,6 +96,10 @@ def MainServicesCfg(cfgFlags, LoopMgr='AthenaEventLoopMgr'):
     if cfgFlags.Exec.DebugStage != "":
         cfg.setDebugStage(cfgFlags.Exec.DebugStage)
 
+    if cfgFlags.Concurrency.NumProcs>0:
+        from AthenaMP.AthenaMPConfig import AthenaMPCfg
+        mploop = AthenaMPCfg(cfgFlags)
+        cfg.merge(mploop)
 
     ########################################################################
     # Additional components needed for threaded jobs only
diff --git a/Control/AthenaMP/python/AthenaMPConfig.py b/Control/AthenaMP/python/AthenaMPConfig.py
new file mode 100644
index 000000000000..4cc535807b7c
--- /dev/null
+++ b/Control/AthenaMP/python/AthenaMPConfig.py
@@ -0,0 +1,217 @@
+# Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentFactory import CompFactory
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.MainServicesConfig import MainServicesCfg
+from AthenaConfiguration.AllConfigFlags import ConfigFlags, GetFileMD
+from AthenaConfiguration.Enums import ProductionStep
+
+from AthenaCommon.Logging import log as msg
+
+import os, shutil
+
+def AthenaMPCfg(configFlags):
+
+    os.putenv('XRD_ENABLEFORKHANDLERS','1')
+    os.putenv('XRD_RUNFORKHANDLER','1')
+
+    result=ComponentAccumulator()
+
+    # Configure MP Event Loop Manager
+    AthMpEvtLoopMgr=CompFactory.AthMpEvtLoopMgr
+    mpevtloop = AthMpEvtLoopMgr()
+
+    mpevtloop.NWorkers=configFlags.Concurrency.NumProcs
+    mpevtloop.Strategy=configFlags.MP.Strategy
+    mpevtloop.WorkerTopDir = configFlags.MP.WorkerTopDir
+    mpevtloop.OutputReportFile = configFlags.MP.OutputReportFile
+    mpevtloop.CollectSubprocessLogs = configFlags.MP.CollectSubprocessLogs
+    mpevtloop.PollingInterval = configFlags.MP.PollingInterval
+    mpevtloop.MemSamplingInterval = configFlags.MP.MemSamplingInterval
+    mpevtloop.IsPileup = True if configFlags.Common.ProductionStep in [ProductionStep.PileUpPresampling, ProductionStep.Overlay] else False
+    mpevtloop.EventsBeforeFork = 0 if configFlags.MP.Strategy == 'EventService' else configFlags.MP.EventsBeforeFork
+
+    # Configure Gaudi File Manager
+    filemgr = CompFactory.FileMgr(LogFile="FileManagerLog")
+    result.addService(filemgr)
+
+    # Save PoolFileCatalog.xml if exists in the run directory
+    # The saved file will be copied over to workers' run directories just after forking
+    if os.path.isfile('PoolFileCatalog.xml'):
+        shutil.copyfile('PoolFileCatalog.xml','PoolFileCatalog.xml.AthenaMP-saved')
+
+    # Compute event chunk size
+    chunk_size = getChunkSize(configFlags)
+
+    # Configure Strategy
+    debug_worker = configFlags.Concurrency.DebugWorkers
+    event_range_channel = configFlags.MP.EventRangeChannel
+    use_shared_reader = configFlags.MP.UseSharedReader
+    use_shared_writer = configFlags.MP.UseSharedWriter
+    use_parallel_compression = configFlags.MP.UseParallelCompression
+
+    if configFlags.MP.Strategy=='SharedQueue' or configFlags.MP.Strategy=='RoundRobin':
+        if use_shared_reader:
+            AthenaSharedMemoryTool = CompFactory.AthenaSharedMemoryTool
+
+            if configFlags.Input.Format == 'BS':
+                evSel=CompFactory.EventSelectorByteStream("EventSelector")
+
+                from ByteStreamCnvSvc.ByteStreamConfig import ByteStreamReadCfg
+                bscfg = ByteStreamReadCfg(configFlags)
+                result.merge(bscfg)
+            else:
+                evSel=CompFactory.EventSelectorAthenaPool("EventSelector")
+
+                # AthenaPoolCnvSvc
+                apcs=CompFactory.AthenaPoolCnvSvc()
+                apcs.InputStreamingTool = AthenaSharedMemoryTool("InputStreamingTool",
+                                                                 SharedMemoryName="InputStream"+str(os.getpid()),
+                                                                 UseMultipleSegments=True)
+                result.addService(apcs)
+
+                from AthenaPoolCnvSvc.PoolReadConfig import PoolReadCfg
+                poolcfg = PoolReadCfg(configFlags)
+
+                result.merge(poolcfg)
+
+            evSel.SharedMemoryTool = AthenaSharedMemoryTool("EventStreamingTool",
+                                                            SharedMemoryName="EventStream"+str(os.getpid()))
+            result.addService(evSel)
+
+        if use_shared_writer:
+            if any((configFlags.Output.doWriteESD,
+                    configFlags.Output.doWriteAOD,
+                    configFlags.Output.doWriteRDO)) or configFlags.Output.HITSFileName!='':
+                AthenaSharedMemoryTool = CompFactory.AthenaSharedMemoryTool
+
+                apcs=CompFactory.AthenaPoolCnvSvc()
+                apcs.OutputStreamingTool += [ AthenaSharedMemoryTool("OutputStreamingTool_0",
+                                                                     SharedMemoryName="OutputStream"+str(os.getpid())) ]
+                apcs.ParallelCompression = use_parallel_compression
+                result.addService(apcs)
+
+                from AthenaPoolCnvSvc.PoolWriteConfig import PoolWriteCfg
+                poolcfg = PoolWriteCfg(configFlags)
+
+                result.merge(poolcfg)
+
+        queue_provider = CompFactory.SharedEvtQueueProvider(UseSharedReader=use_shared_reader,
+                                                            IsPileup=mpevtloop.IsPileup,
+                                                            EventsBeforeFork=mpevtloop.EventsBeforeFork,
+                                                            ChunkSize=chunk_size)
+        if configFlags.Concurrency.NumThreads > 0:
+            if mpevtloop.IsPileup:
+                raise Exception('Running pileup digitization in mixed MP+MT currently not supported')
+            queue_consumer = CompFactory.SharedEvtQueueConsumer(UseSharedWriter=use_shared_writer,
+                                                                EventsBeforeFork=mpevtloop.EventsBeforeFork,
+                                                                Debug=debug_worker)
+        else:
+            queue_consumer = CompFactory.SharedEvtQueueConsumer(UseSharedReader=use_shared_reader,
+                                                                UseSharedWriter=use_shared_writer,
+                                                                IsPileup=mpevtloop.IsPileup,
+                                                                IsRoundRobin=(configFlags.MP.Strategy=='RoundRobin'),
+                                                                EventsBeforeFork=mpevtloop.EventsBeforeFork,
+                                                                ReadEventOrders=configFlags.MP.ReadEventOrders,
+                                                                EventOrdersFile=configFlags.MP.EventOrdersFile,
+                                                                Debug=debug_worker)
+        mpevtloop.Tools += [ queue_provider, queue_consumer ]
+
+        if use_shared_writer:
+            shared_writer = CompFactory.SharedWriterTool(MotherProcess=(mpevtloop.EventsBeforeFork>0))
+            mpevtloop.Tools += [ shared_writer ]
+
+    elif configFlags.MP.Strategy=='FileScheduling':
+        mpevtloop.Tools += [ CompFactory.FileSchedulingTool(IsPileup=mpevtloop.IsPileup,
+                                                            Debug=debug_worker) ]
+
+    elif configFlags.MP.Strategy=='EventService':
+        channelScatterer2Processor = "AthenaMP_Scatterer2Processor"
+        channelProcessor2EvtSel = "AthenaMP_Processor2EvtSel"
+
+        mpevtloop.Tools += [ CompFactory.EvtRangeScatterer(ProcessorChannel = channelScatterer2Processor,
+                                                           EventRangeChannel = event_range_channel,
+                                                           DoCaching=configFlags.MP.EvtRangeScattererCaching) ]
+        mpevtloop.Tools += [ CompFactory.vtRangeProcessor(IsPileup=mpevtloop.IsPileup,
+                                                          Channel2Scatterer = channelScatterer2Processor,
+                                                          Channel2EvtSel = channelProcessor2EvtSel,
+                                                          Debug=debug_worker) ]
+
+    else:
+        msg.warning("Unknown strategy %s. No MP tools will be configured", configFlags.MP.Strategy)
+
+    result.addService(mpevtloop, primary=True)
+
+    return result
+
+def getChunkSize(configFlags) -> int:
+    chunk_size = 1
+    if configFlags.MP.ChunkSize > 0:
+        chunk_size = configFlags.MP.ChunkSize
+        msg.info('Chunk size set to %i', chunk_size)
+    elif configFlags.Input.Files != ["_ATHENA_GENERIC_INPUTFILE_NAME_"]:
+        md = GetFileMD(configFlags.Input.Files)
+        #Don't use auto flush for shared reader
+        if configFlags.MP.UseSharedReader:
+            msg.info('Shared Reader in use, chunk_size set to default (%i)', chunk_size)
+        #Use auto flush only if file is compressed with LZMA, else use default chunk_size
+        elif configFlags.MP.ChunkSize == -1:
+            if md.get('file_comp_alg',-1) == 2:
+                chunk_size = md.get('auto_flush',-1)
+                msg.info('Chunk size set to auto flush (%i)', chunk_size)
+            else:
+                msg.info('LZMA algorithm not in use, chunk_size set to default (%i)', chunk_size)
+        #Use auto flush only if file is compressed with LZMA or ZLIB, else use default chunk_size
+        elif configFlags.MP.ChunkSize == -2:
+            if md.get('file_comp_alg',-1) in [1,2]:
+                chunk_size = md.get('auto_flush',-1)
+                msg.info('Chunk size set to auto flush (%i)', chunk_size)
+            else:
+                msg.info('LZMA nor ZLIB in use, chunk_size set to default (%i)', chunk_size)
+                #Use auto flush only if file is compressed with LZMA, ZLIB or LZ4, else use default chunk_size
+        elif configFlags.MP.ChunkSize == -3:
+            if md.get('file_comp_alg',-1) in [1,2,4]:
+                chunk_size = md.get('auto_flush',-1)
+                msg.info('Chunk size set to auto flush (%i)', chunk_size)
+            else:
+                msg.info('LZMA, ZLIB nor LZ4 in use, chunk_size set to (%i)', chunk_size)
+        #Use auto flush value for chunk_size, regarldess of compression algorithm
+        elif configFlags.MPChunkSize <= -4:
+            chunk_size = md.get('auto_flush',-1)
+            msg.info('Chunk size set to auto flush (%i)', chunk_size)
+        else:
+            msg.warning('Invalid ChunkSize, Chunk Size set to default (%i)', chunk_size)
+
+    return chunk_size
+
+
+if __name__=="__main__":
+
+    # -----------------  Hello World Example ------------------
+    # ConfigFlags.Exec.MaxEvents=10
+    # ConfigFlags.Concurrency.NumProcs=2
+
+    # cfg=MainServicesCfg(ConfigFlags)
+
+    # from AthExHelloWorld.HelloWorldConfig import HelloWorldCfg
+    # cfg.merge(HelloWorldCfg())
+
+    # cfg.run()
+    # -----------------  Hello World Example ------------------
+
+    # -----------------  Example with input file --------------
+    from AthenaConfiguration.TestDefaults import defaultTestFiles
+    ConfigFlags.Input.Files = defaultTestFiles.ESD
+    ConfigFlags.Exec.MaxEvents=10
+    ConfigFlags.Concurrency.NumProcs=2
+
+    from AthenaCommon.Configurable import Configurable
+    Configurable.configurableRun3Behavior=1
+
+    cfg=MainServicesCfg(ConfigFlags)
+    from AthenaPoolCnvSvc.PoolReadConfig import EventSelectorAthenaPoolCfg
+    cfg.merge(EventSelectorAthenaPoolCfg(ConfigFlags))
+    cfg.run()
+    # -----------------  Example with input file --------------
+
+    msg.info('All OK!')
diff --git a/Database/AthenaPOOL/AthenaPoolCnvSvc/python/PoolReadConfig.py b/Database/AthenaPOOL/AthenaPoolCnvSvc/python/PoolReadConfig.py
index af2daac390a4..417f3b603038 100644
--- a/Database/AthenaPOOL/AthenaPoolCnvSvc/python/PoolReadConfig.py
+++ b/Database/AthenaPOOL/AthenaPoolCnvSvc/python/PoolReadConfig.py
@@ -66,7 +66,7 @@ def PoolReadCfg(configFlags):
     
     StoreGateSvc=CompFactory.StoreGateSvc
 
-    result.addService(PoolSvc(MaxFilesOpen=0))
+    result.addService(PoolSvc(MaxFilesOpen=configFlags.PoolSvc.MaxFilesOpen))
     apcs=AthenaPoolCnvSvc()
     apcs.InputPoolAttributes += ["DatabaseName = '*'; ContainerName = 'CollectionTree'; TREE_CACHE = '-1'"]
     result.addService(apcs)
diff --git a/Database/IOVDbSvc/python/IOVDbSvcConfig.py b/Database/IOVDbSvc/python/IOVDbSvcConfig.py
index e982b04edd5d..e689ec714809 100644
--- a/Database/IOVDbSvc/python/IOVDbSvcConfig.py
+++ b/Database/IOVDbSvc/python/IOVDbSvcConfig.py
@@ -54,7 +54,7 @@ def IOVDbSvcCfg(configFlags):
     
     PoolSvc=CompFactory.PoolSvc
     poolSvc=PoolSvc()
-    poolSvc.MaxFilesOpen=0
+    poolSvc.MaxFilesOpen=configFlags.PoolSvc.MaxFilesOpen
     poolSvc.ReadCatalog=["apcfile:poolcond/PoolFileCatalog.xml",
                          "apcfile:poolcond/PoolCat_oflcond.xml",
                          ]
-- 
GitLab