diff --git a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f2e46aa8b6b728f803b10fcf0c0dd2c406407bb5
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt
@@ -0,0 +1,18 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package's name.
+atlas_subdir( AthOnnxComps )
+
+# External dependencies.
+find_package( onnxruntime )
+
+# Libraray in the package.
+atlas_add_component( AthOnnxComps
+   src/*.cxx
+   src/components/*.cxx
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthenaBaseComps AthenaKernel GaudiKernel  AthOnnxInterfaces AthOnnxUtilsLib
+)
+
+# install python modules
+atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} )
diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dcff38f38ecf71d62ef9f8355172722d7a2462a
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py
@@ -0,0 +1,29 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.AthConfigFlags import AthConfigFlags
+from AthenaConfiguration.Enums import FlagEnum
+
+class OnnxRuntimeType(FlagEnum):
+    CPU = 'CPU'
+    CUDA = 'CUDA'
+# possible future backends. Uncomment when implemented.
+    # DML = 'DML'
+    # DNNL = 'DNNL'
+    # NUPHAR = 'NUPHAR'
+    # OPENVINO = 'OPENVINO'
+    # ROCM = 'ROCM'
+    # TENSORRT = 'TENSORRT'
+    # VITISAI = 'VITISAI'
+    # VULKAN = 'VULKAN'
+
+def createOnnxRuntimeFlags():
+    icf = AthConfigFlags()
+
+    icf.addFlag("AthOnnx.ExecutionProvider", OnnxRuntimeType.CPU, type=OnnxRuntimeType)
+
+    return icf
+
+if __name__ == "__main__":
+
+    flags = createOnnxRuntimeFlags()
+    flags.dump()
diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c06316bdaf693431279c7046884df4a8ea940c3
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py
@@ -0,0 +1,20 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
+from typing import Optional
+from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
+
+def OnnxRuntimeInferenceToolCfg(flags, 
+                                model_fname: str = None, 
+                                execution_provider: Optional[OnnxRuntimeType] = None, 
+                                name="OnnxRuntimeInferenceTool", **kwargs):
+    """Configure OnnxRuntimeInferenceTool in Control/AthOnnx/AthOnnxComps/src"""
+
+    acc = ComponentAccumulator()
+
+    session_tool = acc.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname, execution_provider))
+    kwargs["ORTSessionTool"] = session_tool
+    acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeInferenceTool(name, **kwargs))
+    return acc
diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..32b9d4471b14b5d678f5ee056d0be97e3c163d1c
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py
@@ -0,0 +1,30 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
+from typing import Optional
+
+def OnnxRuntimeSessionToolCfg(flags,
+                              model_fname: str,
+                              execution_provider: Optional[OnnxRuntimeType] = None, 
+                              name="OnnxRuntimeSessionTool", **kwargs):
+    """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
+    
+    acc = ComponentAccumulator()
+    
+    if model_fname is None:
+        raise ValueError("model_fname must be specified")
+    
+    execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
+    name += execution_provider.name
+
+    kwargs.setdefault("ModelFileName", model_fname)
+    if execution_provider is OnnxRuntimeType.CPU:
+        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs))
+    elif execution_provider is OnnxRuntimeType.CUDA:
+        acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, **kwargs))
+    else:
+        raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider)
+
+    return acc
diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..e006fe8232c47fe424c1eb51676cde3b7156637e
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py
@@ -0,0 +1,11 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+
+def OnnxRuntimeSvcCfg(flags, name="OnnxRuntimeSvc", **kwargs):
+    acc = ComponentAccumulator()
+    
+    acc.addService(CompFactory.AthOnnx.OnnxRuntimeSvc(name, **kwargs), primary=True, create=True)
+
+    return acc
diff --git a/Control/AthOnnx/AthOnnxComps/python/__init__.py b/Control/AthOnnx/AthOnnxComps/python/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..87ae30225ac7a009fc45760d78d66deeb4b80381
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/python/__init__.py
@@ -0,0 +1 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
diff --git a/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY
new file mode 100644
index 0000000000000000000000000000000000000000..6c51dc7f379791d1b1222c9760327037a5af1d2d
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY
@@ -0,0 +1 @@
+Control/AthOnnx/AthOnnxComps
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..341afe703e6d1d2fb531536be376637afbfb445a
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx
@@ -0,0 +1,123 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "OnnxRuntimeInferenceTool.h"
+#include "AthOnnxUtils/OnnxUtils.h"
+
+AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool(
+  const std::string& type, const std::string& name, const IInterface* parent )
+  : base_class( type, name, parent )
+{
+  declareInterface<IOnnxRuntimeInferenceTool>(this);
+}
+
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize()
+{
+    // Get the Onnx Runtime service.
+    ATH_CHECK(m_onnxRuntimeSvc.retrieve());
+
+    // Create the session.
+    ATH_CHECK(m_onnxSessionTool.retrieve());
+
+    ATH_CHECK(getNodeInfo());
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::finalize()
+{
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo()
+{
+    auto& session = m_onnxSessionTool->session();
+    // obtain the model information
+    m_numInputs = session.GetInputCount();
+    m_numOutputs = session.GetOutputCount();
+
+    AthOnnx::getInputNodeInfo(session, m_inputShapes, m_inputNodeNames);
+    AthOnnx::getOutputNodeInfo(session, m_outputShapes, m_outputNodeNames);
+
+    return StatusCode::SUCCESS;
+}
+
+
+void AthOnnx::OnnxRuntimeInferenceTool::setBatchSize(int64_t batchSize)
+{
+    if (batchSize <= 0) {
+        ATH_MSG_ERROR("Batch size should be positive");
+        return;
+    }
+
+    for (auto& shape : m_inputShapes) {
+        if (shape[0] == -1) {
+            shape[0] = batchSize;
+        }
+    }
+    
+    for (auto& shape : m_outputShapes) {
+        if (shape[0] == -1) {
+            shape[0] = batchSize;
+        }
+    }
+}
+
+int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
+{
+    auto tensorSize = AthOnnx::getTensorSize(m_inputShapes[idx]);
+    if (tensorSize < 0) {
+        return inputDataSize / abs(tensorSize);
+    } else {
+        return -1;
+    }
+}
+
+StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
+{
+    assert (inputTensors.size() == m_numInputs);
+    assert (outputTensors.size() == m_numOutputs);
+
+    // Run the model.
+    AthOnnx::inferenceWithIOBinding(
+            m_onnxSessionTool->session(), 
+            m_inputNodeNames, inputTensors, 
+            m_outputNodeNames, outputTensors);
+
+    return StatusCode::SUCCESS;
+}
+
+void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo() const
+{
+    ATH_MSG_INFO("Number of inputs: " << m_numInputs);
+    ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
+
+    ATH_MSG_INFO("Input node names: ");
+    for (const auto& name : m_inputNodeNames) {
+        ATH_MSG_INFO("\t" << name);
+    }
+
+    ATH_MSG_INFO("Output node names: ");
+    for (const auto& name : m_outputNodeNames) {
+        ATH_MSG_INFO("\t" << name);
+    }
+
+    ATH_MSG_INFO("Input shapes: ");
+    for (const auto& shape : m_inputShapes) {
+        std::string shapeStr = "\t";
+        for (const auto& dim : shape) {
+            shapeStr += std::to_string(dim) + " ";
+        }
+        ATH_MSG_INFO(shapeStr);
+    }
+
+    ATH_MSG_INFO("Output shapes: ");
+    for (const auto& shape : m_outputShapes) {
+        std::string shapeStr = "\t";
+        for (const auto& dim : shape) {
+            shapeStr += std::to_string(dim) + " ";
+        }
+        ATH_MSG_INFO(shapeStr);
+    }
+}
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
new file mode 100644
index 0000000000000000000000000000000000000000..da691ad7e237ccd2fc8a5edfbd1f8d905b4a4c93
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h
@@ -0,0 +1,58 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef OnnxRuntimeInferenceTool_H
+#define OnnxRuntimeInferenceTool_H
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "GaudiKernel/ServiceHandle.h"
+#include "GaudiKernel/ToolHandle.h"
+
+namespace AthOnnx {
+    // @class OnnxRuntimeInferenceTool
+    // 
+    // @brief Tool to create Onnx Runtime session with CPU backend
+    //
+    // @author Xiangyang Ju <xiangyang.ju@cern.ch>
+    class OnnxRuntimeInferenceTool :  public extends<AthAlgTool, IOnnxRuntimeInferenceTool>
+    {
+        public:
+        /// Standard constructor
+        OnnxRuntimeInferenceTool( const std::string& type,
+                                const std::string& name,
+                                const IInterface* parent );
+        virtual ~OnnxRuntimeInferenceTool() = default;
+
+        /// Initialize the tool
+        virtual StatusCode initialize() override;
+        /// Finalize the tool
+        virtual StatusCode finalize() override;
+
+        virtual void setBatchSize(int64_t batchSize) override final;
+        virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
+
+        virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
+
+        virtual void printModelInfo() const override final;
+
+        protected:
+        OnnxRuntimeInferenceTool() = delete;
+        OnnxRuntimeInferenceTool(const OnnxRuntimeInferenceTool&) = delete;
+        OnnxRuntimeInferenceTool& operator=(const OnnxRuntimeInferenceTool&) = delete;
+
+        private:
+        StatusCode getNodeInfo();
+
+        ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+        ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
+            this, "ORTSessionTool", 
+            "AthOnnx::OnnxRuntimeInferenceToolCPU"
+        };        
+        std::vector<std::string> m_inputNodeNames;
+        std::vector<std::string> m_outputNodeNames;
+    };
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..85f1f8df115bdc9ac45129550cef7640bb888fa0
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx
@@ -0,0 +1,46 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "OnnxRuntimeSessionToolCPU.h"
+
+AthOnnx::OnnxRuntimeSessionToolCPU::OnnxRuntimeSessionToolCPU(
+  const std::string& type, const std::string& name, const IInterface* parent )
+  : base_class( type, name, parent )
+{
+  declareInterface<IOnnxRuntimeSessionTool>(this);
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::initialize()
+{
+    // Get the Onnx Runtime service.
+    ATH_CHECK(m_onnxRuntimeSvc.retrieve());
+    ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
+    
+    // Create the session options.
+    // TODO: Make this configurable.
+    // other threading options: https://onnxruntime.ai/docs/performance/tune-performance/threading.html
+    // 1) SetIntraOpNumThreads( 1 );
+    // 2) SetInterOpNumThreads( 1 );
+    // 3) SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
+
+    Ort::SessionOptions sessionOptions;
+    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
+
+    // Create the session.
+    m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::finalize()
+{
+    m_session.reset();
+    ATH_MSG_DEBUG( "Ort::Session object deleted" );
+    return StatusCode::SUCCESS;
+}
+
+Ort::Session& AthOnnx::OnnxRuntimeSessionToolCPU::session() const
+{
+  return *m_session;
+}
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
new file mode 100644
index 0000000000000000000000000000000000000000..a20e5e5d4f2ee0444b7a1812828ef1dbeac77551
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h
@@ -0,0 +1,46 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef OnnxRuntimeSessionToolCPU_H
+#define OnnxRuntimeSessionToolCPU_H
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+#include "GaudiKernel/ServiceHandle.h"
+
+namespace AthOnnx {
+    // @class OnnxRuntimeSessionToolCPU
+    // 
+    // @brief Tool to create Onnx Runtime session with CPU backend
+    //
+    // @author Xiangyang Ju <xiangyang.ju@cern.ch>
+    class OnnxRuntimeSessionToolCPU :  public extends<AthAlgTool, IOnnxRuntimeSessionTool>
+    {
+        public:
+        /// Standard constructor
+        OnnxRuntimeSessionToolCPU( const std::string& type,
+                                const std::string& name,
+                                const IInterface* parent );
+        virtual ~OnnxRuntimeSessionToolCPU() = default;
+
+        /// Initialize the tool
+        virtual StatusCode initialize() override final;
+        /// Finalize the tool
+        virtual StatusCode finalize() override final;
+
+        /// Create Onnx Runtime session
+        virtual Ort::Session& session() const override final;
+
+        protected:
+        OnnxRuntimeSessionToolCPU() = delete;
+        OnnxRuntimeSessionToolCPU(const OnnxRuntimeSessionToolCPU&) = delete;
+        OnnxRuntimeSessionToolCPU& operator=(const OnnxRuntimeSessionToolCPU&) = delete;
+
+        private:
+        StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
+        ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+        std::unique_ptr<Ort::Session> m_session;
+    };
+}
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..f7d39731da835de74267e3f6c68ceb19a7160142
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx
@@ -0,0 +1,70 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+/*
+  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "OnnxRuntimeSessionToolCUDA.h"
+
+AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA(
+    const std::string& type, const std::string& name, const IInterface* parent )
+    : base_class( type, name, parent )
+{
+  declareInterface<IOnnxRuntimeSessionTool>(this);
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::initialize()
+{
+    // Get the Onnx Runtime service.
+    ATH_CHECK(m_onnxRuntimeSvc.retrieve());
+
+    ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
+    // Create the session options.
+    Ort::SessionOptions sessionOptions;
+    sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
+    sessionOptions.DisablePerSessionThreads();    // use global thread pool.
+
+    // TODO: add more cuda options to the interface
+    // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
+    // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
+    OrtCUDAProviderOptions cuda_options;
+    cuda_options.device_id = m_deviceId;
+    cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
+    cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
+
+    // memorry arena options for cuda memory shrinkage
+    // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
+    if (m_enableMemoryShrinkage) {
+        Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
+        // other options are not available in this release.
+        // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
+        // arena_cfg.max_mem = 0;   // let ORT pick default max memory
+        // arena_cfg.arena_extend_strategy = 0;   // 0: kNextPowerOfTwo, 1: kSameAsRequested
+        // arena_cfg.initial_chunk_size_bytes = 1024;
+        // arena_cfg.max_dead_bytes_per_chunk = 0;
+        // arena_cfg.initial_growth_chunk_size_bytes = 256;
+        // arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
+
+        cuda_options.default_memory_arena_cfg = arena_cfg;
+    }
+
+    sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
+
+    // Create the session.
+    m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(),  m_modelFileName.value().c_str(), sessionOptions);
+
+    return StatusCode::SUCCESS;
+}
+
+StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::finalize()
+{
+    m_session.reset();
+    ATH_MSG_DEBUG( "Ort::Session object deleted" );
+    return StatusCode::SUCCESS;
+}
+
+Ort::Session& AthOnnx::OnnxRuntimeSessionToolCUDA::session() const
+{
+  return *m_session;
+}
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
new file mode 100644
index 0000000000000000000000000000000000000000..3d445ed73f508d775b6a574b7234ec41a9c56da6
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h
@@ -0,0 +1,51 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef OnnxRuntimeSessionToolCUDA_H
+#define OnnxRuntimeSessionToolCUDA_H
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+#include "GaudiKernel/ServiceHandle.h"
+
+namespace AthOnnx {
+    // @class OnnxRuntimeSessionToolCUDA
+    // 
+    // @brief Tool to create Onnx Runtime session with CUDA backend
+    //
+    // @author Xiangyang Ju <xiangyang.ju@cern.ch>
+    class OnnxRuntimeSessionToolCUDA :  public extends<AthAlgTool, IOnnxRuntimeSessionTool>
+    {
+        public:
+        /// Standard constructor
+        OnnxRuntimeSessionToolCUDA( const std::string& type,
+                                const std::string& name,
+                                const IInterface* parent );
+        virtual ~OnnxRuntimeSessionToolCUDA() = default;
+
+        /// Initialize the tool
+        virtual StatusCode initialize() override final;
+        /// Finalize the tool
+        virtual StatusCode finalize() override final;
+
+        /// Create Onnx Runtime session
+        virtual Ort::Session& session() const override final;
+
+        protected:
+        OnnxRuntimeSessionToolCUDA() = delete;
+        OnnxRuntimeSessionToolCUDA(const OnnxRuntimeSessionToolCUDA&) = delete;
+        OnnxRuntimeSessionToolCUDA& operator=(const OnnxRuntimeSessionToolCUDA&) = delete;
+
+        private:
+        StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
+        /// The device ID to use.
+        IntegerProperty m_deviceId{this, "DeviceId", 0};
+        BooleanProperty m_enableMemoryShrinkage{this, "EnableMemoryShrinkage", false};
+
+        /// runtime service
+        ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
+        std::unique_ptr<Ort::Session> m_session;
+    };
+}
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..c45e2dba6faad8383bb56cdf2305882e00b6e0e3
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx
@@ -0,0 +1,43 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+// Local include(s).
+#include "OnnxRuntimeSvc.h"
+#include <core/session/onnxruntime_c_api.h>
+
+namespace AthOnnx {
+  OnnxRuntimeSvc::OnnxRuntimeSvc(const std::string& name, ISvcLocator* svc) :
+      asg::AsgService(name, svc)
+   {
+     declareServiceInterface<AthOnnx::IOnnxRuntimeSvc>();
+   }
+   StatusCode OnnxRuntimeSvc::initialize() {
+
+      // Create the environment object.
+      Ort::ThreadingOptions tp_options;
+      tp_options.SetGlobalIntraOpNumThreads(1);
+      tp_options.SetGlobalInterOpNumThreads(1);
+
+      m_env = std::make_unique< Ort::Env >(
+            tp_options, ORT_LOGGING_LEVEL_WARNING, name().c_str());
+      ATH_MSG_DEBUG( "Ort::Env object created" );
+
+      // Return gracefully.
+      return StatusCode::SUCCESS;
+   }
+
+   StatusCode OnnxRuntimeSvc::finalize() {
+
+      // Dekete the environment object.
+      m_env.reset();
+      ATH_MSG_DEBUG( "Ort::Env object deleted" );
+
+      // Return gracefully.
+      return StatusCode::SUCCESS;
+   }
+
+   Ort::Env& OnnxRuntimeSvc::env() const {
+
+      return *m_env;
+   }
+
+} // namespace AthOnnx
diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h
new file mode 100644
index 0000000000000000000000000000000000000000..9f90bc998f65963e95aafedd4fc5dd41f2a312ff
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h
@@ -0,0 +1,58 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
+#define ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
+
+// Local include(s).
+#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
+
+// Framework include(s).
+#include <AsgServices/AsgService.h>
+
+// ONNX include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+// System include(s).
+#include <memory>
+
+namespace AthOnnx {
+
+   /// Service implementing @c AthOnnx::IOnnxRuntimeSvc
+   ///
+   /// This is a very simple implementation, just managing the lifetime
+   /// of some Onnx Runtime C++ objects.
+   ///
+   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
+   ///
+   class OnnxRuntimeSvc : public asg::AsgService, virtual public IOnnxRuntimeSvc {
+
+   public:
+
+      /// @name Function(s) inherited from @c Service
+      /// @{
+      OnnxRuntimeSvc (const std::string& name, ISvcLocator* svc);
+
+      /// Function initialising the service
+      virtual StatusCode initialize() override;
+      /// Function finalising the service
+      virtual StatusCode finalize() override;
+
+      /// @}
+
+      /// @name Function(s) inherited from @c AthOnnx::IOnnxRuntimeSvc
+      /// @{
+
+      /// Return the Onnx Runtime environment object
+      virtual Ort::Env& env() const override;
+
+      /// @}
+
+   private:
+      /// Global runtime environment for Onnx Runtime
+      std::unique_ptr< Ort::Env > m_env;
+
+   }; // class OnnxRuntimeSvc
+
+} // namespace AthOnnx
+
+#endif // ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
diff --git a/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..45ef514180f8b679a1f23b637bda662c88dff8b5
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx
@@ -0,0 +1,13 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+// Local include(s).
+#include "../OnnxRuntimeSvc.h"
+#include "../OnnxRuntimeSessionToolCPU.h"
+#include "../OnnxRuntimeSessionToolCUDA.h"
+#include "../OnnxRuntimeInferenceTool.h"
+
+// Declare the package's components.
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSvc )
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCPU )
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCUDA )
+DECLARE_COMPONENT( AthOnnx::OnnxRuntimeInferenceTool )
diff --git a/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY
new file mode 100644
index 0000000000000000000000000000000000000000..9524c9b32dce4369d5e24025dd71ee41784e29c9
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY
@@ -0,0 +1 @@
+Control/AthOnnx/AthOnnxInterfaces
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h
new file mode 100644
index 0000000000000000000000000000000000000000..467b25e9d76fa7470a4351b6db65ce82df75306b
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h
@@ -0,0 +1,116 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+#ifndef AthOnnx_IOnnxRuntimeInferenceTool_H
+#define AthOnnx_IOnnxRuntimeInferenceTool_H
+
+// Gaudi include(s).
+#include "GaudiKernel/IAlgTool.h"
+#include <memory>
+
+#include <core/session/onnxruntime_cxx_api.h>
+
+
+namespace AthOnnx {
+    /**
+     * @class IOnnxRuntimeInferenceTool
+     * @brief Interface class for creating Onnx Runtime sessions.
+     * @details Interface class for creating Onnx Runtime sessions.
+     * It is thread safe, supports models with various number of inputs and outputs,
+     * supports models with dynamic batch size, and usess . It defines a standardized procedure to 
+     * perform Onnx Runtime inference. The procedure is as follows, assuming the tool `m_onnxTool` is created and initialized:
+     *    1. create input tensors from the input data: 
+     *      ```c++
+     *         std::vector<Ort::Value> inputTensors;
+     *         std::vector<float> inputData_1;   // The input data is filled by users, possibly from the event information. 
+     *         int64_t batchSize = m_onnxTool->getBatchSize(inputData_1.size(), 0);  // The batch size is determined by the input data size to support dynamic batch size.
+     *         m_onnxTool->addInput(inputTensors, inputData_1, 0, batchSize);
+     *         std::vector<int64_t> inputData_2;  // Some models may have multiple inputs. Add inputs one by one.
+     *         int64_t batchSize_2 = m_onnxTool->getBatchSize(inputData_2.size(), 1);
+     *         m_onnxTool->addInput(inputTensors, inputData_2, 1, batchSize_2);
+     *     ```
+     *    2. create output tensors:
+     *      ```c++
+     *          std::vector<Ort::Value> outputTensors;
+     *          std::vector<float> outputData;   // The output data will be filled by the onnx session.
+     *          m_onnxTool->addOutput(outputTensors, outputData, 0, batchSize);
+     *      ```
+     *   3. perform inference:
+     *     ```c++
+     *        m_onnxTool->inference(inputTensors, outputTensors);
+     *    ```
+     *   4. Model outputs will be automatically filled to outputData.
+     * 
+     * 
+     * @author Xiangyang Ju <xju@cern.ch>
+     */
+    class IOnnxRuntimeInferenceTool : virtual public IAlgTool 
+    {
+        public:
+
+        virtual ~IOnnxRuntimeInferenceTool() = default;
+        
+        // @name InterfaceID
+        DeclareInterfaceID(IOnnxRuntimeInferenceTool, 1, 0);
+
+        /**
+         * @brief set batch size. 
+         * @details If the model has dynamic batch size, 
+         *          the batchSize value will be set to both input shapes and output shapes
+         */ 
+        virtual void setBatchSize(int64_t batchSize) = 0;
+
+        /**
+         * @brief methods for determining batch size from the data size
+         * @param dataSize the size of the input data, like std::vector<T>::size()
+         * @param idx the index of the input node
+         * @return the batch size, which equals to dataSize / size of the rest dimensions.
+         */ 
+        virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0;
+
+        /**
+         * @brief add the input data to the input tensors
+         * @param inputTensors the input tensor container
+         * @param data the input data
+         * @param idx the index of the input node
+         * @param batchSize the batch size
+         * @return StatusCode::SUCCESS if the input data is added successfully
+         */
+        template <typename T>
+        StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
+
+        /**
+         * @brief add the output data to the output tensors
+         * @param outputTensors the output tensor container
+         * @param data the output data
+         * @param idx the index of the output node
+         * @param batchSize the batch size
+         * @return StatusCode::SUCCESS if the output data is added successfully
+         */
+        template <typename T>
+        StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
+
+        /**
+         * @brief perform inference
+         * @param inputTensors the input tensor container
+         * @param outputTensors the output tensor container
+         * @return StatusCode::SUCCESS if the inference is performed successfully
+         */
+        virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0;
+
+        virtual void printModelInfo() const = 0;
+
+        protected:
+        int m_numInputs;
+        int m_numOutputs;
+        std::vector<std::vector<int64_t> > m_inputShapes;
+        std::vector<std::vector<int64_t> > m_outputShapes;
+
+        private:
+        template <typename T>
+        Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const;
+
+    };
+
+    #include "IOnnxRuntimeInferenceTool.icc"
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc
new file mode 100644
index 0000000000000000000000000000000000000000..1c24b11c09d8f1b5f058510eece8a93dfc25030a
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc
@@ -0,0 +1,46 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+template <typename T>
+Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
+{
+    std::vector<int64_t> dataShapeCopy = dataShape;
+
+    if (batchSize > 0) {
+        for (auto& shape: dataShapeCopy) {
+            if (shape == -1) {
+                shape = batchSize;
+                break;
+            }
+        }
+    }
+
+    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+    return Ort::Value::CreateTensor<T>(
+        memoryInfo, data.data(), data.size(), dataShapeCopy.data(), dataShapeCopy.size()
+    );
+}
+
+template <typename T>
+StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
+{
+    if (idx >= m_numInputs || idx < 0) {
+        return StatusCode::FAILURE;
+    }
+
+    inputTensors.push_back(std::move(createTensor(data, m_inputShapes[idx], batchSize)));
+    return StatusCode::SUCCESS;
+}
+
+template <typename T>
+StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
+{
+    if (idx >= m_numOutputs || idx < 0) {
+        return StatusCode::FAILURE;
+    }
+    auto tensorSize = std::accumulate(m_outputShapes[idx].begin(), m_outputShapes[idx].end(), 1, std::multiplies<int64_t>());
+    if (tensorSize < 0) {
+        tensorSize = abs(tensorSize) * batchSize;
+    }
+    data.resize(tensorSize);
+    outputTensors.push_back(std::move(createTensor(data, m_outputShapes[idx], batchSize)));
+    return StatusCode::SUCCESS;
+}
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
new file mode 100644
index 0000000000000000000000000000000000000000..4d94e68a840ea27896c436adc81c312cdc99a379
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h
@@ -0,0 +1,34 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+#ifndef AthOnnx_IOnnxRUNTIMESESSIONTool_H
+#define AthOnnx_IOnnxRUNTIMESESSIONTool_H
+
+// Gaudi include(s).
+#include "GaudiKernel/IAlgTool.h"
+
+#include <core/session/onnxruntime_cxx_api.h>
+
+
+namespace AthOnnx {
+    // class IAlgTool
+    //
+    // Interface class for creating Onnx Runtime sessions.
+    // 
+    // @author Xiangyang Ju <xju@cern.ch>
+    //
+    class IOnnxRuntimeSessionTool : virtual public IAlgTool 
+    {
+        public:
+
+        virtual ~IOnnxRuntimeSessionTool() = default;
+        
+        // @name InterfaceID
+        DeclareInterfaceID(IOnnxRuntimeSessionTool, 1, 0);
+
+        // Create Onnx Runtime session
+        virtual Ort::Session& session() const = 0;
+
+    };
+
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h
new file mode 100644
index 0000000000000000000000000000000000000000..b546392b55b40111169cd90d40a473f28d24e886
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h
@@ -0,0 +1,41 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef ATHEXOnnxRUNTIME_IOnnxRUNTIMESVC_H
+#define ATHEXOnnxRUNTIME_IOnnxRUNTIMESVC_H
+
+// Gaudi include(s).
+#include <AsgServices/IAsgService.h>
+
+// Onnx include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+
+/// Namespace holding all of the Onnx Runtime example code
+namespace AthOnnx {
+
+   //class IAsgService
+   /// Service used for managing global objects used by Onnx Runtime
+   ///
+   /// In order to allow multiple clients to use Onnx Runtime at the same
+   /// time, this service is used to manage the objects that must only
+   /// be created once in the Athena process.
+   ///
+   /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
+   ///
+   class IOnnxRuntimeSvc : virtual public asg::IAsgService{
+
+   public:
+      /// Virtual destructor, to make vtable happy
+      virtual ~IOnnxRuntimeSvc() = default;
+
+      /// Declare the interface that this class provides
+      DeclareInterfaceID (AthOnnx::IOnnxRuntimeSvc, 1, 0);
+
+      /// Return the Onnx Runtime environment object
+      virtual Ort::Env& env() const = 0;
+
+   }; // class IOnnxRuntimeSvc
+
+} // namespace AthOnnx
+
+#endif
diff --git a/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt b/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..72b84c814183a8534b0459a35183f015c976e276
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt
@@ -0,0 +1,12 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package name:
+atlas_subdir( AthOnnxInterfaces )
+
+# Component(s) in the package:
+atlas_add_library( AthOnnxInterfaces
+                   AthOnnxInterfaces/*.h
+                   INTERFACE
+                   PUBLIC_HEADERS AthOnnxInterfaces
+                   LINK_LIBRARIES AsgTools AthLinks GaudiKernel 
+                  )
diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..83ced9501c30d5278a5ac557d61ab66203e5f678
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h
@@ -0,0 +1,81 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#ifndef Onnx_UTILS_H
+#define Onnx_UTILS_H
+
+#include <memory>
+#include <vector>
+
+// Onnx Runtime include(s).
+#include <core/session/onnxruntime_cxx_api.h>
+
+namespace AthOnnx {
+
+// @author Xiangyang Ju <xiangyang.ju@cern.ch>
+
+// @brief Convert a vector of vectors to a single vector.
+// @param features The vector of vectors to be flattened.
+// @return A single vector containing all the elements of the input vector of vectors.
+template<typename T>
+inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& features) {
+  // 1. Compute the total size required.
+  int total_size = 0;
+  for (const auto& feature : features) total_size += feature.size();
+  
+  std::vector<T> flatten1D;
+  flatten1D.reserve(total_size);
+
+  for (const auto& feature : features)
+    for (const auto& elem : feature)
+      flatten1D.push_back(elem);
+
+  return flatten1D;
+}
+
+// @brief Get the input data shape and node names (in the computational graph) from the onnx model
+// @param session The onnx session.
+// @param dataShape The shape of the input data. Note that there may be multiple inputs.
+// @param nodeNames The names of the input nodes in the computational graph.
+// the dataShape and nodeNames will be updated.
+void getInputNodeInfo(
+    const Ort::Session& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<std::string>& nodeNames);
+
+// @brief Get the output data shape and node names (in the computational graph) from the onnx model
+// @param session The onnx session.
+// @param dataShape The shape of the output data.
+// @param nodeNames The names of the output nodes in the computational graph.
+// the dataShape and nodeNames will be updated.
+void getOutputNodeInfo(
+    const Ort::Session& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<std::string>& nodeNames);
+
+// Heleper function to get node info
+void getNodeInfo(
+    const Ort::Session& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<std::string>& nodeNames,
+    bool isInput
+);
+
+// @brief to count the total number of elements in a tensor
+// They are useful for reserving spaces for the output data.
+int64_t getTensorSize(const std::vector<int64_t>& dataShape);
+
+// Inference with IO binding. Better for performance, particularly for GPUs.
+// See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html
+void inferenceWithIOBinding(Ort::Session& session, 
+    const std::vector<std::string>& inputNames,
+    const std::vector<Ort::Value>& inputData,
+    const std::vector<std::string>& outputNames,
+    const std::vector<Ort::Value>& outputData
+); 
+
+// @brief Create a tensor from a vector of data and its shape.
+Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape);
+
+
+}
+#endif
diff --git a/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..029f7e5a1ce8f947ed8bef66aa5b0367440fd325
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package's name.
+atlas_subdir( AthOnnxUtils )
+
+# External dependencies.
+find_package( onnxruntime )
+
+# Component(s) in the package.
+atlas_add_library( AthOnnxUtilsLib
+   AthOnnxUtils/*.h 
+   src/*.cxx
+   PUBLIC_HEADERS AthOnnxUtils
+   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaKernel GaudiKernel AsgServicesLib 
+)
diff --git a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..471b12697a518699cb1aa3b315d5563c397720c8
--- /dev/null
+++ b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx
@@ -0,0 +1,94 @@
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+#include "AthOnnxUtils/OnnxUtils.h"
+#include <cassert>
+#include <string>
+
+namespace AthOnnx {
+
+void getNodeInfo(
+    const Ort::Session& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<std::string>& nodeNames,
+    bool isInput
+){
+    dataShape.clear();
+    nodeNames.clear();
+
+    size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
+    dataShape.reserve(numNodes);
+    nodeNames.reserve(numNodes);
+
+    Ort::AllocatorWithDefaultOptions allocator;
+    for( std::size_t i = 0; i < numNodes; i++ ) {
+        Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
+        auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
+        dataShape.emplace_back(tensorInfo.GetShape());
+
+        auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
+        nodeNames.emplace_back(nodeName.get());
+     }
+}
+
+void getInputNodeInfo(
+    const Ort::Session& session,
+    std::vector<std::vector<int64_t> >& dataShape, 
+    std::vector<std::string>& nodeNames
+){
+    getNodeInfo(session, dataShape, nodeNames, true);
+}
+
+void getOutputNodeInfo(
+    const Ort::Session& session,
+    std::vector<std::vector<int64_t> >& dataShape,
+    std::vector<std::string>& nodeNames
+) {
+    getNodeInfo(session, dataShape, nodeNames, false);
+}
+
+void inferenceWithIOBinding(Ort::Session& session, 
+    const std::vector<std::string>& inputNames,
+    const std::vector<Ort::Value>& inputData,
+    const std::vector<std::string>& outputNames,
+    const std::vector<Ort::Value>& outputData){
+    
+    if (inputNames.empty()) {
+        throw std::runtime_error("Onnxruntime input data maping cannot be empty");
+    }
+    assert(inputNames.size() == inputData.size());
+
+    Ort::IoBinding iobinding(session);
+    for(size_t idx = 0; idx < inputNames.size(); ++idx){
+        iobinding.BindInput(inputNames[idx].data(), inputData[idx]);
+    }
+
+
+    for(size_t idx = 0; idx < outputNames.size(); ++idx){
+        iobinding.BindOutput(outputNames[idx].data(), outputData[idx]);
+    }
+
+    session.Run(Ort::RunOptions{nullptr}, iobinding);
+}
+
+int64_t getTensorSize(const std::vector<int64_t>& dataShape){
+    int64_t size = 1;
+    for (const auto& dim : dataShape) {
+            size *= dim;
+    }
+    return size;
+}
+
+Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape)
+{
+    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); 
+
+    return Ort::Value::CreateTensor<float>(
+                                memoryInfo, 
+                                data.data(), 
+                                data.size(),  
+                                dataShape.data(), 
+                                dataShape.size());
+};
+
+
+} // namespace AthOnnx
\ No newline at end of file
diff --git a/Control/AthenaConfiguration/python/AllConfigFlags.py b/Control/AthenaConfiguration/python/AllConfigFlags.py
index 96efa0f8277a3b40e23382f65fd7bf5a18ceb5b8..f6946d79a7b89280a00b82089afb6bfec757cd97 100644
--- a/Control/AthenaConfiguration/python/AllConfigFlags.py
+++ b/Control/AthenaConfiguration/python/AllConfigFlags.py
@@ -495,6 +495,12 @@ def initConfigFlags():
         return createLLPDFConfigFlags()
     _addFlagsCategory(acf, "Derivation.LLP", __llpDerivation, 'DerivationFrameworkLLP' )
 
+    # onnxruntime flags
+    def __onnxruntime():
+        from AthOnnxComps.OnnxRuntimeFlags import createOnnxRuntimeFlags
+        return createOnnxRuntimeFlags()
+    _addFlagsCategory(acf, "AthOnnx", __onnxruntime, 'AthOnnxComps')
+
     # For AnalysisBase, pick up things grabbed in Athena by the functions above
     if not isGaudiEnv():
         def EDMVersion(flags):
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
index 29281dfe4a6e390fb5f9925eebdaf0382ce83298..02249a193fcf65b0c0a5162f846c8a14dfda35a4 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
+++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt
@@ -1,28 +1,31 @@
-# Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
 # Declare the package's name.
 atlas_subdir( AthExOnnxRuntime )
 
-# Component(s) in the package.
-atlas_add_library( AthExOnnxRuntimeLib
-   INTERFACE
-   PUBLIC_HEADERS AthExOnnxRuntime
-   INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} GaudiKernel )
+find_package( onnxruntime )
 
+# Component(s) in the package.
 atlas_add_component( AthExOnnxRuntime
    src/*.h src/*.cxx src/components/*.cxx
    INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
-   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthExOnnxRuntimeLib AthenaBaseComps GaudiKernel PathResolver AthOnnxruntimeServiceLib AthOnnxruntimeUtilsLib)
+   LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaBaseComps GaudiKernel PathResolver 
+   AthOnnxInterfaces AthOnnxUtilsLib AsgServicesLib
+)
 
-# Install files from the package.
-atlas_install_joboptions( share/*.py )
+# Install files from the package:
+atlas_install_python_modules( python/*.py
+                              POST_BUILD_CMD ${ATLAS_FLAKE8} )
 
-# Set up tests for the package.
-atlas_add_test( AthExOnnxRuntimeJob_serial
-   SCRIPT athena.py AthExOnnxRuntime/AthExOnnxRuntime_jobOptions.py
-   POST_EXEC_SCRIPT nopost.sh )
+# Test the packages
+atlas_add_test( AthExOnnxRuntimeTest
+   SCRIPT python -m AthExOnnxRuntime.AthExOnnxRuntime_test
+   PROPERTIES TIMEOUT 100
+   POST_EXEC_SCRIPT noerror.sh 
+   )
 
-atlas_add_test( AthExOnnxRuntimeJob_mt
-   SCRIPT athena.py --threads=2 AthExOnnxRuntime/AthExOnnxRuntime_jobOptions.py
-   POST_EXEC_SCRIPT nopost.sh )
+atlas_add_test( AthExOnnxRuntimeTest_pkl
+   SCRIPT athena --evtMax 2 --CA test_AthExOnnxRuntimeExampleCfg.pkl
+   DEPENDS AthExOnnxRuntimeTest
+   PROPERTIES TIMEOUT 100
+   POST_EXEC_SCRIPT noerror.sh )
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d243fbc6feb21ae87abbc09e244d8a3190a867c
--- /dev/null
+++ b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py
@@ -0,0 +1,45 @@
+# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+from AthenaCommon import Constants
+from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType 
+
+
+def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
+    acc = ComponentAccumulator()
+
+    model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx"
+    execution_provider = OnnxRuntimeType.CPU
+    from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
+    kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
+        OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider)
+    ))
+
+    input_data = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
+    kwargs.setdefault("BatchSize", 3)
+    kwargs.setdefault("InputDataPixel", input_data)
+    kwargs.setdefault("OutputLevel", Constants.DEBUG)
+    acc.addEventAlgo(CompFactory.AthOnnx.EvaluateModel(name, **kwargs))
+
+    return acc
+
+if __name__ == "__main__":
+    from AthenaCommon.Logging import log as msg
+    from AthenaConfiguration.AllConfigFlags import initConfigFlags
+    from AthenaConfiguration.MainServicesConfig import MainServicesCfg
+
+    msg.setLevel(Constants.DEBUG)
+
+    flags = initConfigFlags()
+    flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU
+    flags.lock()
+
+    acc = MainServicesCfg(flags)
+    acc.merge(AthExOnnxRuntimeExampleCfg(flags))
+    acc.printConfig(withDetails=True, summariseProps=True)
+
+    acc.store(open('test_AthExOnnxRuntimeExampleCfg.pkl','wb'))
+
+    import sys
+    sys.exit(acc.run(2).isFailure())
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py b/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py
deleted file mode 100644
index a35ab1f40e4506f22e532e7727af7e1436add528..0000000000000000000000000000000000000000
--- a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
-
-# Set up / access the algorithm sequence.
-from AthenaCommon.AlgSequence import AlgSequence
-algSequence = AlgSequence()
-
-# Set up the job.
-from AthExOnnxRuntime.AthExOnnxRuntimeConf import AthONNX__EvaluateModel
-from AthOnnxruntimeService.AthOnnxruntimeServiceConf import AthONNX__ONNXRuntimeSvc
-
-from AthenaCommon.AppMgr import ServiceMgr
-ServiceMgr += AthONNX__ONNXRuntimeSvc( OutputLevel = DEBUG )
-algSequence += AthONNX__EvaluateModel("AthONNX")
-
-# Get a	random no. between 0 to	10k for	test sample
-from random import randint
-
-AthONNX = algSequence.AthONNX
-AthONNX.TestSample = randint(0, 9999)
-AthONNX.DoBatches = False
-AthONNX.NumberOfBatches = 1
-AthONNX.SizeOfBatch = 1
-AthONNX.OutputLevel = DEBUG
-AthONNX.UseCUDA = False
-
-# Run for 10 "events".
-theApp.EvtMax = 2
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
index fb1b6f25e6252467b9a9fb8439007dd9ce6dbd3c..42a58730d19e4094f0dbfa374f94b88d2b1e6687 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx
@@ -1,14 +1,16 @@
-// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
 // Local include(s).
 #include "EvaluateModel.h"
 #include <tuple>
+#include <fstream>
+#include <chrono>
+#include <arpa/inet.h>
 
 // Framework include(s).
-#include "PathResolver/PathResolver.h"
-#include "AthOnnxruntimeUtils/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 
-namespace AthONNX {
+namespace AthOnnx {
 
    //*******************************************************************
    // for reading MNIST images
@@ -66,195 +68,74 @@ namespace AthONNX {
     }
 
    StatusCode EvaluateModel::initialize() {
-
-      // Access the service.
-      //ATH_CHECK( m_svc.retrieve() );
+    // Fetch tools
+    ATH_CHECK( m_onnxTool.retrieve() );
+    m_onnxTool->printModelInfo();
 
       /*****
        The combination of no. of batches and batch size shouldn't cross 
        the total smple size which is 10000 for this example
       *****/         
-      if(m_doBatches && (m_numberOfBatches*m_sizeOfBatch)>10000){
+      if(m_batchSize > 10000){
         ATH_MSG_INFO("The total no. of sample crossed the no. of available sample ....");
 	return StatusCode::FAILURE;
-      }
-      // Find the model file.
-      const std::string modelFileName =
-         PathResolverFindCalibFile( m_modelFileName );
-      const std::string pixelFileName =
-         PathResolverFindCalibFile( m_pixelFileName );
-      const std::string labelFileName =
-         PathResolverFindCalibFile( m_labelFileName );
-      ATH_MSG_INFO( "Using model file: " << modelFileName );
-      ATH_MSG_INFO( "Using pixel file: " << pixelFileName );
-      ATH_MSG_INFO( "Using pixel file: " << labelFileName );
-      // Set up the ONNX Runtime session.
+       }
+     // read input file, and the target file for comparison.
+      ATH_MSG_INFO( "Using pixel file: " << m_pixelFileName.value() );
   
-      m_session = AthONNX::CreateORTSession(modelFileName, m_useCUDA);
-
-      if (m_useCUDA) {
-         ATH_MSG_INFO( "Created the ONNX Runtime session on CUDA" );
-      } else {
-         ATH_MSG_INFO( "Created the ONNX Runtime session on CPUs" );
-      }
-      
-      m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(pixelFileName);
-      std::vector<std::vector<float>> c = m_input_tensor_values_notFlat[0];
-      m_output_tensor_values = read_mnist_label(labelFileName);    
-      // Return gracefully.
+      m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(m_pixelFileName);
+      ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
+    
       return StatusCode::SUCCESS;
-   }
+}
 
    StatusCode EvaluateModel::execute( const EventContext& /*ctx*/ ) const {
-     
-     Ort::AllocatorWithDefaultOptions allocator;
-  
-     /************************** Input Nodes *****************************/
-     /*********************************************************************/
-     
-     std::tuple<std::vector<int64_t>, std::vector<const char*> > inputInfo = AthONNX::GetInputNodeInfo(m_session);
-     std::vector<int64_t> input_node_dims = std::get<0>(inputInfo);
-     std::vector<const char*> input_node_names = std::get<1>(inputInfo);
-     
-     for( std::size_t i = 0; i < input_node_names.size(); i++ ) {
-        // print input node names
-        ATH_MSG_DEBUG("Input "<<i<<" : "<<" name= "<<input_node_names[i]);
    
-        // print input shapes/dims
-        ATH_MSG_DEBUG("Input "<<i<<" : num_dims= "<<input_node_dims.size());
-        for (std::size_t j = 0; j < input_node_dims.size(); j++){
-           ATH_MSG_DEBUG("Input "<<i<<" : dim "<<j<<"= "<<input_node_dims[j]);
-          }
-       }
-
-     /************************** Output Nodes *****************************/
-     /*********************************************************************/
-     
-     std::tuple<std::vector<int64_t>, std::vector<const char*> > outputInfo = AthONNX::GetOutputNodeInfo(m_session);
-     std::vector<int64_t> output_node_dims = std::get<0>(outputInfo);
-     std::vector<const char*> output_node_names = std::get<1>(outputInfo);
-
-     for( std::size_t i = 0; i < output_node_names.size(); i++ ) {
-        // print input node names
-        ATH_MSG_DEBUG("Output "<<i<<" : "<<" name= "<<output_node_names[i]);
-
-        // print input shapes/dims
-        ATH_MSG_DEBUG("Output "<<i<<" : num_dims= "<<output_node_dims.size());
-        for (std::size_t j = 0; j < output_node_dims.size(); j++){
-           ATH_MSG_DEBUG("Output "<<i<<" : dim "<<j<<"= "<<output_node_dims[j]);
-          }
-       }
-    /************************* Score if input is not a batch ********************/
-    /****************************************************************************/
-     if(m_doBatches == false){
-
-        /**************************************************************************************
-         * input_node_dims[0] = -1; -1 needs to be replaced by the batch size; for no batch is 1 
-         * input_node_dims[1] = 28
-         * input_node_dims[2] = 28
-        ****************************************************************************************/
-
-     	input_node_dims[0] = 1;
-     	output_node_dims[0] = 1;
- 
-       /***************** Choose an example sample randomly ****************************/  
-     	std::vector<std::vector<float>> input_tensor_values = m_input_tensor_values_notFlat[m_testSample];
-        std::vector<float> flatten = AthONNX::FlattenInput_multiD_1D(input_tensor_values);
-        // Output label of corresponding m_input_tensor_values[m_testSample]; e.g 0, 1, 2, 3 etc
-        int output_tensor_values = m_output_tensor_values[m_testSample];
-       
-        // For a check that the sample dimension is fully flatten (1x28x28 = 784)
-        ATH_MSG_DEBUG("Size of Flatten Input tensor: "<<flatten.size());
+   // prepare inputs
+   std::vector<float> inputData;
+   for (int ibatch = 0; ibatch < m_batchSize; ibatch++){
+      const std::vector<std::vector<float> >& imageData = m_input_tensor_values_notFlat[ibatch];
+      std::vector<float> flatten = AthOnnx::flattenNestedVectors(imageData);
+      inputData.insert(inputData.end(), flatten.begin(), flatten.end());
+   }
 
-     	/************** Create input tensor object from input data values to feed into your model *********************/
-        
-        Ort::Value input_tensor = AthONNX::TensorCreator(flatten, input_node_dims );
+   int64_t batchSize = m_onnxTool->getBatchSize(inputData.size());
+   ATH_MSG_INFO("Batch size is " << batchSize << ".");
+   assert(batchSize == m_batchSize);
 
-        /********* Convert 784 elements long flattened 1D array to 3D (1, 28, 28) onnx compatible tensor ************/
-        ATH_MSG_DEBUG("Input tensor size after converted to Ort tensor: "<<input_tensor.GetTensorTypeAndShapeInfo().GetShape());     	
-        // Makes sure input tensor has same dimensions as input layer of the model
-        assert(input_tensor.IsTensor()&&
-     		input_tensor.GetTensorTypeAndShapeInfo().GetShape() == input_node_dims);
+   // bind the input data to the input tensor
+   std::vector<Ort::Value> inputTensors;
+   ATH_CHECK( m_onnxTool->addInput(inputTensors, inputData, 0, batchSize) );
 
-     	/********* Score model by feeding input tensor and get output tensor in return *****************************/
+   // reserve space for output data and bind it to the output tensor
+   std::vector<float> outputScores;
+   std::vector<Ort::Value> outputTensors;
+   ATH_CHECK( m_onnxTool->addOutput(outputTensors, outputScores, 0, batchSize) );
 
-        float* floatarr = AthONNX::Inference(m_session, input_node_names, input_tensor, output_node_names); 
+   // run the inference
+   // the output will be filled to the outputScores.
+   ATH_CHECK( m_onnxTool->inference(inputTensors, outputTensors) );
 
-     	// show  true label for the test input
-     	ATH_MSG_INFO("Label for the input test data  = "<<output_tensor_values);
+     	ATH_MSG_INFO("Label for the input test data: ");
+   for(int ibatch = 0; ibatch < m_batchSize; ibatch++){
      	float max = -999;
      	int max_index;
      	for (int i = 0; i < 10; i++){
-       		ATH_MSG_DEBUG("Score for class "<<i<<" = "<<floatarr[i]);
-       		if (max<floatarr[i]){
-          		max = floatarr[i];
-          		max_index = i;
+       		ATH_MSG_DEBUG("Score for class "<< i <<" = "<<outputScores[i] << " in batch " << ibatch);
+            int index = i + ibatch * 10;
+       		if (max < outputScores[index]){
+          		max = outputScores[index];
+          		max_index = index;
        		}
      	}
-     	ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]);
-     
-     } // m_doBatches == false codition ends   
-    /************************* Score if input is a batch ********************/
-    /****************************************************************************/
-    else {
-        /**************************************************************************************
-         Similar scoring structure like non batch execution but 1st demention needs to be replaced by batch size
-         for this example lets take 3 batches with batch size 5
-         * input_node_dims[0] = 5
-         * input_node_dims[1] = 28
-         * input_node_dims[2] = 28
-        ****************************************************************************************/
-        
-     	input_node_dims[0] = m_sizeOfBatch;
-     	output_node_dims[0] = m_sizeOfBatch;   
-     
-	/************************** process multiple batches ********************************/
-        int l =0; /****** variable for distributing rows in m_input_tensor_values_notFlat equally into batches*****/ 
-     	for (int i = 0; i < m_numberOfBatches; i++) {
-     		ATH_MSG_DEBUG("Processing batch #" << i);
-      		std::vector<float> batch_input_tensor_values;
-      		for (int j = l; j < l+m_sizeOfBatch; j++) {
-                         
-                        std::vector<float> flattened_input = AthONNX::FlattenInput_multiD_1D(m_input_tensor_values_notFlat[j]);
-                        /******************For each batch we need a flattened (5 x 28 x 28 = 3920) 1D array******************************/
-        		batch_input_tensor_values.insert(batch_input_tensor_values.end(), flattened_input.begin(), flattened_input.end());
-        	}   
-
-                Ort::Value batch_input_tensors = AthONNX::TensorCreator(batch_input_tensor_values, input_node_dims );
-
-		// Get pointer to output tensor float values
+      ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch);
+   }
 
-                float* floatarr = AthONNX::Inference(m_session, input_node_names, batch_input_tensors, output_node_names);
-     		// show  true label for the test input
-		for(int i = l; i<l+m_sizeOfBatch; i++){
-     			ATH_MSG_INFO("Label for the input test data  = "<<m_output_tensor_values[i]);
-                	int k = (i-l)*10;
-                	float max = -999;
-                	int max_index = 0;
-     			for (int j =k ; j < k+10; j++){
-       				ATH_MSG_INFO("Score for class "<<j-k<<" = "<<floatarr[j]);
-       				if (max<floatarr[j]){
-          				max = floatarr[j];
-          				max_index = j;
-       				}
-     			}
-    	       	ATH_MSG_INFO("Class: "<<max_index-k<<" has the highest score: "<<floatarr[max_index]);
-       		} 
-           l = l+m_sizeOfBatch;
-          }
-     } // else/m_doBatches == True codition ends
-    // Return gracefully.
       return StatusCode::SUCCESS;
    }
    StatusCode EvaluateModel::finalize() {
 
-      // Delete the session object.
-      m_session.reset();
-
-      // Return gracefully.
       return StatusCode::SUCCESS;
    }
 
-} // namespace AthONNX
-
-
+} // namespace AthOnnx
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
index aaae60dc28581c43837484e25d969fdc3aac2abf..d660543e5e584a75c390bc19aa4181cb722daf6f 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h
@@ -4,24 +4,21 @@
 #define ATHEXONNXRUNTIME_EVALUATEMODEL_H
 
 // Local include(s).
-#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
+
 // Framework include(s).
 #include "AthenaBaseComps/AthReentrantAlgorithm.h"
 #include "GaudiKernel/ServiceHandle.h"
 
-// ONNX Runtime include(s).
+// Onnx Runtime include(s).
 #include <core/session/onnxruntime_cxx_api.h>
 
 // System include(s).
 #include <memory>
 #include <string>
-#include <iostream> 
-#include <fstream>
-#include <arpa/inet.h>
 #include <vector>
-#include <iterator>
 
-namespace AthONNX {
+namespace AthOnnx {
 
    /// Algorithm demonstrating the usage of the ONNX Runtime C++ API
    ///
@@ -53,31 +50,22 @@ namespace AthONNX {
       /// @{
 
       /// Name of the model file to load
-      Gaudi::Property< std::string > m_modelFileName{ this, "ModelFileName",
-         "dev/MLTest/2020-03-02/MNIST_testModel.onnx",
-         "Name of the model file to load" };
       Gaudi::Property< std::string > m_pixelFileName{ this, "InputDataPixel",
          "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte",
          "Name of the input pixel file to load" };
-      Gaudi::Property< std::string > m_labelFileName{ this, "InputDataLabel",
-         "dev/MLTest/2020-03-31/t10k-labels-idx1-ubyte",
-         "Name of the label file to load" };
-      Gaudi::Property<int> m_testSample {this, "TestSample", 0, "A Random Test Sample"};
 
       /// Following properties needed to be consdered if the .onnx model is evaluated in batch mode
-      Gaudi::Property<bool> m_doBatches {this, "DoBatches", false, "Processing events by batches"};
-      Gaudi::Property<int> m_numberOfBatches {this, "NumberOfBatches", 1, "No. of batches to be passed"};
-      Gaudi::Property<int> m_sizeOfBatch {this, "SizeOfBatch", 1, "No. of elements/example in a batch"};
+      Gaudi::Property<int> m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"};
 
-      // If runs on CUDA
-      Gaudi::Property<bool> m_useCUDA {this, "UseCUDA", false, "Use CUDA"};
+      /// Tool handler for onnx inference session
+      ToolHandle< IOnnxRuntimeInferenceTool >  m_onnxTool{
+         this, "ORTInferenceTool", "AthOnnx::OnnxRuntimeInferenceTool"
+      };
       
-      std::unique_ptr< Ort::Session > m_session;
       std::vector<std::vector<std::vector<float>>> m_input_tensor_values_notFlat;
-      std::vector<int> m_output_tensor_values;
 
    }; // class EvaluateModel
 
-} // namespace AthONNX
+} // namespace AthOnnx
 
 #endif // ATHEXONNXRUNTIME_EVALUATEMODEL_H
diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx
index cb25a4bfd9d2e30534c0158a17ae01e6ddd719ee..c89bab7da1362ad847381398705168c3c9e8896d 100644
--- a/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx
+++ b/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx
@@ -3,4 +3,4 @@
 // Local include(s).
 #include "../EvaluateModel.h"
 // Declare the package's components.
-DECLARE_COMPONENT( AthONNX::EvaluateModel )
+DECLARE_COMPONENT( AthOnnx::EvaluateModel )
diff --git a/InnerDetector/InDetGNNTracking/CMakeLists.txt b/InnerDetector/InDetGNNTracking/CMakeLists.txt
index 2bcff0db22e5442b9126c9529931c23e492cd461..daa184c964735fcf3f488b0365b7a1b5ca476a1a 100644
--- a/InnerDetector/InDetGNNTracking/CMakeLists.txt
+++ b/InnerDetector/InDetGNNTracking/CMakeLists.txt
@@ -18,6 +18,7 @@ atlas_add_component( InDetGNNTracking
     PixelReadoutGeometryLib SCT_ReadoutGeometry InDetSimData
     InDetPrepRawData TrkTrack TrkRIO_OnTrack InDetSimEvent
     AtlasHepMCLib InDetRIO_OnTrack InDetRawData TrkTruthData
+    AthOnnxInterfaces AthOnnxUtilsLib 
 )
 
 atlas_install_python_modules( python/*.py )
diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
index a48e9ea21a01f8e12a0b54c156de12f2edc772ca..e07013addd91ad7cc6c8f05f482a3a944717dcd3 100644
--- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
+++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx
@@ -7,7 +7,7 @@
 
 // Framework include(s).
 #include "PathResolver/PathResolver.h"
-#include "AthOnnxruntimeUtils/OnnxUtils.h"
+#include "AthOnnxUtils/OnnxUtils.h"
 #include <cmath>
 
 InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool(
@@ -18,20 +18,12 @@ InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool(
   }
 
 StatusCode InDet::SiGNNTrackFinderTool::initialize() {
-  initTrainedModels();
+  ATH_CHECK( m_embedSessionTool.retrieve() );
+  ATH_CHECK( m_filterSessionTool.retrieve() );
+  ATH_CHECK( m_gnnSessionTool.retrieve() );
   return StatusCode::SUCCESS;
 }
 
-void InDet::SiGNNTrackFinderTool::initTrainedModels() {
-  std::string embedModelPath(m_inputMLModuleDir + "/torchscript/embedding.onnx");
-  std::string filterModelPath(m_inputMLModuleDir + "/torchscript/filtering.onnx");
-  std::string gnnModelPath(m_inputMLModuleDir + "/torchscript/gnn.onnx");
-
-  m_embedSession = AthONNX::CreateORTSession(embedModelPath, m_useCUDA);
-  m_filterSession = AthONNX::CreateORTSession(filterModelPath, m_useCUDA);
-  m_gnnSession = AthONNX::CreateORTSession(gnnModelPath, m_useCUDA);
-}
-
 StatusCode InDet::SiGNNTrackFinderTool::finalize() {
   StatusCode sc = AlgTool::finalize();
   return sc;
@@ -59,7 +51,7 @@ MsgStream& InDet::SiGNNTrackFinderTool::dumpevent( MsgStream& out ) const
   return out;
 }
 
-void InDet::SiGNNTrackFinderTool::getTracks (
+StatusCode InDet::SiGNNTrackFinderTool::getTracks(
   const std::vector<const Trk::SpacePoint*>& spacepoints,
   std::vector<std::vector<uint32_t> >& tracks) const
 {
@@ -84,34 +76,19 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     spacepointIDs.push_back(sp_idx++);
   }
 
-    Ort::AllocatorWithDefaultOptions allocator;
-    auto memoryInfo = Ort::MemoryInfo::CreateCpu(
-        OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
-
     // ************
     // Embedding
     // ************
 
     std::vector<int64_t> eInputShape{numSpacepoints, spacepointFeatures};
-
-    std::vector<const char*> eInputNames{"sp_features"};
     std::vector<Ort::Value> eInputTensor;
-    eInputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, inputValues.data(), inputValues.size(),
-            eInputShape.data(), eInputShape.size())
-    );
+    ATH_CHECK( m_embedSessionTool->addInput(eInputTensor, inputValues, 0, numSpacepoints) );
 
-    std::vector<float> eOutputData(numSpacepoints * m_embeddingDim);
-    std::vector<const char*> eOutputNames{"embedding_output"};
-    std::vector<int64_t> eOutputShape{numSpacepoints, m_embeddingDim};
     std::vector<Ort::Value> eOutputTensor;
-    eOutputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, eOutputData.data(), eOutputData.size(),
-            eOutputShape.data(), eOutputShape.size())
-    );
-    AthONNX::InferenceWithIOBinding(m_embedSession, eInputNames, eInputTensor, eOutputNames, eOutputTensor);
+    std::vector<float> eOutputData;
+    ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) );
+
+    ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) );
 
     // ************
     // Building Edges
@@ -123,29 +100,18 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     // ************
     // Filtering
     // ************
-    std::vector<const char*> fInputNames{"f_nodes", "f_edges"};
     std::vector<Ort::Value> fInputTensor;
     fInputTensor.push_back(
         std::move(eInputTensor[0])
     );
+    ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) );
     std::vector<int64_t> fEdgeShape{2, numEdges};
-    fInputTensor.push_back(
-        Ort::Value::CreateTensor<int64_t>(
-            memoryInfo, edgeList.data(), edgeList.size(),
-            fEdgeShape.data(), fEdgeShape.size())
-    );
 
-    // filtering outputs
-    std::vector<const char*> fOutputNames{"f_edge_score"};
-    std::vector<float> fOutputData(numEdges);
-    std::vector<int64_t> fOutputShape{numEdges, 1};
+    std::vector<float> fOutputData;
     std::vector<Ort::Value> fOutputTensor;
-    fOutputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, fOutputData.data(), fOutputData.size(), 
-            fOutputShape.data(), fOutputShape.size())
-    );
-    AthONNX::InferenceWithIOBinding(m_filterSession, fInputNames, fInputTensor, fOutputNames, fOutputTensor);
+    ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) );
+
+    ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) );
 
     // apply sigmoid to the filtering output data
     // and remove edges with score < filterCut
@@ -167,28 +133,18 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     // ************
     // GNN
     // ************
-    std::vector<const char*> gInputNames{"g_nodes", "g_edges"};
     std::vector<Ort::Value> gInputTensor;
     gInputTensor.push_back(
         std::move(fInputTensor[0])
     );
-    std::vector<int64_t> gEdgeShape{2, numEdgesAfterF};
-    gInputTensor.push_back(
-        Ort::Value::CreateTensor<int64_t>(
-            memoryInfo, edgesAfterFiltering.data(), edgesAfterFiltering.size(),
-            gEdgeShape.data(), gEdgeShape.size())
-    );
+    ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) );
+    
     // gnn outputs
-    std::vector<const char*> gOutputNames{"gnn_edge_score"};
-    std::vector<float> gOutputData(numEdgesAfterF);
-    std::vector<int64_t> gOutputShape{numEdgesAfterF};
+    std::vector<float> gOutputData;
     std::vector<Ort::Value> gOutputTensor;
-    gOutputTensor.push_back(
-        Ort::Value::CreateTensor<float>(
-            memoryInfo, gOutputData.data(), gOutputData.size(), 
-            gOutputShape.data(), gOutputShape.size())
-    );
-    AthONNX::InferenceWithIOBinding(m_gnnSession, gInputNames, gInputTensor, gOutputNames, gOutputTensor);
+    ATH_CHECK( m_gnnSessionTool->addOutput(gOutputTensor, gOutputData, 0, numEdgesAfterF) );
+
+    ATH_CHECK( m_gnnSessionTool->inference(gInputTensor, gOutputTensor) );
     // apply sigmoid to the gnn output data
     for(auto& v : gOutputData){
         v = 1.f / (1.f + std::exp(-v));
@@ -200,7 +156,7 @@ void InDet::SiGNNTrackFinderTool::getTracks (
     std::vector<int32_t> trackLabels(numSpacepoints);
     weaklyConnectedComponents<int64_t,float,int32_t>(numSpacepoints, rowIndices, colIndices, gOutputData, trackLabels);
 
-    if (trackLabels.size() == 0)  return;
+    if (trackLabels.size() == 0)  return StatusCode::SUCCESS;
 
     tracks.clear();
 
@@ -225,5 +181,6 @@ void InDet::SiGNNTrackFinderTool::getTracks (
             existTrkIdx++;
         }
     }
+    return StatusCode::SUCCESS;
 }
 
diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
index 031de24d76081988f20581184487062d5394b6aa..3fca4b225e9784e64db5ad9da983f30a283ed70a 100644
--- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
+++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h
@@ -14,7 +14,7 @@
 #include "InDetRecToolInterfaces/IGNNTrackFinder.h"
 
 // ONNX Runtime include(s).
-#include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
+#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
 #include <core/session/onnxruntime_cxx_api.h>
 
 class MsgStream;
@@ -46,7 +46,7 @@ namespace InDet{
      * 
      * @return 
      */
-    virtual void getTracks(
+    virtual StatusCode getTracks(
       const std::vector<const Trk::SpacePoint*>& spacepoints,
       std::vector<std::vector<uint32_t> >& tracks) const override;
 
@@ -74,9 +74,15 @@ namespace InDet{
     MsgStream&    dumpevent     (MsgStream&    out) const;
 
     private:
-    std::unique_ptr< Ort::Session > m_embedSession;
-    std::unique_ptr< Ort::Session > m_filterSession;
-    std::unique_ptr< Ort::Session > m_gnnSession;
+    ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool {
+      this, "Embedding", "AthOnnx::OnnxRuntimeInferenceTool"
+    };
+    ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool {
+      this, "Filtering", "AthOnnx::OnnxRuntimeInferenceTool"
+    };
+    ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool {
+      this, "GNN", "AthOnnx::OnnxRuntimeInferenceTool"
+    };
 
   };
 
diff --git a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx
index 121a710a66eec9ff5b9d3230f0c4dfeb44dd8069..d7caaa467f6fe633ef9ae313a97dfff712c492ea 100644
--- a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx
+++ b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx
@@ -68,7 +68,7 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const
   getData(m_SpacePointsSCTKey);
 
   std::vector<std::vector<uint32_t> > TT;
-  m_gnnTrackFinder->getTracks(spacePoints, TT);
+  ATH_CHECK(m_gnnTrackFinder->getTracks(spacePoints, TT));
 
 
   ATH_MSG_DEBUG("Obtained " << TT.size() << " Tracks");
diff --git a/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h b/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h
index b16efa8f42e9d47abfe4aa00980e02774cb7bf9e..52bb639c4f6a83abda1b6417632dadf5b4918337 100644
--- a/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h
+++ b/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h
@@ -38,7 +38,7 @@ namespace InDet {
      * @param tracks a list of track candidates in terms of spacepoint indices.
      * @return 
     */
-    virtual void getTracks(
+    virtual StatusCode getTracks(
       const std::vector<const Trk::SpacePoint*>& spacepoints,
       std::vector<std::vector<uint32_t> >& tracks) const=0;