From 61a8f24f380b9db27b6b49ab0ca5708fb09a6434 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Mon, 8 Jan 2024 11:56:38 -0800
Subject: [PATCH 01/18] athena onnx package

---
 Control/AthOnnx/AthOnnxComps/CMakeLists.txt   |  17 ++
 .../src/ALTAS_CHECK_THREAD_SAFETY             |   1 +
 .../src/OnnxRuntimeSessionTool.cxx            | 195 ++++++++++++++++++
 .../AthOnnxComps/src/OnnxRuntimeSessionTool.h |  71 +++++++
 .../src/OnnxRuntimeSessionToolCUDA.cxx        |  58 ++++++
 .../src/OnnxRuntimeSessionToolCUDA.h          |  40 ++++
 .../AthOnnxComps/src/OnnxRuntimeSvc.cxx       |  43 ++++
 .../AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h |  58 ++++++
 .../src/components/AthOnnxComps_entries.cxx   |  11 +
 Control/AthOnnx/AthOnnxConfig/CMakeLists.txt  |   8 +
 .../AthOnnxConfig/python/OnnxRuntimeFlags.py  |  29 +++
 .../python/OnnxRuntimeSessionConfig.py        |  27 +++
 .../python/OnnxRuntimeSvcConfig.py            |  11 +
 .../AthOnnx/AthOnnxConfig/python/__init__.py  |   1 +
 .../ALTAS_CHECK_THREAD_SAFETY                 |   1 +
 .../IOnnxRuntimeSessionTool.h                 | 137 ++++++++++++
 .../IOnnxRuntimeSessionTool.ipp               |  49 +++++
 .../AthOnnxInterfaces/IOnnxRuntimeSvc.h       |  41 ++++
 .../AthOnnx/AthOnnxInterfaces/CMakeLists.txt  |  12 ++
 .../AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY    |   1 +
 .../AthOnnxUtils/AthOnnxUtils/OnnxUtils.h     |  77 +++++++
 Control/AthOnnx/AthOnnxUtils/CMakeLists.txt   |  16 ++
 .../AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx    |  81 ++++++++
 23 files changed, 985 insertions(+)
 create mode 100644 Control/AthOnnx/AthOnnxComps/CMakeLists.txt
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
 create mode 100644 Control/AthOnnx/AthOnnxConfig/CMakeLists.txt
 create mode 100644 Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
 create mode 100644 Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py
 create mode 100644 Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSvcConfig.py
 create mode 100644 Control/AthOnnx/AthOnnxConfig/python/__init__.py
 create mode 100644 Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY
 create mode 100644 Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
 create mode 100644 Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.ipp
 create mode 100644 Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h
 create mode 100644 Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt
 create mode 100644 Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY
 create mode 100644 Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
 create mode 100644 Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
 create mode 100644 Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx

diff --git a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
new file mode 100644
index 000000000000..fc9fb47fc120
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package's name.
+atlas_subdir( AthOnnxComps )
+
+# External dependencies.
+find_package( onnxruntime )
+
+# Component(s) in the package.
+atlas_add_component( AthOnnxComps
+   src/*.h
+   src/*.cxx 
+   src/components/*.cxx
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} 
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib 
+   AthOnnxInterfaces AthenaBaseComps GaudiKernel AthOnnxruntimeServiceLib AthOnnxUtilsLib
+)
diff --git a/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY
new file mode 100644
index 000000000000..6c51dc7f3797
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY
@@ -0,0 +1 @@
+Control/AthOnnx/AthOnnxComps
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx
new file mode 100644
index 000000000000..101b3edae642
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx
@@ -0,0 +1,195 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "OnnxRuntimeSessionTool.h"
+#include "AthOnnxUtils/OnnxUtils.h"
+
+AthOnnx::OnnxRuntimeSessionTool::OnnxRuntimeSessionTool(
+  const std::string& type, const std::string& name, const IInterface* parent )
+  : base_class( type, name, parent )
+{
+  declareInterface<IOnnxRuntimeSessionTool>(this);
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionTool::initialize()
+{
+    // Get the Onnx Runtime service.
+    ATH_CHECK(m_onnxRuntimeSvc.retrieve());
+
+    ATH_CHECK(createModel());
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionTool::finalize()
+{
+    StatusCode sc = AlgTool::finalize();
+    return sc;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionTool::createModel()
+{
+    ATH_CHECK(createSession());
+    ATH_CHECK(getNodeInfo());
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionTool::createSession() 
+{
+    // Create the session options.
+    // TODO: Make this configurable.
+    // other threading options: https://onnxruntime.ai/docs/performance/tune-performance/threading.html
+
+    if (m_modelFileName.empty()) {
+        ATH_MSG_ERROR("Model file name is empty");
+        return StatusCode::FAILURE;
+    }
+
+    Ort::SessionOptions sessionOptions;
+    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
+    sessionOptions.DisablePerSessionThreads();  // use global thread pool.
+
+    // Create the session.
+    m_session =  std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionTool::getNodeInfo()
+{
+    if (m_session == nullptr) {
+        ATH_MSG_ERROR("Session is not created");
+        return StatusCode::FAILURE;
+    }
+
+    // obtain the model information
+    m_numInputs = m_session->GetInputCount();
+    m_numOutputs = m_session->GetOutputCount();
+
+    AthOnnx::GetInputNodeInfo(m_session, m_inputShapes, m_inputNodeNames);
+    AthOnnx::GetOutputNodeInfo(m_session, m_outputShapes, m_outputNodeNames);
+
+    return StatusCode::SUCCESS;
+}
+
+
+void AthOnnx::OnnxRuntimeSessionTool::setBatchSize(int64_t batchSize)
+{
+    if (batchSize <= 0) {
+        ATH_MSG_ERROR("Batch size should be positive");
+        return;
+    }
+
+    for (auto& shape : m_inputShapes) {
+        if (shape[0] == -1) {
+            shape[0] = batchSize;
+        }
+    }
+    
+    for (auto& shape : m_outputShapes) {
+        if (shape[0] == -1) {
+            shape[0] = batchSize;
+        }
+    }
+}
+
+int64_t AthOnnx::OnnxRuntimeSessionTool::getBatchSize(int64_t inputDataSize, int idx) const
+{
+    auto tensorSize = AthOnnx::GetTensorSize(m_inputShapes[idx]);
+    if (tensorSize < 0) {
+        return inputDataSize / abs(tensorSize);
+    } else {
+        return -1;
+    }
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
+{
+    assert (inputTensors.size() == m_numInputs);
+    assert (outputTensors.size() == m_numOutputs);
+
+    // Run the model.
+    AthOnnx::InferenceWithIOBinding(
+            m_session, 
+            m_inputNodeNames, inputTensors, 
+            m_outputNodeNames, outputTensors);
+
+    return StatusCode::SUCCESS;
+}
+
+const std::vector<const char*>& AthOnnx::OnnxRuntimeSessionTool::getInputNodeNames() const
+{
+    return m_inputNodeNames;
+}
+
+const std::vector<const char*>& AthOnnx::OnnxRuntimeSessionTool::getOutputNodeNames() const
+{
+    return m_outputNodeNames;
+}
+
+const std::vector<std::vector<int64_t> >& AthOnnx::OnnxRuntimeSessionTool::getInputShape() const
+{
+    return m_inputShapes;
+}
+
+const std::vector<std::vector<int64_t> >& AthOnnx::OnnxRuntimeSessionTool::getOutputShapes() const
+{
+    return m_outputShapes;
+} 
+
+int AthOnnx::OnnxRuntimeSessionTool::getNumInputs() const
+{
+    return m_numInputs;
+}
+
+int AthOnnx::OnnxRuntimeSessionTool::getNumOutputs() const
+{
+    return m_numOutputs;
+}
+
+int64_t AthOnnx::OnnxRuntimeSessionTool::getInputSize(int idx) const
+{
+    return AthOnnx::GetTensorSize(m_inputShapes[idx]);
+}
+
+int64_t AthOnnx::OnnxRuntimeSessionTool::getOutputSize(int idx) const
+{
+    return AthOnnx::GetTensorSize(m_outputShapes[idx]);
+}
+
+void AthOnnx::OnnxRuntimeSessionTool::printModelInfo() const
+{
+    ATH_MSG_INFO("Model file name: " << m_modelFileName.value());
+    ATH_MSG_INFO("Number of inputs: " << m_numInputs);
+    ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
+
+    ATH_MSG_INFO("Input node names: ");
+    for (const auto& name : m_inputNodeNames) {
+        ATH_MSG_INFO("\t" << name);
+    }
+
+    ATH_MSG_INFO("Output node names: ");
+    for (const auto& name : m_outputNodeNames) {
+        ATH_MSG_INFO("\t" << name);
+    }
+
+    ATH_MSG_INFO("Input shapes: ");
+    for (const auto& shape : m_inputShapes) {
+        std::string shapeStr = "\t";
+        for (const auto& dim : shape) {
+            shapeStr += std::to_string(dim) + " ";
+        }
+        ATH_MSG_INFO(shapeStr);
+    }
+
+    ATH_MSG_INFO("Output shapes: ");
+    for (const auto& shape : m_outputShapes) {
+        std::string shapeStr = "\t";
+        for (const auto& dim : shape) {
+            shapeStr += std::to_string(dim) + " ";
+        }
+        ATH_MSG_INFO(shapeStr);
+    }
+}
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h
new file mode 100644
index 000000000000..56d040d1d8cb
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h
@@ -0,0 +1,71 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef OnnxRuntimeSessionTool_H
+#define OnnxRuntimeSessionTool_H
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+#include "GaudiKernel/ServiceHandle.h"
+
+namespace AthOnnx {
+    // @class OnnxRuntimeSessionTool
+    // 
+    // @brief Tool to create Onnx Runtime session with CPU backend
+    //
+    // @author Xiangyang Ju <xiangyang.ju@cern.ch>
+    class OnnxRuntimeSessionTool :  public extends<AthAlgTool, IOnnxRuntimeSessionTool>
+    {
+        public:
+        /// Standard constructor
+        OnnxRuntimeSessionTool( const std::string& type,
+                                const std::string& name,
+                                const IInterface* parent );
+        virtual ~OnnxRuntimeSessionTool() = default;
+
+        /// Initialize the tool
+        virtual StatusCode initialize() override;
+        /// Finalize the tool
+        virtual StatusCode finalize() override;
+
+        virtual void setBatchSize(int64_t batchSize) override final;
+        virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
+
+        virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
+
+        virtual const std::vector<const char*>& getInputNodeNames() const override final;
+        virtual const std::vector<const char*>& getOutputNodeNames() const override final;
+
+        virtual const std::vector<std::vector<int64_t> >& getInputShape() const override final;
+        virtual const std::vector<std::vector<int64_t> >& getOutputShapes() const override final;
+
+        virtual int getNumInputs() const override final;
+        virtual int getNumOutputs() const override final;
+        virtual int64_t getInputSize(int idx = 0) const override final;
+        virtual int64_t getOutputSize(int idx = 0) const override final;
+
+        virtual void printModelInfo() const override final;
+
+        protected:
+        OnnxRuntimeSessionTool() = delete;
+        OnnxRuntimeSessionTool(const OnnxRuntimeSessionTool&) = delete;
+        OnnxRuntimeSessionTool& operator=(const OnnxRuntimeSessionTool&) = delete;
+
+        StatusCode createModel();
+        StatusCode createSession();
+        StatusCode getNodeInfo();
+
+        ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+        StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
+
+        std::unique_ptr<Ort::Session> m_session;
+        int m_numInputs;
+        int m_numOutputs;
+        std::vector<const char*> m_inputNodeNames;
+        std::vector<const char*> m_outputNodeNames;
+        // std::vector<std::vector<int64_t> > m_inputShapes;
+        // std::vector<std::vector<int64_t> > m_outputShapes;
+    };
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
new file mode 100644
index 000000000000..66b05273ad8d
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
@@ -0,0 +1,58 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "OnnxRuntimeSessionToolCUDA.h"
+#include "AthOnnxUtils/OnnxUtils.h"
+#include <limits>
+
+AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA(
+  const std::string& type, const std::string& name, const IInterface* parent): AthOnnx::OnnxRuntimeSessionTool(type, name, parent)
+{
+
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::createSession() 
+{
+    if (m_modelFileName.empty()) {
+        ATH_MSG_ERROR("Model file name is empty");
+        return StatusCode::FAILURE;
+    }
+
+    ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
+
+    Ort::SessionOptions sessionOptions;
+    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
+    sessionOptions.DisablePerSessionThreads();    // use global thread pool.
+
+    // TODO: add more cuda options to the interface
+    // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
+    // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
+    OrtCUDAProviderOptions cuda_options;
+    cuda_options.device_id = m_deviceId;
+    cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
+    cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
+
+    // memorry arena options for cuda memory shrinkage
+    // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
+    if (m_enableMemoryShrinkage) {
+        Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
+        // other options are not available in this release.
+        // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
+        // arena_cfg.max_mem = 0;   // let ORT pick default max memory
+        // arena_cfg.arena_extend_strategy = 0;   // 0: kNextPowerOfTwo, 1: kSameAsRequested
+        // arena_cfg.initial_chunk_size_bytes = 1024;
+        // arena_cfg.max_dead_bytes_per_chunk = 0;
+        // arena_cfg.initial_growth_chunk_size_bytes = 256;
+        // arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
+
+        cuda_options.default_memory_arena_cfg = arena_cfg;
+    }
+
+    sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
+
+    // Create the session.
+    m_session =  std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
+
+    return StatusCode::SUCCESS;
+}
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
new file mode 100644
index 000000000000..7c147350b0be
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
@@ -0,0 +1,40 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef OnnxRuntimeSessionToolCUDA_H
+#define OnnxRuntimeSessionToolCUDA_H
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "GaudiKernel/ServiceHandle.h"
+#include "OnnxRuntimeSessionTool.h"
+
+namespace AthOnnx {
+    // @class OnnxRuntimeSessionToolCUDA
+    // 
+    // @brief Tool to create Onnx Runtime session with CPU backend
+    //
+    // @author Xiangyang Ju <xiangyang.ju@cern.ch>
+    class OnnxRuntimeSessionToolCUDA :  public OnnxRuntimeSessionTool
+    {
+        public:
+        /// Standard constructor
+        OnnxRuntimeSessionToolCUDA( const std::string& type,
+                                const std::string& name,
+                                const IInterface* parent );
+        virtual ~OnnxRuntimeSessionToolCUDA() = default;      
+
+        protected:
+        OnnxRuntimeSessionToolCUDA() = delete;
+        OnnxRuntimeSessionToolCUDA(const OnnxRuntimeSessionToolCUDA&) = delete;
+        OnnxRuntimeSessionToolCUDA& operator=(const OnnxRuntimeSessionToolCUDA&) = delete;
+
+        protected:
+        StatusCode createSession();
+
+        private:
+        IntegerProperty m_deviceId{this, "DeviceId", 0};
+        BooleanProperty m_enableMemoryShrinkage{this, "EnableMemoryShrinkage", false};
+        
+    };
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx
new file mode 100644
index 000000000000..c45e2dba6faa
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx
@@ -0,0 +1,43 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+// Local include(s).
+#include "OnnxRuntimeSvc.h"
+#include <core/session/onnxruntime_c_api.h>
+
+namespace AthOnnx {
+  OnnxRuntimeSvc::OnnxRuntimeSvc(const std::string& name, ISvcLocator* svc) :
+      asg::AsgService(name, svc)
+   {
+     declareServiceInterface<AthOnnx::IOnnxRuntimeSvc>();
+   }
+   StatusCode OnnxRuntimeSvc::initialize() {
+
+      // Create the environment object.
+      Ort::ThreadingOptions tp_options;
+      tp_options.SetGlobalIntraOpNumThreads(1);
+      tp_options.SetGlobalInterOpNumThreads(1);
+
+      m_env = std::make_unique< Ort::Env >(
+            tp_options, ORT_LOGGING_LEVEL_WARNING, name().c_str());
+      ATH_MSG_DEBUG( "Ort::Env object created" );
+
+      // Return gracefully.
+      return StatusCode::SUCCESS;
+   }
+
+   StatusCode OnnxRuntimeSvc::finalize() {
+
+      // Dekete the environment object.
+      m_env.reset();
+      ATH_MSG_DEBUG( "Ort::Env object deleted" );
+
+      // Return gracefully.
+      return StatusCode::SUCCESS;
+   }
+
+   Ort::Env& OnnxRuntimeSvc::env() const {
+
+      return *m_env;
+   }
+
+} // namespace AthOnnx
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h
new file mode 100644
index 000000000000..9f90bc998f65
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h
@@ -0,0 +1,58 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
+#define ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
+
+// Local include(s).
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+
+// Framework include(s).
+#include <AsgServices/AsgService.h>
+
+// ONNX include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+// System include(s).
+#include <memory>
+
+namespace AthOnnx {
+
+   /// Service implementing @c AthOnnx::IOnnxRuntimeSvc
+   ///
+   /// This is a very simple implementation, just managing the lifetime
+   /// of some Onnx Runtime C++ objects.
+   ///
+   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
+   ///
+   class OnnxRuntimeSvc : public asg::AsgService, virtual public IOnnxRuntimeSvc {
+
+   public:
+
+      /// @name Function(s) inherited from @c Service
+      /// @{
+      OnnxRuntimeSvc (const std::string& name, ISvcLocator* svc);
+
+      /// Function initialising the service
+      virtual StatusCode initialize() override;
+      /// Function finalising the service
+      virtual StatusCode finalize() override;
+
+      /// @}
+
+      /// @name Function(s) inherited from @c AthOnnx::IOnnxRuntimeSvc
+      /// @{
+
+      /// Return the Onnx Runtime environment object
+      virtual Ort::Env& env() const override;
+
+      /// @}
+
+   private:
+      /// Global runtime environment for Onnx Runtime
+      std::unique_ptr< Ort::Env > m_env;
+
+   }; // class OnnxRuntimeSvc
+
+} // namespace AthOnnx
+
+#endif // ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
diff --git a/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
new file mode 100644
index 000000000000..9f9ae59ef5f6
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
@@ -0,0 +1,11 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+// Local include(s).
+#include "../OnnxRuntimeSvc.h"
+#include "../OnnxRuntimeSessionTool.h"
+#include "../OnnxRuntimeSessionToolCUDA.h"
+
+// Declare the package's components.
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSvc )
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionTool )
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCUDA )
diff --git a/Control/AthOnnx/AthOnnxConfig/CMakeLists.txt b/Control/AthOnnx/AthOnnxConfig/CMakeLists.txt
new file mode 100644
index 000000000000..b8492033a884
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxConfig/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package name:
+atlas_subdir( AthOnnxConfig )
+
+
+# install python modules
+atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} )
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
new file mode 100644
index 000000000000..8df14ce012c1
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
@@ -0,0 +1,29 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.AthConfigFlags import AthConfigFlags
+from AthenaConfiguration.Enums import FlagEnum
+
+class OnnxRuntimeExecutionProvider(FlagEnum):
+    CPU = 'CPU'
+    CUDA = 'CUDA'
+# possible future backends. Uncomment when implemented.
+    # DML = 'DML'
+    # DNNL = 'DNNL'
+    # NUPHAR = 'NUPHAR'
+    # OPENVINO = 'OPENVINO'
+    # ROCM = 'ROCM'
+    # TENSORRT = 'TENSORRT'
+    # VITISAI = 'VITISAI'
+    # VULKAN = 'VULKAN'
+
+def createOnnxRuntimeFlags():
+    icf = AthConfigFlags()
+
+    icf.addFlag("AthOnnx.ExecutionProvider", "CPU")
+
+    return icf
+
+if __name__ == "__main__":
+
+    flags = createOnnxRuntimeFlags()
+    flags.dump()
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py
new file mode 100644
index 000000000000..2fc06c26f07f
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+from AthOnnxConfig.OnnxRuntimeFlags import OnnxRuntimeExecutionProvider as OrtEP
+from typing import Optional
+
+def OnnxRuntimeSessionToolCfg(flags,
+                              model_fname: str = None, 
+                              execution_provider: Optional[str] = None, 
+                              name="OnnxRuntimeSessionTool", **kwargs):
+    """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
+    
+    acc = ComponentAccumulator()
+
+    if model_fname is None:
+        raise ValueError("model_fname must be specified")
+    
+    execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
+    if execution_provider == OrtEP.CPU.name:
+        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionTool(name, ModelFileName=model_fname, **kwargs))
+    elif execution_provider == OrtEP.CUDA.name:
+        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, ModelFileName=model_fname,  **kwargs))
+    else:
+        raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider)
+
+    return acc
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSvcConfig.py b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSvcConfig.py
new file mode 100644
index 000000000000..e006fe8232c4
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSvcConfig.py
@@ -0,0 +1,11 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+
+def OnnxRuntimeSvcCfg(flags, name="OnnxRuntimeSvc", **kwargs):
+    acc = ComponentAccumulator()
+    
+    acc.addService(CompFactory.AthOnnx.OnnxRuntimeSvc(name, **kwargs), primary=True, create=True)
+
+    return acc
diff --git a/Control/AthOnnx/AthOnnxConfig/python/__init__.py b/Control/AthOnnx/AthOnnxConfig/python/__init__.py
new file mode 100644
index 000000000000..87ae30225ac7
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxConfig/python/__init__.py
@@ -0,0 +1 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
diff --git a/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY
new file mode 100644
index 000000000000..9524c9b32dce
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY
@@ -0,0 +1 @@
+Control/AthOnnx/AthOnnxInterfaces
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
new file mode 100644
index 000000000000..315c775e3b6a
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
@@ -0,0 +1,137 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+#ifndef AthOnnx_IOnnxRUNTIMESESSIONTool_H
+#define AthOnnx_IOnnxRUNTIMESESSIONTool_H
+
+// Gaudi include(s).
+#include "GaudiKernel/IAlgTool.h"
+#include <memory>
+
+#include <core/session/onnxruntime_cxx_api.h>
+
+
+namespace AthOnnx {
+    /**
+     * @class IOnnxRuntimeSessionTool
+     * @brief Interface class for creating Onnx Runtime sessions.
+     * @details Interface class for creating Onnx Runtime sessions.
+     * It is thread safe, supports models with various number of inputs and outputs,
+     * supports models with dynamic batch size, and usess . It defines a standardized procedure to 
+     * perform Onnx Runtime inference. The procedure is as follows, assuming the tool `m_onnxTool` is created and initialized:
+     *    1. create input tensors from the input data: 
+     *      ```c++
+     *         std::vector<Ort::Value> inputTensors;
+     *         std::vector<float> inputData_1;   // The input data is filled by users, possibly from the event information. 
+     *         int64_t batchSize = m_onnxTool->getBatchSize(inputData_1.size(), 0);  // The batch size is determined by the input data size to support dynamic batch size.
+     *         m_onnxTool->addInput(inputTensors, inputData_1, 0, batchSize);
+     *         std::vector<int64_t> inputData_2;  // Some models may have multiple inputs. Add inputs one by one.
+     *         int64_t batchSize_2 = m_onnxTool->getBatchSize(inputData_2.size(), 1);
+     *         m_onnxTool->addInput(inputTensors, inputData_2, 1, batchSize_2);
+     *     ```
+     *    2. create output tensors:
+     *      ```c++
+     *          std::vector<Ort::Value> outputTensors;
+     *          std::vector<float> outputData;   // The output data will be filled by the onnx session.
+     *          m_onnxTool->addOutput(outputTensors, outputData, 0, batchSize);
+     *      ```
+     *   3. perform inference:
+     *     ```c++
+     *        m_onnxTool->inference(inputTensors, outputTensors);
+     *    ```
+     *   4. Model outputs will be automatically filled to outputData.
+     * 
+     * 
+     * @author Xiangyang Ju <xju@cern.ch>
+     */
+    class IOnnxRuntimeSessionTool : virtual public IAlgTool 
+    {
+        public:
+
+        virtual ~IOnnxRuntimeSessionTool() = default;
+        
+        // @name InterfaceID
+        DeclareInterfaceID(IOnnxRuntimeSessionTool, 1, 0);
+
+        /**
+         * @brief set batch size. 
+         * @details If the model has dynamic batch size, 
+         *          the batchSize value will be set to both input shapes and output shapes
+         */ 
+        virtual void setBatchSize(int64_t batchSize) = 0;
+
+        /**
+         * @brief methods for determining batch size from the data size
+         * @param dataSize the size of the input data, like std::vector<T>::size()
+         * @param idx the index of the input node
+         * @return the batch size, which equals to dataSize / size of the rest dimensions.
+         */ 
+        virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0;
+
+        /**
+         * @brief add the input data to the input tensors
+         * @param inputTensors the input tensor container
+         * @param data the input data
+         * @param idx the index of the input node
+         * @param batchSize the batch size
+         * @return StatusCode::SUCCESS if the input data is added successfully
+         */
+        template <typename T>
+        StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
+
+        /**
+         * @brief add the output data to the output tensors
+         * @param outputTensors the output tensor container
+         * @param data the output data
+         * @param idx the index of the output node
+         * @param batchSize the batch size
+         * @return StatusCode::SUCCESS if the output data is added successfully
+         */
+        template <typename T>
+        StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
+
+        /**
+         * @brief perform inference
+         * @param inputTensors the input tensor container
+         * @param outputTensors the output tensor container
+         * @return StatusCode::SUCCESS if the inference is performed successfully
+         */
+        virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0;
+
+        /// @brief get the input node names
+        virtual const std::vector<const char*>& getInputNodeNames() const = 0;
+        /// @brief get the output node names
+        virtual const std::vector<const char*>& getOutputNodeNames() const = 0;
+
+        virtual const std::vector<std::vector<int64_t> >& getInputShape() const = 0;
+        virtual const std::vector<std::vector<int64_t> >& getOutputShapes() const = 0;
+
+        /// @brief get the number of input nodes
+        virtual int getNumInputs() const = 0;
+        /// @brief get the number of output nodes
+        virtual int getNumOutputs() const = 0;
+
+        /** 
+         * @brief get the size of the input/output tensor
+         * @param idx the index of the input/output node
+         * @return the size of the input/output tensor
+        */
+        virtual int64_t getInputSize(int idx = 0) const = 0;
+        virtual int64_t getOutputSize(int idx = 0) const = 0;
+
+        virtual void printModelInfo() const = 0;
+
+        protected:
+        int m_numInputs;
+        int m_numOutputs;
+        std::vector<std::vector<int64_t> > m_inputShapes;
+        std::vector<std::vector<int64_t> > m_outputShapes;
+
+        private:
+        template <typename T>
+        Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const;
+
+    };
+
+    #include "IOnnxRuntimeSessionTool.ipp"
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.ipp b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.ipp
new file mode 100644
index 000000000000..e46dfed831dd
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.ipp
@@ -0,0 +1,49 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+template <typename T>
+Ort::Value AthOnnx::IOnnxRuntimeSessionTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
+{
+    std::vector<int64_t> dataShapeCopy = dataShape;
+
+    if (batchSize > 0) {
+        for (auto& shape: dataShapeCopy) {
+            if (shape == -1) {
+                shape = batchSize;
+                break;
+            }
+        }
+    }
+
+    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+    return Ort::Value::CreateTensor<T>(
+        memoryInfo, data.data(), data.size(), dataShapeCopy.data(), dataShapeCopy.size()
+    );
+}
+
+template <typename T>
+StatusCode AthOnnx::IOnnxRuntimeSessionTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
+{
+    if (idx >= m_numInputs || idx < 0) {
+        // ATH_MSG_ERROR("Need " << m_numInputs << " tensors; but adding "<< idx << " tensor.");
+        return StatusCode::FAILURE;
+    }
+
+    inputTensors.push_back(std::move(createTensor(data, m_inputShapes[idx], batchSize)));
+    return StatusCode::SUCCESS;
+}
+
+template <typename T>
+StatusCode AthOnnx::IOnnxRuntimeSessionTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
+{
+    if (idx >= m_numOutputs || idx < 0) {
+        // ATH_MSG_ERROR("Need " << m_numOutputs << " tensors; but adding "<< idx << " tensor.");
+        return StatusCode::FAILURE;
+    }
+    auto tensorSize = std::accumulate(m_outputShapes[idx].begin(), m_outputShapes[idx].end(), 1, std::multiplies<int64_t>());
+    if (tensorSize < 0) {
+        tensorSize = abs(tensorSize) * batchSize;
+    }
+    data.resize(tensorSize);
+    outputTensors.push_back(std::move(createTensor(data, m_outputShapes[idx], batchSize)));
+    return StatusCode::SUCCESS;
+}
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h
new file mode 100644
index 000000000000..b546392b55b4
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h
@@ -0,0 +1,41 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef ATHEXOnnxRUNTIME_IOnnxRUNTIMESVC_H
+#define ATHEXOnnxRUNTIME_IOnnxRUNTIMESVC_H
+
+// Gaudi include(s).
+#include <AsgServices/IAsgService.h>
+
+// Onnx include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+
+/// Namespace holding all of the Onnx Runtime example code
+namespace AthOnnx {
+
+   //class IAsgService
+   /// Service used for managing global objects used by Onnx Runtime
+   ///
+   /// In order to allow multiple clients to use Onnx Runtime at the same
+   /// time, this service is used to manage the objects that must only
+   /// be created once in the Athena process.
+   ///
+   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
+   ///
+   class IOnnxRuntimeSvc : virtual public asg::IAsgService{
+
+   public:
+      /// Virtual destructor, to make vtable happy
+      virtual ~IOnnxRuntimeSvc() = default;
+
+      /// Declare the interface that this class provides
+      DeclareInterfaceID (AthOnnx::IOnnxRuntimeSvc, 1, 0);
+
+      /// Return the Onnx Runtime environment object
+      virtual Ort::Env& env() const = 0;
+
+   }; // class IOnnxRuntimeSvc
+
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt b/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt
new file mode 100644
index 000000000000..72b84c814183
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt
@@ -0,0 +1,12 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package name:
+atlas_subdir( AthOnnxInterfaces )
+
+# Component(s) in the package:
+atlas_add_library( AthOnnxInterfaces
+                   AthOnnxInterfaces/*.h
+                   INTERFACE
+                   PUBLIC_HEADERS AthOnnxInterfaces
+                   LINK_LIBRARIES AsgTools AthLinks GaudiKernel 
+                  )
diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY
new file mode 100644
index 000000000000..c5b26fc5a994
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY
@@ -0,0 +1 @@
+Control/AthOnnx/AthOnnxUtils
diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
new file mode 100644
index 000000000000..8ab6cf2cb8d5
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
@@ -0,0 +1,77 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef Onnx_UTILS_H
+#define Onnx_UTILS_H
+
+#include <memory>
+#include <vector>
+
+// Onnx Runtime include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+namespace AthOnnx {
+
+// @author Xiangyang Ju <xiangyang.ju@cern.ch>
+
+// @brief Convert a vector of vectors to a single vector.
+// @param features The vector of vectors to be flattened.
+// @return A single vector containing all the elements of the input vector of vectors.
+template<typename T>
+inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& features) {
+  // 1. Compute the total size required.
+  int total_size = 0;
+  for (const auto& feature : features) total_size += feature.size();
+  
+  std::vector<T> flatten1D;
+  flatten1D.reserve(total_size);
+
+  for (const auto& feature : features)
+    for (const auto& elem : feature)
+      flatten1D.push_back(elem);
+
+  return flatten1D;
+}
+
+// @brief Get the input data shape and node names (in the computational graph) from the onnx model
+// @param session The onnx session.
+// @param dataShape The shape of the input data. Note that there may be multiple inputs.
+// @param nodeNames The names of the input nodes in the computational graph.
+// the dataShape and nodeNames will be updated.
+void GetInputNodeInfo(
+    const std::unique_ptr< Ort::Session >& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<const char*>& nodeNames);
+
+// @brief Get the output data shape and node names (in the computational graph) from the onnx model
+// @param session The onnx session.
+// @param dataShape The shape of the output data.
+// @param nodeNames The names of the output nodes in the computational graph.
+// the dataShape and nodeNames will be updated.
+void GetOutputNodeInfo(
+    const std::unique_ptr< Ort::Session >& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<const char*>& nodeNames);
+
+// Heleper function to get node info
+void GetNodeInfo(
+    const std::unique_ptr< Ort::Session >& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<const char*>& nodeNames,
+    bool isInput
+);
+
+// @brief to count the total number of elements in a tensor
+// They are useful for reserving spaces for the output data.
+int64_t GetTensorSize(const std::vector<int64_t>& dataShape);
+
+// Inference with IO binding. Better for performance, particularly for GPUs.
+// See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html
+void InferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
+    const std::vector<const char*>& inputNames,
+    const std::vector<Ort::Value>& inputData,
+    const std::vector<const char*>& outputNames,
+    const std::vector<Ort::Value>& outputData
+); 
+
+}
+#endif
diff --git a/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
new file mode 100644
index 000000000000..4134a24f81e6
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package's name.
+atlas_subdir( AthOnnxUtils )
+
+# External dependencies.
+find_package( onnxruntime )
+
+# Component(s) in the package.
+atlas_add_library( AthOnnxUtilsLib
+   AthOnnxUtils/*.h 
+   src/*.cxx
+   PUBLIC_HEADERS AthOnnxUtils
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaKernel GaudiKernel AsgServicesLib 
+)
diff --git a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
new file mode 100644
index 000000000000..1de3fe27f341
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
@@ -0,0 +1,81 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#include "AthOnnxUtils/OnnxUtils.h"
+#include <cassert>
+
+namespace AthOnnx {
+
+void GetNodeInfo(
+    const std::unique_ptr< Ort::Session >& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<const char*>& nodeNames,
+    bool isInput
+){
+    dataShape.clear();
+    nodeNames.clear();
+
+    size_t numNodes = isInput? session->GetInputCount(): session->GetOutputCount();
+    dataShape.reserve(numNodes);
+    nodeNames.reserve(numNodes);
+
+    Ort::AllocatorWithDefaultOptions allocator;
+    for( std::size_t i = 0; i < numNodes; i++ ) {
+        Ort::TypeInfo typeInfo = isInput? session->GetInputTypeInfo(i): session->GetOutputTypeInfo(i);
+        auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
+        dataShape.emplace_back(tensorInfo.GetShape());
+
+        char* nodeName = isInput? session->GetInputNameAllocated(i, allocator).release() : session->GetOutputNameAllocated(i, allocator).release();
+        nodeNames.push_back(nodeName);
+     }
+}
+
+void GetInputNodeInfo(
+    const std::unique_ptr< Ort::Session >& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<const char*>& nodeNames
+){
+    GetNodeInfo(session, dataShape, nodeNames, true);
+}
+
+void GetOutputNodeInfo(
+    const std::unique_ptr< Ort::Session >& session,
+    std::vector<std::vector<int64_t> >& dataShape,
+    std::vector<const char*>& nodeNames
+) {
+    GetNodeInfo(session, dataShape, nodeNames, false);
+}
+
+void InferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
+    const std::vector<const char*>& inputNames,
+    const std::vector<Ort::Value>& inputData,
+    const std::vector<const char*>& outputNames,
+    const std::vector<Ort::Value>& outputData){
+    
+    if (inputNames.empty()) {
+        throw std::runtime_error("Onnxruntime input data maping cannot be empty");
+    }
+    assert(inputNames.size() == inputData.size());
+
+    Ort::IoBinding iobinding(*session);
+    for(size_t idx = 0; idx < inputNames.size(); ++idx){
+        iobinding.BindInput(inputNames[idx], inputData[idx]);
+    }
+
+
+    for(size_t idx = 0; idx < outputNames.size(); ++idx){
+        iobinding.BindOutput(outputNames[idx], outputData[idx]);
+    }
+
+    session->Run(Ort::RunOptions{nullptr}, iobinding);
+}
+
+int64_t GetTensorSize(const std::vector<int64_t>& dataShape){
+    int64_t size = 1;
+    for (const auto& dim : dataShape) {
+            size *= dim;
+    }
+    return size;
+}
+
+
+} // namespace AthOnnx
\ No newline at end of file
-- 
GitLab


From c3bf27d97fc36123140d42945d64b7a6dfe108a4 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Mon, 8 Jan 2024 12:29:01 -0800
Subject: [PATCH 02/18] adapt to the new ONNX package

---
 .../python/AllConfigFlags.py                  |   6 +
 .../AthExOnnxRuntime/CMakeLists.txt           |  23 +-
 .../share/AthExOnnxRuntime_CA.py              |  43 ++++
 .../AthExOnnxRuntime/src/EvaluateModel.cxx    | 209 ++++--------------
 .../AthExOnnxRuntime/src/EvaluateModel.h      |  32 +--
 .../components/AthExOnnxRuntime_entries.cxx   |   2 +-
 InnerDetector/InDetGNNTracking/CMakeLists.txt |   1 +
 .../src/SiGNNTrackFinderTool.cxx              |  89 ++------
 .../src/SiGNNTrackFinderTool.h                |  16 +-
 .../src/SiSPGNNTrackMaker.cxx                 |   2 +-
 .../InDetRecToolInterfaces/IGNNTrackFinder.h  |   2 +-
 11 files changed, 152 insertions(+), 273 deletions(-)
 create mode 100644 Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_CA.py

diff --git a/Control/AthenaConfiguration/python/AllConfigFlags.py b/Control/AthenaConfiguration/python/AllConfigFlags.py
index 3f0d9a734fa9..0692a026d890 100644
--- a/Control/AthenaConfiguration/python/AllConfigFlags.py
+++ b/Control/AthenaConfiguration/python/AllConfigFlags.py
@@ -488,6 +488,12 @@ def initConfigFlags():
         return createLLPDFConfigFlags()
     _addFlagsCategory(acf, "Derivation.LLP", __llpDerivation, 'DerivationFrameworkLLP' )
 
+    # onnxruntime flags
+    def __onnxruntime():
+        from AthOnnxConfig.OnnxRuntimeFlags import createOnnxRuntimeFlags
+        return createOnnxRuntimeFlags()
+    _addFlagsCategory(acf, "AthOnnx", __onnxruntime, 'AthOnnxConfig')
+
     # For AnalysisBase, pick up things grabbed in Athena by the functions above
     if not isGaudiEnv():
         def EDMVersion(flags):
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index 29281dfe4a6e..ee1986f4646b 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -1,28 +1,17 @@
-# Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
 # Declare the package's name.
 atlas_subdir( AthExOnnxRuntime )
 
-# Component(s) in the package.
-atlas_add_library( AthExOnnxRuntimeLib
-   INTERFACE
-   PUBLIC_HEADERS AthExOnnxRuntime
-   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} GaudiKernel )
+find_package( onnxruntime )
 
+# Component(s) in the package.
 atlas_add_component( AthExOnnxRuntime
    src/*.h src/*.cxx src/components/*.cxx
    INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthExOnnxRuntimeLib AthenaBaseComps GaudiKernel PathResolver AthOnnxruntimeServiceLib AthOnnxruntimeUtilsLib)
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaBaseComps GaudiKernel PathResolver 
+   AthOnnxInterfaces AthOnnxUtilsLib AsgServicesLib
+)
 
 # Install files from the package.
 atlas_install_joboptions( share/*.py )
-
-# Set up tests for the package.
-atlas_add_test( AthExOnnxRuntimeJob_serial
-   SCRIPT athena.py AthExOnnxRuntime/AthExOnnxRuntime_jobOptions.py
-   POST_EXEC_SCRIPT nopost.sh )
-
-atlas_add_test( AthExOnnxRuntimeJob_mt
-   SCRIPT athena.py --threads=2 AthExOnnxRuntime/AthExOnnxRuntime_jobOptions.py
-   POST_EXEC_SCRIPT nopost.sh )
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_CA.py b/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_CA.py
new file mode 100644
index 000000000000..4a2de47d9df9
--- /dev/null
+++ b/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_CA.py
@@ -0,0 +1,43 @@
+# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+
+
+def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
+    acc = ComponentAccumulator()
+    
+    model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx"
+    kwargs.setdefault("OutputLevel", Constants.DEBUG)
+
+    from AthOnnxConfig.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
+    kwargs.setdefault("OnnxRuntimeSessionTool", acc.popToolsAndMerge(
+        OnnxRuntimeSessionToolCfg(flags, 
+                                  model_fname, 
+#                                  execution_provider="CPU",  # optionally override flags.AthOnnx.ExecutionProvider, default is CPU
+                                  **kwargs)
+    ))
+
+    kwargs.setdefault("BatchSize", 3)
+    acc.addEventAlgo(CompFactory.AthOnnx.EvaluateModel(name, **kwargs))
+
+    return acc
+
+if __name__ == "__main__":
+    from AthenaCommon import Constants
+    from AthenaCommon.Logging import log as msg
+    from AthenaConfiguration.AllConfigFlags import initConfigFlags
+    from AthenaConfiguration.MainServicesConfig import MainServicesCfg
+
+    msg.setLevel(Constants.DEBUG)
+
+    flags = initConfigFlags()
+    flags.AthOnnx.ExecutionProvider = "CPU"
+    flags.lock()
+
+    acc = MainServicesCfg(flags)
+    acc.merge(AthExOnnxRuntimeExampleCfg(flags))
+    acc.printConfig(withDetails=True, summariseProps=True)
+
+    sc = acc.run(maxEvents=2)
+    msg.info(sc.isSuccess())
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
index fb1b6f25e625..29aff6d1c530 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
@@ -1,14 +1,17 @@
-// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
 // Local include(s).
 #include "EvaluateModel.h"
 #include <tuple>
+#include <fstream>
+#include <chrono>
+#include <arpa/inet.h>
 
 // Framework include(s).
 #include "PathResolver/PathResolver.h"
-#include "AthOnnxruntimeUtils/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 
-namespace AthONNX {
+namespace AthOnnx {
 
    //*******************************************************************
    // for reading MNIST images
@@ -66,195 +69,81 @@ namespace AthONNX {
     }
 
    StatusCode EvaluateModel::initialize() {
+    // Fetch tools
+    ATH_CHECK( m_onnxTool.retrieve() );
+    m_onnxTool->printModelInfo();
 
-      // Access the service.
-      //ATH_CHECK( m_svc.retrieve() );
+    // change the batch size
+   //  ATH_MSG_INFO("Setting batch size to "<<m_batchSize);
+   //  m_onnxTool->setBatchSize(m_batchSize);
+   //  m_onnxTool->printModelInfo();
 
       /*****
        The combination of no. of batches and batch size shouldn't cross 
        the total smple size which is 10000 for this example
       *****/         
-      if(m_doBatches && (m_numberOfBatches*m_sizeOfBatch)>10000){
+      if(m_batchSize > 10000){
         ATH_MSG_INFO("The total no. of sample crossed the no. of available sample ....");
 	return StatusCode::FAILURE;
-      }
-      // Find the model file.
-      const std::string modelFileName =
-         PathResolverFindCalibFile( m_modelFileName );
+       }
+     // read input file, and the target file for comparison.
       const std::string pixelFileName =
          PathResolverFindCalibFile( m_pixelFileName );
-      const std::string labelFileName =
-         PathResolverFindCalibFile( m_labelFileName );
-      ATH_MSG_INFO( "Using model file: " << modelFileName );
       ATH_MSG_INFO( "Using pixel file: " << pixelFileName );
-      ATH_MSG_INFO( "Using pixel file: " << labelFileName );
-      // Set up the ONNX Runtime session.
   
-      m_session = AthONNX::CreateORTSession(modelFileName, m_useCUDA);
-
-      if (m_useCUDA) {
-         ATH_MSG_INFO( "Created the ONNX Runtime session on CUDA" );
-      } else {
-         ATH_MSG_INFO( "Created the ONNX Runtime session on CPUs" );
-      }
-      
       m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(pixelFileName);
-      std::vector<std::vector<float>> c = m_input_tensor_values_notFlat[0];
-      m_output_tensor_values = read_mnist_label(labelFileName);    
-      // Return gracefully.
+      ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
+    
       return StatusCode::SUCCESS;
-   }
+}
 
    StatusCode EvaluateModel::execute( const EventContext& /*ctx*/ ) const {
-     
-     Ort::AllocatorWithDefaultOptions allocator;
-  
-     /************************** Input Nodes *****************************/
-     /*********************************************************************/
-     
-     std::tuple<std::vector<int64_t>, std::vector<const char*> > inputInfo = AthONNX::GetInputNodeInfo(m_session);
-     std::vector<int64_t> input_node_dims = std::get<0>(inputInfo);
-     std::vector<const char*> input_node_names = std::get<1>(inputInfo);
-     
-     for( std::size_t i = 0; i < input_node_names.size(); i++ ) {
-        // print input node names
-        ATH_MSG_DEBUG("Input "<<i<<" : "<<" name= "<<input_node_names[i]);
    
-        // print input shapes/dims
-        ATH_MSG_DEBUG("Input "<<i<<" : num_dims= "<<input_node_dims.size());
-        for (std::size_t j = 0; j < input_node_dims.size(); j++){
-           ATH_MSG_DEBUG("Input "<<i<<" : dim "<<j<<"= "<<input_node_dims[j]);
-          }
-       }
-
-     /************************** Output Nodes *****************************/
-     /*********************************************************************/
-     
-     std::tuple<std::vector<int64_t>, std::vector<const char*> > outputInfo = AthONNX::GetOutputNodeInfo(m_session);
-     std::vector<int64_t> output_node_dims = std::get<0>(outputInfo);
-     std::vector<const char*> output_node_names = std::get<1>(outputInfo);
-
-     for( std::size_t i = 0; i < output_node_names.size(); i++ ) {
-        // print input node names
-        ATH_MSG_DEBUG("Output "<<i<<" : "<<" name= "<<output_node_names[i]);
-
-        // print input shapes/dims
-        ATH_MSG_DEBUG("Output "<<i<<" : num_dims= "<<output_node_dims.size());
-        for (std::size_t j = 0; j < output_node_dims.size(); j++){
-           ATH_MSG_DEBUG("Output "<<i<<" : dim "<<j<<"= "<<output_node_dims[j]);
-          }
-       }
-    /************************* Score if input is not a batch ********************/
-    /****************************************************************************/
-     if(m_doBatches == false){
-
-        /**************************************************************************************
-         * input_node_dims[0] = -1; -1 needs to be replaced by the batch size; for no batch is 1 
-         * input_node_dims[1] = 28
-         * input_node_dims[2] = 28
-        ****************************************************************************************/
-
-     	input_node_dims[0] = 1;
-     	output_node_dims[0] = 1;
- 
-       /***************** Choose an example sample randomly ****************************/  
-     	std::vector<std::vector<float>> input_tensor_values = m_input_tensor_values_notFlat[m_testSample];
-        std::vector<float> flatten = AthONNX::FlattenInput_multiD_1D(input_tensor_values);
-        // Output label of corresponding m_input_tensor_values[m_testSample]; e.g 0, 1, 2, 3 etc
-        int output_tensor_values = m_output_tensor_values[m_testSample];
-       
-        // For a check that the sample dimension is fully flatten (1x28x28 = 784)
-        ATH_MSG_DEBUG("Size of Flatten Input tensor: "<<flatten.size());
+   // prepare inputs
+   std::vector<float> inputData;
+   for (int ibatch = 0; ibatch < m_batchSize; ibatch++){
+      const std::vector<std::vector<float> >& imageData = m_input_tensor_values_notFlat[ibatch];
+      std::vector<float> flatten = AthOnnx::flattenNestedVectors(imageData);
+      inputData.insert(inputData.end(), flatten.begin(), flatten.end());
+   }
 
-     	/************** Create input tensor object from input data values to feed into your model *********************/
-        
-        Ort::Value input_tensor = AthONNX::TensorCreator(flatten, input_node_dims );
+   int64_t batchSize = m_onnxTool->getBatchSize(inputData.size());
+   ATH_MSG_INFO("Batch size is " << batchSize << ".");
+   assert(batchSize == m_batchSize);
 
-        /********* Convert 784 elements long flattened 1D array to 3D (1, 28, 28) onnx compatible tensor ************/
-        ATH_MSG_DEBUG("Input tensor size after converted to Ort tensor: "<<input_tensor.GetTensorTypeAndShapeInfo().GetShape());     	
-        // Makes sure input tensor has same dimensions as input layer of the model
-        assert(input_tensor.IsTensor()&&
-     		input_tensor.GetTensorTypeAndShapeInfo().GetShape() == input_node_dims);
+   // bind the input data to the input tensor
+   std::vector<Ort::Value> inputTensors;
+   ATH_CHECK( m_onnxTool->addInput(inputTensors, inputData, 0, batchSize) );
 
-     	/********* Score model by feeding input tensor and get output tensor in return *****************************/
+   // reserve space for output data and bind it to the output tensor
+   std::vector<float> outputScores;
+   std::vector<Ort::Value> outputTensors;
+   ATH_CHECK( m_onnxTool->addOutput(outputTensors, outputScores, 0, batchSize) );
 
-        float* floatarr = AthONNX::Inference(m_session, input_node_names, input_tensor, output_node_names); 
+   // run the inference
+   // the output will be filled to the outputScores.
+   ATH_CHECK( m_onnxTool->inference(inputTensors, outputTensors) );
 
-     	// show  true label for the test input
-     	ATH_MSG_INFO("Label for the input test data  = "<<output_tensor_values);
+     	ATH_MSG_INFO("Label for the input test data: ");
+   for(int ibatch = 0; ibatch < m_batchSize; ibatch++){
      	float max = -999;
      	int max_index;
      	for (int i = 0; i < 10; i++){
-       		ATH_MSG_DEBUG("Score for class "<<i<<" = "<<floatarr[i]);
-       		if (max<floatarr[i]){
-          		max = floatarr[i];
-          		max_index = i;
+       		ATH_MSG_DEBUG("Score for class "<< i <<" = "<<outputScores[i] << " in batch " << ibatch);
+            int index = i + ibatch * 10;
+       		if (max < outputScores[index]){
+          		max = outputScores[index];
+          		max_index = index;
        		}
      	}
-     	ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
-     
-     } // m_doBatches == false codition ends   
-    /************************* Score if input is a batch ********************/
-    /****************************************************************************/
-    else {
-        /**************************************************************************************
-         Similar scoring structure like non batch execution but 1st demention needs to be replaced by batch size
-         for this example lets take 3 batches with batch size 5
-         * input_node_dims[0] = 5
-         * input_node_dims[1] = 28
-         * input_node_dims[2] = 28
-        ****************************************************************************************/
-        
-     	input_node_dims[0] = m_sizeOfBatch;
-     	output_node_dims[0] = m_sizeOfBatch;   
-     
-	/************************** process multiple batches ********************************/
-        int l =0; /****** variable for distributing rows in m_input_tensor_values_notFlat equally into batches*****/ 
-     	for (int i = 0; i < m_numberOfBatches; i++) {
-     		ATH_MSG_DEBUG("Processing batch #" << i);
-      		std::vector<float> batch_input_tensor_values;
-      		for (int j = l; j < l+m_sizeOfBatch; j++) {
-                         
-                        std::vector<float> flattened_input = AthONNX::FlattenInput_multiD_1D(m_input_tensor_values_notFlat[j]);
-                        /******************For each batch we need a flattened (5 x 28 x 28 = 3920) 1D array******************************/
-        		batch_input_tensor_values.insert(batch_input_tensor_values.end(), flattened_input.begin(), flattened_input.end());
-        	}   
-
-                Ort::Value batch_input_tensors = AthONNX::TensorCreator(batch_input_tensor_values, input_node_dims );
-
-		// Get pointer to output tensor float values
+      ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch);
+   }
 
-                float* floatarr = AthONNX::Inference(m_session, input_node_names, batch_input_tensors, output_node_names);
-     		// show  true label for the test input
-		for(int i = l; i<l+m_sizeOfBatch; i++){
-     			ATH_MSG_INFO("Label for the input test data  = "<<m_output_tensor_values[i]);
-                	int k = (i-l)*10;
-                	float max = -999;
-                	int max_index = 0;
-     			for (int j =k ; j < k+10; j++){
-       				ATH_MSG_INFO("Score for class "<<j-k<<" = "<<floatarr[j]);
-       				if (max<floatarr[j]){
-          				max = floatarr[j];
-          				max_index = j;
-       				}
-     			}
-    	       	ATH_MSG_INFO("Class: "<<max_index-k<<" has the highest score: "<<floatarr[max_index]);
-       		} 
-           l = l+m_sizeOfBatch;
-          }
-     } // else/m_doBatches == True codition ends
-    // Return gracefully.
       return StatusCode::SUCCESS;
    }
    StatusCode EvaluateModel::finalize() {
 
-      // Delete the session object.
-      m_session.reset();
-
-      // Return gracefully.
       return StatusCode::SUCCESS;
    }
 
-} // namespace AthONNX
-
-
+} // namespace AthOnnx
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
index aaae60dc2858..d1a1f3a3894a 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
@@ -4,24 +4,21 @@
 #define ATHEXONNXRUNTIME_EVALUATEMODEL_H
 
 // Local include(s).
-#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+
 // Framework include(s).
 #include "AthenaBaseComps/AthReentrantAlgorithm.h"
 #include "GaudiKernel/ServiceHandle.h"
 
-// ONNX Runtime include(s).
+// Onnx Runtime include(s).
 #include <core/session/onnxruntime_cxx_api.h>
 
 // System include(s).
 #include <memory>
 #include <string>
-#include <iostream> 
-#include <fstream>
-#include <arpa/inet.h>
 #include <vector>
-#include <iterator>
 
-namespace AthONNX {
+namespace AthOnnx {
 
    /// Algorithm demonstrating the usage of the ONNX Runtime C++ API
    ///
@@ -53,31 +50,22 @@ namespace AthONNX {
       /// @{
 
       /// Name of the model file to load
-      Gaudi::Property< std::string > m_modelFileName{ this, "ModelFileName",
-         "dev/MLTest/2020-03-02/MNIST_testModel.onnx",
-         "Name of the model file to load" };
       Gaudi::Property< std::string > m_pixelFileName{ this, "InputDataPixel",
          "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte",
          "Name of the input pixel file to load" };
-      Gaudi::Property< std::string > m_labelFileName{ this, "InputDataLabel",
-         "dev/MLTest/2020-03-31/t10k-labels-idx1-ubyte",
-         "Name of the label file to load" };
-      Gaudi::Property<int> m_testSample {this, "TestSample", 0, "A Random Test Sample"};
 
       /// Following properties needed to be consdered if the .onnx model is evaluated in batch mode
-      Gaudi::Property<bool> m_doBatches {this, "DoBatches", false, "Processing events by batches"};
-      Gaudi::Property<int> m_numberOfBatches {this, "NumberOfBatches", 1, "No. of batches to be passed"};
-      Gaudi::Property<int> m_sizeOfBatch {this, "SizeOfBatch", 1, "No. of elements/example in a batch"};
+      Gaudi::Property<int> m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"};
 
-      // If runs on CUDA
-      Gaudi::Property<bool> m_useCUDA {this, "UseCUDA", false, "Use CUDA"};
+      /// Tool handler for onnx inference session
+      ToolHandle< IOnnxRuntimeSessionTool >  m_onnxTool{
+         this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+      };
       
-      std::unique_ptr< Ort::Session > m_session;
       std::vector<std::vector<std::vector<float>>> m_input_tensor_values_notFlat;
-      std::vector<int> m_output_tensor_values;
 
    }; // class EvaluateModel
 
-} // namespace AthONNX
+} // namespace AthOnnx
 
 #endif // ATHEXONNXRUNTIME_EVALUATEMODEL_H
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx
index cb25a4bfd9d2..c89bab7da136 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx
@@ -3,4 +3,4 @@
 // Local include(s).
 #include "../EvaluateModel.h"
 // Declare the package's components.
-DECLARE_COMPONENT( AthONNX::EvaluateModel )
+DECLARE_COMPONENT( AthOnnx::EvaluateModel )
diff --git a/InnerDetector/InDetGNNTracking/CMakeLists.txt b/InnerDetector/InDetGNNTracking/CMakeLists.txt
index 2bcff0db22e5..ca9e45890dec 100644
--- a/InnerDetector/InDetGNNTracking/CMakeLists.txt
+++ b/InnerDetector/InDetGNNTracking/CMakeLists.txt
@@ -18,6 +18,7 @@ atlas_add_component( InDetGNNTracking
     PixelReadoutGeometryLib SCT_ReadoutGeometry InDetSimData
     InDetPrepRawData TrkTrack TrkRIO_OnTrack InDetSimEvent
     AtlasHepMCLib InDetRIO_OnTrack InDetRawData TrkTruthData
+    AthOnnxInterfaces AthOnnxUtilsLib
 )
 
 atlas_install_python_modules( python/*.py )
diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
index a48e9ea21a01..e07013addd91 100644
--- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
+++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
@@ -7,7 +7,7 @@
 
 // Framework include(s).
 #include "PathResolver/PathResolver.h"
-#include "AthOnnxruntimeUtils/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 #include <cmath>
 
 InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool(
@@ -18,20 +18,12 @@ InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool(
   }
 
 StatusCode InDet::SiGNNTrackFinderTool::initialize() {
-  initTrainedModels();
+  ATH_CHECK( m_embedSessionTool.retrieve() );
+  ATH_CHECK( m_filterSessionTool.retrieve() );
+  ATH_CHECK( m_gnnSessionTool.retrieve() );
   return StatusCode::SUCCESS;
 }
 
-void InDet::SiGNNTrackFinderTool::initTrainedModels() {
-  std::string embedModelPath(m_inputMLModuleDir + "/torchscript/embedding.onnx");
-  std::string filterModelPath(m_inputMLModuleDir + "/torchscript/filtering.onnx");
-  std::string gnnModelPath(m_inputMLModuleDir + "/torchscript/gnn.onnx");
-
-  m_embedSession = AthONNX::CreateORTSession(embedModelPath, m_useCUDA);
-  m_filterSession = AthONNX::CreateORTSession(filterModelPath, m_useCUDA);
-  m_gnnSession = AthONNX::CreateORTSession(gnnModelPath, m_useCUDA);
-}
-
 StatusCode InDet::SiGNNTrackFinderTool::finalize() {
   StatusCode sc = AlgTool::finalize();
   return sc;
@@ -59,7 +51,7 @@ MsgStream& InDet::SiGNNTrackFinderTool::dumpevent( MsgStream& out ) const
   return out;
 }
 
-void InDet::SiGNNTrackFinderTool::getTracks (
+StatusCode InDet::SiGNNTrackFinderTool::getTracks(
   const std::vector<const Trk::SpacePoint*>& spacepoints,
   std::vector<std::vector<uint32_t> >& tracks) const
 {
@@ -84,34 +76,19 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     spacepointIDs.push_back(sp_idx++);
   }
 
-    Ort::AllocatorWithDefaultOptions allocator;
-    auto memoryInfo = Ort::MemoryInfo::CreateCpu(
-        OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
-
     // ************
     // Embedding
     // ************
 
     std::vector<int64_t> eInputShape{numSpacepoints, spacepointFeatures};
-
-    std::vector<const char*> eInputNames{"sp_features"};
     std::vector<Ort::Value> eInputTensor;
-    eInputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, inputValues.data(), inputValues.size(),
-            eInputShape.data(), eInputShape.size())
-    );
+    ATH_CHECK( m_embedSessionTool->addInput(eInputTensor, inputValues, 0, numSpacepoints) );
 
-    std::vector<float> eOutputData(numSpacepoints * m_embeddingDim);
-    std::vector<const char*> eOutputNames{"embedding_output"};
-    std::vector<int64_t> eOutputShape{numSpacepoints, m_embeddingDim};
     std::vector<Ort::Value> eOutputTensor;
-    eOutputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, eOutputData.data(), eOutputData.size(),
-            eOutputShape.data(), eOutputShape.size())
-    );
-    AthONNX::InferenceWithIOBinding(m_embedSession, eInputNames, eInputTensor, eOutputNames, eOutputTensor);
+    std::vector<float> eOutputData;
+    ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) );
+
+    ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) );
 
     // ************
     // Building Edges
@@ -123,29 +100,18 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     // ************
     // Filtering
     // ************
-    std::vector<const char*> fInputNames{"f_nodes", "f_edges"};
     std::vector<Ort::Value> fInputTensor;
     fInputTensor.push_back(
         std::move(eInputTensor[0])
     );
+    ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) );
     std::vector<int64_t> fEdgeShape{2, numEdges};
-    fInputTensor.push_back(
-        Ort::Value::CreateTensor<int64_t>(
-            memoryInfo, edgeList.data(), edgeList.size(),
-            fEdgeShape.data(), fEdgeShape.size())
-    );
 
-    // filtering outputs
-    std::vector<const char*> fOutputNames{"f_edge_score"};
-    std::vector<float> fOutputData(numEdges);
-    std::vector<int64_t> fOutputShape{numEdges, 1};
+    std::vector<float> fOutputData;
     std::vector<Ort::Value> fOutputTensor;
-    fOutputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, fOutputData.data(), fOutputData.size(), 
-            fOutputShape.data(), fOutputShape.size())
-    );
-    AthONNX::InferenceWithIOBinding(m_filterSession, fInputNames, fInputTensor, fOutputNames, fOutputTensor);
+    ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) );
+
+    ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) );
 
     // apply sigmoid to the filtering output data
     // and remove edges with score < filterCut
@@ -167,28 +133,18 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     // ************
     // GNN
     // ************
-    std::vector<const char*> gInputNames{"g_nodes", "g_edges"};
     std::vector<Ort::Value> gInputTensor;
     gInputTensor.push_back(
         std::move(fInputTensor[0])
     );
-    std::vector<int64_t> gEdgeShape{2, numEdgesAfterF};
-    gInputTensor.push_back(
-        Ort::Value::CreateTensor<int64_t>(
-            memoryInfo, edgesAfterFiltering.data(), edgesAfterFiltering.size(),
-            gEdgeShape.data(), gEdgeShape.size())
-    );
+    ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) );
+    
     // gnn outputs
-    std::vector<const char*> gOutputNames{"gnn_edge_score"};
-    std::vector<float> gOutputData(numEdgesAfterF);
-    std::vector<int64_t> gOutputShape{numEdgesAfterF};
+    std::vector<float> gOutputData;
     std::vector<Ort::Value> gOutputTensor;
-    gOutputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, gOutputData.data(), gOutputData.size(), 
-            gOutputShape.data(), gOutputShape.size())
-    );
-    AthONNX::InferenceWithIOBinding(m_gnnSession, gInputNames, gInputTensor, gOutputNames, gOutputTensor);
+    ATH_CHECK( m_gnnSessionTool->addOutput(gOutputTensor, gOutputData, 0, numEdgesAfterF) );
+
+    ATH_CHECK( m_gnnSessionTool->inference(gInputTensor, gOutputTensor) );
     // apply sigmoid to the gnn output data
     for(auto& v : gOutputData){
         v = 1.f / (1.f + std::exp(-v));
@@ -200,7 +156,7 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     std::vector<int32_t> trackLabels(numSpacepoints);
     weaklyConnectedComponents<int64_t,float,int32_t>(numSpacepoints, rowIndices, colIndices, gOutputData, trackLabels);
 
-    if (trackLabels.size() == 0)  return;
+    if (trackLabels.size() == 0)  return StatusCode::SUCCESS;
 
     tracks.clear();
 
@@ -225,5 +181,6 @@ void InDet::SiGNNTrackFinderTool::getTracks (
             existTrkIdx++;
         }
     }
+    return StatusCode::SUCCESS;
 }
 
diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
index 031de24d7608..3e70de690ee9 100644
--- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
+++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
@@ -14,7 +14,7 @@
 #include "InDetRecToolInterfaces/IGNNTrackFinder.h"
 
 // ONNX Runtime include(s).
-#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
 #include <core/session/onnxruntime_cxx_api.h>
 
 class MsgStream;
@@ -46,7 +46,7 @@ namespace InDet{
      * 
      * @return 
      */
-    virtual void getTracks(
+    virtual StatusCode getTracks(
       const std::vector<const Trk::SpacePoint*>& spacepoints,
       std::vector<std::vector<uint32_t> >& tracks) const override;
 
@@ -74,9 +74,15 @@ namespace InDet{
     MsgStream&    dumpevent     (MsgStream&    out) const;
 
     private:
-    std::unique_ptr< Ort::Session > m_embedSession;
-    std::unique_ptr< Ort::Session > m_filterSession;
-    std::unique_ptr< Ort::Session > m_gnnSession;
+    ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_embedSessionTool {
+      this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+    };
+    ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_filterSessionTool {
+      this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+    };
+    ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_gnnSessionTool {
+      this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+    };
 
   };
 
diff --git a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx
index 121a710a66ee..d7caaa467f6f 100644
--- a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx
+++ b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx
@@ -68,7 +68,7 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const
   getData(m_SpacePointsSCTKey);
 
   std::vector<std::vector<uint32_t> > TT;
-  m_gnnTrackFinder->getTracks(spacePoints, TT);
+  ATH_CHECK(m_gnnTrackFinder->getTracks(spacePoints, TT));
 
 
   ATH_MSG_DEBUG("Obtained " << TT.size() << " Tracks");
diff --git a/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h b/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h
index b16efa8f42e9..52bb639c4f6a 100644
--- a/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h
+++ b/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h
@@ -38,7 +38,7 @@ namespace InDet {
      * @param tracks a list of track candidates in terms of spacepoint indices.
      * @return 
     */
-    virtual void getTracks(
+    virtual StatusCode getTracks(
       const std::vector<const Trk::SpacePoint*>& spacepoints,
       std::vector<std::vector<uint32_t> >& tracks) const=0;
 
-- 
GitLab


From 0ac0c9a247c79b12b941561809c69e0fb8592fab Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 9 Jan 2024 09:53:42 -0800
Subject: [PATCH 03/18] remove redundant packages

---
 .../ATLAS_CHECK_THREAD_SAFETY                 |   1 -
 .../AthOnnxruntimeServiceDict.h               |  12 --
 .../AthOnnxruntimeService/IONNXRuntimeSvc.h   |  41 -----
 .../AthOnnxruntimeService/ONNXRuntimeSvc.h    |  58 ------
 .../AthOnnxruntimeService/selection.xml       |   6 -
 Control/AthOnnxruntimeService/CMakeLists.txt  |  30 ----
 Control/AthOnnxruntimeService/README.md       |   7 -
 .../Root/ONNXRuntimeSvc.cxx                   |  38 ----
 .../AthOnnxruntimeService_entries.cxx         |   8 -
 .../ATLAS_CHECK_THREAD_SAFETY                 |   1 -
 .../AthOnnxruntimeUtils/OnnxUtils.h           | 168 ------------------
 Control/AthOnnxruntimeUtils/CMakeLists.txt    |  14 --
 12 files changed, 384 deletions(-)
 delete mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY
 delete mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
 delete mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h
 delete mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h
 delete mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml
 delete mode 100644 Control/AthOnnxruntimeService/CMakeLists.txt
 delete mode 100644 Control/AthOnnxruntimeService/README.md
 delete mode 100644 Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx
 delete mode 100644 Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx
 delete mode 100644 Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY
 delete mode 100644 Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
 delete mode 100644 Control/AthOnnxruntimeUtils/CMakeLists.txt

diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY b/Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY
deleted file mode 100644
index 8b969b1a5cf5..000000000000
--- a/Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY
+++ /dev/null
@@ -1 +0,0 @@
-Control/AthOnnxruntimeService
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h b/Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
deleted file mode 100644
index e8453362a499..000000000000
--- a/Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
+++ /dev/null
@@ -1,12 +0,0 @@
-/*
-  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
-*/
-
-
-#ifndef ATHONNXRUNTIMESERVICE__ATHONNXRUNTIMESERVICE_DICT_H
-#define ATHONNXRUNTIMESERVICE__ATHONNXRUNTIMESERVICE_DICT_H
-
-#include "AthOnnxruntimeService/ONNXRuntimeSvc.h"
-
-#endif
-
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h b/Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h
deleted file mode 100644
index cae4f26e3049..000000000000
--- a/Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h
+++ /dev/null
@@ -1,41 +0,0 @@
-// Dear emacs, this is -*- c++ -*-
-// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
-#ifndef ATHEXONNXRUNTIME_IONNXRUNTIMESVC_H
-#define ATHEXONNXRUNTIME_IONNXRUNTIMESVC_H
-
-// Gaudi include(s).
-#include <AsgServices/IAsgService.h>
-
-// ONNX include(s).
-#include <core/session/onnxruntime_cxx_api.h>
-
-
-/// Namespace holding all of the ONNX Runtime example code
-namespace AthONNX {
-
-   //class IAsgService
-   /// Service used for managing global objects used by ONNX Runtime
-   ///
-   /// In order to allow multiple clients to use ONNX Runtime at the same
-   /// time, this service is used to manage the objects that must only
-   /// be created once in the Athena process.
-   ///
-   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
-   ///
-   class IONNXRuntimeSvc : virtual public asg::IAsgService{
-
-   public:
-      /// Virtual destructor, to make vtable happy
-      virtual ~IONNXRuntimeSvc() = default;
-
-      /// Declare the interface that this class provides
-      DeclareInterfaceID (AthONNX::IONNXRuntimeSvc, 1, 0);
-
-      /// Return the ONNX Runtime environment object
-      virtual Ort::Env& env() const = 0;
-
-   }; // class IONNXRuntimeSvc
-
-} // namespace AthONNX
-
-#endif // ATHEXONNXRUNTIME_IONNXRUNTIMESVC_H
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h b/Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h
deleted file mode 100644
index 475c36f9dfd6..000000000000
--- a/Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h
+++ /dev/null
@@ -1,58 +0,0 @@
-// Dear emacs, this is -*- c++ -*-
-// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
-#ifndef ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
-#define ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
-
-// Local include(s).
-#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
-
-// Framework include(s).
-#include <AsgServices/AsgService.h>
-
-// ONNX include(s).
-#include <core/session/onnxruntime_cxx_api.h>
-
-// System include(s).
-#include <memory>
-
-namespace AthONNX {
-
-   /// Service implementing @c AthONNX::IONNXRuntimeSvc
-   ///
-   /// This is a very simple implementation, just managing the lifetime
-   /// of some ONNX Runtime C++ objects.
-   ///
-   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
-   ///
-   class ONNXRuntimeSvc : public asg::AsgService, virtual public IONNXRuntimeSvc {
-
-   public:
-
-      /// @name Function(s) inherited from @c Service
-      /// @{
-      ONNXRuntimeSvc (const std::string& name, ISvcLocator* svc);
-
-      /// Function initialising the service
-      virtual StatusCode initialize() override;
-      /// Function finalising the service
-      virtual StatusCode finalize() override;
-
-      /// @}
-
-      /// @name Function(s) inherited from @c AthONNX::IONNXRuntimeSvc
-      /// @{
-
-      /// Return the ONNX Runtime environment object
-      virtual Ort::Env& env() const override;
-
-      /// @}
-
-   private:
-      /// Global runtime environment for ONNX Runtime
-      std::unique_ptr< Ort::Env > m_env;
-
-   }; // class ONNXRuntimeSvc
-
-} // namespace AthONNX
-
-#endif // ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml b/Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml
deleted file mode 100644
index 7ab8a4333797..000000000000
--- a/Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-<lcgdict>
-
-   <class name="AthONNX::IONNXRuntimeSvc" />
-   <class name="AthONNX::ONNXRuntimeSvc" />
-
-</lcgdict>
diff --git a/Control/AthOnnxruntimeService/CMakeLists.txt b/Control/AthOnnxruntimeService/CMakeLists.txt
deleted file mode 100644
index ea7ee02c224c..000000000000
--- a/Control/AthOnnxruntimeService/CMakeLists.txt
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
-
-# Declare the package's name.
-atlas_subdir( AthOnnxruntimeService )
-
-# External dependencies.
-find_package( onnxruntime )
-
-# Component(s) in the package.
-atlas_add_library( AthOnnxruntimeServiceLib
-   AthOnnxruntimeService/*.h Root/*.cxx
-   PUBLIC_HEADERS AthOnnxruntimeService
-   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgServicesLib)
-
-if (XAOD_STANDALONE)
-atlas_add_dictionary( AthOnnxruntimeServiceDict
-   AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
-   AthOnnxruntimeService/selection.xml
-   LINK_LIBRARIES AthOnnxruntimeServiceLib )
-endif ()
-
-if (NOT XAOD_STANDALONE)
-  atlas_add_component( AthOnnxruntimeService
-     src/*.h src/*.cxx src/components/*.cxx
-     INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-     LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaBaseComps GaudiKernel AsgServicesLib)
-endif ()
-
-
diff --git a/Control/AthOnnxruntimeService/README.md b/Control/AthOnnxruntimeService/README.md
deleted file mode 100644
index 7ef7a41b6910..000000000000
--- a/Control/AthOnnxruntimeService/README.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# ONNXRUNTIMESERVICE
-
-This package is meant to accommodate all onnxruntimeService related services
-e.g. `IONNXRuntimeSvc.h` and `ONNXRuntimeSvc.*`
-
-To observe its usecases please check
-`https://gitlab.cern.ch/atlas/athena/-/blob/main/Control/AthenaExamples/AthExOnnxRuntime/src/CxxApiAlgorithm.h#L66`
diff --git a/Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx b/Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx
deleted file mode 100644
index 6e37a9074ae9..000000000000
--- a/Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
-
-// Local include(s).
-#include "AthOnnxruntimeService/ONNXRuntimeSvc.h"
-
-namespace AthONNX {
-  ONNXRuntimeSvc::ONNXRuntimeSvc(const std::string& name, ISvcLocator* svc) :
-      asg::AsgService(name, svc)
-   {
-     declareServiceInterface<AthONNX::IONNXRuntimeSvc>();
-   }
-   StatusCode ONNXRuntimeSvc::initialize() {
-
-      // Create the environment object.
-      m_env = std::make_unique< Ort::Env >( ORT_LOGGING_LEVEL_WARNING,
-                                            name().c_str() );
-      ATH_MSG_DEBUG( "Ort::Env object created" );
-
-      // Return gracefully.
-      return StatusCode::SUCCESS;
-   }
-
-   StatusCode ONNXRuntimeSvc::finalize() {
-
-      // Dekete the environment object.
-      m_env.reset();
-      ATH_MSG_DEBUG( "Ort::Env object deleted" );
-
-      // Return gracefully.
-      return StatusCode::SUCCESS;
-   }
-
-   Ort::Env& ONNXRuntimeSvc::env() const {
-
-      return *m_env;
-   }
-
-} // namespace AthONNX
diff --git a/Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx b/Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx
deleted file mode 100644
index fcc3ac20e682..000000000000
--- a/Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx
+++ /dev/null
@@ -1,8 +0,0 @@
-// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
-
-// Local include(s).
-#include <AthOnnxruntimeService/ONNXRuntimeSvc.h>
-
-// Declare the package's components.
-DECLARE_COMPONENT( AthONNX::ONNXRuntimeSvc )
-
diff --git a/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY b/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY
deleted file mode 100644
index 584871312ff4..000000000000
--- a/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY
+++ /dev/null
@@ -1 +0,0 @@
-Control/AthOnnxruntimeUtils
diff --git a/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h b/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
deleted file mode 100644
index 19f24b7a83f9..000000000000
--- a/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
+++ /dev/null
@@ -1,168 +0,0 @@
-// Dear emacs, this is -*- c++ -*-
-// Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
-#ifndef ONNX_UTILS_H
-#define ONNX_UTILS_H
-
-#include <string>
-#include <iostream> 
-#include <fstream>
-#include <arpa/inet.h>
-#include <vector>
-#include <iterator>
-#include <tuple>
-
-// ONNX Runtime include(s).
-#include <core/session/onnxruntime_cxx_api.h>
-// Local include(s).
-#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
-#include "GaudiKernel/ServiceHandle.h"
-
-namespace AthONNX {
-
-/************************Flattening of Input Data***************************/
-/***************************************************************************/ 
-
- template<typename T>
-
- inline std::vector<T> FlattenInput_multiD_1D( std::vector<std::vector<T>> features){
-    // 1. Compute the total size required.
-    int total_size = 0;
-    for (const auto& feature : features) total_size += feature.size();
-    
-    // 2. Create a vector to hold the data.
-    std::vector<T> Flatten1D;
-    Flatten1D.reserve(total_size);
-
-    // 3. Fill it
-    for (const auto& feature : features)
-      for (const auto& elem : feature)
-        Flatten1D.push_back(elem);
- 
-   return Flatten1D;
-  }
-
-/*********************************Creation of ORT tensor*********************************/
-/****************************************************************************************/
-
- template<typename T>
- inline Ort::Value TensorCreator(std::vector<T>& flattenData, std::vector<int64_t>& input_node_dims ){ 
-    /************** Create input tensor object from input data values to feed into your model *********************/
-    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); 
-    Ort::Value input_tensor = Ort::Value::CreateTensor<T>(memory_info, 
-                                                                  flattenData.data(), 
-                                                                  flattenData.size(),  /*** 1x28x28 = 784 ***/ 
-                                                                  input_node_dims.data(), 
-                                                                  input_node_dims.size());     /*** [1, 28, 28] = 3 ***/
-    return input_tensor;
-   }
-
-
-/*********************************Creation of ORT Session*********************************/
-/*****************************************************************************************/
-
- //template<typename T>
- inline std::unique_ptr< Ort::Session > CreateORTSession(const std::string& modelFile, bool withCUDA=false){
-   
-    // Set up the ONNX Runtime session.
-    Ort::SessionOptions sessionOptions;
-    sessionOptions.SetIntraOpNumThreads( 1 );
-    if (withCUDA) {
-      ;  // does nothing for now until we have a GPU enabled build
-    }
-    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
-
-    ServiceHandle< IONNXRuntimeSvc > svc("AthONNX::ONNXRuntimeSvc",
-                                              "AthONNX::ONNXRuntimeSvc");
-
-    return std::make_unique<Ort::Session>( svc->env(),
-                                     modelFile.c_str(),
-                                     sessionOptions );
-   }
-
-
-/*********************************Input Node Structure of Model*********************************/
-/***********************************************************************************************/
-
-  inline  std::tuple<std::vector<int64_t>, std::vector<const char*> > GetInputNodeInfo(const std::unique_ptr< Ort::Session >& session){
-    
-    std::vector<int64_t> input_node_dims;
-    size_t num_input_nodes = session->GetInputCount();
-    std::vector<const char*> input_node_names(num_input_nodes);
-    Ort::AllocatorWithDefaultOptions allocator;
-    for( std::size_t i = 0; i < num_input_nodes; i++ ) {
-        char* input_name = session->GetInputNameAllocated(i, allocator).release();
-        input_node_names[i] = input_name;
-        Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
-        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
-   
-        input_node_dims = tensor_info.GetShape();
-     }
-     return std::make_tuple(input_node_dims, input_node_names); 
-  }
-
-/*********************************Output Node Structure of Model*********************************/
-/***********************************************************************************************/
-
-  inline  std::tuple<std::vector<int64_t>, std::vector<const char*> > GetOutputNodeInfo(const std::unique_ptr< Ort::Session >& session){
-     
-     //output nodes
-     std::vector<int64_t> output_node_dims;
-     size_t num_output_nodes = session->GetOutputCount();
-     std::vector<const char*> output_node_names(num_output_nodes);
-     Ort::AllocatorWithDefaultOptions allocator;
-
-      for( std::size_t i = 0; i < num_output_nodes; i++ ) {
-        char* output_name = session->GetOutputNameAllocated(i, allocator).release();
-        output_node_names[i] = output_name;
-
-        Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
-        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
-
-        output_node_dims = tensor_info.GetShape();
-      }
-    return std::make_tuple(output_node_dims, output_node_names);
-   }
-
-
-/*********************************Running Inference through ORT*********************************/
-/***********************************************************************************************/
-  inline float* Inference(const std::unique_ptr< Ort::Session >& session,std::vector<const char*>& input_node_names, Ort::Value& input_tensor, std::vector<const char*>& output_node_names){
-     auto output_tensor =  session->Run(Ort::RunOptions{nullptr},
-                                             input_node_names.data(),
-                                             &input_tensor,
-                                             input_node_names.size(),      /** 1, flatten_input:0 **/
-                                             output_node_names.data(),
-                                             output_node_names.size());    /** 1, dense_1/Softmax:0 **/ 
- 
-     //assert(output_tensor.size() == output_node_names.size() && output_tensor.front().IsTensor());
-     // Get pointer to output tensor float values
-     float* floatarr = output_tensor.front().GetTensorMutableData<float>();
-     return floatarr;
-  }
-
-  void InferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
-    const std::vector<const char*>& inputNames,
-    const std::vector<Ort::Value>& inputData,
-    const std::vector<const char*>& outputNames,
-    const std::vector<Ort::Value>& outputData){
-    
-    if (inputNames.empty()) {
-        throw std::runtime_error("Onnxruntime input data maping cannot be empty");
-    }
-    assert(inputNames.size() == inputData.size());
-
-    Ort::IoBinding iobinding(*session);
-    for(size_t idx = 0; idx < inputNames.size(); ++idx){
-        iobinding.BindInput(inputNames[idx], inputData[idx]);
-    }
-
-
-    for(size_t idx = 0; idx < outputNames.size(); ++idx){
-        iobinding.BindOutput(outputNames[idx], outputData[idx]);
-    }
-
-    session->Run(Ort::RunOptions{nullptr}, iobinding);
-  }
-
-}
-#endif
diff --git a/Control/AthOnnxruntimeUtils/CMakeLists.txt b/Control/AthOnnxruntimeUtils/CMakeLists.txt
deleted file mode 100644
index d8e08a2f21f2..000000000000
--- a/Control/AthOnnxruntimeUtils/CMakeLists.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
-
-# Declare the package's name.
-atlas_subdir( AthOnnxruntimeUtils )
-
-# External dependencies.
-find_package( onnxruntime )
-
-# Component(s) in the package.
-atlas_add_library( AthOnnxruntimeUtilsLib
-   INTERFACE
-   PUBLIC_HEADERS AthOnnxruntimeUtils
-   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaKernel GaudiKernel AsgServicesLib )
-- 
GitLab


From a6fbcebb2bcc1276f33aa00d1406921eaea9979a Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 9 Jan 2024 12:55:14 -0800
Subject: [PATCH 04/18] address Tadej comments

---
 .../src/OnnxRuntimeSessionTool.cxx            | 15 +++++------
 .../AthOnnxComps/src/OnnxRuntimeSessionTool.h |  5 +---
 .../AthOnnxConfig/python/OnnxRuntimeFlags.py  |  4 +--
 .../python/OnnxRuntimeSessionConfig.py        |  8 +++---
 .../IOnnxRuntimeSessionTool.h                 |  2 +-
 ...onTool.ipp => IOnnxRuntimeSessionTool.icc} |  3 ---
 .../AthOnnxUtils/AthOnnxUtils/OnnxUtils.h     | 10 +++----
 .../AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx    | 14 +++++-----
 .../AthExOnnxRuntime/CMakeLists.txt           |  2 +-
 .../share/AthExOnnxRuntime_jobOptions.py      | 27 -------------------
 .../{share => tests}/AthExOnnxRuntime_CA.py   |  7 ++---
 11 files changed, 32 insertions(+), 65 deletions(-)
 rename Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/{IOnnxRuntimeSessionTool.ipp => IOnnxRuntimeSessionTool.icc} (89%)
 delete mode 100644 Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py
 rename Control/AthenaExamples/AthExOnnxRuntime/{share => tests}/AthExOnnxRuntime_CA.py (82%)

diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx
index 101b3edae642..9a1fd88e2c0d 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx
@@ -24,8 +24,7 @@ StatusCode AthOnnx::OnnxRuntimeSessionTool::initialize()
 
 StatusCode AthOnnx::OnnxRuntimeSessionTool::finalize()
 {
-    StatusCode sc = AlgTool::finalize();
-    return sc;
+    return StatusCode::SUCCESS;
 }
 
 StatusCode AthOnnx::OnnxRuntimeSessionTool::createModel()
@@ -68,8 +67,8 @@ StatusCode AthOnnx::OnnxRuntimeSessionTool::getNodeInfo()
     m_numInputs = m_session->GetInputCount();
     m_numOutputs = m_session->GetOutputCount();
 
-    AthOnnx::GetInputNodeInfo(m_session, m_inputShapes, m_inputNodeNames);
-    AthOnnx::GetOutputNodeInfo(m_session, m_outputShapes, m_outputNodeNames);
+    AthOnnx::getInputNodeInfo(m_session, m_inputShapes, m_inputNodeNames);
+    AthOnnx::getOutputNodeInfo(m_session, m_outputShapes, m_outputNodeNames);
 
     return StatusCode::SUCCESS;
 }
@@ -97,7 +96,7 @@ void AthOnnx::OnnxRuntimeSessionTool::setBatchSize(int64_t batchSize)
 
 int64_t AthOnnx::OnnxRuntimeSessionTool::getBatchSize(int64_t inputDataSize, int idx) const
 {
-    auto tensorSize = AthOnnx::GetTensorSize(m_inputShapes[idx]);
+    auto tensorSize = AthOnnx::getTensorSize(m_inputShapes[idx]);
     if (tensorSize < 0) {
         return inputDataSize / abs(tensorSize);
     } else {
@@ -111,7 +110,7 @@ StatusCode AthOnnx::OnnxRuntimeSessionTool::inference(std::vector<Ort::Value>& i
     assert (outputTensors.size() == m_numOutputs);
 
     // Run the model.
-    AthOnnx::InferenceWithIOBinding(
+    AthOnnx::inferenceWithIOBinding(
             m_session, 
             m_inputNodeNames, inputTensors, 
             m_outputNodeNames, outputTensors);
@@ -151,12 +150,12 @@ int AthOnnx::OnnxRuntimeSessionTool::getNumOutputs() const
 
 int64_t AthOnnx::OnnxRuntimeSessionTool::getInputSize(int idx) const
 {
-    return AthOnnx::GetTensorSize(m_inputShapes[idx]);
+    return AthOnnx::getTensorSize(m_inputShapes[idx]);
 }
 
 int64_t AthOnnx::OnnxRuntimeSessionTool::getOutputSize(int idx) const
 {
-    return AthOnnx::GetTensorSize(m_outputShapes[idx]);
+    return AthOnnx::getTensorSize(m_outputShapes[idx]);
 }
 
 void AthOnnx::OnnxRuntimeSessionTool::printModelInfo() const
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h
index 56d040d1d8cb..fd8d954ddd11 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h
@@ -59,12 +59,9 @@ namespace AthOnnx {
         StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
 
         std::unique_ptr<Ort::Session> m_session;
-        int m_numInputs;
-        int m_numOutputs;
         std::vector<const char*> m_inputNodeNames;
         std::vector<const char*> m_outputNodeNames;
-        // std::vector<std::vector<int64_t> > m_inputShapes;
-        // std::vector<std::vector<int64_t> > m_outputShapes;
+
     };
 } // namespace AthOnnx
 
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
index 8df14ce012c1..6545bae01a6a 100644
--- a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
+++ b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
@@ -3,7 +3,7 @@
 from AthenaConfiguration.AthConfigFlags import AthConfigFlags
 from AthenaConfiguration.Enums import FlagEnum
 
-class OnnxRuntimeExecutionProvider(FlagEnum):
+class OnnxRuntimeType(FlagEnum):
     CPU = 'CPU'
     CUDA = 'CUDA'
 # possible future backends. Uncomment when implemented.
@@ -19,7 +19,7 @@ class OnnxRuntimeExecutionProvider(FlagEnum):
 def createOnnxRuntimeFlags():
     icf = AthConfigFlags()
 
-    icf.addFlag("AthOnnx.ExecutionProvider", "CPU")
+    icf.addFlag("AthOnnx.ExecutionProvider", OnnxRuntimeType.CPU, enum=OnnxRuntimeType)
 
     return icf
 
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py
index 2fc06c26f07f..779969b80af3 100644
--- a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py
+++ b/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py
@@ -2,12 +2,12 @@
 
 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
 from AthenaConfiguration.ComponentFactory import CompFactory
-from AthOnnxConfig.OnnxRuntimeFlags import OnnxRuntimeExecutionProvider as OrtEP
+from AthOnnxConfig.OnnxRuntimeFlags import OnnxRuntimeType
 from typing import Optional
 
 def OnnxRuntimeSessionToolCfg(flags,
                               model_fname: str = None, 
-                              execution_provider: Optional[str] = None, 
+                              execution_provider: Optional[OnnxRuntimeType] = None, 
                               name="OnnxRuntimeSessionTool", **kwargs):
     """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
     
@@ -17,9 +17,9 @@ def OnnxRuntimeSessionToolCfg(flags,
         raise ValueError("model_fname must be specified")
     
     execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
-    if execution_provider == OrtEP.CPU.name:
+    if execution_provider is OnnxRuntimeType.CPU:
         acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionTool(name, ModelFileName=model_fname, **kwargs))
-    elif execution_provider == OrtEP.CUDA.name:
+    elif execution_provider is OnnxRuntimeType.CUDA:
         acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, ModelFileName=model_fname,  **kwargs))
     else:
         raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider)
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
index 315c775e3b6a..c7936c16a828 100644
--- a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
@@ -131,7 +131,7 @@ namespace AthOnnx {
 
     };
 
-    #include "IOnnxRuntimeSessionTool.ipp"
+    #include "IOnnxRuntimeSessionTool.icc"
 } // namespace AthOnnx
 
 #endif
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.ipp b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.icc
similarity index 89%
rename from Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.ipp
rename to Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.icc
index e46dfed831dd..6e858c525c48 100644
--- a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.ipp
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.icc
@@ -1,5 +1,4 @@
 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
-
 template <typename T>
 Ort::Value AthOnnx::IOnnxRuntimeSessionTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
 {
@@ -24,7 +23,6 @@ template <typename T>
 StatusCode AthOnnx::IOnnxRuntimeSessionTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
 {
     if (idx >= m_numInputs || idx < 0) {
-        // ATH_MSG_ERROR("Need " << m_numInputs << " tensors; but adding "<< idx << " tensor.");
         return StatusCode::FAILURE;
     }
 
@@ -36,7 +34,6 @@ template <typename T>
 StatusCode AthOnnx::IOnnxRuntimeSessionTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
 {
     if (idx >= m_numOutputs || idx < 0) {
-        // ATH_MSG_ERROR("Need " << m_numOutputs << " tensors; but adding "<< idx << " tensor.");
         return StatusCode::FAILURE;
     }
     auto tensorSize = std::accumulate(m_outputShapes[idx].begin(), m_outputShapes[idx].end(), 1, std::multiplies<int64_t>());
diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
index 8ab6cf2cb8d5..702a658e475b 100644
--- a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
+++ b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
@@ -37,7 +37,7 @@ inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& f
 // @param dataShape The shape of the input data. Note that there may be multiple inputs.
 // @param nodeNames The names of the input nodes in the computational graph.
 // the dataShape and nodeNames will be updated.
-void GetInputNodeInfo(
+void getInputNodeInfo(
     const std::unique_ptr< Ort::Session >& session,
     std::vector<std::vector<int64_t> >& dataShape, 
     std::vector<const char*>& nodeNames);
@@ -47,13 +47,13 @@ void GetInputNodeInfo(
 // @param dataShape The shape of the output data.
 // @param nodeNames The names of the output nodes in the computational graph.
 // the dataShape and nodeNames will be updated.
-void GetOutputNodeInfo(
+void getOutputNodeInfo(
     const std::unique_ptr< Ort::Session >& session,
     std::vector<std::vector<int64_t> >& dataShape, 
     std::vector<const char*>& nodeNames);
 
 // Heleper function to get node info
-void GetNodeInfo(
+void getNodeInfo(
     const std::unique_ptr< Ort::Session >& session,
     std::vector<std::vector<int64_t> >& dataShape, 
     std::vector<const char*>& nodeNames,
@@ -62,11 +62,11 @@ void GetNodeInfo(
 
 // @brief to count the total number of elements in a tensor
 // They are useful for reserving spaces for the output data.
-int64_t GetTensorSize(const std::vector<int64_t>& dataShape);
+int64_t getTensorSize(const std::vector<int64_t>& dataShape);
 
 // Inference with IO binding. Better for performance, particularly for GPUs.
 // See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html
-void InferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
+void inferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
     const std::vector<const char*>& inputNames,
     const std::vector<Ort::Value>& inputData,
     const std::vector<const char*>& outputNames,
diff --git a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
index 1de3fe27f341..0eec51cf0cba 100644
--- a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
+++ b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
@@ -5,7 +5,7 @@
 
 namespace AthOnnx {
 
-void GetNodeInfo(
+void getNodeInfo(
     const std::unique_ptr< Ort::Session >& session,
     std::vector<std::vector<int64_t> >& dataShape, 
     std::vector<const char*>& nodeNames,
@@ -29,23 +29,23 @@ void GetNodeInfo(
      }
 }
 
-void GetInputNodeInfo(
+void getInputNodeInfo(
     const std::unique_ptr< Ort::Session >& session,
     std::vector<std::vector<int64_t> >& dataShape, 
     std::vector<const char*>& nodeNames
 ){
-    GetNodeInfo(session, dataShape, nodeNames, true);
+    getNodeInfo(session, dataShape, nodeNames, true);
 }
 
-void GetOutputNodeInfo(
+void getOutputNodeInfo(
     const std::unique_ptr< Ort::Session >& session,
     std::vector<std::vector<int64_t> >& dataShape,
     std::vector<const char*>& nodeNames
 ) {
-    GetNodeInfo(session, dataShape, nodeNames, false);
+    getNodeInfo(session, dataShape, nodeNames, false);
 }
 
-void InferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
+void inferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
     const std::vector<const char*>& inputNames,
     const std::vector<Ort::Value>& inputData,
     const std::vector<const char*>& outputNames,
@@ -69,7 +69,7 @@ void InferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session,
     session->Run(Ort::RunOptions{nullptr}, iobinding);
 }
 
-int64_t GetTensorSize(const std::vector<int64_t>& dataShape){
+int64_t getTensorSize(const std::vector<int64_t>& dataShape){
     int64_t size = 1;
     for (const auto& dim : dataShape) {
             size *= dim;
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index ee1986f4646b..24c3ea72cdc2 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -14,4 +14,4 @@ atlas_add_component( AthExOnnxRuntime
 )
 
 # Install files from the package.
-atlas_install_joboptions( share/*.py )
+atlas_install_joboptions( tests/*.py )
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py b/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py
deleted file mode 100644
index a35ab1f40e45..000000000000
--- a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
-
-# Set up / access the algorithm sequence.
-from AthenaCommon.AlgSequence import AlgSequence
-algSequence = AlgSequence()
-
-# Set up the job.
-from AthExOnnxRuntime.AthExOnnxRuntimeConf import AthONNX__EvaluateModel
-from AthOnnxruntimeService.AthOnnxruntimeServiceConf import AthONNX__ONNXRuntimeSvc
-
-from AthenaCommon.AppMgr import ServiceMgr
-ServiceMgr += AthONNX__ONNXRuntimeSvc( OutputLevel = DEBUG )
-algSequence += AthONNX__EvaluateModel("AthONNX")
-
-# Get a	random no. between 0 to	10k for	test sample
-from random import randint
-
-AthONNX = algSequence.AthONNX
-AthONNX.TestSample = randint(0, 9999)
-AthONNX.DoBatches = False
-AthONNX.NumberOfBatches = 1
-AthONNX.SizeOfBatch = 1
-AthONNX.OutputLevel = DEBUG
-AthONNX.UseCUDA = False
-
-# Run for 10 "events".
-theApp.EvtMax = 2
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_CA.py b/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_CA.py
similarity index 82%
rename from Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_CA.py
rename to Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_CA.py
index 4a2de47d9df9..16c0745cac23 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_CA.py
+++ b/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_CA.py
@@ -2,6 +2,8 @@
 
 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
 from AthenaConfiguration.ComponentFactory import CompFactory
+from AthenaCommon import Constants
+from AthOnnxConfig.OnnxRuntimeFlags import OnnxRuntimeType 
 
 
 def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
@@ -14,7 +16,7 @@ def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
     kwargs.setdefault("OnnxRuntimeSessionTool", acc.popToolsAndMerge(
         OnnxRuntimeSessionToolCfg(flags, 
                                   model_fname, 
-#                                  execution_provider="CPU",  # optionally override flags.AthOnnx.ExecutionProvider, default is CPU
+                                #  execution_provider=OnnxRuntimeType.CUDA,  # optionally override flags.AthOnnx.ExecutionProvider, default is CPU
                                   **kwargs)
     ))
 
@@ -24,7 +26,6 @@ def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
     return acc
 
 if __name__ == "__main__":
-    from AthenaCommon import Constants
     from AthenaCommon.Logging import log as msg
     from AthenaConfiguration.AllConfigFlags import initConfigFlags
     from AthenaConfiguration.MainServicesConfig import MainServicesCfg
@@ -32,7 +33,7 @@ if __name__ == "__main__":
     msg.setLevel(Constants.DEBUG)
 
     flags = initConfigFlags()
-    flags.AthOnnx.ExecutionProvider = "CPU"
+    flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
     flags.lock()
 
     acc = MainServicesCfg(flags)
-- 
GitLab


From d0928e2a952b51876fa5e9e930ddfc7924828ca6 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 9 Jan 2024 13:04:22 -0800
Subject: [PATCH 05/18] additional comments

---
 .../AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx    | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
index 29aff6d1c530..c151b858355c 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
@@ -73,11 +73,6 @@ namespace AthOnnx {
     ATH_CHECK( m_onnxTool.retrieve() );
     m_onnxTool->printModelInfo();
 
-    // change the batch size
-   //  ATH_MSG_INFO("Setting batch size to "<<m_batchSize);
-   //  m_onnxTool->setBatchSize(m_batchSize);
-   //  m_onnxTool->printModelInfo();
-
       /*****
        The combination of no. of batches and batch size shouldn't cross 
        the total smple size which is 10000 for this example
-- 
GitLab


From cc52c20cd4e1d0bad22f46a43c3838b6d268ad95 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 9 Jan 2024 15:22:09 -0800
Subject: [PATCH 06/18] Revert "remove redundant packages"

This reverts commit 0ac0c9a247c79b12b941561809c69e0fb8592fab.
---
 .../ATLAS_CHECK_THREAD_SAFETY                 |   1 +
 .../AthOnnxruntimeServiceDict.h               |  12 ++
 .../AthOnnxruntimeService/IONNXRuntimeSvc.h   |  41 +++++
 .../AthOnnxruntimeService/ONNXRuntimeSvc.h    |  58 ++++++
 .../AthOnnxruntimeService/selection.xml       |   6 +
 Control/AthOnnxruntimeService/CMakeLists.txt  |  30 ++++
 Control/AthOnnxruntimeService/README.md       |   7 +
 .../Root/ONNXRuntimeSvc.cxx                   |  38 ++++
 .../AthOnnxruntimeService_entries.cxx         |   8 +
 .../ATLAS_CHECK_THREAD_SAFETY                 |   1 +
 .../AthOnnxruntimeUtils/OnnxUtils.h           | 168 ++++++++++++++++++
 Control/AthOnnxruntimeUtils/CMakeLists.txt    |  14 ++
 12 files changed, 384 insertions(+)
 create mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY
 create mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
 create mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h
 create mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h
 create mode 100644 Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml
 create mode 100644 Control/AthOnnxruntimeService/CMakeLists.txt
 create mode 100644 Control/AthOnnxruntimeService/README.md
 create mode 100644 Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx
 create mode 100644 Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx
 create mode 100644 Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY
 create mode 100644 Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
 create mode 100644 Control/AthOnnxruntimeUtils/CMakeLists.txt

diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY b/Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY
new file mode 100644
index 000000000000..8b969b1a5cf5
--- /dev/null
+++ b/Control/AthOnnxruntimeService/AthOnnxruntimeService/ATLAS_CHECK_THREAD_SAFETY
@@ -0,0 +1 @@
+Control/AthOnnxruntimeService
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h b/Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
new file mode 100644
index 000000000000..e8453362a499
--- /dev/null
+++ b/Control/AthOnnxruntimeService/AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
@@ -0,0 +1,12 @@
+/*
+  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
+*/
+
+
+#ifndef ATHONNXRUNTIMESERVICE__ATHONNXRUNTIMESERVICE_DICT_H
+#define ATHONNXRUNTIMESERVICE__ATHONNXRUNTIMESERVICE_DICT_H
+
+#include "AthOnnxruntimeService/ONNXRuntimeSvc.h"
+
+#endif
+
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h b/Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h
new file mode 100644
index 000000000000..cae4f26e3049
--- /dev/null
+++ b/Control/AthOnnxruntimeService/AthOnnxruntimeService/IONNXRuntimeSvc.h
@@ -0,0 +1,41 @@
+// Dear emacs, this is -*- c++ -*-
+// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+#ifndef ATHEXONNXRUNTIME_IONNXRUNTIMESVC_H
+#define ATHEXONNXRUNTIME_IONNXRUNTIMESVC_H
+
+// Gaudi include(s).
+#include <AsgServices/IAsgService.h>
+
+// ONNX include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+
+/// Namespace holding all of the ONNX Runtime example code
+namespace AthONNX {
+
+   //class IAsgService
+   /// Service used for managing global objects used by ONNX Runtime
+   ///
+   /// In order to allow multiple clients to use ONNX Runtime at the same
+   /// time, this service is used to manage the objects that must only
+   /// be created once in the Athena process.
+   ///
+   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
+   ///
+   class IONNXRuntimeSvc : virtual public asg::IAsgService{
+
+   public:
+      /// Virtual destructor, to make vtable happy
+      virtual ~IONNXRuntimeSvc() = default;
+
+      /// Declare the interface that this class provides
+      DeclareInterfaceID (AthONNX::IONNXRuntimeSvc, 1, 0);
+
+      /// Return the ONNX Runtime environment object
+      virtual Ort::Env& env() const = 0;
+
+   }; // class IONNXRuntimeSvc
+
+} // namespace AthONNX
+
+#endif // ATHEXONNXRUNTIME_IONNXRUNTIMESVC_H
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h b/Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h
new file mode 100644
index 000000000000..475c36f9dfd6
--- /dev/null
+++ b/Control/AthOnnxruntimeService/AthOnnxruntimeService/ONNXRuntimeSvc.h
@@ -0,0 +1,58 @@
+// Dear emacs, this is -*- c++ -*-
+// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+#ifndef ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
+#define ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
+
+// Local include(s).
+#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
+
+// Framework include(s).
+#include <AsgServices/AsgService.h>
+
+// ONNX include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+// System include(s).
+#include <memory>
+
+namespace AthONNX {
+
+   /// Service implementing @c AthONNX::IONNXRuntimeSvc
+   ///
+   /// This is a very simple implementation, just managing the lifetime
+   /// of some ONNX Runtime C++ objects.
+   ///
+   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
+   ///
+   class ONNXRuntimeSvc : public asg::AsgService, virtual public IONNXRuntimeSvc {
+
+   public:
+
+      /// @name Function(s) inherited from @c Service
+      /// @{
+      ONNXRuntimeSvc (const std::string& name, ISvcLocator* svc);
+
+      /// Function initialising the service
+      virtual StatusCode initialize() override;
+      /// Function finalising the service
+      virtual StatusCode finalize() override;
+
+      /// @}
+
+      /// @name Function(s) inherited from @c AthONNX::IONNXRuntimeSvc
+      /// @{
+
+      /// Return the ONNX Runtime environment object
+      virtual Ort::Env& env() const override;
+
+      /// @}
+
+   private:
+      /// Global runtime environment for ONNX Runtime
+      std::unique_ptr< Ort::Env > m_env;
+
+   }; // class ONNXRuntimeSvc
+
+} // namespace AthONNX
+
+#endif // ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
diff --git a/Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml b/Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml
new file mode 100644
index 000000000000..7ab8a4333797
--- /dev/null
+++ b/Control/AthOnnxruntimeService/AthOnnxruntimeService/selection.xml
@@ -0,0 +1,6 @@
+<lcgdict>
+
+   <class name="AthONNX::IONNXRuntimeSvc" />
+   <class name="AthONNX::ONNXRuntimeSvc" />
+
+</lcgdict>
diff --git a/Control/AthOnnxruntimeService/CMakeLists.txt b/Control/AthOnnxruntimeService/CMakeLists.txt
new file mode 100644
index 000000000000..ea7ee02c224c
--- /dev/null
+++ b/Control/AthOnnxruntimeService/CMakeLists.txt
@@ -0,0 +1,30 @@
+# Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package's name.
+atlas_subdir( AthOnnxruntimeService )
+
+# External dependencies.
+find_package( onnxruntime )
+
+# Component(s) in the package.
+atlas_add_library( AthOnnxruntimeServiceLib
+   AthOnnxruntimeService/*.h Root/*.cxx
+   PUBLIC_HEADERS AthOnnxruntimeService
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgServicesLib)
+
+if (XAOD_STANDALONE)
+atlas_add_dictionary( AthOnnxruntimeServiceDict
+   AthOnnxruntimeService/AthOnnxruntimeServiceDict.h
+   AthOnnxruntimeService/selection.xml
+   LINK_LIBRARIES AthOnnxruntimeServiceLib )
+endif ()
+
+if (NOT XAOD_STANDALONE)
+  atlas_add_component( AthOnnxruntimeService
+     src/*.h src/*.cxx src/components/*.cxx
+     INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+     LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaBaseComps GaudiKernel AsgServicesLib)
+endif ()
+
+
diff --git a/Control/AthOnnxruntimeService/README.md b/Control/AthOnnxruntimeService/README.md
new file mode 100644
index 000000000000..7ef7a41b6910
--- /dev/null
+++ b/Control/AthOnnxruntimeService/README.md
@@ -0,0 +1,7 @@
+# ONNXRUNTIMESERVICE
+
+This package is meant to accommodate all onnxruntimeService related services
+e.g. `IONNXRuntimeSvc.h` and `ONNXRuntimeSvc.*`
+
+To observe its usecases please check
+`https://gitlab.cern.ch/atlas/athena/-/blob/main/Control/AthenaExamples/AthExOnnxRuntime/src/CxxApiAlgorithm.h#L66`
diff --git a/Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx b/Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx
new file mode 100644
index 000000000000..6e37a9074ae9
--- /dev/null
+++ b/Control/AthOnnxruntimeService/Root/ONNXRuntimeSvc.cxx
@@ -0,0 +1,38 @@
+// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+
+// Local include(s).
+#include "AthOnnxruntimeService/ONNXRuntimeSvc.h"
+
+namespace AthONNX {
+  ONNXRuntimeSvc::ONNXRuntimeSvc(const std::string& name, ISvcLocator* svc) :
+      asg::AsgService(name, svc)
+   {
+     declareServiceInterface<AthONNX::IONNXRuntimeSvc>();
+   }
+   StatusCode ONNXRuntimeSvc::initialize() {
+
+      // Create the environment object.
+      m_env = std::make_unique< Ort::Env >( ORT_LOGGING_LEVEL_WARNING,
+                                            name().c_str() );
+      ATH_MSG_DEBUG( "Ort::Env object created" );
+
+      // Return gracefully.
+      return StatusCode::SUCCESS;
+   }
+
+   StatusCode ONNXRuntimeSvc::finalize() {
+
+      // Dekete the environment object.
+      m_env.reset();
+      ATH_MSG_DEBUG( "Ort::Env object deleted" );
+
+      // Return gracefully.
+      return StatusCode::SUCCESS;
+   }
+
+   Ort::Env& ONNXRuntimeSvc::env() const {
+
+      return *m_env;
+   }
+
+} // namespace AthONNX
diff --git a/Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx b/Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx
new file mode 100644
index 000000000000..fcc3ac20e682
--- /dev/null
+++ b/Control/AthOnnxruntimeService/src/components/AthOnnxruntimeService_entries.cxx
@@ -0,0 +1,8 @@
+// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+
+// Local include(s).
+#include <AthOnnxruntimeService/ONNXRuntimeSvc.h>
+
+// Declare the package's components.
+DECLARE_COMPONENT( AthONNX::ONNXRuntimeSvc )
+
diff --git a/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY b/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY
new file mode 100644
index 000000000000..584871312ff4
--- /dev/null
+++ b/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/ATLAS_CHECK_THREAD_SAFETY
@@ -0,0 +1 @@
+Control/AthOnnxruntimeUtils
diff --git a/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h b/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
new file mode 100644
index 000000000000..19f24b7a83f9
--- /dev/null
+++ b/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h
@@ -0,0 +1,168 @@
+// Dear emacs, this is -*- c++ -*-
+// Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+#ifndef ONNX_UTILS_H
+#define ONNX_UTILS_H
+
+#include <string>
+#include <iostream> 
+#include <fstream>
+#include <arpa/inet.h>
+#include <vector>
+#include <iterator>
+#include <tuple>
+
+// ONNX Runtime include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+// Local include(s).
+#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
+#include "GaudiKernel/ServiceHandle.h"
+
+namespace AthONNX {
+
+/************************Flattening of Input Data***************************/
+/***************************************************************************/ 
+
+ template<typename T>
+
+ inline std::vector<T> FlattenInput_multiD_1D( std::vector<std::vector<T>> features){
+    // 1. Compute the total size required.
+    int total_size = 0;
+    for (const auto& feature : features) total_size += feature.size();
+    
+    // 2. Create a vector to hold the data.
+    std::vector<T> Flatten1D;
+    Flatten1D.reserve(total_size);
+
+    // 3. Fill it
+    for (const auto& feature : features)
+      for (const auto& elem : feature)
+        Flatten1D.push_back(elem);
+ 
+   return Flatten1D;
+  }
+
+/*********************************Creation of ORT tensor*********************************/
+/****************************************************************************************/
+
+ template<typename T>
+ inline Ort::Value TensorCreator(std::vector<T>& flattenData, std::vector<int64_t>& input_node_dims ){ 
+    /************** Create input tensor object from input data values to feed into your model *********************/
+    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); 
+    Ort::Value input_tensor = Ort::Value::CreateTensor<T>(memory_info, 
+                                                                  flattenData.data(), 
+                                                                  flattenData.size(),  /*** 1x28x28 = 784 ***/ 
+                                                                  input_node_dims.data(), 
+                                                                  input_node_dims.size());     /*** [1, 28, 28] = 3 ***/
+    return input_tensor;
+   }
+
+
+/*********************************Creation of ORT Session*********************************/
+/*****************************************************************************************/
+
+ //template<typename T>
+ inline std::unique_ptr< Ort::Session > CreateORTSession(const std::string& modelFile, bool withCUDA=false){
+   
+    // Set up the ONNX Runtime session.
+    Ort::SessionOptions sessionOptions;
+    sessionOptions.SetIntraOpNumThreads( 1 );
+    if (withCUDA) {
+      ;  // does nothing for now until we have a GPU enabled build
+    }
+    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
+
+    ServiceHandle< IONNXRuntimeSvc > svc("AthONNX::ONNXRuntimeSvc",
+                                              "AthONNX::ONNXRuntimeSvc");
+
+    return std::make_unique<Ort::Session>( svc->env(),
+                                     modelFile.c_str(),
+                                     sessionOptions );
+   }
+
+
+/*********************************Input Node Structure of Model*********************************/
+/***********************************************************************************************/
+
+  inline  std::tuple<std::vector<int64_t>, std::vector<const char*> > GetInputNodeInfo(const std::unique_ptr< Ort::Session >& session){
+    
+    std::vector<int64_t> input_node_dims;
+    size_t num_input_nodes = session->GetInputCount();
+    std::vector<const char*> input_node_names(num_input_nodes);
+    Ort::AllocatorWithDefaultOptions allocator;
+    for( std::size_t i = 0; i < num_input_nodes; i++ ) {
+        char* input_name = session->GetInputNameAllocated(i, allocator).release();
+        input_node_names[i] = input_name;
+        Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
+        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+   
+        input_node_dims = tensor_info.GetShape();
+     }
+     return std::make_tuple(input_node_dims, input_node_names); 
+  }
+
+/*********************************Output Node Structure of Model*********************************/
+/***********************************************************************************************/
+
+  inline  std::tuple<std::vector<int64_t>, std::vector<const char*> > GetOutputNodeInfo(const std::unique_ptr< Ort::Session >& session){
+     
+     //output nodes
+     std::vector<int64_t> output_node_dims;
+     size_t num_output_nodes = session->GetOutputCount();
+     std::vector<const char*> output_node_names(num_output_nodes);
+     Ort::AllocatorWithDefaultOptions allocator;
+
+      for( std::size_t i = 0; i < num_output_nodes; i++ ) {
+        char* output_name = session->GetOutputNameAllocated(i, allocator).release();
+        output_node_names[i] = output_name;
+
+        Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
+        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+
+        output_node_dims = tensor_info.GetShape();
+      }
+    return std::make_tuple(output_node_dims, output_node_names);
+   }
+
+
+/*********************************Running Inference through ORT*********************************/
+/***********************************************************************************************/
+  inline float* Inference(const std::unique_ptr< Ort::Session >& session,std::vector<const char*>& input_node_names, Ort::Value& input_tensor, std::vector<const char*>& output_node_names){
+     auto output_tensor =  session->Run(Ort::RunOptions{nullptr},
+                                             input_node_names.data(),
+                                             &input_tensor,
+                                             input_node_names.size(),      /** 1, flatten_input:0 **/
+                                             output_node_names.data(),
+                                             output_node_names.size());    /** 1, dense_1/Softmax:0 **/ 
+ 
+     //assert(output_tensor.size() == output_node_names.size() && output_tensor.front().IsTensor());
+     // Get pointer to output tensor float values
+     float* floatarr = output_tensor.front().GetTensorMutableData<float>();
+     return floatarr;
+  }
+
+  void InferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
+    const std::vector<const char*>& inputNames,
+    const std::vector<Ort::Value>& inputData,
+    const std::vector<const char*>& outputNames,
+    const std::vector<Ort::Value>& outputData){
+    
+    if (inputNames.empty()) {
+        throw std::runtime_error("Onnxruntime input data maping cannot be empty");
+    }
+    assert(inputNames.size() == inputData.size());
+
+    Ort::IoBinding iobinding(*session);
+    for(size_t idx = 0; idx < inputNames.size(); ++idx){
+        iobinding.BindInput(inputNames[idx], inputData[idx]);
+    }
+
+
+    for(size_t idx = 0; idx < outputNames.size(); ++idx){
+        iobinding.BindOutput(outputNames[idx], outputData[idx]);
+    }
+
+    session->Run(Ort::RunOptions{nullptr}, iobinding);
+  }
+
+}
+#endif
diff --git a/Control/AthOnnxruntimeUtils/CMakeLists.txt b/Control/AthOnnxruntimeUtils/CMakeLists.txt
new file mode 100644
index 000000000000..d8e08a2f21f2
--- /dev/null
+++ b/Control/AthOnnxruntimeUtils/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package's name.
+atlas_subdir( AthOnnxruntimeUtils )
+
+# External dependencies.
+find_package( onnxruntime )
+
+# Component(s) in the package.
+atlas_add_library( AthOnnxruntimeUtilsLib
+   INTERFACE
+   PUBLIC_HEADERS AthOnnxruntimeUtils
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaKernel GaudiKernel AsgServicesLib )
-- 
GitLab


From aaba270ee3344353a4b4810a53d0af69ea6fc1f4 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Mon, 29 Jan 2024 09:07:00 +0100
Subject: [PATCH 07/18] seperate tools

---
 .../AthOnnxComps}/OnnxUtils.h                 |   4 +
 Control/AthOnnx/AthOnnxComps/CMakeLists.txt   |  12 ++
 .../python/OnnxRuntimeFlags.py                |   2 +-
 .../python/OnnxRuntimeSessionConfig.py        |   0
 .../python/OnnxRuntimeSvcConfig.py            |   0
 .../python/__init__.py                        |   0
 ...nTool.cxx => OnnxRuntimeInferenceTool.cxx} |   0
 ...ssionTool.h => OnnxRuntimeInferenceTool.h} |   0
 .../src/OnnxRuntimeSessionToolCPU.cxx         |  43 +++++++
 .../src/OnnxRuntimeSessionToolCPU.h           |  44 +++++++
 .../src/OnnxRuntimeSessionToolCUDA.cxx        |  71 ++++++++++-
 .../src/OnnxRuntimeSessionToolCUDA.h          |  28 ++--
 .../src/OnnxUtils.cxx                         |  14 +-
 Control/AthOnnx/AthOnnxConfig/CMakeLists.txt  |   8 --
 .../IOnnxRuntimeInferenceTool.h               | 116 +++++++++++++++++
 ...Tool.icc => IOnnxRuntimeInferenceTool.icc} |   0
 .../IOnnxRuntimeSessionTool.h                 | 120 ++----------------
 .../AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY    |   1 -
 Control/AthOnnx/AthOnnxUtils/CMakeLists.txt   |  16 ---
 .../AthExOnnxRuntime/CMakeLists.txt           |   6 +
 .../tests/AthExOnnxRuntimeTest.ref            |   0
 ...Runtime_CA.py => AthExOnnxRuntime_test.py} |   0
 22 files changed, 336 insertions(+), 149 deletions(-)
 rename Control/AthOnnx/{AthOnnxUtils/AthOnnxUtils => AthOnnxComps/AthOnnxComps}/OnnxUtils.h (94%)
 rename Control/AthOnnx/{AthOnnxConfig => AthOnnxComps}/python/OnnxRuntimeFlags.py (96%)
 rename Control/AthOnnx/{AthOnnxConfig => AthOnnxComps}/python/OnnxRuntimeSessionConfig.py (100%)
 rename Control/AthOnnx/{AthOnnxConfig => AthOnnxComps}/python/OnnxRuntimeSvcConfig.py (100%)
 rename Control/AthOnnx/{AthOnnxConfig => AthOnnxComps}/python/__init__.py (100%)
 rename Control/AthOnnx/AthOnnxComps/src/{OnnxRuntimeSessionTool.cxx => OnnxRuntimeInferenceTool.cxx} (100%)
 rename Control/AthOnnx/AthOnnxComps/src/{OnnxRuntimeSessionTool.h => OnnxRuntimeInferenceTool.h} (100%)
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
 create mode 100644 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
 rename Control/AthOnnx/{AthOnnxUtils => AthOnnxComps}/src/OnnxUtils.cxx (82%)
 delete mode 100644 Control/AthOnnx/AthOnnxConfig/CMakeLists.txt
 create mode 100644 Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h
 rename Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/{IOnnxRuntimeSessionTool.icc => IOnnxRuntimeInferenceTool.icc} (100%)
 delete mode 100644 Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY
 delete mode 100644 Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
 create mode 100644 Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntimeTest.ref
 rename Control/AthenaExamples/AthExOnnxRuntime/tests/{AthExOnnxRuntime_CA.py => AthExOnnxRuntime_test.py} (100%)

diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h b/Control/AthOnnx/AthOnnxComps/AthOnnxComps/OnnxUtils.h
similarity index 94%
rename from Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
rename to Control/AthOnnx/AthOnnxComps/AthOnnxComps/OnnxUtils.h
index 702a658e475b..fb78a07abaae 100644
--- a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
+++ b/Control/AthOnnx/AthOnnxComps/AthOnnxComps/OnnxUtils.h
@@ -73,5 +73,9 @@ void inferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session,
     const std::vector<Ort::Value>& outputData
 ); 
 
+// @brief Create a tensor from a vector of data and its shape.
+Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape);
+
+
 }
 #endif
diff --git a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
index fc9fb47fc120..8c8727e10131 100644
--- a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
+++ b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
@@ -15,3 +15,15 @@ atlas_add_component( AthOnnxComps
    LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib 
    AthOnnxInterfaces AthenaBaseComps GaudiKernel AthOnnxruntimeServiceLib AthOnnxUtilsLib
 )
+
+# Component(s) in the package.
+atlas_add_library( AthOnnxUtilsLib
+   AthOnnxComps/*.h 
+   src/*.cxx
+   PUBLIC_HEADERS AthOnnxComps
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaKernel GaudiKernel AsgServicesLib 
+)
+
+# install python modules
+atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} )
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py
similarity index 96%
rename from Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
rename to Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py
index 6545bae01a6a..6dcff38f38ec 100644
--- a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeFlags.py
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py
@@ -19,7 +19,7 @@ class OnnxRuntimeType(FlagEnum):
 def createOnnxRuntimeFlags():
     icf = AthConfigFlags()
 
-    icf.addFlag("AthOnnx.ExecutionProvider", OnnxRuntimeType.CPU, enum=OnnxRuntimeType)
+    icf.addFlag("AthOnnx.ExecutionProvider", OnnxRuntimeType.CPU, type=OnnxRuntimeType)
 
     return icf
 
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
similarity index 100%
rename from Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSessionConfig.py
rename to Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
diff --git a/Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSvcConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py
similarity index 100%
rename from Control/AthOnnx/AthOnnxConfig/python/OnnxRuntimeSvcConfig.py
rename to Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py
diff --git a/Control/AthOnnx/AthOnnxConfig/python/__init__.py b/Control/AthOnnx/AthOnnxComps/python/__init__.py
similarity index 100%
rename from Control/AthOnnx/AthOnnxConfig/python/__init__.py
rename to Control/AthOnnx/AthOnnxComps/python/__init__.py
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
similarity index 100%
rename from Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.cxx
rename to Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
similarity index 100%
rename from Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionTool.h
rename to Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
new file mode 100644
index 000000000000..9b43b6fe69d8
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
@@ -0,0 +1,43 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "OnnxRuntimeSessionToolCPU.h"
+
+AthOnnx::OnnxRuntimeSessionToolCPU::OnnxRuntimeSessionToolCPU(
+  const std::string& type, const std::string& name, const IInterface* parent )
+  : base_class( type, name, parent )
+{
+  declareInterface<IOnnxRuntimeSessionTool>(this);
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::initialize()
+{
+    // Get the Onnx Runtime service.
+    ATH_CHECK(m_onnxRuntimeSvc.retrieve());
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::finalize()
+{
+    StatusCode sc = AlgTool::finalize();
+    return sc;
+}
+
+std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCPU::createSession(
+    const std::string& modelFileName) const
+{
+    // Create the session options.
+    // TODO: Make this configurable.
+    // other threading options: https://onnxruntime.ai/docs/performance/tune-performance/threading.html
+    // 1) SetIntraOpNumThreads( 1 );
+    // 2) SetInterOpNumThreads( 1 );
+    // 3) SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
+
+    Ort::SessionOptions sessionOptions;
+    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
+
+    // Create the session.
+    return std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFileName.c_str(), sessionOptions);
+}
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
new file mode 100644
index 000000000000..d10e2c478ff6
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
@@ -0,0 +1,44 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef OnnxRuntimeSessionToolCPU_H
+#define OnnxRuntimeSessionToolCPU_H
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+#include "GaudiKernel/ServiceHandle.h"
+
+namespace AthOnnx {
+    // @class OnnxRuntimeSessionToolCPU
+    // 
+    // @brief Tool to create Onnx Runtime session with CPU backend
+    //
+    // @author Xiangyang Ju <xiangyang.ju@cern.ch>
+    class OnnxRuntimeSessionToolCPU :  public extends<AthAlgTool, IOnnxRuntimeSessionTool>
+    {
+        public:
+        /// Standard constructor
+        OnnxRuntimeSessionToolCPU( const std::string& type,
+                                const std::string& name,
+                                const IInterface* parent );
+        virtual ~OnnxRuntimeSessionToolCPU() = default;
+
+        /// Initialize the tool
+        virtual StatusCode initialize() override final;
+        /// Finalize the tool
+        virtual StatusCode finalize() override final;
+
+        /// Create Onnx Runtime session
+        virtual std::unique_ptr<Ort::Session> createSession(
+            const std::string& modelFileName) const override final;
+
+        protected:
+        OnnxRuntimeSessionToolCPU() = delete;
+        OnnxRuntimeSessionToolCPU(const OnnxRuntimeSessionToolCPU&) = delete;
+        OnnxRuntimeSessionToolCPU& operator=(const OnnxRuntimeSessionToolCPU&) = delete;
+
+        ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+    };
+}
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
index 66b05273ad8d..d1286794b919 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
@@ -1,15 +1,82 @@
 /*
   Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
+/*
+  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "OnnxRuntimeSessionToolCUDA.h"
+
+AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA(
+    const std::string& type, const std::string& name, const IInterface* parent )
+    : base_class( type, name, parent )
+{
+  declareInterface<IOnnxRuntimeSessionTool>(this);
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::initialize()
+{
+    // Get the Onnx Runtime service.
+    ATH_CHECK(m_onnxRuntimeSvc.retrieve());
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::finalize()
+{
+    StatusCode sc = AlgTool::finalize();
+    return sc;
+}
+
+std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCUDA::createSession(
+    const std::string& modelFileName) const
+{
+    ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
+
+    Ort::SessionOptions sessionOptions;
+    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
+    sessionOptions.DisablePerSessionThreads();    // use global thread pool.
+
+    // TODO: add more cuda options to the interface
+    // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
+    // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
+    OrtCUDAProviderOptions cuda_options;
+    cuda_options.device_id = m_deviceId;
+    cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
+    cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
+
+    // memorry arena options for cuda memory shrinkage
+    // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
+    if (m_enableMemoryShrinkage) {
+        Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
+        // other options are not available in this release.
+        // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
+        // arena_cfg.max_mem = 0;   // let ORT pick default max memory
+        // arena_cfg.arena_extend_strategy = 0;   // 0: kNextPowerOfTwo, 1: kSameAsRequested
+        // arena_cfg.initial_chunk_size_bytes = 1024;
+        // arena_cfg.max_dead_bytes_per_chunk = 0;
+        // arena_cfg.initial_growth_chunk_size_bytes = 256;
+        // arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
+
+        cuda_options.default_memory_arena_cfg = arena_cfg;
+    }
+
+    sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
+
+    // Create the session.
+    return std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFileName.c_str(), sessionOptions);
+}
+
+
 
 #include "OnnxRuntimeSessionToolCUDA.h"
 #include "AthOnnxUtils/OnnxUtils.h"
 #include <limits>
 
 AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA(
-  const std::string& type, const std::string& name, const IInterface* parent): AthOnnx::OnnxRuntimeSessionTool(type, name, parent)
+    const std::string& type, const std::string& name, const IInterface* parent )
+    : base_class( type, name, parent )
 {
-
+  declareInterface<IOnnxRuntimeSessionTool>(this);
 }
 
 StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::createSession() 
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
index 7c147350b0be..0963e99c4100 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
@@ -4,37 +4,47 @@
 #define OnnxRuntimeSessionToolCUDA_H
 
 #include "AthenaBaseComps/AthAlgTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
 #include "GaudiKernel/ServiceHandle.h"
-#include "OnnxRuntimeSessionTool.h"
 
 namespace AthOnnx {
     // @class OnnxRuntimeSessionToolCUDA
     // 
-    // @brief Tool to create Onnx Runtime session with CPU backend
+    // @brief Tool to create Onnx Runtime session with CUDA backend
     //
     // @author Xiangyang Ju <xiangyang.ju@cern.ch>
-    class OnnxRuntimeSessionToolCUDA :  public OnnxRuntimeSessionTool
+    class OnnxRuntimeSessionToolCUDA :  public extends<AthAlgTool, IOnnxRuntimeSessionTool>
     {
         public:
         /// Standard constructor
         OnnxRuntimeSessionToolCUDA( const std::string& type,
                                 const std::string& name,
                                 const IInterface* parent );
-        virtual ~OnnxRuntimeSessionToolCUDA() = default;      
+        virtual ~OnnxRuntimeSessionToolCUDA() = default;
+
+        /// Initialize the tool
+        virtual StatusCode initialize() override final;
+        /// Finalize the tool
+        virtual StatusCode finalize() override final;
+
+        /// Create Onnx Runtime session
+        virtual std::unique_ptr<Ort::Session> createSession(
+            const std::string& modelFileName) const override final;
 
         protected:
         OnnxRuntimeSessionToolCUDA() = delete;
         OnnxRuntimeSessionToolCUDA(const OnnxRuntimeSessionToolCUDA&) = delete;
         OnnxRuntimeSessionToolCUDA& operator=(const OnnxRuntimeSessionToolCUDA&) = delete;
 
-        protected:
-        StatusCode createSession();
-
         private:
+        /// The device ID to use.
         IntegerProperty m_deviceId{this, "DeviceId", 0};
         BooleanProperty m_enableMemoryShrinkage{this, "EnableMemoryShrinkage", false};
-        
+
+        /// runtime service
+        ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
     };
-} // namespace AthOnnx
+}
 
 #endif
diff --git a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
similarity index 82%
rename from Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
rename to Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
index 0eec51cf0cba..df0cf5ee077f 100644
--- a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
@@ -1,6 +1,6 @@
 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
-#include "AthOnnxUtils/OnnxUtils.h"
+#include "AthOnnxComps/OnnxUtils.h"
 #include <cassert>
 
 namespace AthOnnx {
@@ -77,5 +77,17 @@ int64_t getTensorSize(const std::vector<int64_t>& dataShape){
     return size;
 }
 
+Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape) const
+{
+    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); 
+
+    return Ort::Value::CreateTensor<float>(
+                                memoryInfo, 
+                                data.data(), 
+                                data.size(),  
+                                dataShape.data(), 
+                                dataShape.size());
+};
+
 
 } // namespace AthOnnx
\ No newline at end of file
diff --git a/Control/AthOnnx/AthOnnxConfig/CMakeLists.txt b/Control/AthOnnx/AthOnnxConfig/CMakeLists.txt
deleted file mode 100644
index b8492033a884..000000000000
--- a/Control/AthOnnx/AthOnnxConfig/CMakeLists.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
-
-# Declare the package name:
-atlas_subdir( AthOnnxConfig )
-
-
-# install python modules
-atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} )
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h
new file mode 100644
index 000000000000..467b25e9d76f
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h
@@ -0,0 +1,116 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+#ifndef AthOnnx_IOnnxRuntimeInferenceTool_H
+#define AthOnnx_IOnnxRuntimeInferenceTool_H
+
+// Gaudi include(s).
+#include "GaudiKernel/IAlgTool.h"
+#include <memory>
+
+#include <core/session/onnxruntime_cxx_api.h>
+
+
+namespace AthOnnx {
+    /**
+     * @class IOnnxRuntimeInferenceTool
+     * @brief Interface class for creating Onnx Runtime sessions.
+     * @details Interface class for creating Onnx Runtime sessions.
+     * It is thread safe, supports models with various number of inputs and outputs,
+     * supports models with dynamic batch size, and usess . It defines a standardized procedure to 
+     * perform Onnx Runtime inference. The procedure is as follows, assuming the tool `m_onnxTool` is created and initialized:
+     *    1. create input tensors from the input data: 
+     *      ```c++
+     *         std::vector<Ort::Value> inputTensors;
+     *         std::vector<float> inputData_1;   // The input data is filled by users, possibly from the event information. 
+     *         int64_t batchSize = m_onnxTool->getBatchSize(inputData_1.size(), 0);  // The batch size is determined by the input data size to support dynamic batch size.
+     *         m_onnxTool->addInput(inputTensors, inputData_1, 0, batchSize);
+     *         std::vector<int64_t> inputData_2;  // Some models may have multiple inputs. Add inputs one by one.
+     *         int64_t batchSize_2 = m_onnxTool->getBatchSize(inputData_2.size(), 1);
+     *         m_onnxTool->addInput(inputTensors, inputData_2, 1, batchSize_2);
+     *     ```
+     *    2. create output tensors:
+     *      ```c++
+     *          std::vector<Ort::Value> outputTensors;
+     *          std::vector<float> outputData;   // The output data will be filled by the onnx session.
+     *          m_onnxTool->addOutput(outputTensors, outputData, 0, batchSize);
+     *      ```
+     *   3. perform inference:
+     *     ```c++
+     *        m_onnxTool->inference(inputTensors, outputTensors);
+     *    ```
+     *   4. Model outputs will be automatically filled to outputData.
+     * 
+     * 
+     * @author Xiangyang Ju <xju@cern.ch>
+     */
+    class IOnnxRuntimeInferenceTool : virtual public IAlgTool 
+    {
+        public:
+
+        virtual ~IOnnxRuntimeInferenceTool() = default;
+        
+        // @name InterfaceID
+        DeclareInterfaceID(IOnnxRuntimeInferenceTool, 1, 0);
+
+        /**
+         * @brief set batch size. 
+         * @details If the model has dynamic batch size, 
+         *          the batchSize value will be set to both input shapes and output shapes
+         */ 
+        virtual void setBatchSize(int64_t batchSize) = 0;
+
+        /**
+         * @brief methods for determining batch size from the data size
+         * @param dataSize the size of the input data, like std::vector<T>::size()
+         * @param idx the index of the input node
+         * @return the batch size, which equals to dataSize / size of the rest dimensions.
+         */ 
+        virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0;
+
+        /**
+         * @brief add the input data to the input tensors
+         * @param inputTensors the input tensor container
+         * @param data the input data
+         * @param idx the index of the input node
+         * @param batchSize the batch size
+         * @return StatusCode::SUCCESS if the input data is added successfully
+         */
+        template <typename T>
+        StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
+
+        /**
+         * @brief add the output data to the output tensors
+         * @param outputTensors the output tensor container
+         * @param data the output data
+         * @param idx the index of the output node
+         * @param batchSize the batch size
+         * @return StatusCode::SUCCESS if the output data is added successfully
+         */
+        template <typename T>
+        StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
+
+        /**
+         * @brief perform inference
+         * @param inputTensors the input tensor container
+         * @param outputTensors the output tensor container
+         * @return StatusCode::SUCCESS if the inference is performed successfully
+         */
+        virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0;
+
+        virtual void printModelInfo() const = 0;
+
+        protected:
+        int m_numInputs;
+        int m_numOutputs;
+        std::vector<std::vector<int64_t> > m_inputShapes;
+        std::vector<std::vector<int64_t> > m_outputShapes;
+
+        private:
+        template <typename T>
+        Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const;
+
+    };
+
+    #include "IOnnxRuntimeInferenceTool.icc"
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.icc b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc
similarity index 100%
rename from Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.icc
rename to Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
index c7936c16a828..1665cf49639f 100644
--- a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
@@ -4,44 +4,17 @@
 
 // Gaudi include(s).
 #include "GaudiKernel/IAlgTool.h"
-#include <memory>
 
 #include <core/session/onnxruntime_cxx_api.h>
 
 
 namespace AthOnnx {
-    /**
-     * @class IOnnxRuntimeSessionTool
-     * @brief Interface class for creating Onnx Runtime sessions.
-     * @details Interface class for creating Onnx Runtime sessions.
-     * It is thread safe, supports models with various number of inputs and outputs,
-     * supports models with dynamic batch size, and usess . It defines a standardized procedure to 
-     * perform Onnx Runtime inference. The procedure is as follows, assuming the tool `m_onnxTool` is created and initialized:
-     *    1. create input tensors from the input data: 
-     *      ```c++
-     *         std::vector<Ort::Value> inputTensors;
-     *         std::vector<float> inputData_1;   // The input data is filled by users, possibly from the event information. 
-     *         int64_t batchSize = m_onnxTool->getBatchSize(inputData_1.size(), 0);  // The batch size is determined by the input data size to support dynamic batch size.
-     *         m_onnxTool->addInput(inputTensors, inputData_1, 0, batchSize);
-     *         std::vector<int64_t> inputData_2;  // Some models may have multiple inputs. Add inputs one by one.
-     *         int64_t batchSize_2 = m_onnxTool->getBatchSize(inputData_2.size(), 1);
-     *         m_onnxTool->addInput(inputTensors, inputData_2, 1, batchSize_2);
-     *     ```
-     *    2. create output tensors:
-     *      ```c++
-     *          std::vector<Ort::Value> outputTensors;
-     *          std::vector<float> outputData;   // The output data will be filled by the onnx session.
-     *          m_onnxTool->addOutput(outputTensors, outputData, 0, batchSize);
-     *      ```
-     *   3. perform inference:
-     *     ```c++
-     *        m_onnxTool->inference(inputTensors, outputTensors);
-     *    ```
-     *   4. Model outputs will be automatically filled to outputData.
-     * 
-     * 
-     * @author Xiangyang Ju <xju@cern.ch>
-     */
+    // class IAlgTool
+    //
+    // Interface class for creating Onnx Runtime sessions.
+    // 
+    // @author Xiangyang Ju <xju@cern.ch>
+    //
     class IOnnxRuntimeSessionTool : virtual public IAlgTool 
     {
         public:
@@ -51,87 +24,12 @@ namespace AthOnnx {
         // @name InterfaceID
         DeclareInterfaceID(IOnnxRuntimeSessionTool, 1, 0);
 
-        /**
-         * @brief set batch size. 
-         * @details If the model has dynamic batch size, 
-         *          the batchSize value will be set to both input shapes and output shapes
-         */ 
-        virtual void setBatchSize(int64_t batchSize) = 0;
-
-        /**
-         * @brief methods for determining batch size from the data size
-         * @param dataSize the size of the input data, like std::vector<T>::size()
-         * @param idx the index of the input node
-         * @return the batch size, which equals to dataSize / size of the rest dimensions.
-         */ 
-        virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0;
-
-        /**
-         * @brief add the input data to the input tensors
-         * @param inputTensors the input tensor container
-         * @param data the input data
-         * @param idx the index of the input node
-         * @param batchSize the batch size
-         * @return StatusCode::SUCCESS if the input data is added successfully
-         */
-        template <typename T>
-        StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
-
-        /**
-         * @brief add the output data to the output tensors
-         * @param outputTensors the output tensor container
-         * @param data the output data
-         * @param idx the index of the output node
-         * @param batchSize the batch size
-         * @return StatusCode::SUCCESS if the output data is added successfully
-         */
-        template <typename T>
-        StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
-
-        /**
-         * @brief perform inference
-         * @param inputTensors the input tensor container
-         * @param outputTensors the output tensor container
-         * @return StatusCode::SUCCESS if the inference is performed successfully
-         */
-        virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0;
-
-        /// @brief get the input node names
-        virtual const std::vector<const char*>& getInputNodeNames() const = 0;
-        /// @brief get the output node names
-        virtual const std::vector<const char*>& getOutputNodeNames() const = 0;
-
-        virtual const std::vector<std::vector<int64_t> >& getInputShape() const = 0;
-        virtual const std::vector<std::vector<int64_t> >& getOutputShapes() const = 0;
-
-        /// @brief get the number of input nodes
-        virtual int getNumInputs() const = 0;
-        /// @brief get the number of output nodes
-        virtual int getNumOutputs() const = 0;
-
-        /** 
-         * @brief get the size of the input/output tensor
-         * @param idx the index of the input/output node
-         * @return the size of the input/output tensor
-        */
-        virtual int64_t getInputSize(int idx = 0) const = 0;
-        virtual int64_t getOutputSize(int idx = 0) const = 0;
-
-        virtual void printModelInfo() const = 0;
-
-        protected:
-        int m_numInputs;
-        int m_numOutputs;
-        std::vector<std::vector<int64_t> > m_inputShapes;
-        std::vector<std::vector<int64_t> > m_outputShapes;
-
-        private:
-        template <typename T>
-        Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const;
+        // Create Onnx Runtime session
+        virtual std::unique_ptr<Ort::Session> createSession(
+                const std::string& modelFileName) const = 0;
 
     };
 
-    #include "IOnnxRuntimeSessionTool.icc"
 } // namespace AthOnnx
 
 #endif
diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY
deleted file mode 100644
index c5b26fc5a994..000000000000
--- a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/ATLAS_CHECK_THREAD_SAFETY
+++ /dev/null
@@ -1 +0,0 @@
-Control/AthOnnx/AthOnnxUtils
diff --git a/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
deleted file mode 100644
index 4134a24f81e6..000000000000
--- a/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
-
-# Declare the package's name.
-atlas_subdir( AthOnnxUtils )
-
-# External dependencies.
-find_package( onnxruntime )
-
-# Component(s) in the package.
-atlas_add_library( AthOnnxUtilsLib
-   AthOnnxUtils/*.h 
-   src/*.cxx
-   PUBLIC_HEADERS AthOnnxUtils
-   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaKernel GaudiKernel AsgServicesLib 
-)
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index 24c3ea72cdc2..aa38ffcd87b7 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -15,3 +15,9 @@ atlas_add_component( AthExOnnxRuntime
 
 # Install files from the package.
 atlas_install_joboptions( tests/*.py )
+
+# Test the packages
+atlas_add_test( AthExOnnxRuntimeTest
+   SCRIPT athena.py --CA  AthExOnnxRuntime_test.py 
+   PROPERTIES TIMEOUT 600
+)
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntimeTest.ref b/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntimeTest.ref
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_CA.py b/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py
similarity index 100%
rename from Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_CA.py
rename to Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py
-- 
GitLab


From 244dcc51f2cf515fc1d1e42ad4ad280e0fa758c7 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Wed, 31 Jan 2024 05:52:56 +0100
Subject: [PATCH 08/18] inference tool

---
 .../src/OnnxRuntimeInferenceTool.cxx          | 93 +++----------------
 .../src/OnnxRuntimeInferenceTool.h            | 38 +++-----
 .../src/OnnxRuntimeSessionToolCUDA.cxx        | 61 +-----------
 3 files changed, 30 insertions(+), 162 deletions(-)

diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
index 9a1fd88e2c0d..a7e36d4f7403 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
@@ -2,61 +2,36 @@
   Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
-#include "OnnxRuntimeSessionTool.h"
-#include "AthOnnxUtils/OnnxUtils.h"
+#include "OnnxRuntimeInferenceTool.h"
+#include "AthOnnxComps/OnnxUtils.h"
 
-AthOnnx::OnnxRuntimeSessionTool::OnnxRuntimeSessionTool(
+AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool(
   const std::string& type, const std::string& name, const IInterface* parent )
   : base_class( type, name, parent )
 {
-  declareInterface<IOnnxRuntimeSessionTool>(this);
+  declareInterface<IOnnxRuntimeInferenceTool>(this);
 }
 
-StatusCode AthOnnx::OnnxRuntimeSessionTool::initialize()
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize()
 {
     // Get the Onnx Runtime service.
     ATH_CHECK(m_onnxRuntimeSvc.retrieve());
 
-    ATH_CHECK(createModel());
-
-    return StatusCode::SUCCESS;
-}
-
-StatusCode AthOnnx::OnnxRuntimeSessionTool::finalize()
-{
-    return StatusCode::SUCCESS;
-}
+    // Create the session.
+    ATH_CHECK(m_onnxSessionTool.retrieve());
 
-StatusCode AthOnnx::OnnxRuntimeSessionTool::createModel()
-{
-    ATH_CHECK(createSession());
+    m_session = m_onnxSessionTool->createSession(m_modelFileName.value());
     ATH_CHECK(getNodeInfo());
 
     return StatusCode::SUCCESS;
 }
 
-StatusCode AthOnnx::OnnxRuntimeSessionTool::createSession() 
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::finalize()
 {
-    // Create the session options.
-    // TODO: Make this configurable.
-    // other threading options: https://onnxruntime.ai/docs/performance/tune-performance/threading.html
-
-    if (m_modelFileName.empty()) {
-        ATH_MSG_ERROR("Model file name is empty");
-        return StatusCode::FAILURE;
-    }
-
-    Ort::SessionOptions sessionOptions;
-    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
-    sessionOptions.DisablePerSessionThreads();  // use global thread pool.
-
-    // Create the session.
-    m_session =  std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
-
     return StatusCode::SUCCESS;
 }
 
-StatusCode AthOnnx::OnnxRuntimeSessionTool::getNodeInfo()
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo()
 {
     if (m_session == nullptr) {
         ATH_MSG_ERROR("Session is not created");
@@ -74,7 +49,7 @@ StatusCode AthOnnx::OnnxRuntimeSessionTool::getNodeInfo()
 }
 
 
-void AthOnnx::OnnxRuntimeSessionTool::setBatchSize(int64_t batchSize)
+void AthOnnx::OnnxRuntimeInferenceTool::setBatchSize(int64_t batchSize)
 {
     if (batchSize <= 0) {
         ATH_MSG_ERROR("Batch size should be positive");
@@ -94,7 +69,7 @@ void AthOnnx::OnnxRuntimeSessionTool::setBatchSize(int64_t batchSize)
     }
 }
 
-int64_t AthOnnx::OnnxRuntimeSessionTool::getBatchSize(int64_t inputDataSize, int idx) const
+int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
 {
     auto tensorSize = AthOnnx::getTensorSize(m_inputShapes[idx]);
     if (tensorSize < 0) {
@@ -104,7 +79,7 @@ int64_t AthOnnx::OnnxRuntimeSessionTool::getBatchSize(int64_t inputDataSize, int
     }
 }
 
-StatusCode AthOnnx::OnnxRuntimeSessionTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
 {
     assert (inputTensors.size() == m_numInputs);
     assert (outputTensors.size() == m_numOutputs);
@@ -118,47 +93,7 @@ StatusCode AthOnnx::OnnxRuntimeSessionTool::inference(std::vector<Ort::Value>& i
     return StatusCode::SUCCESS;
 }
 
-const std::vector<const char*>& AthOnnx::OnnxRuntimeSessionTool::getInputNodeNames() const
-{
-    return m_inputNodeNames;
-}
-
-const std::vector<const char*>& AthOnnx::OnnxRuntimeSessionTool::getOutputNodeNames() const
-{
-    return m_outputNodeNames;
-}
-
-const std::vector<std::vector<int64_t> >& AthOnnx::OnnxRuntimeSessionTool::getInputShape() const
-{
-    return m_inputShapes;
-}
-
-const std::vector<std::vector<int64_t> >& AthOnnx::OnnxRuntimeSessionTool::getOutputShapes() const
-{
-    return m_outputShapes;
-} 
-
-int AthOnnx::OnnxRuntimeSessionTool::getNumInputs() const
-{
-    return m_numInputs;
-}
-
-int AthOnnx::OnnxRuntimeSessionTool::getNumOutputs() const
-{
-    return m_numOutputs;
-}
-
-int64_t AthOnnx::OnnxRuntimeSessionTool::getInputSize(int idx) const
-{
-    return AthOnnx::getTensorSize(m_inputShapes[idx]);
-}
-
-int64_t AthOnnx::OnnxRuntimeSessionTool::getOutputSize(int idx) const
-{
-    return AthOnnx::getTensorSize(m_outputShapes[idx]);
-}
-
-void AthOnnx::OnnxRuntimeSessionTool::printModelInfo() const
+void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo() const
 {
     ATH_MSG_INFO("Model file name: " << m_modelFileName.value());
     ATH_MSG_INFO("Number of inputs: " << m_numInputs);
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
index fd8d954ddd11..e6cae97aac4a 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
@@ -1,27 +1,28 @@
 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
-#ifndef OnnxRuntimeSessionTool_H
-#define OnnxRuntimeSessionTool_H
+#ifndef OnnxRuntimeInferenceTool_H
+#define OnnxRuntimeInferenceTool_H
 
 #include "AthenaBaseComps/AthAlgTool.h"
-#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
 #include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
 #include "GaudiKernel/ServiceHandle.h"
+#include "GaudiKernel/ToolHandle.h"
 
 namespace AthOnnx {
-    // @class OnnxRuntimeSessionTool
+    // @class OnnxRuntimeInferenceTool
     // 
     // @brief Tool to create Onnx Runtime session with CPU backend
     //
     // @author Xiangyang Ju <xiangyang.ju@cern.ch>
-    class OnnxRuntimeSessionTool :  public extends<AthAlgTool, IOnnxRuntimeSessionTool>
+    class OnnxRuntimeInferenceTool :  public extends<AthAlgTool, IOnnxRuntimeInferenceTool>
     {
         public:
         /// Standard constructor
-        OnnxRuntimeSessionTool( const std::string& type,
+        OnnxRuntimeInferenceTool( const std::string& type,
                                 const std::string& name,
                                 const IInterface* parent );
-        virtual ~OnnxRuntimeSessionTool() = default;
+        virtual ~OnnxRuntimeInferenceTool() = default;
 
         /// Initialize the tool
         virtual StatusCode initialize() override;
@@ -33,29 +34,20 @@ namespace AthOnnx {
 
         virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
 
-        virtual const std::vector<const char*>& getInputNodeNames() const override final;
-        virtual const std::vector<const char*>& getOutputNodeNames() const override final;
-
-        virtual const std::vector<std::vector<int64_t> >& getInputShape() const override final;
-        virtual const std::vector<std::vector<int64_t> >& getOutputShapes() const override final;
-
-        virtual int getNumInputs() const override final;
-        virtual int getNumOutputs() const override final;
-        virtual int64_t getInputSize(int idx = 0) const override final;
-        virtual int64_t getOutputSize(int idx = 0) const override final;
-
         virtual void printModelInfo() const override final;
 
         protected:
-        OnnxRuntimeSessionTool() = delete;
-        OnnxRuntimeSessionTool(const OnnxRuntimeSessionTool&) = delete;
-        OnnxRuntimeSessionTool& operator=(const OnnxRuntimeSessionTool&) = delete;
+        OnnxRuntimeInferenceTool() = delete;
+        OnnxRuntimeInferenceTool(const OnnxRuntimeInferenceTool&) = delete;
+        OnnxRuntimeInferenceTool& operator=(const OnnxRuntimeInferenceTool&) = delete;
 
-        StatusCode createModel();
-        StatusCode createSession();
+        private:
         StatusCode getNodeInfo();
 
         ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+        ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
+            this, "ORTSessionTool", 
+            "AthOnnx::OnnxRuntimeInferenceToolCPU"};
         StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
 
         std::unique_ptr<Ort::Session> m_session;
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
index d1286794b919..2c12b321b0c8 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
@@ -23,8 +23,7 @@ StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::initialize()
 
 StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::finalize()
 {
-    StatusCode sc = AlgTool::finalize();
-    return sc;
+    return StatusCode::SUCCESS;
 }
 
 std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCUDA::createSession(
@@ -65,61 +64,3 @@ std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCUDA::createSession
     // Create the session.
     return std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFileName.c_str(), sessionOptions);
 }
-
-
-
-#include "OnnxRuntimeSessionToolCUDA.h"
-#include "AthOnnxUtils/OnnxUtils.h"
-#include <limits>
-
-AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA(
-    const std::string& type, const std::string& name, const IInterface* parent )
-    : base_class( type, name, parent )
-{
-  declareInterface<IOnnxRuntimeSessionTool>(this);
-}
-
-StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::createSession() 
-{
-    if (m_modelFileName.empty()) {
-        ATH_MSG_ERROR("Model file name is empty");
-        return StatusCode::FAILURE;
-    }
-
-    ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
-
-    Ort::SessionOptions sessionOptions;
-    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
-    sessionOptions.DisablePerSessionThreads();    // use global thread pool.
-
-    // TODO: add more cuda options to the interface
-    // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
-    // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
-    OrtCUDAProviderOptions cuda_options;
-    cuda_options.device_id = m_deviceId;
-    cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
-    cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
-
-    // memorry arena options for cuda memory shrinkage
-    // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
-    if (m_enableMemoryShrinkage) {
-        Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
-        // other options are not available in this release.
-        // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
-        // arena_cfg.max_mem = 0;   // let ORT pick default max memory
-        // arena_cfg.arena_extend_strategy = 0;   // 0: kNextPowerOfTwo, 1: kSameAsRequested
-        // arena_cfg.initial_chunk_size_bytes = 1024;
-        // arena_cfg.max_dead_bytes_per_chunk = 0;
-        // arena_cfg.initial_growth_chunk_size_bytes = 256;
-        // arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
-
-        cuda_options.default_memory_arena_cfg = arena_cfg;
-    }
-
-    sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
-
-    // Create the session.
-    m_session =  std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
-
-    return StatusCode::SUCCESS;
-}
-- 
GitLab


From 9ddb2682da23ff5e2dc4cb307ca1eae51ea797d8 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Wed, 31 Jan 2024 07:52:47 +0100
Subject: [PATCH 09/18] to inference

---
 Control/AthOnnx/AthOnnxComps/CMakeLists.txt   | 23 +++++++++----------
 .../src/OnnxRuntimeInferenceTool.h            |  4 +++-
 .../AthOnnx/AthOnnxComps/src/OnnxUtils.cxx    |  2 +-
 .../src/components/AthOnnxComps_entries.cxx   |  6 +++--
 .../IOnnxRuntimeInferenceTool.icc             |  6 ++---
 .../AthExOnnxRuntime/CMakeLists.txt           |  2 +-
 .../AthExOnnxRuntime/src/EvaluateModel.cxx    |  2 +-
 .../AthExOnnxRuntime/src/EvaluateModel.h      |  6 ++---
 InnerDetector/InDetGNNTracking/CMakeLists.txt |  2 +-
 .../src/SiGNNTrackFinderTool.cxx              |  2 +-
 .../src/SiGNNTrackFinderTool.h                | 14 +++++------
 11 files changed, 36 insertions(+), 33 deletions(-)

diff --git a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
index 8c8727e10131..3c605e7f87cc 100644
--- a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
+++ b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
@@ -6,24 +6,23 @@ atlas_subdir( AthOnnxComps )
 # External dependencies.
 find_package( onnxruntime )
 
+# Libraray in the package.
+atlas_add_library( AthOnnxCompsLib
+   AthOnnxComps/*.h
+   src/*.h
+   src/*.cxx
+   PUBLIC_HEADERS AthOnnxComps
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthenaBaseComps AthenaKernel GaudiKernel  AthOnnxInterfaces
+)
+
 # Component(s) in the package.
 atlas_add_component( AthOnnxComps
-   src/*.h
-   src/*.cxx 
    src/components/*.cxx
    INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} 
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib 
-   AthOnnxInterfaces AthenaBaseComps GaudiKernel AthOnnxruntimeServiceLib AthOnnxUtilsLib
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxCompsLib
 )
 
-# Component(s) in the package.
-atlas_add_library( AthOnnxUtilsLib
-   AthOnnxComps/*.h 
-   src/*.cxx
-   PUBLIC_HEADERS AthOnnxComps
-   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxruntimeServiceLib AthenaKernel GaudiKernel AsgServicesLib 
-)
 
 # install python modules
 atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} )
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
index e6cae97aac4a..e878d621327f 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
@@ -6,6 +6,7 @@
 #include "AthenaBaseComps/AthAlgTool.h"
 #include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
 #include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
 #include "GaudiKernel/ServiceHandle.h"
 #include "GaudiKernel/ToolHandle.h"
 
@@ -47,7 +48,8 @@ namespace AthOnnx {
         ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
         ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
             this, "ORTSessionTool", 
-            "AthOnnx::OnnxRuntimeInferenceToolCPU"};
+            "AthOnnx::OnnxRuntimeInferenceToolCPU"
+        };
         StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
 
         std::unique_ptr<Ort::Session> m_session;
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
index df0cf5ee077f..25550a86ebf7 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
@@ -77,7 +77,7 @@ int64_t getTensorSize(const std::vector<int64_t>& dataShape){
     return size;
 }
 
-Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape) const
+Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape)
 {
     auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); 
 
diff --git a/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
index 9f9ae59ef5f6..45ef514180f8 100644
--- a/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
@@ -2,10 +2,12 @@
 
 // Local include(s).
 #include "../OnnxRuntimeSvc.h"
-#include "../OnnxRuntimeSessionTool.h"
+#include "../OnnxRuntimeSessionToolCPU.h"
 #include "../OnnxRuntimeSessionToolCUDA.h"
+#include "../OnnxRuntimeInferenceTool.h"
 
 // Declare the package's components.
 DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSvc )
-DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionTool )
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCPU )
 DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCUDA )
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeInferenceTool )
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc
index 6e858c525c48..1c24b11c09d8 100644
--- a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc
@@ -1,6 +1,6 @@
 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 template <typename T>
-Ort::Value AthOnnx::IOnnxRuntimeSessionTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
+Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
 {
     std::vector<int64_t> dataShapeCopy = dataShape;
 
@@ -20,7 +20,7 @@ Ort::Value AthOnnx::IOnnxRuntimeSessionTool::createTensor(std::vector<T>& data,
 }
 
 template <typename T>
-StatusCode AthOnnx::IOnnxRuntimeSessionTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
+StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
 {
     if (idx >= m_numInputs || idx < 0) {
         return StatusCode::FAILURE;
@@ -31,7 +31,7 @@ StatusCode AthOnnx::IOnnxRuntimeSessionTool::addInput(std::vector<Ort::Value>& i
 }
 
 template <typename T>
-StatusCode AthOnnx::IOnnxRuntimeSessionTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
+StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
 {
     if (idx >= m_numOutputs || idx < 0) {
         return StatusCode::FAILURE;
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index aa38ffcd87b7..baa00b73c654 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -10,7 +10,7 @@ atlas_add_component( AthExOnnxRuntime
    src/*.h src/*.cxx src/components/*.cxx
    INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
    LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaBaseComps GaudiKernel PathResolver 
-   AthOnnxInterfaces AthOnnxUtilsLib AsgServicesLib
+   AthOnnxInterfaces AthOnnxCompsLib AsgServicesLib
 )
 
 # Install files from the package.
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
index c151b858355c..f1e007ff9dd1 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
@@ -9,7 +9,7 @@
 
 // Framework include(s).
 #include "PathResolver/PathResolver.h"
-#include "AthOnnxUtils/OnnxUtils.h"
+#include "AthOnnxComps/OnnxUtils.h"
 
 namespace AthOnnx {
 
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
index d1a1f3a3894a..d660543e5e58 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
@@ -4,7 +4,7 @@
 #define ATHEXONNXRUNTIME_EVALUATEMODEL_H
 
 // Local include(s).
-#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
 
 // Framework include(s).
 #include "AthenaBaseComps/AthReentrantAlgorithm.h"
@@ -58,8 +58,8 @@ namespace AthOnnx {
       Gaudi::Property<int> m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"};
 
       /// Tool handler for onnx inference session
-      ToolHandle< IOnnxRuntimeSessionTool >  m_onnxTool{
-         this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+      ToolHandle< IOnnxRuntimeInferenceTool >  m_onnxTool{
+         this, "ORTInferenceTool", "AthOnnx::OnnxRuntimeInferenceTool"
       };
       
       std::vector<std::vector<std::vector<float>>> m_input_tensor_values_notFlat;
diff --git a/InnerDetector/InDetGNNTracking/CMakeLists.txt b/InnerDetector/InDetGNNTracking/CMakeLists.txt
index ca9e45890dec..c2cce19dadf8 100644
--- a/InnerDetector/InDetGNNTracking/CMakeLists.txt
+++ b/InnerDetector/InDetGNNTracking/CMakeLists.txt
@@ -18,7 +18,7 @@ atlas_add_component( InDetGNNTracking
     PixelReadoutGeometryLib SCT_ReadoutGeometry InDetSimData
     InDetPrepRawData TrkTrack TrkRIO_OnTrack InDetSimEvent
     AtlasHepMCLib InDetRIO_OnTrack InDetRawData TrkTruthData
-    AthOnnxInterfaces AthOnnxUtilsLib
+    AthOnnxInterfaces AthOnnxCompsLib 
 )
 
 atlas_install_python_modules( python/*.py )
diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
index e07013addd91..454b5b143ea6 100644
--- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
+++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
@@ -7,7 +7,7 @@
 
 // Framework include(s).
 #include "PathResolver/PathResolver.h"
-#include "AthOnnxUtils/OnnxUtils.h"
+#include "AthOnnxComps/OnnxUtils.h"
 #include <cmath>
 
 InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool(
diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
index 3e70de690ee9..3fca4b225e97 100644
--- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
+++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
@@ -14,7 +14,7 @@
 #include "InDetRecToolInterfaces/IGNNTrackFinder.h"
 
 // ONNX Runtime include(s).
-#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
 #include <core/session/onnxruntime_cxx_api.h>
 
 class MsgStream;
@@ -74,14 +74,14 @@ namespace InDet{
     MsgStream&    dumpevent     (MsgStream&    out) const;
 
     private:
-    ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_embedSessionTool {
-      this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+    ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool {
+      this, "Embedding", "AthOnnx::OnnxRuntimeInferenceTool"
     };
-    ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_filterSessionTool {
-      this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+    ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool {
+      this, "Filtering", "AthOnnx::OnnxRuntimeInferenceTool"
     };
-    ToolHandle< AthOnnx::IOnnxRuntimeSessionTool > m_gnnSessionTool {
-      this, "OnnxRuntimeSessionTool", "AthOnnx::OnnxRuntimeSessionTool"
+    ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool {
+      this, "GNN", "AthOnnx::OnnxRuntimeInferenceTool"
     };
 
   };
-- 
GitLab


From 27c13e14164fded5ad07c23ed64ff84230b195ff Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Wed, 31 Jan 2024 08:48:47 +0100
Subject: [PATCH 10/18] update configurations

---
 .../python/OnnxRuntimeInferenceConfig.py      | 24 +++++++++++++++++++
 .../python/OnnxRuntimeSessionConfig.py        | 11 ++++-----
 .../python/AllConfigFlags.py                  |  4 ++--
 .../tests/AthExOnnxRuntime_test.py            |  8 +++----
 4 files changed, 34 insertions(+), 13 deletions(-)
 create mode 100644 Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py

diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py
new file mode 100644
index 000000000000..db570490aae7
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py
@@ -0,0 +1,24 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
+from typing import Optional
+from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
+
+def OnnxRuntimeInferenceToolCfg(flags, 
+                                model_fname: str = None, 
+                                execution_provider: Optional[OnnxRuntimeType] = None, 
+                                name="OnnxRuntimeInferenceTool", **kwargs):
+    """Configure OnnxRuntimeInferenceTool in Control/AthOnnx/AthOnnxComps/src"""
+
+    acc = ComponentAccumulator()
+
+    if model_fname is None:
+        raise ValueError("model_fname must be specified")
+
+    session_tool = acc.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, execution_provider))
+    kwargs["ORTSessionTool"] = session_tool
+    kwargs["ModelFileName"] = model_fname
+    acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeInferenceTool(name, **kwargs))
+    return acc
diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
index 779969b80af3..1dd398327af4 100644
--- a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
@@ -2,25 +2,22 @@
 
 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
 from AthenaConfiguration.ComponentFactory import CompFactory
-from AthOnnxConfig.OnnxRuntimeFlags import OnnxRuntimeType
+from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
 from typing import Optional
 
 def OnnxRuntimeSessionToolCfg(flags,
-                              model_fname: str = None, 
                               execution_provider: Optional[OnnxRuntimeType] = None, 
                               name="OnnxRuntimeSessionTool", **kwargs):
     """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
     
     acc = ComponentAccumulator()
-
-    if model_fname is None:
-        raise ValueError("model_fname must be specified")
     
     execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
+    name += execution_provider.name
     if execution_provider is OnnxRuntimeType.CPU:
-        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionTool(name, ModelFileName=model_fname, **kwargs))
+        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs))
     elif execution_provider is OnnxRuntimeType.CUDA:
-        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, ModelFileName=model_fname,  **kwargs))
+        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, **kwargs))
     else:
         raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider)
 
diff --git a/Control/AthenaConfiguration/python/AllConfigFlags.py b/Control/AthenaConfiguration/python/AllConfigFlags.py
index 61e45547f3b0..18102cacba5a 100644
--- a/Control/AthenaConfiguration/python/AllConfigFlags.py
+++ b/Control/AthenaConfiguration/python/AllConfigFlags.py
@@ -490,9 +490,9 @@ def initConfigFlags():
 
     # onnxruntime flags
     def __onnxruntime():
-        from AthOnnxConfig.OnnxRuntimeFlags import createOnnxRuntimeFlags
+        from AthOnnxComps.OnnxRuntimeFlags import createOnnxRuntimeFlags
         return createOnnxRuntimeFlags()
-    _addFlagsCategory(acf, "AthOnnx", __onnxruntime, 'AthOnnxConfig')
+    _addFlagsCategory(acf, "AthOnnx", __onnxruntime, 'AthOnnxComps')
 
     # For AnalysisBase, pick up things grabbed in Athena by the functions above
     if not isGaudiEnv():
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py b/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py
index 16c0745cac23..4ed92286804c 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py
+++ b/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py
@@ -3,7 +3,7 @@
 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
 from AthenaConfiguration.ComponentFactory import CompFactory
 from AthenaCommon import Constants
-from AthOnnxConfig.OnnxRuntimeFlags import OnnxRuntimeType 
+from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType 
 
 
 def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
@@ -12,9 +12,9 @@ def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
     model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx"
     kwargs.setdefault("OutputLevel", Constants.DEBUG)
 
-    from AthOnnxConfig.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
-    kwargs.setdefault("OnnxRuntimeSessionTool", acc.popToolsAndMerge(
-        OnnxRuntimeSessionToolCfg(flags, 
+    from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
+    kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
+        OnnxRuntimeInferenceToolCfg(flags, 
                                   model_fname, 
                                 #  execution_provider=OnnxRuntimeType.CUDA,  # optionally override flags.AthOnnx.ExecutionProvider, default is CPU
                                   **kwargs)
-- 
GitLab


From db85d6cb4935ffe0923986ee36469c478d8a8433 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Wed, 31 Jan 2024 08:50:46 +0100
Subject: [PATCH 11/18] no ref

---
 .../AthExOnnxRuntime/tests/AthExOnnxRuntimeTest.ref               | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 delete mode 100644 Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntimeTest.ref

diff --git a/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntimeTest.ref b/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntimeTest.ref
deleted file mode 100644
index e69de29bb2d1..000000000000
-- 
GitLab


From 006da8da7c35c9a6542da77503faaa879e192ef0 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Thu, 1 Feb 2024 12:27:17 +0100
Subject: [PATCH 12/18] fix testing issue

---
 .../AthExOnnxRuntime/CMakeLists.txt               | 15 ++++++++++-----
 .../{tests => python}/AthExOnnxRuntime_test.py    |  6 ++++--
 2 files changed, 14 insertions(+), 7 deletions(-)
 rename Control/AthenaExamples/AthExOnnxRuntime/{tests => python}/AthExOnnxRuntime_test.py (93%)

diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index baa00b73c654..46db4be0c18f 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -13,11 +13,16 @@ atlas_add_component( AthExOnnxRuntime
    AthOnnxInterfaces AthOnnxCompsLib AsgServicesLib
 )
 
-# Install files from the package.
-atlas_install_joboptions( tests/*.py )
+# Install files from the package:
+atlas_install_python_modules( python/*.py
+                              POST_BUILD_CMD ${ATLAS_FLAKE8} )
 
 # Test the packages
 atlas_add_test( AthExOnnxRuntimeTest
-   SCRIPT athena.py --CA  AthExOnnxRuntime_test.py 
-   PROPERTIES TIMEOUT 600
-)
+   SCRIPT python -m AthExOnnxRuntime.AthExOnnxRuntime_test
+   POST_EXEC_SCRIPT noerror.sh )
+
+atlas_add_test( AthExOnnxRuntimeTest_pkl
+   SCRIPT athena --evtMax 2 --CA test_AthExOnnxRuntimeExampleCfg.pkl
+   DEPENDS AthExOnnxRuntimeTest
+   POST_EXEC_SCRIPT noerror.sh )
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
similarity index 93%
rename from Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py
rename to Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
index 4ed92286804c..7a5967da343e 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/tests/AthExOnnxRuntime_test.py
+++ b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
@@ -40,5 +40,7 @@ if __name__ == "__main__":
     acc.merge(AthExOnnxRuntimeExampleCfg(flags))
     acc.printConfig(withDetails=True, summariseProps=True)
 
-    sc = acc.run(maxEvents=2)
-    msg.info(sc.isSuccess())
+    acc.store(open('test_AthExOnnxRuntimeExampleCfg.pkl','wb'))
+
+    import sys
+    sys.exit(acc.run(2).isFailure())
-- 
GitLab


From e40a6155291637fa16fa0fdd7a94c5b033ea7980 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Thu, 1 Feb 2024 12:59:40 +0100
Subject: [PATCH 13/18] more robust

---
 .../AthenaExamples/AthExOnnxRuntime/CMakeLists.txt  |  5 ++++-
 .../python/AthExOnnxRuntime_test.py                 | 13 ++++++-------
 .../AthExOnnxRuntime/src/EvaluateModel.cxx          |  7 ++-----
 3 files changed, 12 insertions(+), 13 deletions(-)

diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index 46db4be0c18f..ca3ae078d6c9 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -20,9 +20,12 @@ atlas_install_python_modules( python/*.py
 # Test the packages
 atlas_add_test( AthExOnnxRuntimeTest
    SCRIPT python -m AthExOnnxRuntime.AthExOnnxRuntime_test
-   POST_EXEC_SCRIPT noerror.sh )
+   PROPERTIES TIMEOUT 100
+   POST_EXEC_SCRIPT noerror.sh 
+   )
 
 atlas_add_test( AthExOnnxRuntimeTest_pkl
    SCRIPT athena --evtMax 2 --CA test_AthExOnnxRuntimeExampleCfg.pkl
    DEPENDS AthExOnnxRuntimeTest
+   PROPERTIES TIMEOUT 100
    POST_EXEC_SCRIPT noerror.sh )
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
index 7a5967da343e..6d243fbc6feb 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
+++ b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
@@ -8,19 +8,18 @@ from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
 
 def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
     acc = ComponentAccumulator()
-    
-    model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx"
-    kwargs.setdefault("OutputLevel", Constants.DEBUG)
 
+    model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx"
+    execution_provider = OnnxRuntimeType.CPU
     from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
     kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
-        OnnxRuntimeInferenceToolCfg(flags, 
-                                  model_fname, 
-                                #  execution_provider=OnnxRuntimeType.CUDA,  # optionally override flags.AthOnnx.ExecutionProvider, default is CPU
-                                  **kwargs)
+        OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider)
     ))
 
+    input_data = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
     kwargs.setdefault("BatchSize", 3)
+    kwargs.setdefault("InputDataPixel", input_data)
+    kwargs.setdefault("OutputLevel", Constants.DEBUG)
     acc.addEventAlgo(CompFactory.AthOnnx.EvaluateModel(name, **kwargs))
 
     return acc
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
index f1e007ff9dd1..fe1f6d165f95 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
@@ -8,7 +8,6 @@
 #include <arpa/inet.h>
 
 // Framework include(s).
-#include "PathResolver/PathResolver.h"
 #include "AthOnnxComps/OnnxUtils.h"
 
 namespace AthOnnx {
@@ -82,11 +81,9 @@ namespace AthOnnx {
 	return StatusCode::FAILURE;
        }
      // read input file, and the target file for comparison.
-      const std::string pixelFileName =
-         PathResolverFindCalibFile( m_pixelFileName );
-      ATH_MSG_INFO( "Using pixel file: " << pixelFileName );
+      ATH_MSG_INFO( "Using pixel file: " << m_pixelFileName );
   
-      m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(pixelFileName);
+      m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(m_pixelFileName);
       ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
     
       return StatusCode::SUCCESS;
-- 
GitLab


From f7b148ad7c4a075242c94408ce5b2a3debdcf7fa Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 6 Feb 2024 16:19:37 +0100
Subject: [PATCH 14/18] factorize utils

---
 Control/AthOnnx/AthOnnxComps/CMakeLists.txt      | 15 ++-------------
 .../src/OnnxRuntimeInferenceTool.cxx             |  2 +-
 .../AthOnnxUtils}/OnnxUtils.h                    |  0
 Control/AthOnnx/AthOnnxUtils/CMakeLists.txt      | 16 ++++++++++++++++
 .../src/OnnxUtils.cxx                            |  2 +-
 .../AthExOnnxRuntime/CMakeLists.txt              |  2 +-
 .../AthExOnnxRuntime/src/EvaluateModel.cxx       |  2 +-
 InnerDetector/InDetGNNTracking/CMakeLists.txt    |  2 +-
 .../src/SiGNNTrackFinderTool.cxx                 |  2 +-
 9 files changed, 24 insertions(+), 19 deletions(-)
 rename Control/AthOnnx/{AthOnnxComps/AthOnnxComps => AthOnnxUtils/AthOnnxUtils}/OnnxUtils.h (100%)
 create mode 100644 Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
 rename Control/AthOnnx/{AthOnnxComps => AthOnnxUtils}/src/OnnxUtils.cxx (98%)

diff --git a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
index 3c605e7f87cc..f9169beeed1a 100644
--- a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
+++ b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
@@ -7,22 +7,11 @@ atlas_subdir( AthOnnxComps )
 find_package( onnxruntime )
 
 # Libraray in the package.
-atlas_add_library( AthOnnxCompsLib
-   AthOnnxComps/*.h
-   src/*.h
+atlas_add_library( AthOnnxComps
    src/*.cxx
-   PUBLIC_HEADERS AthOnnxComps
    INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthenaBaseComps AthenaKernel GaudiKernel  AthOnnxInterfaces
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthenaBaseComps AthenaKernel GaudiKernel  AthOnnxInterfaces AthOnnxUtilsLib
 )
 
-# Component(s) in the package.
-atlas_add_component( AthOnnxComps
-   src/components/*.cxx
-   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} 
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthOnnxCompsLib
-)
-
-
 # install python modules
 atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} )
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
index a7e36d4f7403..cb7bccb8bcde 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
@@ -3,7 +3,7 @@
 */
 
 #include "OnnxRuntimeInferenceTool.h"
-#include "AthOnnxComps/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 
 AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool(
   const std::string& type, const std::string& name, const IInterface* parent )
diff --git a/Control/AthOnnx/AthOnnxComps/AthOnnxComps/OnnxUtils.h b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
similarity index 100%
rename from Control/AthOnnx/AthOnnxComps/AthOnnxComps/OnnxUtils.h
rename to Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
diff --git a/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
new file mode 100644
index 000000000000..029f7e5a1ce8
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package's name.
+atlas_subdir( AthOnnxUtils )
+
+# External dependencies.
+find_package( onnxruntime )
+
+# Component(s) in the package.
+atlas_add_library( AthOnnxUtilsLib
+   AthOnnxUtils/*.h 
+   src/*.cxx
+   PUBLIC_HEADERS AthOnnxUtils
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaKernel GaudiKernel AsgServicesLib 
+)
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
similarity index 98%
rename from Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
rename to Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
index 25550a86ebf7..cbf0b3663341 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxUtils.cxx
+++ b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
@@ -1,6 +1,6 @@
 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
-#include "AthOnnxComps/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 #include <cassert>
 
 namespace AthOnnx {
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index ca3ae078d6c9..02249a193fcf 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -10,7 +10,7 @@ atlas_add_component( AthExOnnxRuntime
    src/*.h src/*.cxx src/components/*.cxx
    INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
    LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaBaseComps GaudiKernel PathResolver 
-   AthOnnxInterfaces AthOnnxCompsLib AsgServicesLib
+   AthOnnxInterfaces AthOnnxUtilsLib AsgServicesLib
 )
 
 # Install files from the package:
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
index fe1f6d165f95..525d2387c4eb 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
@@ -8,7 +8,7 @@
 #include <arpa/inet.h>
 
 // Framework include(s).
-#include "AthOnnxComps/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 
 namespace AthOnnx {
 
diff --git a/InnerDetector/InDetGNNTracking/CMakeLists.txt b/InnerDetector/InDetGNNTracking/CMakeLists.txt
index c2cce19dadf8..daa184c96473 100644
--- a/InnerDetector/InDetGNNTracking/CMakeLists.txt
+++ b/InnerDetector/InDetGNNTracking/CMakeLists.txt
@@ -18,7 +18,7 @@ atlas_add_component( InDetGNNTracking
     PixelReadoutGeometryLib SCT_ReadoutGeometry InDetSimData
     InDetPrepRawData TrkTrack TrkRIO_OnTrack InDetSimEvent
     AtlasHepMCLib InDetRIO_OnTrack InDetRawData TrkTruthData
-    AthOnnxInterfaces AthOnnxCompsLib 
+    AthOnnxInterfaces AthOnnxUtilsLib 
 )
 
 atlas_install_python_modules( python/*.py )
diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
index 454b5b143ea6..e07013addd91 100644
--- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
+++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
@@ -7,7 +7,7 @@
 
 // Framework include(s).
 #include "PathResolver/PathResolver.h"
-#include "AthOnnxComps/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 #include <cmath>
 
 InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool(
-- 
GitLab


From 1751a37373b666d57b9d818df7350cadc628583c Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 6 Feb 2024 18:11:21 +0100
Subject: [PATCH 15/18] address more comments

---
 Control/AthOnnx/AthOnnxComps/CMakeLists.txt   |  3 +-
 .../python/OnnxRuntimeInferenceConfig.py      |  6 +---
 .../python/OnnxRuntimeSessionConfig.py        |  6 ++++
 .../src/OnnxRuntimeInferenceTool.cxx          | 18 ++++------
 .../src/OnnxRuntimeInferenceTool.h            |  7 ++--
 .../src/OnnxRuntimeSessionToolCPU.cxx         | 31 ++++++++--------
 .../src/OnnxRuntimeSessionToolCPU.h           |  6 ++--
 .../src/OnnxRuntimeSessionToolCUDA.cxx        | 28 ++++++++-------
 .../src/OnnxRuntimeSessionToolCUDA.h          |  5 +--
 .../IOnnxRuntimeSessionTool.h                 |  3 +-
 .../AthOnnxUtils/AthOnnxUtils/OnnxUtils.h     | 18 +++++-----
 .../AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx    | 36 ++++++++++---------
 12 files changed, 87 insertions(+), 80 deletions(-)

diff --git a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
index f9169beeed1a..f2e46aa8b6b7 100644
--- a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
+++ b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
@@ -7,8 +7,9 @@ atlas_subdir( AthOnnxComps )
 find_package( onnxruntime )
 
 # Libraray in the package.
-atlas_add_library( AthOnnxComps
+atlas_add_component( AthOnnxComps
    src/*.cxx
+   src/components/*.cxx
    INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
    LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthenaBaseComps AthenaKernel GaudiKernel  AthOnnxInterfaces AthOnnxUtilsLib
 )
diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py
index db570490aae7..9c06316bdaf6 100644
--- a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py
@@ -14,11 +14,7 @@ def OnnxRuntimeInferenceToolCfg(flags,
 
     acc = ComponentAccumulator()
 
-    if model_fname is None:
-        raise ValueError("model_fname must be specified")
-
-    session_tool = acc.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, execution_provider))
+    session_tool = acc.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname, execution_provider))
     kwargs["ORTSessionTool"] = session_tool
-    kwargs["ModelFileName"] = model_fname
     acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeInferenceTool(name, **kwargs))
     return acc
diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
index 1dd398327af4..32b9d4471b14 100644
--- a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
@@ -6,14 +6,20 @@ from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
 from typing import Optional
 
 def OnnxRuntimeSessionToolCfg(flags,
+                              model_fname: str,
                               execution_provider: Optional[OnnxRuntimeType] = None, 
                               name="OnnxRuntimeSessionTool", **kwargs):
     """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
     
     acc = ComponentAccumulator()
     
+    if model_fname is None:
+        raise ValueError("model_fname must be specified")
+    
     execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
     name += execution_provider.name
+
+    kwargs.setdefault("ModelFileName", model_fname)
     if execution_provider is OnnxRuntimeType.CPU:
         acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs))
     elif execution_provider is OnnxRuntimeType.CUDA:
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
index cb7bccb8bcde..a8dfa4e5031b 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
@@ -20,7 +20,6 @@ StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize()
     // Create the session.
     ATH_CHECK(m_onnxSessionTool.retrieve());
 
-    m_session = m_onnxSessionTool->createSession(m_modelFileName.value());
     ATH_CHECK(getNodeInfo());
 
     return StatusCode::SUCCESS;
@@ -33,17 +32,13 @@ StatusCode AthOnnx::OnnxRuntimeInferenceTool::finalize()
 
 StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo()
 {
-    if (m_session == nullptr) {
-        ATH_MSG_ERROR("Session is not created");
-        return StatusCode::FAILURE;
-    }
-
+    auto& session = m_onnxSessionTool->session();
     // obtain the model information
-    m_numInputs = m_session->GetInputCount();
-    m_numOutputs = m_session->GetOutputCount();
+    m_numInputs = session.GetInputCount();
+    m_numOutputs = session.session().GetOutputCount();
 
-    AthOnnx::getInputNodeInfo(m_session, m_inputShapes, m_inputNodeNames);
-    AthOnnx::getOutputNodeInfo(m_session, m_outputShapes, m_outputNodeNames);
+    AthOnnx::getInputNodeInfo(session.session(), m_inputShapes, m_inputNodeNames);
+    AthOnnx::getOutputNodeInfo(session.session(), m_outputShapes, m_outputNodeNames);
 
     return StatusCode::SUCCESS;
 }
@@ -86,7 +81,7 @@ StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>&
 
     // Run the model.
     AthOnnx::inferenceWithIOBinding(
-            m_session, 
+            m_onnxSessionTool->session(), 
             m_inputNodeNames, inputTensors, 
             m_outputNodeNames, outputTensors);
 
@@ -95,7 +90,6 @@ StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>&
 
 void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo() const
 {
-    ATH_MSG_INFO("Model file name: " << m_modelFileName.value());
     ATH_MSG_INFO("Number of inputs: " << m_numInputs);
     ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
 
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
index e878d621327f..b8f3e3e5ad69 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
@@ -51,10 +51,9 @@ namespace AthOnnx {
             "AthOnnx::OnnxRuntimeInferenceToolCPU"
         };
         StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
-
-        std::unique_ptr<Ort::Session> m_session;
-        std::vector<const char*> m_inputNodeNames;
-        std::vector<const char*> m_outputNodeNames;
+        
+        std::vector<std::string> m_inputNodeNames;
+        std::vector<std::string> m_outputNodeNames;
 
     };
 } // namespace AthOnnx
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
index 9b43b6fe69d8..85f1f8df115b 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
@@ -15,19 +15,8 @@ StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::initialize()
 {
     // Get the Onnx Runtime service.
     ATH_CHECK(m_onnxRuntimeSvc.retrieve());
-
-    return StatusCode::SUCCESS;
-}
-
-StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::finalize()
-{
-    StatusCode sc = AlgTool::finalize();
-    return sc;
-}
-
-std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCPU::createSession(
-    const std::string& modelFileName) const
-{
+    ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
+    
     // Create the session options.
     // TODO: Make this configurable.
     // other threading options: https://onnxruntime.ai/docs/performance/tune-performance/threading.html
@@ -39,5 +28,19 @@ std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCPU::createSession(
     sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
 
     // Create the session.
-    return std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFileName.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::finalize()
+{
+    m_session.reset();
+    ATH_MSG_DEBUG( "Ort::Session object deleted" );
+    return StatusCode::SUCCESS;
+}
+
+Ort::Session& AthOnnx::OnnxRuntimeSessionToolCPU::session() const
+{
+  return *m_session;
 }
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
index d10e2c478ff6..a20e5e5d4f2e 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
@@ -29,15 +29,17 @@ namespace AthOnnx {
         virtual StatusCode finalize() override final;
 
         /// Create Onnx Runtime session
-        virtual std::unique_ptr<Ort::Session> createSession(
-            const std::string& modelFileName) const override final;
+        virtual Ort::Session& session() const override final;
 
         protected:
         OnnxRuntimeSessionToolCPU() = delete;
         OnnxRuntimeSessionToolCPU(const OnnxRuntimeSessionToolCPU&) = delete;
         OnnxRuntimeSessionToolCPU& operator=(const OnnxRuntimeSessionToolCPU&) = delete;
 
+        private:
+        StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
         ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+        std::unique_ptr<Ort::Session> m_session;
     };
 }
 
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
index 2c12b321b0c8..f7d39731da83 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
@@ -18,19 +18,9 @@ StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::initialize()
 {
     // Get the Onnx Runtime service.
     ATH_CHECK(m_onnxRuntimeSvc.retrieve());
-    return StatusCode::SUCCESS;
-}
-
-StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::finalize()
-{
-    return StatusCode::SUCCESS;
-}
 
-std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCUDA::createSession(
-    const std::string& modelFileName) const
-{
     ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
-
+    // Create the session options.
     Ort::SessionOptions sessionOptions;
     sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
     sessionOptions.DisablePerSessionThreads();    // use global thread pool.
@@ -62,5 +52,19 @@ std::unique_ptr<Ort::Session> AthOnnx::OnnxRuntimeSessionToolCUDA::createSession
     sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
 
     // Create the session.
-    return std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFileName.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(),  m_modelFileName.value().c_str(), sessionOptions);
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::finalize()
+{
+    m_session.reset();
+    ATH_MSG_DEBUG( "Ort::Session object deleted" );
+    return StatusCode::SUCCESS;
+}
+
+Ort::Session& AthOnnx::OnnxRuntimeSessionToolCUDA::session() const
+{
+  return *m_session;
 }
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
index 0963e99c4100..3d445ed73f50 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
@@ -29,8 +29,7 @@ namespace AthOnnx {
         virtual StatusCode finalize() override final;
 
         /// Create Onnx Runtime session
-        virtual std::unique_ptr<Ort::Session> createSession(
-            const std::string& modelFileName) const override final;
+        virtual Ort::Session& session() const override final;
 
         protected:
         OnnxRuntimeSessionToolCUDA() = delete;
@@ -38,12 +37,14 @@ namespace AthOnnx {
         OnnxRuntimeSessionToolCUDA& operator=(const OnnxRuntimeSessionToolCUDA&) = delete;
 
         private:
+        StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
         /// The device ID to use.
         IntegerProperty m_deviceId{this, "DeviceId", 0};
         BooleanProperty m_enableMemoryShrinkage{this, "EnableMemoryShrinkage", false};
 
         /// runtime service
         ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+        std::unique_ptr<Ort::Session> m_session;
     };
 }
 
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
index 1665cf49639f..4d94e68a840e 100644
--- a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
@@ -25,8 +25,7 @@ namespace AthOnnx {
         DeclareInterfaceID(IOnnxRuntimeSessionTool, 1, 0);
 
         // Create Onnx Runtime session
-        virtual std::unique_ptr<Ort::Session> createSession(
-                const std::string& modelFileName) const = 0;
+        virtual Ort::Session& session() const = 0;
 
     };
 
diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
index fb78a07abaae..83ced9501c30 100644
--- a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
+++ b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
@@ -38,9 +38,9 @@ inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& f
 // @param nodeNames The names of the input nodes in the computational graph.
 // the dataShape and nodeNames will be updated.
 void getInputNodeInfo(
-    const std::unique_ptr< Ort::Session >& session,
+    const Ort::Session& session,
     std::vector<std::vector<int64_t> >& dataShape, 
-    std::vector<const char*>& nodeNames);
+    std::vector<std::string>& nodeNames);
 
 // @brief Get the output data shape and node names (in the computational graph) from the onnx model
 // @param session The onnx session.
@@ -48,15 +48,15 @@ void getInputNodeInfo(
 // @param nodeNames The names of the output nodes in the computational graph.
 // the dataShape and nodeNames will be updated.
 void getOutputNodeInfo(
-    const std::unique_ptr< Ort::Session >& session,
+    const Ort::Session& session,
     std::vector<std::vector<int64_t> >& dataShape, 
-    std::vector<const char*>& nodeNames);
+    std::vector<std::string>& nodeNames);
 
 // Heleper function to get node info
 void getNodeInfo(
-    const std::unique_ptr< Ort::Session >& session,
+    const Ort::Session& session,
     std::vector<std::vector<int64_t> >& dataShape, 
-    std::vector<const char*>& nodeNames,
+    std::vector<std::string>& nodeNames,
     bool isInput
 );
 
@@ -66,10 +66,10 @@ int64_t getTensorSize(const std::vector<int64_t>& dataShape);
 
 // Inference with IO binding. Better for performance, particularly for GPUs.
 // See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html
-void inferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
-    const std::vector<const char*>& inputNames,
+void inferenceWithIOBinding(Ort::Session& session, 
+    const std::vector<std::string>& inputNames,
     const std::vector<Ort::Value>& inputData,
-    const std::vector<const char*>& outputNames,
+    const std::vector<std::string>& outputNames,
     const std::vector<Ort::Value>& outputData
 ); 
 
diff --git a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
index cbf0b3663341..b92ab9e1f230 100644
--- a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
+++ b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
@@ -2,53 +2,55 @@
 
 #include "AthOnnxUtils/OnnxUtils.h"
 #include <cassert>
+#include <string>
 
 namespace AthOnnx {
 
 void getNodeInfo(
-    const std::unique_ptr< Ort::Session >& session,
+    const Ort::Session& session,
     std::vector<std::vector<int64_t> >& dataShape, 
-    std::vector<const char*>& nodeNames,
+    std::vector<std::string>& nodeNames,
     bool isInput
 ){
     dataShape.clear();
     nodeNames.clear();
 
-    size_t numNodes = isInput? session->GetInputCount(): session->GetOutputCount();
+    size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
     dataShape.reserve(numNodes);
     nodeNames.reserve(numNodes);
 
     Ort::AllocatorWithDefaultOptions allocator;
     for( std::size_t i = 0; i < numNodes; i++ ) {
-        Ort::TypeInfo typeInfo = isInput? session->GetInputTypeInfo(i): session->GetOutputTypeInfo(i);
+        Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
         auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
         dataShape.emplace_back(tensorInfo.GetShape());
 
-        char* nodeName = isInput? session->GetInputNameAllocated(i, allocator).release() : session->GetOutputNameAllocated(i, allocator).release();
-        nodeNames.push_back(nodeName);
+        char* nodeName = isInput? session.GetInputNameAllocated(i, allocator).release() : session.GetOutputNameAllocated(i, allocator).release();
+        nodeNames.emplace_back(nodeName);
+        delete[] nodeName;
      }
 }
 
 void getInputNodeInfo(
-    const std::unique_ptr< Ort::Session >& session,
+    const Ort::Session& session,
     std::vector<std::vector<int64_t> >& dataShape, 
-    std::vector<const char*>& nodeNames
+    std::vector<std::string>& nodeNames
 ){
     getNodeInfo(session, dataShape, nodeNames, true);
 }
 
 void getOutputNodeInfo(
-    const std::unique_ptr< Ort::Session >& session,
+    const Ort::Session& session,
     std::vector<std::vector<int64_t> >& dataShape,
-    std::vector<const char*>& nodeNames
+    std::vector<std::string>& nodeNames
 ) {
     getNodeInfo(session, dataShape, nodeNames, false);
 }
 
-void inferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session, 
-    const std::vector<const char*>& inputNames,
+void inferenceWithIOBinding(Ort::Session& session, 
+    const std::vector<std::string>& inputNames,
     const std::vector<Ort::Value>& inputData,
-    const std::vector<const char*>& outputNames,
+    const std::vector<std::string>& outputNames,
     const std::vector<Ort::Value>& outputData){
     
     if (inputNames.empty()) {
@@ -56,17 +58,17 @@ void inferenceWithIOBinding(const std::unique_ptr<Ort::Session>& session,
     }
     assert(inputNames.size() == inputData.size());
 
-    Ort::IoBinding iobinding(*session);
+    Ort::IoBinding iobinding(session);
     for(size_t idx = 0; idx < inputNames.size(); ++idx){
-        iobinding.BindInput(inputNames[idx], inputData[idx]);
+        iobinding.BindInput(inputNames[idx].data(), inputData[idx]);
     }
 
 
     for(size_t idx = 0; idx < outputNames.size(); ++idx){
-        iobinding.BindOutput(outputNames[idx], outputData[idx]);
+        iobinding.BindOutput(outputNames[idx].data(), outputData[idx]);
     }
 
-    session->Run(Ort::RunOptions{nullptr}, iobinding);
+    session.Run(Ort::RunOptions{nullptr}, iobinding);
 }
 
 int64_t getTensorSize(const std::vector<int64_t>& dataShape){
-- 
GitLab


From 8600dea1f23789a27233c2c374e024f12c602a9d Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 6 Feb 2024 18:11:30 +0100
Subject: [PATCH 16/18] up

---
 Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
index b8f3e3e5ad69..da691ad7e237 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
@@ -49,12 +49,9 @@ namespace AthOnnx {
         ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
             this, "ORTSessionTool", 
             "AthOnnx::OnnxRuntimeInferenceToolCPU"
-        };
-        StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
-        
+        };        
         std::vector<std::string> m_inputNodeNames;
         std::vector<std::string> m_outputNodeNames;
-
     };
 } // namespace AthOnnx
 
-- 
GitLab


From 4cdd90fee8d522dc4f6fc07e0c6fe606c740e28a Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Tue, 6 Feb 2024 18:16:56 +0100
Subject: [PATCH 17/18] less code

---
 .../AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx   | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
index a8dfa4e5031b..341afe703e6d 100644
--- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
@@ -35,10 +35,10 @@ StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo()
     auto& session = m_onnxSessionTool->session();
     // obtain the model information
     m_numInputs = session.GetInputCount();
-    m_numOutputs = session.session().GetOutputCount();
+    m_numOutputs = session.GetOutputCount();
 
-    AthOnnx::getInputNodeInfo(session.session(), m_inputShapes, m_inputNodeNames);
-    AthOnnx::getOutputNodeInfo(session.session(), m_outputShapes, m_outputNodeNames);
+    AthOnnx::getInputNodeInfo(session, m_inputShapes, m_inputNodeNames);
+    AthOnnx::getOutputNodeInfo(session, m_outputShapes, m_outputNodeNames);
 
     return StatusCode::SUCCESS;
 }
-- 
GitLab


From 08aaed045163b69ac9547189445a6af39c173b30 Mon Sep 17 00:00:00 2001
From: Xiangyang Ju <xiangyang.ju@gmail.com>
Date: Wed, 7 Feb 2024 17:05:17 +0100
Subject: [PATCH 18/18] no release

---
 Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx               | 5 ++---
 .../AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx    | 2 +-
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
index b92ab9e1f230..471b12697a51 100644
--- a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
+++ b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
@@ -25,9 +25,8 @@ void getNodeInfo(
         auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
         dataShape.emplace_back(tensorInfo.GetShape());
 
-        char* nodeName = isInput? session.GetInputNameAllocated(i, allocator).release() : session.GetOutputNameAllocated(i, allocator).release();
-        nodeNames.emplace_back(nodeName);
-        delete[] nodeName;
+        auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
+        nodeNames.emplace_back(nodeName.get());
      }
 }
 
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
index 525d2387c4eb..42a58730d19e 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
@@ -81,7 +81,7 @@ namespace AthOnnx {
 	return StatusCode::FAILURE;
        }
      // read input file, and the target file for comparison.
-      ATH_MSG_INFO( "Using pixel file: " << m_pixelFileName );
+      ATH_MSG_INFO( "Using pixel file: " << m_pixelFileName.value() );
   
       m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(m_pixelFileName);
       ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
-- 
GitLab