Skip to content
Snippets Groups Projects
Verified Commit 13fe8744 authored by Xiangyang Ju's avatar Xiangyang Ju
Browse files

path resolver

parent 1281974a
Branches onnx_infer
No related tags found
No related merge requests found
Pipeline #7722497 passed
......@@ -11,7 +11,7 @@ atlas_add_library( AthOnnxCompsLib
src/*.cxx
PUBLIC_HEADERS AthOnnxComps
INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS}
LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthOnnxInterfaces AthOnnxUtilsLib
LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} AsgTools AsgServicesLib AthOnnxInterfaces AthOnnxUtilsLib PathResolver
)
atlas_add_dictionary( AthOnnxCompsDict
......
......@@ -12,9 +12,7 @@ def OnnxRuntimeSessionToolCfg(flags,
""""Configure OnnxRuntimeSessionTool in Control/AthOnnx/AthOnnxComps/src"""
acc = ComponentAccumulator()
if model_fname is None:
raise ValueError("model_fname must be specified")
execution_provider = flags.AthOnnx.ExecutionProvider if execution_provider is None else execution_provider
name += execution_provider.name
......
......@@ -3,6 +3,7 @@
*/
#include "AthOnnxComps/OnnxRuntimeSessionToolCPU.h"
#include "PathResolver/PathResolver.h"
AthOnnx::OnnxRuntimeSessionToolCPU::OnnxRuntimeSessionToolCPU(const std::string& name )
: asg::AsgTool( name)
......@@ -26,7 +27,10 @@ StatusCode AthOnnx::OnnxRuntimeSessionToolCPU::initialize()
sessionOptions.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_ALL );
// Create the session.
m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
ATH_MSG_INFO("Asking model from: " << m_modelFileName.value());
std::string modelFilePath = PathResolver::find_file(m_modelFileName.value(), "CALIBPATH", PathResolver::RecursiveSearch);
ATH_MSG_INFO("Loading model from: " << modelFilePath);
m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFilePath.c_str(), sessionOptions);
return StatusCode::SUCCESS;
}
......
/*
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
*/
/*
Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
*/
#include "AthOnnxComps/OnnxRuntimeSessionToolCUDA.h"
#include "PathResolver/PathResolver.h"
AthOnnx::OnnxRuntimeSessionToolCUDA::OnnxRuntimeSessionToolCUDA(const std::string& name )
: asg::AsgTool(name)
......@@ -50,7 +48,10 @@ StatusCode AthOnnx::OnnxRuntimeSessionToolCUDA::initialize()
sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
// Create the session.
m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), m_modelFileName.value().c_str(), sessionOptions);
ATH_MSG_INFO("Asking model from: " << m_modelFileName.value());
std::string modelFilePath = PathResolver::find_file(m_modelFileName.value(), "CALIBPATH", PathResolver::RecursiveSearch);
ATH_MSG_INFO("Loading model from: " << modelFilePath);
m_session = std::make_unique<Ort::Session>(m_onnxRuntimeSvc->env(), modelFilePath.c_str(), sessionOptions);
return StatusCode::SUCCESS;
}
......
......@@ -9,14 +9,14 @@ from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
acc = ComponentAccumulator()
model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx"
model_fname = "dev/MLTest/2020-03-02/MNIST_testModel.onnx"
execution_provider = OnnxRuntimeType.CPU
from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider)
))
input_data = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
kwargs.setdefault("BatchSize", 3)
kwargs.setdefault("InputDataPixel", input_data)
kwargs.setdefault("OutputLevel", Constants.DEBUG)
......
......@@ -9,14 +9,14 @@ from AthOnnxComps.OnnxRuntimeFlags import OnnxRuntimeType
def AthExOnnxRuntimeExampleCfg(flags, name="AthOnnxExample", **kwargs):
acc = ComponentAccumulator()
model_fname = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-02/MNIST_testModel.onnx"
model_fname = "dev/MLTest/2020-03-02/MNIST_testModel.onnx"
execution_provider = OnnxRuntimeType.CPU
from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
kwargs.setdefault("ORTInferenceTool", acc.popToolsAndMerge(
OnnxRuntimeInferenceToolCfg(flags, model_fname, execution_provider)
))
input_data = "/cvmfs/atlas.cern.ch/repo/sw/database/GroupData/dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
input_data = "dev/MLTest/2020-03-31/t10k-images-idx3-ubyte"
kwargs.setdefault("BatchSize", 3)
kwargs.setdefault("InputDataPixel", input_data)
kwargs.setdefault("OutputLevel", Constants.DEBUG)
......
......@@ -6,6 +6,7 @@
// Framework include(s).
#include "AthOnnxUtils/OnnxUtils.h"
#include "EvaluateUtils.h"
#include "PathResolver/PathResolver.h"
namespace AthOnnx {
......@@ -23,9 +24,10 @@ namespace AthOnnx {
return StatusCode::FAILURE;
}
// read input file, and the target file for comparison.
ATH_MSG_INFO( "Using pixel file: " << m_pixelFileName.value() );
std::string pixelFilePath = PathResolver::find_file(m_pixelFileName.value(), "CALIBPATH", PathResolver::RecursiveSearch);
ATH_MSG_INFO( "Using pixel file: " << pixelFilePath );
m_input_tensor_values_notFlat = EvaluateUtils::read_mnist_pixel_notFlat(m_pixelFileName);
m_input_tensor_values_notFlat = EvaluateUtils::read_mnist_pixel_notFlat(pixelFilePath);
ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
return StatusCode::SUCCESS;
......
......@@ -6,6 +6,7 @@
// Framework include(s).
#include "AthOnnxUtils/OnnxUtils.h"
#include "EvaluateUtils.h"
#include "PathResolver/PathResolver.h"
namespace AthOnnx {
......@@ -18,9 +19,10 @@ StatusCode EvaluateModelWithAthInfer::initialize() {
return StatusCode::FAILURE;
}
// read input file, and the target file for comparison.
ATH_MSG_INFO( "Using pixel file: " << m_pixelFileName.value() );
std::string pixelFilePath = PathResolver::find_file(m_pixelFileName.value(), "CALIBPATH", PathResolver::RecursiveSearch);
ATH_MSG_INFO( "Using pixel file: " << pixelFilePath );
m_input_tensor_values_notFlat = EvaluateUtils::read_mnist_pixel_notFlat(m_pixelFileName);
m_input_tensor_values_notFlat = EvaluateUtils::read_mnist_pixel_notFlat(pixelFilePath);
ATH_MSG_INFO("Total no. of samples: "<<m_input_tensor_values_notFlat.size());
return StatusCode::SUCCESS;
......@@ -50,7 +52,7 @@ StatusCode EvaluateModelWithAthInfer::execute( [[maybe_unused]] const EventConte
ATH_CHECK(m_onnxTool->inference(inputData, outputData));
auto& outputScores = std::get<std::vector<float>>(outputData["dense_1/Softmax"].second);
ATH_MSG_INFO("Label for the input test data: ");
ATH_MSG_DEBUG("Label for the input test data: ");
for(int ibatch = 0; ibatch < m_batchSize; ibatch++){
float max = -999;
int max_index;
......@@ -62,7 +64,7 @@ StatusCode EvaluateModelWithAthInfer::execute( [[maybe_unused]] const EventConte
max_index = index;
}
}
ATH_MSG_INFO("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch);
ATH_MSG_DEBUG("Class: "<<max_index<<" has the highest score: "<<outputScores[max_index] << " in batch " << ibatch);
}
return StatusCode::SUCCESS;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment