diff --git a/Control/AthOnnx/AthOnnxComps/CMakeLists.txt b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f2e46aa8b6b728f803b10fcf0c0dd2c406407bb5 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +# Declare the package's name. +atlas_subdir( AthOnnxComps ) + +# External dependencies. +find_package( onnxruntime ) + +# Libraray in the package. +atlas_add_component( AthOnnxComps + src/*.cxx + src/components/*.cxx + INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} + LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthenaBaseComps AthenaKernel GaudiKernel AthOnnxInterfaces AthOnnxUtilsLib +) + +# install python modules +atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} ) diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcff38f38ecf71d62ef9f8355172722d7a2462a --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeFlags.py @@ -0,0 +1,29 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +from AthenaConfiguration.AthConfigFlags import AthConfigFlags +from AthenaConfiguration.Enums import FlagEnum + +class OnnxRuntimeType(FlagEnum): + CPU = 'CPU' + CUDA = 'CUDA' +# possible future backends. Uncomment when implemented. + # DML = 'DML' + # DNNL = 'DNNL' + # NUPHAR = 'NUPHAR' + # OPENVINO = 'OPENVINO' + # ROCM = 'ROCM' + # TENSORRT = 'TENSORRT' + # VITISAI = 'VITISAI' + # VULKAN = 'VULKAN' + +def createOnnxRuntimeFlags(): + icf = AthConfigFlags() + + icf.addFlag("AthOnnx.ExecutionProvider", OnnxRuntimeType.CPU, type=OnnxRuntimeType) + + return icf + +if __name__ == "__main__": + + flags = createOnnxRuntimeFlags() + flags.dump() diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py new file mode 100644 index 0000000000000000000000000000000000000000..9c06316bdaf693431279c7046884df4a8ea940c3 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeInferenceConfig.py @@ -0,0 +1,20 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator +from AthenaConfiguration.ComponentFactory import CompFactory +from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType +from typing import Optional +from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg + +def OnnxRuntimeInferenceToolCfg(flags, + model_fname: str = None, + execution_provider: Optional[OnnxRuntimeType] = None, + name="OnnxRuntimeInferenceTool", **kwargs): + """Configure OnnxRuntimeInferenceTool in Control/AthOnnx/AthOnnxComps/src""" + + acc = ComponentAccumulator() + + session_tool = acc.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname, execution_provider)) + kwargs["ORTSessionTool"] = session_tool + acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeInferenceTool(name, **kwargs)) + return acc diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py new file mode 100644 index 0000000000000000000000000000000000000000..32b9d4471b14b5d678f5ee056d0be97e3c163d1c --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSessionConfig.py @@ -0,0 +1,30 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator +from AthenaConfiguration.ComponentFactory import CompFactory +from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType +from typing import Optional + +def OnnxRuntimeSessionToolCfg(flags, + model_fname: str, + execution_provider: Optional[OnnxRuntimeType] = None, + name="OnnxRuntimeSessionTool", **kwargs): + """"Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src""" + + acc = ComponentAccumulator() + + if model_fname is None: + raise ValueError("model_fname must be specified") + + execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider + name += execution_provider.name + + kwargs.setdefault("ModelFileName", model_fname) + if execution_provider is OnnxRuntimeType.CPU: + acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs)) + elif execution_provider is OnnxRuntimeType.CUDA: + acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, **kwargs)) + else: + raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider) + + return acc diff --git a/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py new file mode 100644 index 0000000000000000000000000000000000000000..e006fe8232c47fe424c1eb51676cde3b7156637e --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/python/OnnxRuntimeSvcConfig.py @@ -0,0 +1,11 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator +from AthenaConfiguration.ComponentFactory import CompFactory + +def OnnxRuntimeSvcCfg(flags, name="OnnxRuntimeSvc", **kwargs): + acc = ComponentAccumulator() + + acc.addService(CompFactory.AthOnnx.OnnxRuntimeSvc(name, **kwargs), primary=True, create=True) + + return acc diff --git a/Control/AthOnnx/AthOnnxComps/python/__init__.py b/Control/AthOnnx/AthOnnxComps/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87ae30225ac7a009fc45760d78d66deeb4b80381 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/python/__init__.py @@ -0,0 +1 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration diff --git a/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY new file mode 100644 index 0000000000000000000000000000000000000000..6c51dc7f379791d1b1222c9760327037a5af1d2d --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/ALTAS_CHECK_THREAD_SAFETY @@ -0,0 +1 @@ +Control/AthOnnx/AthOnnxComps diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx new file mode 100644 index 0000000000000000000000000000000000000000..341afe703e6d1d2fb531536be376637afbfb445a --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.cxx @@ -0,0 +1,123 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "OnnxRuntimeInferenceTool.h" +#include "AthOnnxUtils/OnnxUtils.h" + +AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool( + const std::string& type, const std::string& name, const IInterface* parent ) + : base_class( type, name, parent ) +{ + declareInterface<IOnnxRuntimeInferenceTool>(this); +} + +StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize() +{ + // Get the Onnx Runtime service. + ATH_CHECK(m_onnxRuntimeSvc.retrieve()); + + // Create the session. + ATH_CHECK(m_onnxSessionTool.retrieve()); + + ATH_CHECK(getNodeInfo()); + + return StatusCode::SUCCESS; +} + +StatusCode AthOnnx::OnnxRuntimeInferenceTool::finalize() +{ + return StatusCode::SUCCESS; +} + +StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo() +{ + auto& session = m_onnxSessionTool->session(); + // obtain the model information + m_numInputs = session.GetInputCount(); + m_numOutputs = session.GetOutputCount(); + + AthOnnx::getInputNodeInfo(session, m_inputShapes, m_inputNodeNames); + AthOnnx::getOutputNodeInfo(session, m_outputShapes, m_outputNodeNames); + + return StatusCode::SUCCESS; +} + + +void AthOnnx::OnnxRuntimeInferenceTool::setBatchSize(int64_t batchSize) +{ + if (batchSize <= 0) { + ATH_MSG_ERROR("Batch size should be positive"); + return; + } + + for (auto& shape : m_inputShapes) { + if (shape[0] == -1) { + shape[0] = batchSize; + } + } + + for (auto& shape : m_outputShapes) { + if (shape[0] == -1) { + shape[0] = batchSize; + } + } +} + +int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const +{ + auto tensorSize = AthOnnx::getTensorSize(m_inputShapes[idx]); + if (tensorSize < 0) { + return inputDataSize / abs(tensorSize); + } else { + return -1; + } +} + +StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const +{ + assert (inputTensors.size() == m_numInputs); + assert (outputTensors.size() == m_numOutputs); + + // Run the model. + AthOnnx::inferenceWithIOBinding( + m_onnxSessionTool->session(), + m_inputNodeNames, inputTensors, + m_outputNodeNames, outputTensors); + + return StatusCode::SUCCESS; +} + +void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo() const +{ + ATH_MSG_INFO("Number of inputs: " << m_numInputs); + ATH_MSG_INFO("Number of outputs: " << m_numOutputs); + + ATH_MSG_INFO("Input node names: "); + for (const auto& name : m_inputNodeNames) { + ATH_MSG_INFO("\t" << name); + } + + ATH_MSG_INFO("Output node names: "); + for (const auto& name : m_outputNodeNames) { + ATH_MSG_INFO("\t" << name); + } + + ATH_MSG_INFO("Input shapes: "); + for (const auto& shape : m_inputShapes) { + std::string shapeStr = "\t"; + for (const auto& dim : shape) { + shapeStr += std::to_string(dim) + " "; + } + ATH_MSG_INFO(shapeStr); + } + + ATH_MSG_INFO("Output shapes: "); + for (const auto& shape : m_outputShapes) { + std::string shapeStr = "\t"; + for (const auto& dim : shape) { + shapeStr += std::to_string(dim) + " "; + } + ATH_MSG_INFO(shapeStr); + } +} diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h new file mode 100644 index 0000000000000000000000000000000000000000..da691ad7e237ccd2fc8a5edfbd1f8d905b4a4c93 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h @@ -0,0 +1,58 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +#ifndef OnnxRuntimeInferenceTool_H +#define OnnxRuntimeInferenceTool_H + +#include "AthenaBaseComps/AthAlgTool.h" +#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h" +#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h" +#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h" +#include "GaudiKernel/ServiceHandle.h" +#include "GaudiKernel/ToolHandle.h" + +namespace AthOnnx { + // @class OnnxRuntimeInferenceTool + // + // @brief Tool to create Onnx Runtime session with CPU backend + // + // @author Xiangyang Ju <xiangyang.ju@cern.ch> + class OnnxRuntimeInferenceTool : public extends<AthAlgTool, IOnnxRuntimeInferenceTool> + { + public: + /// Standard constructor + OnnxRuntimeInferenceTool( const std::string& type, + const std::string& name, + const IInterface* parent ); + virtual ~OnnxRuntimeInferenceTool() = default; + + /// Initialize the tool + virtual StatusCode initialize() override; + /// Finalize the tool + virtual StatusCode finalize() override; + + virtual void setBatchSize(int64_t batchSize) override final; + virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final; + + virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final; + + virtual void printModelInfo() const override final; + + protected: + OnnxRuntimeInferenceTool() = delete; + OnnxRuntimeInferenceTool(const OnnxRuntimeInferenceTool&) = delete; + OnnxRuntimeInferenceTool& operator=(const OnnxRuntimeInferenceTool&) = delete; + + private: + StatusCode getNodeInfo(); + + ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"}; + ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{ + this, "ORTSessionTool", + "AthOnnx::OnnxRuntimeInferenceToolCPU" + }; + std::vector<std::string> m_inputNodeNames; + std::vector<std::string> m_outputNodeNames; + }; +} // namespace AthOnnx + +#endif diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx new file mode 100644 index 0000000000000000000000000000000000000000..85f1f8df115bdc9ac45129550cef7640bb888fa0 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.cxx @@ -0,0 +1,46 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "OnnxRuntimeSessionToolCPU.h" + +AthOnnx::OnnxRuntimeSessionToolCPU::OnnxRuntimeSessionToolCPU( + const std::string& type, const std::string& name, const IInterface* parent ) + : base_class( type, name, parent ) +{ + declareInterface<IOnnxRuntimeSessionTool>(this); +} + +StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::initialize() +{ + // Get the Onnx Runtime service. + ATH_CHECK(m_onnxRuntimeSvc.retrieve()); + ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString()); + + // Create the session options. + // TODO: Make this configurable. + // other threading options: https://onnxruntime.ai/docs/performance/tune-performance/threading.html + // 1) SetIntraOpNumThreads( 1 ); + // 2) SetInterOpNumThreads( 1 ); + // 3) SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED ); + + Ort::SessionOptions sessionOptions; + sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL ); + + // Create the session. + m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions); + + return StatusCode::SUCCESS; +} + +StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::finalize() +{ + m_session.reset(); + ATH_MSG_DEBUG( "Ort::Session object deleted" ); + return StatusCode::SUCCESS; +} + +Ort::Session& AthOnnx::OnnxRuntimeSessionToolCPU::session() const +{ + return *m_session; +} diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h new file mode 100644 index 0000000000000000000000000000000000000000..a20e5e5d4f2ee0444b7a1812828ef1dbeac77551 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCPU.h @@ -0,0 +1,46 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +#ifndef OnnxRuntimeSessionToolCPU_H +#define OnnxRuntimeSessionToolCPU_H + +#include "AthenaBaseComps/AthAlgTool.h" +#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h" +#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h" +#include "GaudiKernel/ServiceHandle.h" + +namespace AthOnnx { + // @class OnnxRuntimeSessionToolCPU + // + // @brief Tool to create Onnx Runtime session with CPU backend + // + // @author Xiangyang Ju <xiangyang.ju@cern.ch> + class OnnxRuntimeSessionToolCPU : public extends<AthAlgTool, IOnnxRuntimeSessionTool> + { + public: + /// Standard constructor + OnnxRuntimeSessionToolCPU( const std::string& type, + const std::string& name, + const IInterface* parent ); + virtual ~OnnxRuntimeSessionToolCPU() = default; + + /// Initialize the tool + virtual StatusCode initialize() override final; + /// Finalize the tool + virtual StatusCode finalize() override final; + + /// Create Onnx Runtime session + virtual Ort::Session& session() const override final; + + protected: + OnnxRuntimeSessionToolCPU() = delete; + OnnxRuntimeSessionToolCPU(const OnnxRuntimeSessionToolCPU&) = delete; + OnnxRuntimeSessionToolCPU& operator=(const OnnxRuntimeSessionToolCPU&) = delete; + + private: + StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"}; + ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"}; + std::unique_ptr<Ort::Session> m_session; + }; +} + +#endif diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx new file mode 100644 index 0000000000000000000000000000000000000000..f7d39731da835de74267e3f6c68ceb19a7160142 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.cxx @@ -0,0 +1,70 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ +/* + Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +*/ + +#include "OnnxRuntimeSessionToolCUDA.h" + +AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA( + const std::string& type, const std::string& name, const IInterface* parent ) + : base_class( type, name, parent ) +{ + declareInterface<IOnnxRuntimeSessionTool>(this); +} + +StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::initialize() +{ + // Get the Onnx Runtime service. + ATH_CHECK(m_onnxRuntimeSvc.retrieve()); + + ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString()); + // Create the session options. + Ort::SessionOptions sessionOptions; + sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED ); + sessionOptions.DisablePerSessionThreads(); // use global thread pool. + + // TODO: add more cuda options to the interface + // https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc + // Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html + OrtCUDAProviderOptions cuda_options; + cuda_options.device_id = m_deviceId; + cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; + cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max(); + + // memorry arena options for cuda memory shrinkage + // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7 + if (m_enableMemoryShrinkage) { + Ort::ArenaCfg arena_cfg{0, 0, 1024, 0}; + // other options are not available in this release. + // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21 + // arena_cfg.max_mem = 0; // let ORT pick default max memory + // arena_cfg.arena_extend_strategy = 0; // 0: kNextPowerOfTwo, 1: kSameAsRequested + // arena_cfg.initial_chunk_size_bytes = 1024; + // arena_cfg.max_dead_bytes_per_chunk = 0; + // arena_cfg.initial_growth_chunk_size_bytes = 256; + // arena_cfg.max_power_of_two_extend_bytes = 1L << 24; + + cuda_options.default_memory_arena_cfg = arena_cfg; + } + + sessionOptions.AppendExecutionProvider_CUDA(cuda_options); + + // Create the session. + m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions); + + return StatusCode::SUCCESS; +} + +StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::finalize() +{ + m_session.reset(); + ATH_MSG_DEBUG( "Ort::Session object deleted" ); + return StatusCode::SUCCESS; +} + +Ort::Session& AthOnnx::OnnxRuntimeSessionToolCUDA::session() const +{ + return *m_session; +} diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h new file mode 100644 index 0000000000000000000000000000000000000000..3d445ed73f508d775b6a574b7234ec41a9c56da6 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSessionToolCUDA.h @@ -0,0 +1,51 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +#ifndef OnnxRuntimeSessionToolCUDA_H +#define OnnxRuntimeSessionToolCUDA_H + +#include "AthenaBaseComps/AthAlgTool.h" +#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h" +#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h" +#include "GaudiKernel/ServiceHandle.h" + +namespace AthOnnx { + // @class OnnxRuntimeSessionToolCUDA + // + // @brief Tool to create Onnx Runtime session with CUDA backend + // + // @author Xiangyang Ju <xiangyang.ju@cern.ch> + class OnnxRuntimeSessionToolCUDA : public extends<AthAlgTool, IOnnxRuntimeSessionTool> + { + public: + /// Standard constructor + OnnxRuntimeSessionToolCUDA( const std::string& type, + const std::string& name, + const IInterface* parent ); + virtual ~OnnxRuntimeSessionToolCUDA() = default; + + /// Initialize the tool + virtual StatusCode initialize() override final; + /// Finalize the tool + virtual StatusCode finalize() override final; + + /// Create Onnx Runtime session + virtual Ort::Session& session() const override final; + + protected: + OnnxRuntimeSessionToolCUDA() = delete; + OnnxRuntimeSessionToolCUDA(const OnnxRuntimeSessionToolCUDA&) = delete; + OnnxRuntimeSessionToolCUDA& operator=(const OnnxRuntimeSessionToolCUDA&) = delete; + + private: + StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"}; + /// The device ID to use. + IntegerProperty m_deviceId{this, "DeviceId", 0}; + BooleanProperty m_enableMemoryShrinkage{this, "EnableMemoryShrinkage", false}; + + /// runtime service + ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"}; + std::unique_ptr<Ort::Session> m_session; + }; +} + +#endif diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx new file mode 100644 index 0000000000000000000000000000000000000000..c45e2dba6faad8383bb56cdf2305882e00b6e0e3 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.cxx @@ -0,0 +1,43 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +// Local include(s). +#include "OnnxRuntimeSvc.h" +#include <core/session/onnxruntime_c_api.h> + +namespace AthOnnx { + OnnxRuntimeSvc::OnnxRuntimeSvc(const std::string& name, ISvcLocator* svc) : + asg::AsgService(name, svc) + { + declareServiceInterface<AthOnnx::IOnnxRuntimeSvc>(); + } + StatusCode OnnxRuntimeSvc::initialize() { + + // Create the environment object. + Ort::ThreadingOptions tp_options; + tp_options.SetGlobalIntraOpNumThreads(1); + tp_options.SetGlobalInterOpNumThreads(1); + + m_env = std::make_unique< Ort::Env >( + tp_options, ORT_LOGGING_LEVEL_WARNING, name().c_str()); + ATH_MSG_DEBUG( "Ort::Env object created" ); + + // Return gracefully. + return StatusCode::SUCCESS; + } + + StatusCode OnnxRuntimeSvc::finalize() { + + // Dekete the environment object. + m_env.reset(); + ATH_MSG_DEBUG( "Ort::Env object deleted" ); + + // Return gracefully. + return StatusCode::SUCCESS; + } + + Ort::Env& OnnxRuntimeSvc::env() const { + + return *m_env; + } + +} // namespace AthOnnx diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h new file mode 100644 index 0000000000000000000000000000000000000000..9f90bc998f65963e95aafedd4fc5dd41f2a312ff --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeSvc.h @@ -0,0 +1,58 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +#ifndef ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H +#define ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H + +// Local include(s). +#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h" + +// Framework include(s). +#include <AsgServices/AsgService.h> + +// ONNX include(s). +#include <core/session/onnxruntime_cxx_api.h> + +// System include(s). +#include <memory> + +namespace AthOnnx { + + /// Service implementing @c AthOnnx::IOnnxRuntimeSvc + /// + /// This is a very simple implementation, just managing the lifetime + /// of some Onnx Runtime C++ objects. + /// + /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch> + /// + class OnnxRuntimeSvc : public asg::AsgService, virtual public IOnnxRuntimeSvc { + + public: + + /// @name Function(s) inherited from @c Service + /// @{ + OnnxRuntimeSvc (const std::string& name, ISvcLocator* svc); + + /// Function initialising the service + virtual StatusCode initialize() override; + /// Function finalising the service + virtual StatusCode finalize() override; + + /// @} + + /// @name Function(s) inherited from @c AthOnnx::IOnnxRuntimeSvc + /// @{ + + /// Return the Onnx Runtime environment object + virtual Ort::Env& env() const override; + + /// @} + + private: + /// Global runtime environment for Onnx Runtime + std::unique_ptr< Ort::Env > m_env; + + }; // class OnnxRuntimeSvc + +} // namespace AthOnnx + +#endif // ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H diff --git a/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx new file mode 100644 index 0000000000000000000000000000000000000000..45ef514180f8b679a1f23b637bda662c88dff8b5 --- /dev/null +++ b/Control/AthOnnx/AthOnnxComps/src/components/AthOnnxComps_entries.cxx @@ -0,0 +1,13 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +// Local include(s). +#include "../OnnxRuntimeSvc.h" +#include "../OnnxRuntimeSessionToolCPU.h" +#include "../OnnxRuntimeSessionToolCUDA.h" +#include "../OnnxRuntimeInferenceTool.h" + +// Declare the package's components. +DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSvc ) +DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCPU ) +DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCUDA ) +DECLARE_COMPONENT( AthOnnx::OnnxRuntimeInferenceTool ) diff --git a/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY b/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY new file mode 100644 index 0000000000000000000000000000000000000000..9524c9b32dce4369d5e24025dd71ee41784e29c9 --- /dev/null +++ b/Control/AthOnnx/AthOnnxInterfaces/ALTAS_CHECK_THREAD_SAFETY @@ -0,0 +1 @@ +Control/AthOnnx/AthOnnxInterfaces diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h new file mode 100644 index 0000000000000000000000000000000000000000..467b25e9d76fa7470a4351b6db65ce82df75306b --- /dev/null +++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h @@ -0,0 +1,116 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +#ifndef AthOnnx_IOnnxRuntimeInferenceTool_H +#define AthOnnx_IOnnxRuntimeInferenceTool_H + +// Gaudi include(s). +#include "GaudiKernel/IAlgTool.h" +#include <memory> + +#include <core/session/onnxruntime_cxx_api.h> + + +namespace AthOnnx { + /** + * @class IOnnxRuntimeInferenceTool + * @brief Interface class for creating Onnx Runtime sessions. + * @details Interface class for creating Onnx Runtime sessions. + * It is thread safe, supports models with various number of inputs and outputs, + * supports models with dynamic batch size, and usess . It defines a standardized procedure to + * perform Onnx Runtime inference. The procedure is as follows, assuming the tool `m_onnxTool` is created and initialized: + * 1. create input tensors from the input data: + * ```c++ + * std::vector<Ort::Value> inputTensors; + * std::vector<float> inputData_1; // The input data is filled by users, possibly from the event information. + * int64_t batchSize = m_onnxTool->getBatchSize(inputData_1.size(), 0); // The batch size is determined by the input data size to support dynamic batch size. + * m_onnxTool->addInput(inputTensors, inputData_1, 0, batchSize); + * std::vector<int64_t> inputData_2; // Some models may have multiple inputs. Add inputs one by one. + * int64_t batchSize_2 = m_onnxTool->getBatchSize(inputData_2.size(), 1); + * m_onnxTool->addInput(inputTensors, inputData_2, 1, batchSize_2); + * ``` + * 2. create output tensors: + * ```c++ + * std::vector<Ort::Value> outputTensors; + * std::vector<float> outputData; // The output data will be filled by the onnx session. + * m_onnxTool->addOutput(outputTensors, outputData, 0, batchSize); + * ``` + * 3. perform inference: + * ```c++ + * m_onnxTool->inference(inputTensors, outputTensors); + * ``` + * 4. Model outputs will be automatically filled to outputData. + * + * + * @author Xiangyang Ju <xju@cern.ch> + */ + class IOnnxRuntimeInferenceTool : virtual public IAlgTool + { + public: + + virtual ~IOnnxRuntimeInferenceTool() = default; + + // @name InterfaceID + DeclareInterfaceID(IOnnxRuntimeInferenceTool, 1, 0); + + /** + * @brief set batch size. + * @details If the model has dynamic batch size, + * the batchSize value will be set to both input shapes and output shapes + */ + virtual void setBatchSize(int64_t batchSize) = 0; + + /** + * @brief methods for determining batch size from the data size + * @param dataSize the size of the input data, like std::vector<T>::size() + * @param idx the index of the input node + * @return the batch size, which equals to dataSize / size of the rest dimensions. + */ + virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0; + + /** + * @brief add the input data to the input tensors + * @param inputTensors the input tensor container + * @param data the input data + * @param idx the index of the input node + * @param batchSize the batch size + * @return StatusCode::SUCCESS if the input data is added successfully + */ + template <typename T> + StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const; + + /** + * @brief add the output data to the output tensors + * @param outputTensors the output tensor container + * @param data the output data + * @param idx the index of the output node + * @param batchSize the batch size + * @return StatusCode::SUCCESS if the output data is added successfully + */ + template <typename T> + StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const; + + /** + * @brief perform inference + * @param inputTensors the input tensor container + * @param outputTensors the output tensor container + * @return StatusCode::SUCCESS if the inference is performed successfully + */ + virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0; + + virtual void printModelInfo() const = 0; + + protected: + int m_numInputs; + int m_numOutputs; + std::vector<std::vector<int64_t> > m_inputShapes; + std::vector<std::vector<int64_t> > m_outputShapes; + + private: + template <typename T> + Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const; + + }; + + #include "IOnnxRuntimeInferenceTool.icc" +} // namespace AthOnnx + +#endif diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc new file mode 100644 index 0000000000000000000000000000000000000000..1c24b11c09d8f1b5f058510eece8a93dfc25030a --- /dev/null +++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeInferenceTool.icc @@ -0,0 +1,46 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +template <typename T> +Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const +{ + std::vector<int64_t> dataShapeCopy = dataShape; + + if (batchSize > 0) { + for (auto& shape: dataShapeCopy) { + if (shape == -1) { + shape = batchSize; + break; + } + } + } + + auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + return Ort::Value::CreateTensor<T>( + memoryInfo, data.data(), data.size(), dataShapeCopy.data(), dataShapeCopy.size() + ); +} + +template <typename T> +StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx, int64_t batchSize) const +{ + if (idx >= m_numInputs || idx < 0) { + return StatusCode::FAILURE; + } + + inputTensors.push_back(std::move(createTensor(data, m_inputShapes[idx], batchSize))); + return StatusCode::SUCCESS; +} + +template <typename T> +StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx, int64_t batchSize) const +{ + if (idx >= m_numOutputs || idx < 0) { + return StatusCode::FAILURE; + } + auto tensorSize = std::accumulate(m_outputShapes[idx].begin(), m_outputShapes[idx].end(), 1, std::multiplies<int64_t>()); + if (tensorSize < 0) { + tensorSize = abs(tensorSize) * batchSize; + } + data.resize(tensorSize); + outputTensors.push_back(std::move(createTensor(data, m_outputShapes[idx], batchSize))); + return StatusCode::SUCCESS; +} diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h new file mode 100644 index 0000000000000000000000000000000000000000..4d94e68a840ea27896c436adc81c312cdc99a379 --- /dev/null +++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSessionTool.h @@ -0,0 +1,34 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +#ifndef AthOnnx_IOnnxRUNTIMESESSIONTool_H +#define AthOnnx_IOnnxRUNTIMESESSIONTool_H + +// Gaudi include(s). +#include "GaudiKernel/IAlgTool.h" + +#include <core/session/onnxruntime_cxx_api.h> + + +namespace AthOnnx { + // class IAlgTool + // + // Interface class for creating Onnx Runtime sessions. + // + // @author Xiangyang Ju <xju@cern.ch> + // + class IOnnxRuntimeSessionTool : virtual public IAlgTool + { + public: + + virtual ~IOnnxRuntimeSessionTool() = default; + + // @name InterfaceID + DeclareInterfaceID(IOnnxRuntimeSessionTool, 1, 0); + + // Create Onnx Runtime session + virtual Ort::Session& session() const = 0; + + }; + +} // namespace AthOnnx + +#endif diff --git a/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h new file mode 100644 index 0000000000000000000000000000000000000000..b546392b55b40111169cd90d40a473f28d24e886 --- /dev/null +++ b/Control/AthOnnx/AthOnnxInterfaces/AthOnnxInterfaces/IOnnxRuntimeSvc.h @@ -0,0 +1,41 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +#ifndef ATHEXOnnxRUNTIME_IOnnxRUNTIMESVC_H +#define ATHEXOnnxRUNTIME_IOnnxRUNTIMESVC_H + +// Gaudi include(s). +#include <AsgServices/IAsgService.h> + +// Onnx include(s). +#include <core/session/onnxruntime_cxx_api.h> + + +/// Namespace holding all of the Onnx Runtime example code +namespace AthOnnx { + + //class IAsgService + /// Service used for managing global objects used by Onnx Runtime + /// + /// In order to allow multiple clients to use Onnx Runtime at the same + /// time, this service is used to manage the objects that must only + /// be created once in the Athena process. + /// + /// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch> + /// + class IOnnxRuntimeSvc : virtual public asg::IAsgService{ + + public: + /// Virtual destructor, to make vtable happy + virtual ~IOnnxRuntimeSvc() = default; + + /// Declare the interface that this class provides + DeclareInterfaceID (AthOnnx::IOnnxRuntimeSvc, 1, 0); + + /// Return the Onnx Runtime environment object + virtual Ort::Env& env() const = 0; + + }; // class IOnnxRuntimeSvc + +} // namespace AthOnnx + +#endif diff --git a/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt b/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..72b84c814183a8534b0459a35183f015c976e276 --- /dev/null +++ b/Control/AthOnnx/AthOnnxInterfaces/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +# Declare the package name: +atlas_subdir( AthOnnxInterfaces ) + +# Component(s) in the package: +atlas_add_library( AthOnnxInterfaces + AthOnnxInterfaces/*.h + INTERFACE + PUBLIC_HEADERS AthOnnxInterfaces + LINK_LIBRARIES AsgTools AthLinks GaudiKernel + ) diff --git a/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..83ced9501c30d5278a5ac557d61ab66203e5f678 --- /dev/null +++ b/Control/AthOnnx/AthOnnxUtils/AthOnnxUtils/OnnxUtils.h @@ -0,0 +1,81 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +#ifndef Onnx_UTILS_H +#define Onnx_UTILS_H + +#include <memory> +#include <vector> + +// Onnx Runtime include(s). +#include <core/session/onnxruntime_cxx_api.h> + +namespace AthOnnx { + +// @author Xiangyang Ju <xiangyang.ju@cern.ch> + +// @brief Convert a vector of vectors to a single vector. +// @param features The vector of vectors to be flattened. +// @return A single vector containing all the elements of the input vector of vectors. +template<typename T> +inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& features) { + // 1. Compute the total size required. + int total_size = 0; + for (const auto& feature : features) total_size += feature.size(); + + std::vector<T> flatten1D; + flatten1D.reserve(total_size); + + for (const auto& feature : features) + for (const auto& elem : feature) + flatten1D.push_back(elem); + + return flatten1D; +} + +// @brief Get the input data shape and node names (in the computational graph) from the onnx model +// @param session The onnx session. +// @param dataShape The shape of the input data. Note that there may be multiple inputs. +// @param nodeNames The names of the input nodes in the computational graph. +// the dataShape and nodeNames will be updated. +void getInputNodeInfo( + const Ort::Session& session, + std::vector<std::vector<int64_t> >& dataShape, + std::vector<std::string>& nodeNames); + +// @brief Get the output data shape and node names (in the computational graph) from the onnx model +// @param session The onnx session. +// @param dataShape The shape of the output data. +// @param nodeNames The names of the output nodes in the computational graph. +// the dataShape and nodeNames will be updated. +void getOutputNodeInfo( + const Ort::Session& session, + std::vector<std::vector<int64_t> >& dataShape, + std::vector<std::string>& nodeNames); + +// Heleper function to get node info +void getNodeInfo( + const Ort::Session& session, + std::vector<std::vector<int64_t> >& dataShape, + std::vector<std::string>& nodeNames, + bool isInput +); + +// @brief to count the total number of elements in a tensor +// They are useful for reserving spaces for the output data. +int64_t getTensorSize(const std::vector<int64_t>& dataShape); + +// Inference with IO binding. Better for performance, particularly for GPUs. +// See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html +void inferenceWithIOBinding(Ort::Session& session, + const std::vector<std::string>& inputNames, + const std::vector<Ort::Value>& inputData, + const std::vector<std::string>& outputNames, + const std::vector<Ort::Value>& outputData +); + +// @brief Create a tensor from a vector of data and its shape. +Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape); + + +} +#endif diff --git a/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..029f7e5a1ce8f947ed8bef66aa5b0367440fd325 --- /dev/null +++ b/Control/AthOnnx/AthOnnxUtils/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +# Declare the package's name. +atlas_subdir( AthOnnxUtils ) + +# External dependencies. +find_package( onnxruntime ) + +# Component(s) in the package. +atlas_add_library( AthOnnxUtilsLib + AthOnnxUtils/*.h + src/*.cxx + PUBLIC_HEADERS AthOnnxUtils + INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} + LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaKernel GaudiKernel AsgServicesLib +) diff --git a/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx new file mode 100644 index 0000000000000000000000000000000000000000..471b12697a518699cb1aa3b315d5563c397720c8 --- /dev/null +++ b/Control/AthOnnx/AthOnnxUtils/src/OnnxUtils.cxx @@ -0,0 +1,94 @@ +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +#include "AthOnnxUtils/OnnxUtils.h" +#include <cassert> +#include <string> + +namespace AthOnnx { + +void getNodeInfo( + const Ort::Session& session, + std::vector<std::vector<int64_t> >& dataShape, + std::vector<std::string>& nodeNames, + bool isInput +){ + dataShape.clear(); + nodeNames.clear(); + + size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount(); + dataShape.reserve(numNodes); + nodeNames.reserve(numNodes); + + Ort::AllocatorWithDefaultOptions allocator; + for( std::size_t i = 0; i < numNodes; i++ ) { + Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i); + auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo(); + dataShape.emplace_back(tensorInfo.GetShape()); + + auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator); + nodeNames.emplace_back(nodeName.get()); + } +} + +void getInputNodeInfo( + const Ort::Session& session, + std::vector<std::vector<int64_t> >& dataShape, + std::vector<std::string>& nodeNames +){ + getNodeInfo(session, dataShape, nodeNames, true); +} + +void getOutputNodeInfo( + const Ort::Session& session, + std::vector<std::vector<int64_t> >& dataShape, + std::vector<std::string>& nodeNames +) { + getNodeInfo(session, dataShape, nodeNames, false); +} + +void inferenceWithIOBinding(Ort::Session& session, + const std::vector<std::string>& inputNames, + const std::vector<Ort::Value>& inputData, + const std::vector<std::string>& outputNames, + const std::vector<Ort::Value>& outputData){ + + if (inputNames.empty()) { + throw std::runtime_error("Onnxruntime input data maping cannot be empty"); + } + assert(inputNames.size() == inputData.size()); + + Ort::IoBinding iobinding(session); + for(size_t idx = 0; idx < inputNames.size(); ++idx){ + iobinding.BindInput(inputNames[idx].data(), inputData[idx]); + } + + + for(size_t idx = 0; idx < outputNames.size(); ++idx){ + iobinding.BindOutput(outputNames[idx].data(), outputData[idx]); + } + + session.Run(Ort::RunOptions{nullptr}, iobinding); +} + +int64_t getTensorSize(const std::vector<int64_t>& dataShape){ + int64_t size = 1; + for (const auto& dim : dataShape) { + size *= dim; + } + return size; +} + +Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape) +{ + auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); + + return Ort::Value::CreateTensor<float>( + memoryInfo, + data.data(), + data.size(), + dataShape.data(), + dataShape.size()); +}; + + +} // namespace AthOnnx \ No newline at end of file diff --git a/Control/AthenaConfiguration/python/AllConfigFlags.py b/Control/AthenaConfiguration/python/AllConfigFlags.py index 96efa0f8277a3b40e23382f65fd7bf5a18ceb5b8..f6946d79a7b89280a00b82089afb6bfec757cd97 100644 --- a/Control/AthenaConfiguration/python/AllConfigFlags.py +++ b/Control/AthenaConfiguration/python/AllConfigFlags.py @@ -495,6 +495,12 @@ def initConfigFlags(): return createLLPDFConfigFlags() _addFlagsCategory(acf, "Derivation.LLP", __llpDerivation, 'DerivationFrameworkLLP' ) + # onnxruntime flags + def __onnxruntime(): + from AthOnnxComps.OnnxRuntimeFlags import createOnnxRuntimeFlags + return createOnnxRuntimeFlags() + _addFlagsCategory(acf, "AthOnnx", __onnxruntime, 'AthOnnxComps') + # For AnalysisBase, pick up things grabbed in Athena by the functions above if not isGaudiEnv(): def EDMVersion(flags): diff --git a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt index 29281dfe4a6e390fb5f9925eebdaf0382ce83298..02249a193fcf65b0c0a5162f846c8a14dfda35a4 100644 --- a/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt +++ b/Control/AthenaExamples/AthExOnnxRuntime/CMakeLists.txt @@ -1,28 +1,31 @@ -# Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration # Declare the package's name. atlas_subdir( AthExOnnxRuntime ) -# Component(s) in the package. -atlas_add_library( AthExOnnxRuntimeLib - INTERFACE - PUBLIC_HEADERS AthExOnnxRuntime - INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} - LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} GaudiKernel ) +find_package( onnxruntime ) +# Component(s) in the package. atlas_add_component( AthExOnnxRuntime src/*.h src/*.cxx src/components/*.cxx INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} - LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthExOnnxRuntimeLib AthenaBaseComps GaudiKernel PathResolver AthOnnxruntimeServiceLib AthOnnxruntimeUtilsLib) + LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AthenaBaseComps GaudiKernel PathResolver + AthOnnxInterfaces AthOnnxUtilsLib AsgServicesLib +) -# Install files from the package. -atlas_install_joboptions( share/*.py ) +# Install files from the package: +atlas_install_python_modules( python/*.py + POST_BUILD_CMD ${ATLAS_FLAKE8} ) -# Set up tests for the package. -atlas_add_test( AthExOnnxRuntimeJob_serial - SCRIPT athena.py AthExOnnxRuntime/AthExOnnxRuntime_jobOptions.py - POST_EXEC_SCRIPT nopost.sh ) +# Test the packages +atlas_add_test( AthExOnnxRuntimeTest + SCRIPT python -m AthExOnnxRuntime.AthExOnnxRuntime_test + PROPERTIES TIMEOUT 100 + POST_EXEC_SCRIPT noerror.sh + ) -atlas_add_test( AthExOnnxRuntimeJob_mt - SCRIPT athena.py --threads=2 AthExOnnxRuntime/AthExOnnxRuntime_jobOptions.py - POST_EXEC_SCRIPT nopost.sh ) +atlas_add_test( AthExOnnxRuntimeTest_pkl + SCRIPT athena --evtMax 2 --CA test_AthExOnnxRuntimeExampleCfg.pkl + DEPENDS AthExOnnxRuntimeTest + PROPERTIES TIMEOUT 100 + POST_EXEC_SCRIPT noerror.sh ) diff --git a/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6d243fbc6feb21ae87abbc09e244d8a3190a867c --- /dev/null +++ b/Control/AthenaExamples/AthExOnnxRuntime/python/AthExOnnxRuntime_test.py @@ -0,0 +1,45 @@ +# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration + +from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator +from AthenaConfiguration.ComponentFactory import CompFactory +from AthenaCommon import Constants +from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType + + +def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs): + acc = ComponentAccumulator() + + model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx" + execution_provider = OnnxRuntimeType.CPU + from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg + kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge( + OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider) + )) + + input_data = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-31/t10k-images-idx3-ubyte" + kwargs.setdefault("BatchSize", 3) + kwargs.setdefault("InputDataPixel", input_data) + kwargs.setdefault("OutputLevel", Constants.DEBUG) + acc.addEventAlgo(CompFactory.AthOnnx.EvaluateModel(name, **kwargs)) + + return acc + +if __name__ == "__main__": + from AthenaCommon.Logging import log as msg + from AthenaConfiguration.AllConfigFlags import initConfigFlags + from AthenaConfiguration.MainServicesConfig import MainServicesCfg + + msg.setLevel(Constants.DEBUG) + + flags = initConfigFlags() + flags.AthOnnx.ExecutionProvider = OnnxRuntimeType.CPU + flags.lock() + + acc = MainServicesCfg(flags) + acc.merge(AthExOnnxRuntimeExampleCfg(flags)) + acc.printConfig(withDetails=True, summariseProps=True) + + acc.store(open('test_AthExOnnxRuntimeExampleCfg.pkl','wb')) + + import sys + sys.exit(acc.run(2).isFailure()) diff --git a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py b/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py deleted file mode 100644 index a35ab1f40e4506f22e532e7727af7e1436add528..0000000000000000000000000000000000000000 --- a/Control/AthenaExamples/AthExOnnxRuntime/share/AthExOnnxRuntime_jobOptions.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration - -# Set up / access the algorithm sequence. -from AthenaCommon.AlgSequence import AlgSequence -algSequence = AlgSequence() - -# Set up the job. -from AthExOnnxRuntime.AthExOnnxRuntimeConf import AthONNX__EvaluateModel -from AthOnnxruntimeService.AthOnnxruntimeServiceConf import AthONNX__ONNXRuntimeSvc - -from AthenaCommon.AppMgr import ServiceMgr -ServiceMgr += AthONNX__ONNXRuntimeSvc( OutputLevel = DEBUG ) -algSequence += AthONNX__EvaluateModel("AthONNX") - -# Get a random no. between 0 to 10k for test sample -from random import randint - -AthONNX = algSequence.AthONNX -AthONNX.TestSample = randint(0, 9999) -AthONNX.DoBatches = False -AthONNX.NumberOfBatches = 1 -AthONNX.SizeOfBatch = 1 -AthONNX.OutputLevel = DEBUG -AthONNX.UseCUDA = False - -# Run for 10 "events". -theApp.EvtMax = 2 diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx index fb1b6f25e6252467b9a9fb8439007dd9ce6dbd3c..42a58730d19e4094f0dbfa374f94b88d2b1e6687 100644 --- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx +++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.cxx @@ -1,14 +1,16 @@ -// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration +// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration // Local include(s). #include "EvaluateModel.h" #include <tuple> +#include <fstream> +#include <chrono> +#include <arpa/inet.h> // Framework include(s). -#include "PathResolver/PathResolver.h" -#include "AthOnnxruntimeUtils/OnnxUtils.h" +#include "AthOnnxUtils/OnnxUtils.h" -namespace AthONNX { +namespace AthOnnx { //******************************************************************* // for reading MNIST images @@ -66,195 +68,74 @@ namespace AthONNX { } StatusCode EvaluateModel::initialize() { - - // Access the service. - //ATH_CHECK( m_svc.retrieve() ); + // Fetch tools + ATH_CHECK( m_onnxTool.retrieve() ); + m_onnxTool->printModelInfo(); /***** The combination of no. of batches and batch size shouldn't cross the total smple size which is 10000 for this example *****/ - if(m_doBatches && (m_numberOfBatches*m_sizeOfBatch)>10000){ + if(m_batchSize > 10000){ ATH_MSG_INFO("The total no. of sample crossed the no. of available sample ...."); return StatusCode::FAILURE; - } - // Find the model file. - const std::string modelFileName = - PathResolverFindCalibFile( m_modelFileName ); - const std::string pixelFileName = - PathResolverFindCalibFile( m_pixelFileName ); - const std::string labelFileName = - PathResolverFindCalibFile( m_labelFileName ); - ATH_MSG_INFO( "Using model file: " << modelFileName ); - ATH_MSG_INFO( "Using pixel file: " << pixelFileName ); - ATH_MSG_INFO( "Using pixel file: " << labelFileName ); - // Set up the ONNX Runtime session. + } + // read input file, and the target file for comparison. + ATH_MSG_INFO( "Using pixel file: " << m_pixelFileName.value() ); - m_session = AthONNX::CreateORTSession(modelFileName, m_useCUDA); - - if (m_useCUDA) { - ATH_MSG_INFO( "Created the ONNX Runtime session on CUDA" ); - } else { - ATH_MSG_INFO( "Created the ONNX Runtime session on CPUs" ); - } - - m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(pixelFileName); - std::vector<std::vector<float>> c = m_input_tensor_values_notFlat[0]; - m_output_tensor_values = read_mnist_label(labelFileName); - // Return gracefully. + m_input_tensor_values_notFlat = read_mnist_pixel_notFlat(m_pixelFileName); + ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size()); + return StatusCode::SUCCESS; - } +} StatusCode EvaluateModel::execute( const EventContext& /*ctx*/ ) const { - - Ort::AllocatorWithDefaultOptions allocator; - - /************************** Input Nodes *****************************/ - /*********************************************************************/ - - std::tuple<std::vector<int64_t>, std::vector<const char*> > inputInfo = AthONNX::GetInputNodeInfo(m_session); - std::vector<int64_t> input_node_dims = std::get<0>(inputInfo); - std::vector<const char*> input_node_names = std::get<1>(inputInfo); - - for( std::size_t i = 0; i < input_node_names.size(); i++ ) { - // print input node names - ATH_MSG_DEBUG("Input "<<i<<" : "<<" name= "<<input_node_names[i]); - // print input shapes/dims - ATH_MSG_DEBUG("Input "<<i<<" : num_dims= "<<input_node_dims.size()); - for (std::size_t j = 0; j < input_node_dims.size(); j++){ - ATH_MSG_DEBUG("Input "<<i<<" : dim "<<j<<"= "<<input_node_dims[j]); - } - } - - /************************** Output Nodes *****************************/ - /*********************************************************************/ - - std::tuple<std::vector<int64_t>, std::vector<const char*> > outputInfo = AthONNX::GetOutputNodeInfo(m_session); - std::vector<int64_t> output_node_dims = std::get<0>(outputInfo); - std::vector<const char*> output_node_names = std::get<1>(outputInfo); - - for( std::size_t i = 0; i < output_node_names.size(); i++ ) { - // print input node names - ATH_MSG_DEBUG("Output "<<i<<" : "<<" name= "<<output_node_names[i]); - - // print input shapes/dims - ATH_MSG_DEBUG("Output "<<i<<" : num_dims= "<<output_node_dims.size()); - for (std::size_t j = 0; j < output_node_dims.size(); j++){ - ATH_MSG_DEBUG("Output "<<i<<" : dim "<<j<<"= "<<output_node_dims[j]); - } - } - /************************* Score if input is not a batch ********************/ - /****************************************************************************/ - if(m_doBatches == false){ - - /************************************************************************************** - * input_node_dims[0] = -1; -1 needs to be replaced by the batch size; for no batch is 1 - * input_node_dims[1] = 28 - * input_node_dims[2] = 28 - ****************************************************************************************/ - - input_node_dims[0] = 1; - output_node_dims[0] = 1; - - /***************** Choose an example sample randomly ****************************/ - std::vector<std::vector<float>> input_tensor_values = m_input_tensor_values_notFlat[m_testSample]; - std::vector<float> flatten = AthONNX::FlattenInput_multiD_1D(input_tensor_values); - // Output label of corresponding m_input_tensor_values[m_testSample]; e.g 0, 1, 2, 3 etc - int output_tensor_values = m_output_tensor_values[m_testSample]; - - // For a check that the sample dimension is fully flatten (1x28x28 = 784) - ATH_MSG_DEBUG("Size of Flatten Input tensor: "<<flatten.size()); + // prepare inputs + std::vector<float> inputData; + for (int ibatch = 0; ibatch < m_batchSize; ibatch++){ + const std::vector<std::vector<float> >& imageData = m_input_tensor_values_notFlat[ibatch]; + std::vector<float> flatten = AthOnnx::flattenNestedVectors(imageData); + inputData.insert(inputData.end(), flatten.begin(), flatten.end()); + } - /************** Create input tensor object from input data values to feed into your model *********************/ - - Ort::Value input_tensor = AthONNX::TensorCreator(flatten, input_node_dims ); + int64_t batchSize = m_onnxTool->getBatchSize(inputData.size()); + ATH_MSG_INFO("Batch size is " << batchSize << "."); + assert(batchSize == m_batchSize); - /********* Convert 784 elements long flattened 1D array to 3D (1, 28, 28) onnx compatible tensor ************/ - ATH_MSG_DEBUG("Input tensor size after converted to Ort tensor: "<<input_tensor.GetTensorTypeAndShapeInfo().GetShape()); - // Makes sure input tensor has same dimensions as input layer of the model - assert(input_tensor.IsTensor()&& - input_tensor.GetTensorTypeAndShapeInfo().GetShape() == input_node_dims); + // bind the input data to the input tensor + std::vector<Ort::Value> inputTensors; + ATH_CHECK( m_onnxTool->addInput(inputTensors, inputData, 0, batchSize) ); - /********* Score model by feeding input tensor and get output tensor in return *****************************/ + // reserve space for output data and bind it to the output tensor + std::vector<float> outputScores; + std::vector<Ort::Value> outputTensors; + ATH_CHECK( m_onnxTool->addOutput(outputTensors, outputScores, 0, batchSize) ); - float* floatarr = AthONNX::Inference(m_session, input_node_names, input_tensor, output_node_names); + // run the inference + // the output will be filled to the outputScores. + ATH_CHECK( m_onnxTool->inference(inputTensors, outputTensors) ); - // show true label for the test input - ATH_MSG_INFO("Label for the input test data = "<<output_tensor_values); + ATH_MSG_INFO("Label for the input test data: "); + for(int ibatch = 0; ibatch < m_batchSize; ibatch++){ float max = -999; int max_index; for (int i = 0; i < 10; i++){ - ATH_MSG_DEBUG("Score for class "<<i<<" = "<<floatarr[i]); - if (max<floatarr[i]){ - max = floatarr[i]; - max_index = i; + ATH_MSG_DEBUG("Score for class "<< i <<" = "<<outputScores[i] << " in batch " << ibatch); + int index = i + ibatch * 10; + if (max < outputScores[index]){ + max = outputScores[index]; + max_index = index; } } - ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<floatarr[max_index]); - - } // m_doBatches == false codition ends - /************************* Score if input is a batch ********************/ - /****************************************************************************/ - else { - /************************************************************************************** - Similar scoring structure like non batch execution but 1st demention needs to be replaced by batch size - for this example lets take 3 batches with batch size 5 - * input_node_dims[0] = 5 - * input_node_dims[1] = 28 - * input_node_dims[2] = 28 - ****************************************************************************************/ - - input_node_dims[0] = m_sizeOfBatch; - output_node_dims[0] = m_sizeOfBatch; - - /************************** process multiple batches ********************************/ - int l =0; /****** variable for distributing rows in m_input_tensor_values_notFlat equally into batches*****/ - for (int i = 0; i < m_numberOfBatches; i++) { - ATH_MSG_DEBUG("Processing batch #" << i); - std::vector<float> batch_input_tensor_values; - for (int j = l; j < l+m_sizeOfBatch; j++) { - - std::vector<float> flattened_input = AthONNX::FlattenInput_multiD_1D(m_input_tensor_values_notFlat[j]); - /******************For each batch we need a flattened (5 x 28 x 28 = 3920) 1D array******************************/ - batch_input_tensor_values.insert(batch_input_tensor_values.end(), flattened_input.begin(), flattened_input.end()); - } - - Ort::Value batch_input_tensors = AthONNX::TensorCreator(batch_input_tensor_values, input_node_dims ); - - // Get pointer to output tensor float values + ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch); + } - float* floatarr = AthONNX::Inference(m_session, input_node_names, batch_input_tensors, output_node_names); - // show true label for the test input - for(int i = l; i<l+m_sizeOfBatch; i++){ - ATH_MSG_INFO("Label for the input test data = "<<m_output_tensor_values[i]); - int k = (i-l)*10; - float max = -999; - int max_index = 0; - for (int j =k ; j < k+10; j++){ - ATH_MSG_INFO("Score for class "<<j-k<<" = "<<floatarr[j]); - if (max<floatarr[j]){ - max = floatarr[j]; - max_index = j; - } - } - ATH_MSG_INFO("Class: "<<max_index-k<<" has the highest score: "<<floatarr[max_index]); - } - l = l+m_sizeOfBatch; - } - } // else/m_doBatches == True codition ends - // Return gracefully. return StatusCode::SUCCESS; } StatusCode EvaluateModel::finalize() { - // Delete the session object. - m_session.reset(); - - // Return gracefully. return StatusCode::SUCCESS; } -} // namespace AthONNX - - +} // namespace AthOnnx diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h index aaae60dc28581c43837484e25d969fdc3aac2abf..d660543e5e584a75c390bc19aa4181cb722daf6f 100644 --- a/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h +++ b/Control/AthenaExamples/AthExOnnxRuntime/src/EvaluateModel.h @@ -4,24 +4,21 @@ #define ATHEXONNXRUNTIME_EVALUATEMODEL_H // Local include(s). -#include "AthOnnxruntimeService/IONNXRuntimeSvc.h" +#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h" + // Framework include(s). #include "AthenaBaseComps/AthReentrantAlgorithm.h" #include "GaudiKernel/ServiceHandle.h" -// ONNX Runtime include(s). +// Onnx Runtime include(s). #include <core/session/onnxruntime_cxx_api.h> // System include(s). #include <memory> #include <string> -#include <iostream> -#include <fstream> -#include <arpa/inet.h> #include <vector> -#include <iterator> -namespace AthONNX { +namespace AthOnnx { /// Algorithm demonstrating the usage of the ONNX Runtime C++ API /// @@ -53,31 +50,22 @@ namespace AthONNX { /// @{ /// Name of the model file to load - Gaudi::Property< std::string > m_modelFileName{ this, "ModelFileName", - "dev/MLTest/2020-03-02/MNIST_testModel.onnx", - "Name of the model file to load" }; Gaudi::Property< std::string > m_pixelFileName{ this, "InputDataPixel", "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte", "Name of the input pixel file to load" }; - Gaudi::Property< std::string > m_labelFileName{ this, "InputDataLabel", - "dev/MLTest/2020-03-31/t10k-labels-idx1-ubyte", - "Name of the label file to load" }; - Gaudi::Property<int> m_testSample {this, "TestSample", 0, "A Random Test Sample"}; /// Following properties needed to be consdered if the .onnx model is evaluated in batch mode - Gaudi::Property<bool> m_doBatches {this, "DoBatches", false, "Processing events by batches"}; - Gaudi::Property<int> m_numberOfBatches {this, "NumberOfBatches", 1, "No. of batches to be passed"}; - Gaudi::Property<int> m_sizeOfBatch {this, "SizeOfBatch", 1, "No. of elements/example in a batch"}; + Gaudi::Property<int> m_batchSize {this, "BatchSize", 1, "No. of elements/example in a batch"}; - // If runs on CUDA - Gaudi::Property<bool> m_useCUDA {this, "UseCUDA", false, "Use CUDA"}; + /// Tool handler for onnx inference session + ToolHandle< IOnnxRuntimeInferenceTool > m_onnxTool{ + this, "ORTInferenceTool", "AthOnnx::OnnxRuntimeInferenceTool" + }; - std::unique_ptr< Ort::Session > m_session; std::vector<std::vector<std::vector<float>>> m_input_tensor_values_notFlat; - std::vector<int> m_output_tensor_values; }; // class EvaluateModel -} // namespace AthONNX +} // namespace AthOnnx #endif // ATHEXONNXRUNTIME_EVALUATEMODEL_H diff --git a/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx b/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx index cb25a4bfd9d2e30534c0158a17ae01e6ddd719ee..c89bab7da1362ad847381398705168c3c9e8896d 100644 --- a/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx +++ b/Control/AthenaExamples/AthExOnnxRuntime/src/components/AthExOnnxRuntime_entries.cxx @@ -3,4 +3,4 @@ // Local include(s). #include "../EvaluateModel.h" // Declare the package's components. -DECLARE_COMPONENT( AthONNX::EvaluateModel ) +DECLARE_COMPONENT( AthOnnx::EvaluateModel ) diff --git a/InnerDetector/InDetGNNTracking/CMakeLists.txt b/InnerDetector/InDetGNNTracking/CMakeLists.txt index 2bcff0db22e5442b9126c9529931c23e492cd461..daa184c964735fcf3f488b0365b7a1b5ca476a1a 100644 --- a/InnerDetector/InDetGNNTracking/CMakeLists.txt +++ b/InnerDetector/InDetGNNTracking/CMakeLists.txt @@ -18,6 +18,7 @@ atlas_add_component( InDetGNNTracking PixelReadoutGeometryLib SCT_ReadoutGeometry InDetSimData InDetPrepRawData TrkTrack TrkRIO_OnTrack InDetSimEvent AtlasHepMCLib InDetRIO_OnTrack InDetRawData TrkTruthData + AthOnnxInterfaces AthOnnxUtilsLib ) atlas_install_python_modules( python/*.py ) diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx index a48e9ea21a01f8e12a0b54c156de12f2edc772ca..e07013addd91ad7cc6c8f05f482a3a944717dcd3 100644 --- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx +++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx @@ -7,7 +7,7 @@ // Framework include(s). #include "PathResolver/PathResolver.h" -#include "AthOnnxruntimeUtils/OnnxUtils.h" +#include "AthOnnxUtils/OnnxUtils.h" #include <cmath> InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool( @@ -18,20 +18,12 @@ InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool( } StatusCode InDet::SiGNNTrackFinderTool::initialize() { - initTrainedModels(); + ATH_CHECK( m_embedSessionTool.retrieve() ); + ATH_CHECK( m_filterSessionTool.retrieve() ); + ATH_CHECK( m_gnnSessionTool.retrieve() ); return StatusCode::SUCCESS; } -void InDet::SiGNNTrackFinderTool::initTrainedModels() { - std::string embedModelPath(m_inputMLModuleDir + "/torchscript/embedding.onnx"); - std::string filterModelPath(m_inputMLModuleDir + "/torchscript/filtering.onnx"); - std::string gnnModelPath(m_inputMLModuleDir + "/torchscript/gnn.onnx"); - - m_embedSession = AthONNX::CreateORTSession(embedModelPath, m_useCUDA); - m_filterSession = AthONNX::CreateORTSession(filterModelPath, m_useCUDA); - m_gnnSession = AthONNX::CreateORTSession(gnnModelPath, m_useCUDA); -} - StatusCode InDet::SiGNNTrackFinderTool::finalize() { StatusCode sc = AlgTool::finalize(); return sc; @@ -59,7 +51,7 @@ MsgStream& InDet::SiGNNTrackFinderTool::dumpevent( MsgStream& out ) const return out; } -void InDet::SiGNNTrackFinderTool::getTracks ( +StatusCode InDet::SiGNNTrackFinderTool::getTracks( const std::vector<const Trk::SpacePoint*>& spacepoints, std::vector<std::vector<uint32_t> >& tracks) const { @@ -84,34 +76,19 @@ void InDet::SiGNNTrackFinderTool::getTracks ( spacepointIDs.push_back(sp_idx++); } - Ort::AllocatorWithDefaultOptions allocator; - auto memoryInfo = Ort::MemoryInfo::CreateCpu( - OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); - // ************ // Embedding // ************ std::vector<int64_t> eInputShape{numSpacepoints, spacepointFeatures}; - - std::vector<const char*> eInputNames{"sp_features"}; std::vector<Ort::Value> eInputTensor; - eInputTensor.push_back( - Ort::Value::CreateTensor<float>( - memoryInfo, inputValues.data(), inputValues.size(), - eInputShape.data(), eInputShape.size()) - ); + ATH_CHECK( m_embedSessionTool->addInput(eInputTensor, inputValues, 0, numSpacepoints) ); - std::vector<float> eOutputData(numSpacepoints * m_embeddingDim); - std::vector<const char*> eOutputNames{"embedding_output"}; - std::vector<int64_t> eOutputShape{numSpacepoints, m_embeddingDim}; std::vector<Ort::Value> eOutputTensor; - eOutputTensor.push_back( - Ort::Value::CreateTensor<float>( - memoryInfo, eOutputData.data(), eOutputData.size(), - eOutputShape.data(), eOutputShape.size()) - ); - AthONNX::InferenceWithIOBinding(m_embedSession, eInputNames, eInputTensor, eOutputNames, eOutputTensor); + std::vector<float> eOutputData; + ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) ); + + ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) ); // ************ // Building Edges @@ -123,29 +100,18 @@ void InDet::SiGNNTrackFinderTool::getTracks ( // ************ // Filtering // ************ - std::vector<const char*> fInputNames{"f_nodes", "f_edges"}; std::vector<Ort::Value> fInputTensor; fInputTensor.push_back( std::move(eInputTensor[0]) ); + ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) ); std::vector<int64_t> fEdgeShape{2, numEdges}; - fInputTensor.push_back( - Ort::Value::CreateTensor<int64_t>( - memoryInfo, edgeList.data(), edgeList.size(), - fEdgeShape.data(), fEdgeShape.size()) - ); - // filtering outputs - std::vector<const char*> fOutputNames{"f_edge_score"}; - std::vector<float> fOutputData(numEdges); - std::vector<int64_t> fOutputShape{numEdges, 1}; + std::vector<float> fOutputData; std::vector<Ort::Value> fOutputTensor; - fOutputTensor.push_back( - Ort::Value::CreateTensor<float>( - memoryInfo, fOutputData.data(), fOutputData.size(), - fOutputShape.data(), fOutputShape.size()) - ); - AthONNX::InferenceWithIOBinding(m_filterSession, fInputNames, fInputTensor, fOutputNames, fOutputTensor); + ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) ); + + ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) ); // apply sigmoid to the filtering output data // and remove edges with score < filterCut @@ -167,28 +133,18 @@ void InDet::SiGNNTrackFinderTool::getTracks ( // ************ // GNN // ************ - std::vector<const char*> gInputNames{"g_nodes", "g_edges"}; std::vector<Ort::Value> gInputTensor; gInputTensor.push_back( std::move(fInputTensor[0]) ); - std::vector<int64_t> gEdgeShape{2, numEdgesAfterF}; - gInputTensor.push_back( - Ort::Value::CreateTensor<int64_t>( - memoryInfo, edgesAfterFiltering.data(), edgesAfterFiltering.size(), - gEdgeShape.data(), gEdgeShape.size()) - ); + ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) ); + // gnn outputs - std::vector<const char*> gOutputNames{"gnn_edge_score"}; - std::vector<float> gOutputData(numEdgesAfterF); - std::vector<int64_t> gOutputShape{numEdgesAfterF}; + std::vector<float> gOutputData; std::vector<Ort::Value> gOutputTensor; - gOutputTensor.push_back( - Ort::Value::CreateTensor<float>( - memoryInfo, gOutputData.data(), gOutputData.size(), - gOutputShape.data(), gOutputShape.size()) - ); - AthONNX::InferenceWithIOBinding(m_gnnSession, gInputNames, gInputTensor, gOutputNames, gOutputTensor); + ATH_CHECK( m_gnnSessionTool->addOutput(gOutputTensor, gOutputData, 0, numEdgesAfterF) ); + + ATH_CHECK( m_gnnSessionTool->inference(gInputTensor, gOutputTensor) ); // apply sigmoid to the gnn output data for(auto& v : gOutputData){ v = 1.f / (1.f + std::exp(-v)); @@ -200,7 +156,7 @@ void InDet::SiGNNTrackFinderTool::getTracks ( std::vector<int32_t> trackLabels(numSpacepoints); weaklyConnectedComponents<int64_t,float,int32_t>(numSpacepoints, rowIndices, colIndices, gOutputData, trackLabels); - if (trackLabels.size() == 0) return; + if (trackLabels.size() == 0) return StatusCode::SUCCESS; tracks.clear(); @@ -225,5 +181,6 @@ void InDet::SiGNNTrackFinderTool::getTracks ( existTrkIdx++; } } + return StatusCode::SUCCESS; } diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h index 031de24d76081988f20581184487062d5394b6aa..3fca4b225e9784e64db5ad9da983f30a283ed70a 100644 --- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h +++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h @@ -14,7 +14,7 @@ #include "InDetRecToolInterfaces/IGNNTrackFinder.h" // ONNX Runtime include(s). -#include "AthOnnxruntimeService/IONNXRuntimeSvc.h" +#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h" #include <core/session/onnxruntime_cxx_api.h> class MsgStream; @@ -46,7 +46,7 @@ namespace InDet{ * * @return */ - virtual void getTracks( + virtual StatusCode getTracks( const std::vector<const Trk::SpacePoint*>& spacepoints, std::vector<std::vector<uint32_t> >& tracks) const override; @@ -74,9 +74,15 @@ namespace InDet{ MsgStream& dumpevent (MsgStream& out) const; private: - std::unique_ptr< Ort::Session > m_embedSession; - std::unique_ptr< Ort::Session > m_filterSession; - std::unique_ptr< Ort::Session > m_gnnSession; + ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool { + this, "Embedding", "AthOnnx::OnnxRuntimeInferenceTool" + }; + ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool { + this, "Filtering", "AthOnnx::OnnxRuntimeInferenceTool" + }; + ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool { + this, "GNN", "AthOnnx::OnnxRuntimeInferenceTool" + }; }; diff --git a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx index 121a710a66eec9ff5b9d3230f0c4dfeb44dd8069..d7caaa467f6fe633ef9ae313a97dfff712c492ea 100644 --- a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx +++ b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx @@ -68,7 +68,7 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const getData(m_SpacePointsSCTKey); std::vector<std::vector<uint32_t> > TT; - m_gnnTrackFinder->getTracks(spacePoints, TT); + ATH_CHECK(m_gnnTrackFinder->getTracks(spacePoints, TT)); ATH_MSG_DEBUG("Obtained " << TT.size() << " Tracks"); diff --git a/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h b/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h index b16efa8f42e9d47abfe4aa00980e02774cb7bf9e..52bb639c4f6a83abda1b6417632dadf5b4918337 100644 --- a/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h +++ b/InnerDetector/InDetRecTools/InDetRecToolInterfaces/InDetRecToolInterfaces/IGNNTrackFinder.h @@ -38,7 +38,7 @@ namespace InDet { * @param tracks a list of track candidates in terms of spacepoint indices. * @return */ - virtual void getTracks( + virtual StatusCode getTracks( const std::vector<const Trk::SpacePoint*>& spacepoints, std::vector<std::vector<uint32_t> >& tracks) const=0;