Skip to content
Snippets Groups Projects
Commit ab053687 authored by Ricardo Woelker's avatar Ricardo Woelker Committed by Frank Winklmeier
Browse files

Use shared ONNX runtime in CaloMuonScoreTool

parent d6399fc1
No related branches found
No related tags found
No related merge requests found
Showing
with 23 additions and 165 deletions
...@@ -8,14 +8,13 @@ atlas_subdir( CaloTrkMuIdTools ) ...@@ -8,14 +8,13 @@ atlas_subdir( CaloTrkMuIdTools )
# External dependencies: # External dependencies:
find_package( CLHEP ) find_package( CLHEP )
find_package( ROOT COMPONENTS Core Tree MathCore Hist RIO pthread ) find_package( ROOT COMPONENTS Core Tree MathCore Hist RIO pthread )
find_package( onnxruntime )
# Component(s) in the package: # Component(s) in the package:
atlas_add_component( CaloTrkMuIdTools atlas_add_component( CaloTrkMuIdTools
src/*.cxx src/*.cxx
src/components/*.cxx src/components/*.cxx
INCLUDE_DIRS ${ROOT_INCLUDE_DIRS} ${CLHEP_INCLUDE_DIRS} ${ONNXRUNTIME_INCLUDE_DIRS} INCLUDE_DIRS ${ROOT_INCLUDE_DIRS} ${CLHEP_INCLUDE_DIRS}
LINK_LIBRARIES ${ROOT_LIBRARIES} ${CLHEP_LIBRARIES} ${ONNXRUNTIME_LIBRARIES} CaloEvent AthenaBaseComps StoreGateLib SGtests xAODTracking GaudiKernel ICaloTrkMuIdTools RecoToolInterfaces TrkExInterfaces CaloDetDescrLib CaloGeoHelpers CaloIdentifier CaloUtilsLib xAODCaloEvent ParticleCaloExtension TileDetDescr PathResolver TrkSurfaces TrkCaloExtension TrkEventPrimitives CaloTrackingGeometryLib ) LINK_LIBRARIES ${ROOT_LIBRARIES} ${CLHEP_LIBRARIES} CaloEvent AthenaBaseComps StoreGateLib SGtests xAODTracking GaudiKernel ICaloTrkMuIdTools RecoToolInterfaces TrkExInterfaces CaloDetDescrLib CaloGeoHelpers CaloIdentifier CaloUtilsLib xAODCaloEvent ParticleCaloExtension TileDetDescr PathResolver TrkSurfaces TrkCaloExtension TrkEventPrimitives CaloTrackingGeometryLib AthOnnxruntimeServiceLib)
# Install files from the package: # Install files from the package:
atlas_install_headers( CaloTrkMuIdTools ) atlas_install_headers( CaloTrkMuIdTools )
......
// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
#ifndef CALOTRKMUIDTOOLS_CALOMUONSCOREONNXRUNTIMESVC_H
#define CALOTRKMUIDTOOLS_CALOMUONSCOREONNXRUNTIMESVC_H
// Local include(s).
#include "ICaloTrkMuIdTools/ICaloMuonScoreONNXRuntimeSvc.h"
// Framework include(s).
#include "AthenaBaseComps/AthService.h"
// ONNX include(s).
#include <core/session/onnxruntime_cxx_api.h>
// System include(s).
#include <memory>
/// Service implementing @c ICaloMuonScoreONNXRuntimeSvc
///
/// This is a very simple implementation, just managing the lifetime
/// of some ONNX Runtime C++ objects.
///
/// Ported from Control/AthenaExamples/AthExOnnxRuntime (Ricardo Woelker <ricardo.woelker@cern.ch>)
///
/// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
///
class CaloMuonScoreONNXRuntimeSvc : public extends< AthService, ICaloMuonScoreONNXRuntimeSvc > {
public:
/// Inherit the base class's constructor
using extends::extends;
/// @name Function(s) inherited from @c Service
/// @{
/// Function initialising the service
virtual StatusCode initialize() override;
/// Function finalising the service
virtual StatusCode finalize() override;
/// @}
/// @name Function(s) inherited from @c ICaloMuonScoreONNXRuntimeSvc
/// @{
/// Return the ONNX Runtime environment object
virtual Ort::Env& env() const override;
/// @}
private:
/// Global runtime environment for ONNX Runtime
std::unique_ptr< Ort::Env > m_env;
}; // class CaloMuonScoreONNXRuntimeSvc
#endif // CALOTRKMUIDTOOLS_CALOMUONSCOREONNXRUNTIMESVC_H
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#define CALOTRKMUIDTOOLS_CALOMUONSCORETOOL_H #define CALOTRKMUIDTOOLS_CALOMUONSCORETOOL_H
#include "ICaloTrkMuIdTools/ICaloMuonScoreTool.h" #include "ICaloTrkMuIdTools/ICaloMuonScoreTool.h"
#include "ICaloTrkMuIdTools/ICaloMuonScoreONNXRuntimeSvc.h" #include "AthOnnxruntimeService/IONNXRuntimeSvc.h"
#include "AthenaBaseComps/AthAlgTool.h" #include "AthenaBaseComps/AthAlgTool.h"
#include "GaudiKernel/ToolHandle.h" #include "GaudiKernel/ToolHandle.h"
#include "GaudiKernel/ServiceHandle.h" #include "GaudiKernel/ServiceHandle.h"
...@@ -67,27 +67,20 @@ public: ...@@ -67,27 +67,20 @@ public:
std::vector<float> getInputTensor(std::vector<float> &eta, std::vector<float> &phi, std::vector<float> &energy, std::vector<int> &sampling) const; std::vector<float> getInputTensor(std::vector<float> &eta, std::vector<float> &phi, std::vector<float> &energy, std::vector<int> &sampling) const;
private: private:
// Number of bins in eta
int m_etaBins = 30;
// Number of bins in phi Gaudi::Property<float> m_CaloCellAssociationConeSize {this, "CaloCellAssociationConeSize", 0.2, "Size of the cone within which calo cells are associated with a track particle"};
int m_phiBins = 30; Gaudi::Property<int> m_etaBins {this, "etaBins", 30, "Number of bins in eta"};
Gaudi::Property<int> m_phiBins {this, "phiBins", 30, "Number of bins in phi"};
// window in terms of abs(eta) to consider around the median eta value Gaudi::Property<float> m_etaCut {this, "etaCut", 0.25, "Eta cut on the calorimeter cells associated with the track particle after centering of the calorimeter image"};
float m_etaCut = 0.25; Gaudi::Property<float> m_phiCut {this, "phiCut", 0.25, "Phi cut on the calorimeter cells associated with the track particle after centering of the calorimeter image"};
Gaudi::Property<int> m_nChannels {this, "nChannels", 7, "Number of colour channels in the convolutional neural network"};
// window in terms of abs(phi) to consider around the median phi value
float m_phiCut = 0.25;
// Number of colour channels to consider in the convolutional neural network
int m_nChannels = 7;
ToolHandle <Rec::IParticleCaloCellAssociationTool> m_caloCellAssociationTool{this, "ParticleCaloCellAssociationTool", ""}; ToolHandle <Rec::IParticleCaloCellAssociationTool> m_caloCellAssociationTool{this, "ParticleCaloCellAssociationTool", ""};
/// Handle to @c IONNXRuntimeSvc /// Handle to @c AthONNX::IONNXRuntimeSvc
ServiceHandle< ICaloMuonScoreONNXRuntimeSvc > m_svc{ this, "CaloMuonScoreONNXRuntimeSvc", ServiceHandle< AthONNX::IONNXRuntimeSvc > m_svc{ this, "ONNXRuntimeSvc",
"CaloMuonScoreONNXRuntimeSvc", "AthONNX::ONNXRuntimeSvc",
"Name of the service to use" }; "CaloMuonScoreTool ONNXRuntimeSvc" };
std::unique_ptr< Ort::Session > m_session; std::unique_ptr< Ort::Session > m_session;
......
...@@ -21,10 +21,6 @@ ...@@ -21,10 +21,6 @@
calorimeter cell energy deposits using a convolutional calorimeter cell energy deposits using a convolutional
neural network. neural network.
@section CaloTrkMuIdTools_CaloMuonScoreONNXRuntimeSvcIntroduction CaloMuonScoreONNXRuntimeSvc
Service that maintains a ONNX session which holds
a tensorflow model and can perform inference on it.
@section CaloTrkMuIdTools_CaloMuonTagIntroduction CaloMuonTag @section CaloTrkMuIdTools_CaloMuonTagIntroduction CaloMuonTag
Muon tagger using calorimeter deposits. Muon tagger using calorimeter deposits.
A track is tagged when deposits above the noise treshold are found in the A track is tagged when deposits above the noise treshold are found in the
......
// Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
// Local include(s).
#include "CaloTrkMuIdTools/CaloMuonScoreONNXRuntimeSvc.h"
StatusCode CaloMuonScoreONNXRuntimeSvc::initialize() {
// Create the environment object.
m_env = std::make_unique< Ort::Env >( ORT_LOGGING_LEVEL_WARNING,
name().c_str() );
ATH_MSG_DEBUG( "Ort::Env object created" );
// Return gracefully.
return StatusCode::SUCCESS;
}
StatusCode CaloMuonScoreONNXRuntimeSvc::finalize() {
// Dekete the environment object.
m_env.reset();
ATH_MSG_DEBUG( "Ort::Env object deleted" );
// Return gracefully.
return StatusCode::SUCCESS;
}
Ort::Env& CaloMuonScoreONNXRuntimeSvc::env() const {
return *m_env;
}
...@@ -146,8 +146,10 @@ float CaloMuonScoreTool::getMuonScore( const xAOD::TrackParticle* trk ) const { ...@@ -146,8 +146,10 @@ float CaloMuonScoreTool::getMuonScore( const xAOD::TrackParticle* trk ) const {
ATH_MSG_DEBUG("Calculating muon score for track particle with eta="<<track_eta); ATH_MSG_DEBUG("Calculating muon score for track particle with eta="<<track_eta);
// - associate calocells to trackparticle, cone size 0.2, use cache ATH_MSG_DEBUG("Finding calo cell association for track particle within cone of delta R="<<m_CaloCellAssociationConeSize);
std::unique_ptr<const Rec::ParticleCellAssociation> association = m_caloCellAssociationTool->particleCellAssociation(*trk,0.2,nullptr);
// - associate calocells to trackparticle
std::unique_ptr<const Rec::ParticleCellAssociation> association = m_caloCellAssociationTool->particleCellAssociation(*trk,m_CaloCellAssociationConeSize,nullptr);
if(!association){ if(!association){
ATH_MSG_VERBOSE("Could not get particleCellAssociation"); ATH_MSG_VERBOSE("Could not get particleCellAssociation");
return -1.; return -1.;
......
...@@ -3,12 +3,9 @@ ...@@ -3,12 +3,9 @@
#include "CaloTrkMuIdTools/TrackDepositInCaloTool.h" #include "CaloTrkMuIdTools/TrackDepositInCaloTool.h"
#include "CaloTrkMuIdTools/CaloMuonLikelihoodTool.h" #include "CaloTrkMuIdTools/CaloMuonLikelihoodTool.h"
#include "CaloTrkMuIdTools/CaloMuonScoreTool.h" #include "CaloTrkMuIdTools/CaloMuonScoreTool.h"
#include "CaloTrkMuIdTools/CaloMuonScoreONNXRuntimeSvc.h"
DECLARE_COMPONENT( CaloMuonTag ) DECLARE_COMPONENT( CaloMuonTag )
DECLARE_COMPONENT( TrackEnergyInCaloTool ) DECLARE_COMPONENT( TrackEnergyInCaloTool )
DECLARE_COMPONENT( TrackDepositInCaloTool ) DECLARE_COMPONENT( TrackDepositInCaloTool )
DECLARE_COMPONENT( CaloMuonLikelihoodTool ) DECLARE_COMPONENT( CaloMuonLikelihoodTool )
DECLARE_COMPONENT( CaloMuonScoreTool ) DECLARE_COMPONENT( CaloMuonScoreTool )
DECLARE_COMPONENT( CaloMuonScoreONNXRuntimeSvc )
...@@ -5,10 +5,7 @@ ...@@ -5,10 +5,7 @@
# Declare the package name: # Declare the package name:
atlas_subdir( ICaloTrkMuIdTools ) atlas_subdir( ICaloTrkMuIdTools )
find_package( onnxruntime )
# Component(s) in the package: # Component(s) in the package:
atlas_add_library( ICaloTrkMuIdTools atlas_add_library( ICaloTrkMuIdTools
PUBLIC_HEADERS ICaloTrkMuIdTools PUBLIC_HEADERS ICaloTrkMuIdTools
INCLUDE_DIRS ${ONNXRUNTIME_INCLUDE_DIRS} LINK_LIBRARIES CaloEvent CaloIdentifier xAODCaloEvent xAODTracking GaudiKernel muonEvent TrkSurfaces TrkEventPrimitives TrkParameters TrkTrack CaloDetDescrLib )
LINK_LIBRARIES ${ONNXRUNTIME_LIBRARIES} CaloEvent CaloIdentifier xAODCaloEvent xAODTracking GaudiKernel muonEvent TrkSurfaces TrkEventPrimitives TrkParameters TrkTrack CaloDetDescrLib )
#ifndef CALOTRKMUIDTOOLS_ICALOMUONSCOREONNXRUNTIMESVC_H
#define CALOTRKMUIDTOOLS_ICALOMUONSCOREONNXRUNTIMESVC_H
// Gaudi include(s).
#include "GaudiKernel/IService.h"
// ONNX include(s).
#include <core/session/onnxruntime_cxx_api.h>
/// Namespace holding all of the ONNX Runtime example code
/// Service used for managing global objects used by ONNX Runtime
///
/// In order to allow multiple clients to use ONNX Runtime at the same
/// time, this service is used to manage the objects that must only
/// be created once in the Athena process.
///
/// Ported from Control/AthenaExamples/AthExOnnxRuntime (Ricardo Woelker <ricardo.woelker@cern.ch>)
///
/// @author Attila Krasznahorkay <Attila.Krasznahorkay@cern.ch>
///
class ICaloMuonScoreONNXRuntimeSvc : public virtual IService {
public:
/// Virtual destructor, to make vtable happy
virtual ~ICaloMuonScoreONNXRuntimeSvc() = default;
/// Declare an ID for this interface
DeclareInterfaceID( ICaloMuonScoreONNXRuntimeSvc, 1, 0 );
/// Return the ONNX Runtime environment object
virtual Ort::Env& env() const = 0;
}; // class ICaloMuonScoreONNXRuntimeSvc
#endif // CALOTRKMUIDTOOLS_ICALOMUONSCOREONNXRUNTIMESVC_H
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "ICaloTrkMuIdTools/ICaloMuonScoreTool.h" #include "ICaloTrkMuIdTools/ICaloMuonScoreTool.h"
#include "ICaloTrkMuIdTools/ICaloMuonTag.h" #include "ICaloTrkMuIdTools/ICaloMuonTag.h"
#include "ICaloTrkMuIdTools/ITrackDepositInCaloTool.h" #include "ICaloTrkMuIdTools/ITrackDepositInCaloTool.h"
#include "ICaloTrkMuIdTools/ICaloMuonScoreONNXRuntimeSvc.h"
#include "TrkToolInterfaces/ITrackSelectorTool.h" #include "TrkToolInterfaces/ITrackSelectorTool.h"
#include "StoreGate/ReadHandleKey.h" #include "StoreGate/ReadHandleKey.h"
...@@ -103,7 +102,6 @@ namespace MuonCombined { ...@@ -103,7 +102,6 @@ namespace MuonCombined {
// --- CaloTrkMuIdTools --- // --- CaloTrkMuIdTools ---
ToolHandle<ICaloMuonLikelihoodTool> m_caloMuonLikelihood{this,"CaloMuonLikelihoodTool","CaloMuonLikelihoodTool/CaloMuonLikelihoodTool"}; ToolHandle<ICaloMuonLikelihoodTool> m_caloMuonLikelihood{this,"CaloMuonLikelihoodTool","CaloMuonLikelihoodTool/CaloMuonLikelihoodTool"};
ToolHandle<ICaloMuonScoreTool> m_caloMuonScoreTool{this, "CaloMuonScoreTool", "CaloMuonScoreTool/CaloMuonScoreTool"}; ToolHandle<ICaloMuonScoreTool> m_caloMuonScoreTool{this, "CaloMuonScoreTool", "CaloMuonScoreTool/CaloMuonScoreTool"};
ServiceHandle<ICaloMuonScoreONNXRuntimeSvc> m_caloMuonScoreONNXRuntimeSvc{this, "CaloMuonScoreONNXRuntimeSvc", "CaloMuonScoreTool/CaloMuonScoreONNXRuntimeSvc"};
ToolHandle<ICaloMuonTag> m_caloMuonTagLoose{this,"CaloMuonTagLoose","CaloMuonTag/CaloMuonTagLoose","CaloTrkMuIdTools::CaloMuonTag for loose tagging"}; ToolHandle<ICaloMuonTag> m_caloMuonTagLoose{this,"CaloMuonTagLoose","CaloMuonTag/CaloMuonTagLoose","CaloTrkMuIdTools::CaloMuonTag for loose tagging"};
ToolHandle<ICaloMuonTag> m_caloMuonTagTight{this,"CaloMuonTagTight","CaloMuonTag/CaloMuonTag","CaloTrkMuIdTools::CaloMuonTag for tight tagging"}; ToolHandle<ICaloMuonTag> m_caloMuonTagTight{this,"CaloMuonTagTight","CaloMuonTag/CaloMuonTag","CaloTrkMuIdTools::CaloMuonTag for tight tagging"};
......
...@@ -873,8 +873,6 @@ def CaloMuonScoreToolCfg(flags, name='CaloMuonScoreTool', **kwargs ): ...@@ -873,8 +873,6 @@ def CaloMuonScoreToolCfg(flags, name='CaloMuonScoreTool', **kwargs ):
from TrackToCalo.TrackToCaloConfig import ParticleCaloCellAssociationToolCfg from TrackToCalo.TrackToCaloConfig import ParticleCaloCellAssociationToolCfg
result = ParticleCaloCellAssociationToolCfg(flags) result = ParticleCaloCellAssociationToolCfg(flags)
kwargs.setdefault("ParticleCaloCellAssociationTool", result.popPrivateTools()) kwargs.setdefault("ParticleCaloCellAssociationTool", result.popPrivateTools())
caloMuonScoreSvc = CompFactory.CaloMuonScoreONNXRuntimeSvc(name="CaloMuonScoreONNXRuntimeSvc")
result.addService(caloMuonScoreSvc)
tool = CompFactory.CaloMuonScoreTool(name, **kwargs ) tool = CompFactory.CaloMuonScoreTool(name, **kwargs )
result.setPrivateTools(tool) result.setPrivateTools(tool)
return result return result
......
...@@ -39,9 +39,6 @@ def TrackDepositInCaloTool( name ='TrackDepositInCaloTool', **kwargs ): ...@@ -39,9 +39,6 @@ def TrackDepositInCaloTool( name ='TrackDepositInCaloTool', **kwargs ):
kwargs.setdefault("ParticleCaloCellAssociationTool", caloCellAssociationTool ) kwargs.setdefault("ParticleCaloCellAssociationTool", caloCellAssociationTool )
return CfgMgr.TrackDepositInCaloTool(name,**kwargs) return CfgMgr.TrackDepositInCaloTool(name,**kwargs)
def CaloMuonScoreONNXRuntimeSvc(name='CaloMuonScoreONNXRuntimeSvc', **kwargs):
return CfgMgr.CaloMuonScoreONNXRuntimeSvc(name, **kwargs)
def CaloMuonLikelihoodTool(name='CaloMuonLikelihoodTool', **kwargs ): def CaloMuonLikelihoodTool(name='CaloMuonLikelihoodTool', **kwargs ):
kwargs.setdefault("ParticleCaloExtensionTool", getPublicTool("MuonParticleCaloExtensionTool") ) kwargs.setdefault("ParticleCaloExtensionTool", getPublicTool("MuonParticleCaloExtensionTool") )
return CfgMgr.CaloMuonLikelihoodTool(name,**kwargs) return CfgMgr.CaloMuonLikelihoodTool(name,**kwargs)
...@@ -50,7 +47,11 @@ def CaloMuonScoreTool(name='CaloMuonScoreTool', **kwargs ): ...@@ -50,7 +47,11 @@ def CaloMuonScoreTool(name='CaloMuonScoreTool', **kwargs ):
from TrackToCalo.TrackToCaloConf import Rec__ParticleCaloCellAssociationTool from TrackToCalo.TrackToCaloConf import Rec__ParticleCaloCellAssociationTool
caloCellAssociationTool = Rec__ParticleCaloCellAssociationTool(ParticleCaloExtensionTool = getPublicTool("MuonParticleCaloExtensionTool")) caloCellAssociationTool = Rec__ParticleCaloCellAssociationTool(ParticleCaloExtensionTool = getPublicTool("MuonParticleCaloExtensionTool"))
kwargs.setdefault("ParticleCaloCellAssociationTool", caloCellAssociationTool ) kwargs.setdefault("ParticleCaloCellAssociationTool", caloCellAssociationTool )
kwargs.setdefault("CaloMuonScoreONNXRuntimeSvc", getService("CaloMuonScoreONNXRuntimeSvc") )
from AthOnnxruntimeService.AthOnnxruntimeServiceConf import AthONNX__ONNXRuntimeSvc
onnxRuntimeSvc = AthONNX__ONNXRuntimeSvc( )
kwargs.setdefault("ONNXRuntimeSvc", onnxRuntimeSvc)
return CfgMgr.CaloMuonScoreTool(name,**kwargs) return CfgMgr.CaloMuonScoreTool(name,**kwargs)
def MuonCaloTagTool( name='MuonCaloTagTool', **kwargs ): def MuonCaloTagTool( name='MuonCaloTagTool', **kwargs ):
......
...@@ -99,7 +99,6 @@ addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonTagLoose","CaloMuonTagLo ...@@ -99,7 +99,6 @@ addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonTagLoose","CaloMuonTagLo
addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonTag","CaloMuonTag") addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonTag","CaloMuonTag")
addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonLikelihoodTool","CaloMuonLikelihoodTool") addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonLikelihoodTool","CaloMuonLikelihoodTool")
addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonScoreTool","CaloMuonScoreTool") addTool("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonScoreTool","CaloMuonScoreTool")
addService("MuonCombinedRecExample.MuonCaloTagTool.CaloMuonScoreONNXRuntimeSvc","CaloMuonScoreONNXRuntimeSvc")
####### muid tools ####### muid tools
addTool("MuonCombinedRecExample.MuonCombinedFitTools.MuonAlignmentUncertToolTheta","MuonAlignmentUncertToolTheta") addTool("MuonCombinedRecExample.MuonCombinedFitTools.MuonAlignmentUncertToolTheta","MuonAlignmentUncertToolTheta")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment