Skip to content
Snippets Groups Projects
Commit 7f74bdea authored by Xiangyang Ju's avatar Xiangyang Ju Committed by Tadej Novak
Browse files

Infrastructure for Machine Learning inference with ONNX Runtime

Revert "remove redundant packages"

This reverts commit 0ac0c9a2.
parent bfe0e772
No related branches found
No related tags found
1 merge request!68093Infrastructure for Machine Learning inference with ONNX Runtime
Showing
with 815 additions and 0 deletions
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
# Declare the package's name.
atlas_subdir( AthOnnxComps )
# External dependencies.
find_package( onnxruntime )
# Libraray in the package.
atlas_add_component( AthOnnxComps
src/*.cxx
src/components/*.cxx
INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthenaBaseComps AthenaKernel GaudiKernel AthOnnxInterfaces AthOnnxUtilsLib
)
# install python modules
atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} )
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
from AthenaConfiguration.AthConfigFlags import AthConfigFlags
from AthenaConfiguration.Enums import FlagEnum
class OnnxRuntimeType(FlagEnum):
CPU = 'CPU'
CUDA = 'CUDA'
# possible future backends. Uncomment when implemented.
# DML = 'DML'
# DNNL = 'DNNL'
# NUPHAR = 'NUPHAR'
# OPENVINO = 'OPENVINO'
# ROCM = 'ROCM'
# TENSORRT = 'TENSORRT'
# VITISAI = 'VITISAI'
# VULKAN = 'VULKAN'
def createOnnxRuntimeFlags():
icf = AthConfigFlags()
icf.addFlag("AthOnnx.ExecutionProvider", OnnxRuntimeType.CPU, type=OnnxRuntimeType)
return icf
if __name__ == "__main__":
flags = createOnnxRuntimeFlags()
flags.dump()
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
from AthenaConfiguration.ComponentFactory import CompFactory
from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
from typing import Optional
from AthOnnxComps.OnnxRuntimeSessionConfig import OnnxRuntimeSessionToolCfg
def OnnxRuntimeInferenceToolCfg(flags,
model_fname: str = None,
execution_provider: Optional[OnnxRuntimeType] = None,
name="OnnxRuntimeInferenceTool", **kwargs):
"""Configure OnnxRuntimeInferenceTool in Control/AthOnnx/AthOnnxComps/src"""
acc = ComponentAccumulator()
session_tool = acc.popToolsAndMerge(OnnxRuntimeSessionToolCfg(flags, model_fname, execution_provider))
kwargs["ORTSessionTool"] = session_tool
acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeInferenceTool(name, **kwargs))
return acc
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
from AthenaConfiguration.ComponentFactory import CompFactory
from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
from typing import Optional
def OnnxRuntimeSessionToolCfg(flags,
model_fname: str,
execution_provider: Optional[OnnxRuntimeType] = None,
name="OnnxRuntimeSessionTool", **kwargs):
""""Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
acc = ComponentAccumulator()
if model_fname is None:
raise ValueError("model_fname must be specified")
execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
name += execution_provider.name
kwargs.setdefault("ModelFileName", model_fname)
if execution_provider is OnnxRuntimeType.CPU:
acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCPU(name, **kwargs))
elif execution_provider is OnnxRuntimeType.CUDA:
acc.setPrivateTools(CompFactory.AthOnnx.OnnxRuntimeSessionToolCUDA(name, **kwargs))
else:
raise ValueError("Unknown OnnxRuntime Execution Provider: %s" % execution_provider)
return acc
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
from AthenaConfiguration.ComponentFactory import CompFactory
def OnnxRuntimeSvcCfg(flags, name="OnnxRuntimeSvc", **kwargs):
acc = ComponentAccumulator()
acc.addService(CompFactory.AthOnnx.OnnxRuntimeSvc(name, **kwargs), primary=True, create=True)
return acc
# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
Control/AthOnnx/AthOnnxComps
/*
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
*/
#include "OnnxRuntimeInferenceTool.h"
#include "AthOnnxUtils/OnnxUtils.h"
AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool(
const std::string& type, const std::string& name, const IInterface* parent )
: base_class( type, name, parent )
{
declareInterface<IOnnxRuntimeInferenceTool>(this);
}
StatusCode AthOnnx::OnnxRuntimeInferenceTool::initialize()
{
// Get the Onnx Runtime service.
ATH_CHECK(m_onnxRuntimeSvc.retrieve());
// Create the session.
ATH_CHECK(m_onnxSessionTool.retrieve());
ATH_CHECK(getNodeInfo());
return StatusCode::SUCCESS;
}
StatusCode AthOnnx::OnnxRuntimeInferenceTool::finalize()
{
return StatusCode::SUCCESS;
}
StatusCode AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo()
{
auto& session = m_onnxSessionTool->session();
// obtain the model information
m_numInputs = session.GetInputCount();
m_numOutputs = session.GetOutputCount();
AthOnnx::getInputNodeInfo(session, m_inputShapes, m_inputNodeNames);
AthOnnx::getOutputNodeInfo(session, m_outputShapes, m_outputNodeNames);
return StatusCode::SUCCESS;
}
void AthOnnx::OnnxRuntimeInferenceTool::setBatchSize(int64_t batchSize)
{
if (batchSize <= 0) {
ATH_MSG_ERROR("Batch size should be positive");
return;
}
for (auto& shape : m_inputShapes) {
if (shape[0] == -1) {
shape[0] = batchSize;
}
}
for (auto& shape : m_outputShapes) {
if (shape[0] == -1) {
shape[0] = batchSize;
}
}
}
int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
{
auto tensorSize = AthOnnx::getTensorSize(m_inputShapes[idx]);
if (tensorSize < 0) {
return inputDataSize / abs(tensorSize);
} else {
return -1;
}
}
StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
{
assert (inputTensors.size() == m_numInputs);
assert (outputTensors.size() == m_numOutputs);
// Run the model.
AthOnnx::inferenceWithIOBinding(
m_onnxSessionTool->session(),
m_inputNodeNames, inputTensors,
m_outputNodeNames, outputTensors);
return StatusCode::SUCCESS;
}
void AthOnnx::OnnxRuntimeInferenceTool::printModelInfo() const
{
ATH_MSG_INFO("Number of inputs: " << m_numInputs);
ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
ATH_MSG_INFO("Input node names: ");
for (const auto& name : m_inputNodeNames) {
ATH_MSG_INFO("\t" << name);
}
ATH_MSG_INFO("Output node names: ");
for (const auto& name : m_outputNodeNames) {
ATH_MSG_INFO("\t" << name);
}
ATH_MSG_INFO("Input shapes: ");
for (const auto& shape : m_inputShapes) {
std::string shapeStr = "\t";
for (const auto& dim : shape) {
shapeStr += std::to_string(dim) + " ";
}
ATH_MSG_INFO(shapeStr);
}
ATH_MSG_INFO("Output shapes: ");
for (const auto& shape : m_outputShapes) {
std::string shapeStr = "\t";
for (const auto& dim : shape) {
shapeStr += std::to_string(dim) + " ";
}
ATH_MSG_INFO(shapeStr);
}
}
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
#ifndef OnnxRuntimeInferenceTool_H
#define OnnxRuntimeInferenceTool_H
#include "AthenaBaseComps/AthAlgTool.h"
#include "AthOnnxInterfaces/IOnnxRuntimeInferenceTool.h"
#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
#include "GaudiKernel/ServiceHandle.h"
#include "GaudiKernel/ToolHandle.h"
namespace AthOnnx {
// @class OnnxRuntimeInferenceTool
//
// @brief Tool to create Onnx Runtime session with CPU backend
//
// @author Xiangyang Ju <xiangyang.ju@cern.ch>
class OnnxRuntimeInferenceTool : public extends<AthAlgTool, IOnnxRuntimeInferenceTool>
{
public:
/// Standard constructor
OnnxRuntimeInferenceTool( const std::string& type,
const std::string& name,
const IInterface* parent );
virtual ~OnnxRuntimeInferenceTool() = default;
/// Initialize the tool
virtual StatusCode initialize() override;
/// Finalize the tool
virtual StatusCode finalize() override;
virtual void setBatchSize(int64_t batchSize) override final;
virtual int64_t getBatchSize(int64_t inputDataSize, int idx = 0) const override final;
virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const override final;
virtual void printModelInfo() const override final;
protected:
OnnxRuntimeInferenceTool() = delete;
OnnxRuntimeInferenceTool(const OnnxRuntimeInferenceTool&) = delete;
OnnxRuntimeInferenceTool& operator=(const OnnxRuntimeInferenceTool&) = delete;
private:
StatusCode getNodeInfo();
ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{
this, "ORTSessionTool",
"AthOnnx::OnnxRuntimeInferenceToolCPU"
};
std::vector<std::string> m_inputNodeNames;
std::vector<std::string> m_outputNodeNames;
};
} // namespace AthOnnx
#endif
/*
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
*/
#include "OnnxRuntimeSessionToolCPU.h"
AthOnnx::OnnxRuntimeSessionToolCPU::OnnxRuntimeSessionToolCPU(
const std::string& type, const std::string& name, const IInterface* parent )
: base_class( type, name, parent )
{
declareInterface<IOnnxRuntimeSessionTool>(this);
}
StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::initialize()
{
// Get the Onnx Runtime service.
ATH_CHECK(m_onnxRuntimeSvc.retrieve());
ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
// Create the session options.
// TODO: Make this configurable.
// other threading options: https://onnxruntime.ai/docs/performance/tune-performance/threading.html
// 1) SetIntraOpNumThreads( 1 );
// 2) SetInterOpNumThreads( 1 );
// 3) SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
Ort::SessionOptions sessionOptions;
sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
// Create the session.
m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
return StatusCode::SUCCESS;
}
StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::finalize()
{
m_session.reset();
ATH_MSG_DEBUG( "Ort::Session object deleted" );
return StatusCode::SUCCESS;
}
Ort::Session& AthOnnx::OnnxRuntimeSessionToolCPU::session() const
{
return *m_session;
}
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
#ifndef OnnxRuntimeSessionToolCPU_H
#define OnnxRuntimeSessionToolCPU_H
#include "AthenaBaseComps/AthAlgTool.h"
#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
#include "GaudiKernel/ServiceHandle.h"
namespace AthOnnx {
// @class OnnxRuntimeSessionToolCPU
//
// @brief Tool to create Onnx Runtime session with CPU backend
//
// @author Xiangyang Ju <xiangyang.ju@cern.ch>
class OnnxRuntimeSessionToolCPU : public extends<AthAlgTool, IOnnxRuntimeSessionTool>
{
public:
/// Standard constructor
OnnxRuntimeSessionToolCPU( const std::string& type,
const std::string& name,
const IInterface* parent );
virtual ~OnnxRuntimeSessionToolCPU() = default;
/// Initialize the tool
virtual StatusCode initialize() override final;
/// Finalize the tool
virtual StatusCode finalize() override final;
/// Create Onnx Runtime session
virtual Ort::Session& session() const override final;
protected:
OnnxRuntimeSessionToolCPU() = delete;
OnnxRuntimeSessionToolCPU(const OnnxRuntimeSessionToolCPU&) = delete;
OnnxRuntimeSessionToolCPU& operator=(const OnnxRuntimeSessionToolCPU&) = delete;
private:
StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
std::unique_ptr<Ort::Session> m_session;
};
}
#endif
/*
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
*/
/*
Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
*/
#include "OnnxRuntimeSessionToolCUDA.h"
AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA(
const std::string& type, const std::string& name, const IInterface* parent )
: base_class( type, name, parent )
{
declareInterface<IOnnxRuntimeSessionTool>(this);
}
StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::initialize()
{
// Get the Onnx Runtime service.
ATH_CHECK(m_onnxRuntimeSvc.retrieve());
ATH_MSG_INFO(" OnnxRuntime release: " << OrtGetApiBase()->GetVersionString());
// Create the session options.
Ort::SessionOptions sessionOptions;
sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED );
sessionOptions.DisablePerSessionThreads(); // use global thread pool.
// TODO: add more cuda options to the interface
// https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cc
// Options: https://onnxruntime.ai/docs/api/c/struct_ort_c_u_d_a_provider_options.html
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id = m_deviceId;
cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
cuda_options.gpu_mem_limit = std::numeric_limits<size_t>::max();
// memorry arena options for cuda memory shrinkage
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/utils.cc#L7
if (m_enableMemoryShrinkage) {
Ort::ArenaCfg arena_cfg{0, 0, 1024, 0};
// other options are not available in this release.
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/shared_lib/test_inference.cc#L2802C21-L2802C21
// arena_cfg.max_mem = 0; // let ORT pick default max memory
// arena_cfg.arena_extend_strategy = 0; // 0: kNextPowerOfTwo, 1: kSameAsRequested
// arena_cfg.initial_chunk_size_bytes = 1024;
// arena_cfg.max_dead_bytes_per_chunk = 0;
// arena_cfg.initial_growth_chunk_size_bytes = 256;
// arena_cfg.max_power_of_two_extend_bytes = 1L << 24;
cuda_options.default_memory_arena_cfg = arena_cfg;
}
sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
// Create the session.
m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
return StatusCode::SUCCESS;
}
StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::finalize()
{
m_session.reset();
ATH_MSG_DEBUG( "Ort::Session object deleted" );
return StatusCode::SUCCESS;
}
Ort::Session& AthOnnx::OnnxRuntimeSessionToolCUDA::session() const
{
return *m_session;
}
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
#ifndef OnnxRuntimeSessionToolCUDA_H
#define OnnxRuntimeSessionToolCUDA_H
#include "AthenaBaseComps/AthAlgTool.h"
#include "AthOnnxInterfaces/IOnnxRuntimeSessionTool.h"
#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
#include "GaudiKernel/ServiceHandle.h"
namespace AthOnnx {
// @class OnnxRuntimeSessionToolCUDA
//
// @brief Tool to create Onnx Runtime session with CUDA backend
//
// @author Xiangyang Ju <xiangyang.ju@cern.ch>
class OnnxRuntimeSessionToolCUDA : public extends<AthAlgTool, IOnnxRuntimeSessionTool>
{
public:
/// Standard constructor
OnnxRuntimeSessionToolCUDA( const std::string& type,
const std::string& name,
const IInterface* parent );
virtual ~OnnxRuntimeSessionToolCUDA() = default;
/// Initialize the tool
virtual StatusCode initialize() override final;
/// Finalize the tool
virtual StatusCode finalize() override final;
/// Create Onnx Runtime session
virtual Ort::Session& session() const override final;
protected:
OnnxRuntimeSessionToolCUDA() = delete;
OnnxRuntimeSessionToolCUDA(const OnnxRuntimeSessionToolCUDA&) = delete;
OnnxRuntimeSessionToolCUDA& operator=(const OnnxRuntimeSessionToolCUDA&) = delete;
private:
StringProperty m_modelFileName{this, "ModelFileName", "", "The model file name"};
/// The device ID to use.
IntegerProperty m_deviceId{this, "DeviceId", 0};
BooleanProperty m_enableMemoryShrinkage{this, "EnableMemoryShrinkage", false};
/// runtime service
ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"};
std::unique_ptr<Ort::Session> m_session;
};
}
#endif
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
// Local include(s).
#include "OnnxRuntimeSvc.h"
#include <core/session/onnxruntime_c_api.h>
namespace AthOnnx {
OnnxRuntimeSvc::OnnxRuntimeSvc(const std::string& name, ISvcLocator* svc) :
asg::AsgService(name, svc)
{
declareServiceInterface<AthOnnx::IOnnxRuntimeSvc>();
}
StatusCode OnnxRuntimeSvc::initialize() {
// Create the environment object.
Ort::ThreadingOptions tp_options;
tp_options.SetGlobalIntraOpNumThreads(1);
tp_options.SetGlobalInterOpNumThreads(1);
m_env = std::make_unique< Ort::Env >(
tp_options, ORT_LOGGING_LEVEL_WARNING, name().c_str());
ATH_MSG_DEBUG( "Ort::Env object created" );
// Return gracefully.
return StatusCode::SUCCESS;
}
StatusCode OnnxRuntimeSvc::finalize() {
// Dekete the environment object.
m_env.reset();
ATH_MSG_DEBUG( "Ort::Env object deleted" );
// Return gracefully.
return StatusCode::SUCCESS;
}
Ort::Env& OnnxRuntimeSvc::env() const {
return *m_env;
}
} // namespace AthOnnx
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
#ifndef ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
#define ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
// Local include(s).
#include "AthOnnxInterfaces/IOnnxRuntimeSvc.h"
// Framework include(s).
#include <AsgServices/AsgService.h>
// ONNX include(s).
#include <core/session/onnxruntime_cxx_api.h>
// System include(s).
#include <memory>
namespace AthOnnx {
/// Service implementing @c AthOnnx::IOnnxRuntimeSvc
///
/// This is a very simple implementation, just managing the lifetime
/// of some Onnx Runtime C++ objects.
///
/// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
///
class OnnxRuntimeSvc : public asg::AsgService, virtual public IOnnxRuntimeSvc {
public:
/// @name Function(s) inherited from @c Service
/// @{
OnnxRuntimeSvc (const std::string& name, ISvcLocator* svc);
/// Function initialising the service
virtual StatusCode initialize() override;
/// Function finalising the service
virtual StatusCode finalize() override;
/// @}
/// @name Function(s) inherited from @c AthOnnx::IOnnxRuntimeSvc
/// @{
/// Return the Onnx Runtime environment object
virtual Ort::Env& env() const override;
/// @}
private:
/// Global runtime environment for Onnx Runtime
std::unique_ptr< Ort::Env > m_env;
}; // class OnnxRuntimeSvc
} // namespace AthOnnx
#endif // ATHONNXRUNTIMESERVICE_ONNXRUNTIMESVC_H
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
// Local include(s).
#include "../OnnxRuntimeSvc.h"
#include "../OnnxRuntimeSessionToolCPU.h"
#include "../OnnxRuntimeSessionToolCUDA.h"
#include "../OnnxRuntimeInferenceTool.h"
// Declare the package's components.
DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSvc )
DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCPU )
DECLARE_COMPONENT( AthOnnx::OnnxRuntimeSessionToolCUDA )
DECLARE_COMPONENT( AthOnnx::OnnxRuntimeInferenceTool )
Control/AthOnnx/AthOnnxInterfaces
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
#ifndef AthOnnx_IOnnxRuntimeInferenceTool_H
#define AthOnnx_IOnnxRuntimeInferenceTool_H
// Gaudi include(s).
#include "GaudiKernel/IAlgTool.h"
#include <memory>
#include <core/session/onnxruntime_cxx_api.h>
namespace AthOnnx {
/**
* @class IOnnxRuntimeInferenceTool
* @brief Interface class for creating Onnx Runtime sessions.
* @details Interface class for creating Onnx Runtime sessions.
* It is thread safe, supports models with various number of inputs and outputs,
* supports models with dynamic batch size, and usess . It defines a standardized procedure to
* perform Onnx Runtime inference. The procedure is as follows, assuming the tool `m_onnxTool` is created and initialized:
* 1. create input tensors from the input data:
* ```c++
* std::vector<Ort::Value> inputTensors;
* std::vector<float> inputData_1; // The input data is filled by users, possibly from the event information.
* int64_t batchSize = m_onnxTool->getBatchSize(inputData_1.size(), 0); // The batch size is determined by the input data size to support dynamic batch size.
* m_onnxTool->addInput(inputTensors, inputData_1, 0, batchSize);
* std::vector<int64_t> inputData_2; // Some models may have multiple inputs. Add inputs one by one.
* int64_t batchSize_2 = m_onnxTool->getBatchSize(inputData_2.size(), 1);
* m_onnxTool->addInput(inputTensors, inputData_2, 1, batchSize_2);
* ```
* 2. create output tensors:
* ```c++
* std::vector<Ort::Value> outputTensors;
* std::vector<float> outputData; // The output data will be filled by the onnx session.
* m_onnxTool->addOutput(outputTensors, outputData, 0, batchSize);
* ```
* 3. perform inference:
* ```c++
* m_onnxTool->inference(inputTensors, outputTensors);
* ```
* 4. Model outputs will be automatically filled to outputData.
*
*
* @author Xiangyang Ju <xju@cern.ch>
*/
class IOnnxRuntimeInferenceTool : virtual public IAlgTool
{
public:
virtual ~IOnnxRuntimeInferenceTool() = default;
// @name InterfaceID
DeclareInterfaceID(IOnnxRuntimeInferenceTool, 1, 0);
/**
* @brief set batch size.
* @details If the model has dynamic batch size,
* the batchSize value will be set to both input shapes and output shapes
*/
virtual void setBatchSize(int64_t batchSize) = 0;
/**
* @brief methods for determining batch size from the data size
* @param dataSize the size of the input data, like std::vector<T>::size()
* @param idx the index of the input node
* @return the batch size, which equals to dataSize / size of the rest dimensions.
*/
virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0;
/**
* @brief add the input data to the input tensors
* @param inputTensors the input tensor container
* @param data the input data
* @param idx the index of the input node
* @param batchSize the batch size
* @return StatusCode::SUCCESS if the input data is added successfully
*/
template <typename T>
StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
/**
* @brief add the output data to the output tensors
* @param outputTensors the output tensor container
* @param data the output data
* @param idx the index of the output node
* @param batchSize the batch size
* @return StatusCode::SUCCESS if the output data is added successfully
*/
template <typename T>
StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx = 0, int64_t batchSize = -1) const;
/**
* @brief perform inference
* @param inputTensors the input tensor container
* @param outputTensors the output tensor container
* @return StatusCode::SUCCESS if the inference is performed successfully
*/
virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0;
virtual void printModelInfo() const = 0;
protected:
int m_numInputs;
int m_numOutputs;
std::vector<std::vector<int64_t> > m_inputShapes;
std::vector<std::vector<int64_t> > m_outputShapes;
private:
template <typename T>
Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const;
};
#include "IOnnxRuntimeInferenceTool.icc"
} // namespace AthOnnx
#endif
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
template <typename T>
Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
{
std::vector<int64_t> dataShapeCopy = dataShape;
if (batchSize > 0) {
for (auto& shape: dataShapeCopy) {
if (shape == -1) {
shape = batchSize;
break;
}
}
}
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
return Ort::Value::CreateTensor<T>(
memoryInfo, data.data(), data.size(), dataShapeCopy.data(), dataShapeCopy.size()
);
}
template <typename T>
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
{
if (idx >= m_numInputs || idx < 0) {
return StatusCode::FAILURE;
}
inputTensors.push_back(std::move(createTensor(data, m_inputShapes[idx], batchSize)));
return StatusCode::SUCCESS;
}
template <typename T>
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, int idx, int64_t batchSize) const
{
if (idx >= m_numOutputs || idx < 0) {
return StatusCode::FAILURE;
}
auto tensorSize = std::accumulate(m_outputShapes[idx].begin(), m_outputShapes[idx].end(), 1, std::multiplies<int64_t>());
if (tensorSize < 0) {
tensorSize = abs(tensorSize) * batchSize;
}
data.resize(tensorSize);
outputTensors.push_back(std::move(createTensor(data, m_outputShapes[idx], batchSize)));
return StatusCode::SUCCESS;
}
// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
#ifndef AthOnnx_IOnnxRUNTIMESESSIONTool_H
#define AthOnnx_IOnnxRUNTIMESESSIONTool_H
// Gaudi include(s).
#include "GaudiKernel/IAlgTool.h"
#include <core/session/onnxruntime_cxx_api.h>
namespace AthOnnx {
// class IAlgTool
//
// Interface class for creating Onnx Runtime sessions.
//
// @author Xiangyang Ju <xju@cern.ch>
//
class IOnnxRuntimeSessionTool : virtual public IAlgTool
{
public:
virtual ~IOnnxRuntimeSessionTool() = default;
// @name InterfaceID
DeclareInterfaceID(IOnnxRuntimeSessionTool, 1, 0);
// Create Onnx Runtime session
virtual Ort::Session& session() const = 0;
};
} // namespace AthOnnx
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment