From 335f3bb4edb02f73de50e2ff3da17fedb8d570f5 Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Tue, 18 Feb 2025 11:30:43 +0100 Subject: [PATCH 1/9] First inference infrastructure --- .../MuonBucketGraph/FCGraphMaker.cpp | 10 + .../MuonBucketGraph/FCGraphMaker.h | 21 + .../MuonBucketGraph/GraphFramer.h | 14 + .../MuonLearning/MuonInference/CMakeLists.txt | 16 + .../MuonInference/InferenceInterface.cpp | 10 + .../MuonInference/InferenceInterface.h | 17 + .../MuonInference/python/InferenceConfig.py | 11 + .../src/GraphInferenceToolBase.cxx | 123 ++++++ .../src/GraphInferenceToolBase.h | 55 +++ .../MuonInference/src/InferenceAlg.cxx | 24 ++ .../MuonInference/src/InferenceAlg.h | 25 ++ .../src/components/MuonInference_entries.cxx | 8 + .../MuonInferenceInterfaces/CMakeLists.txt | 17 + .../MuonInferenceInterfaces/GraphData.h | 50 +++ .../IGraphInferenceTool.h | 22 + .../MuonInferenceInterfaces/LayerBucket.h | 35 ++ .../MuonInferenceInterfaces/NodeConnector.h | 50 +++ .../MuonInferenceInterfaces/NodeFeature.h | 50 +++ .../NodeFeatureFactory.h | 26 ++ .../MuonInferenceInterfaces/NodeFeatureList.h | 63 +++ .../src/LayerBucket.cxx | 32 ++ .../src/NodeFeatureFactory.cxx | 91 ++++ .../src/NodeFeatureList.cxx | 94 +++++ .../MuonLearning/MuonSPId/CMakeLists.txt | 24 ++ .../MuonSPId/python/MuonSPIdDumpConfig.py | 17 + .../MuonSPId/python/muonSPIdDump.py | 43 ++ .../MuonSPId/python/muonSPIdDump_data.py | 38 ++ .../MuonSPId/src/SPIdDumperAlg.cxx | 399 ++++++++++++++++++ .../MuonLearning/MuonSPId/src/SPIdDumperAlg.h | 95 +++++ .../MuonSPId/src/SPIdentifierAlg.cxx | 340 +++++++++++++++ .../MuonSPId/src/SPIdentifierAlg.h | 91 ++++ .../src/components/MuonSPId_entries.cxx | 7 + 32 files changed, 1918 insertions(+) create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp new file mode 100644 index 000000000000..becdc324c032 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp @@ -0,0 +1,10 @@ +#include "FCGraphMaker.h" + +FCGraphMaker::FCGraphMaker(const SpacePointBucket* bucket, const GraphFramer& graphType, const IMuonIdHelperSvc* idHelperSvc) + : m_bucket(bucket), m_graphType(graphType), m_idHelperSvc(idHelperSvc) {} + +void FCGraphMaker::extractGraphData(std::vector<float>& features, std::vector<int64_t>& edge_src, std::vector<int64_t>& edge_dst) { + // implementation + +} + diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h new file mode 100644 index 000000000000..f3924b6876d3 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h @@ -0,0 +1,21 @@ +#ifndef FC_GRAPH_MAKER_H +#define FC_GRAPH_MAKER_H + +#include "GraphFramer.h" +#include "MuonIdHelpers/IMuonIdHelperSvc.h" +#include <vector> +#include <cstdint> + +class SpacePointBucket; + +class FCGraphMaker { +public: + FCGraphMaker(const SpacePointBucket* bucket, const GraphFramer& graphType, const IMuonIdHelperSvc* idHelperSvc); + void extractGraphData(std::vector<float>& features, std::vector<int64_t>& edge_src, std::vector<int64_t>& edge_dst); +private: + const SpacePointBucket* m_bucket; + GraphFramer m_graphType; + const IMuonIdHelperSvc* m_idHelperSvc; +}; + +#endif // FC_GRAPH_MAKER_H diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h new file mode 100644 index 000000000000..acf92d271a10 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h @@ -0,0 +1,14 @@ +#ifndef GRAPH_FRAMER_H +#define GRAPH_FRAMER_H + +class GraphFramer { +public: + GraphFramer(bool sparse, bool classification); + bool isSparse() const; + bool isClassification() const; +private: + bool m_sparse; + bool m_classification; +}; + +#endif // GRAPH_FRAMER_H diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt new file mode 100644 index 000000000000..c78436b8cd0b --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration + +################################################################################ +# Package: MuonInferenceInterfaces +################################################################################ + +# Declare the package name: +atlas_subdir( MuonInference ) + + +find_package( onnxruntime REQUIRED) + +atlas_add_component( MuonInference + src/components/*.cxx src/*.cxx + LINK_LIBRARIES AthenaKernel StoreGateLib MuonInferenceInterfaces PathResolver + ${ONNXRUNTIME_LIBRARIES}) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp new file mode 100644 index 000000000000..3115feae747a --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp @@ -0,0 +1,10 @@ +#include "InferenceInterface.h" + +InferenceInterface::InferenceInterface(const std::string& model_path) + : m_env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"), + m_session(m_env, model_path.c_str(), Ort::SessionOptions{}) {} + +std::vector<float> InferenceInterface::runInference(const std::vector<float>& features, const std::vector<int64_t>& edge_index) { + // Implement ONNX model inference logic + return {}; // Placeholder return +} diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h new file mode 100644 index 000000000000..f86138bd8a7b --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h @@ -0,0 +1,17 @@ +#ifndef INFERENCE_INTERFACE_H +#define INFERENCE_INTERFACE_H + +#include <vector> +#include <string> +#include <onnxruntime/core/session/onnxruntime_cxx_api.h> + +class InferenceInterface { +public: + InferenceInterface(const std::string& model_path); + std::vector<float> runInference(const std::vector<float>& features, const std::vector<int64_t>& edge_index); +private: + Ort::Session m_session; + Ort::Env m_env; +}; + +#endif // INFERENCE_INTERFACE_H diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py new file mode 100644 index 000000000000..bee02114be11 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py @@ -0,0 +1,11 @@ +# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration + +from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator +from AthenaConfiguration.ComponentFactory import CompFactory + +def GraphAssemblyToolCfg(flags, name = "GraphAssemblyTool", **kwargs): + result = ComponentAccumulator() + the_tool = CompFactory.MuonML.GraphAssemblyTool(name, **kwargs) + result.setPrivateTools(the_tool) + return result + diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx new file mode 100644 index 000000000000..51e666263121 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx @@ -0,0 +1,123 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#include "GraphInferenceToolBase.h" + +#include "MuonInferenceInterfaces/GraphData.h" +#include "MuonInferenceInterfaces/NodeFeatureList.h" +#include "MuonPatternHelpers/MatrixUtils.h" + + +#include "PathResolver/PathResolver.h" +namespace MuonML{ + StatusCode GraphInferenceToolBase::setupModel() { + const std::string modelPath = PathResolver::FindCalibFile(m_modelPath); + if (modelPath.empty()) { + ATH_MSG_FATAL("No such file or directory "<<m_modelPath<<"."); + return StatusCode::FAILURE; + } + try { + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"); + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + + m_model = std::make_unique<Ort::Session>(env, modelPath.c_str(), session_options); + ATH_MSG_DEBUG("Successfully loaded infernce model from "<<modelPath); + + + Ort::ModelMetadata metadata = m_model->GetModelMetadata(); + Ort::AllocatorWithDefaultOptions allocator; + Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator); + + if (feature_json_ptr) { + std::string feature_json = feature_json_ptr.get(); + nlohmann::json json_obj = nlohmann::json::parse(feature_json); + for (const auto& feature : json_obj) { + m_graphFeatures.addFeature(feature.get<std::string>(), msgStream()); + } + } + } catch (const std::exception& e) { + ATH_MSG_ERROR("Failed to retrieve feature from ONNX model: " << e.what()); + return StatusCode::FAILURE; + } + + if (!m_graphFeatures.isValid()) { + ATH_MSG_FATAL("No graph features have been parsed. Please check the model"); + return StatusCode::FAILURE; + } + ATH_CHECK(m_spacePointKey.initialize()); + return StatusCode::SUCCESS; + } + StatusCode GraphInferenceToolBase::buildGraph(const EventContext& ctx, + GraphRawData& graphData) const { + + /** Check whether the graph needs a rebuild */ + if ((*graphData.previousList) != m_graphFeatures) { + graphData.graph.reset(); + } + /** Don't launch the rebuild of the graph */ + if (graphData.graph) { + return StatusCode::SUCCESS; + } + if (!m_graphFeatures.isValid()) { + ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set"); + return StatusCode::FAILURE; + } + + graphData.graph = std::make_unique<InferenceGraph>(); + + SG::ReadHandle spacePoints{m_spacePointKey, ctx}; + ATH_CHECK(spacePoints.isPresent()); + + int64_t nNodes{0}, possConn{0}; + graphData.spacePointsInBucket.clear(); + graphData.spacePointsInBucket.reserve(spacePoints->size()); + + for (const MuonR4::SpacePointBucket* bucket : *spacePoints) { + nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size()); + possConn += MuonR4::sumUp(graphData.spacePointsInBucket.back()); + } + graphData.nodeIndex = 0; + graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures()); + graphData.currLeave = graphData.featureLeaves.begin(); + + + graphData.srcEdges.reserve(possConn); + graphData.desEdges.reserve(possConn); + + + /** Fill the graph edge features and all their respective connections */ + for (const MuonR4::SpacePointBucket* bucket : *spacePoints) { + const LayerSpBucket mlBucket{*bucket}; + m_graphFeatures.fillInData(mlBucket, graphData); + } + + graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()), + std::make_move_iterator(graphData.desEdges.end())); + + std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())}; // (N, 8) + std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)}; // (2, E) + + Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + + graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo, + graphData.featureLeaves.data(), graphData.featureLeaves.size(), + featShape.data(), featShape.size())); + + + + Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(), + graphData.srcEdges.size(), edgeShape.data(), edgeShape.size()); + + + graphData.previousList = &m_graphFeatures; + + graphData.srcEdges.clear(); + graphData.desEdges.clear(); + graphData.featureLeaves.clear(); + return StatusCode::SUCCESS; + } + +} \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h new file mode 100644 index 000000000000..f27ed78ee67e --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h @@ -0,0 +1,55 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCETOOLS_GRAPHINFERENCETOOL_H +#define MUONINFERENCETOOLS_GRAPHINFERENCETOOL_H + +#include "MuonInferenceInterfaces/IGraphInferenceTool.h" +#include "MuonInferenceInterfaces/NodeFeatureList.h" +#include "MuonInferenceInterfaces/GraphData.h" + + +#include "MuonSpacePoint/SpacePointContainer.h" + +#include "AthenaBaseComps/AthAlgTool.h" +#include "StoreGate/ReadHandleKey.h" +#include <onnxruntime_cxx_api.h> // is this somewhere else? +#include "nlohmann/json.hpp" + +namespace MuonML{ + /** @brief Baseline tool to handle the */ + class GraphInferenceToolBase : public extends<AthAlgTool, IGraphInferenceTool> { + public: + /** @brief Keep the constructor of the parent class */ + using base_class::base_class; + /** @brief Fill up the GraphRawData and construct the graph for the ML inference with + * ONNX. If the graph has been built by another inference tool and would be the + * same than this one the rebuild is skipped + * @param ctx: EventContext to access the space ponit container from StoreGate + * @param graphData: Rerference to the data object to be filled. */ + StatusCode buildGraph(const EventContext& ctx, + GraphRawData& graphData) const; + + + protected: + StatusCode setupModel(); + + const Ort::Session* model() const; + private: + + // input space points from SG + SG::ReadHandleKey<MuonR4::SpacePointContainer> m_spacePointKey{this, "SpacePointContainer", "MuonSpacePoints"}; + + /** @brief Location of the model file */ + Gaudi::Property<std::string> m_modelPath{this, "ModelPath", ""}; + + /** @brief List of features to be used for the inference */ + NodeFeatureList m_graphFeatures{}; + /** @brief Pointer to the ONNX model session */ + std::unique_ptr<Ort::Session> m_model{}; + + }; + +} + +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx new file mode 100644 index 000000000000..63b9b1520f79 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx @@ -0,0 +1,24 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#include "InferenceAlg.h" + +#include "MuonInferenceInterfaces/GraphData.h" + +namespace MuonML{ + StatusCode InferenceAlg::initialize() { + if (m_inferenceTools.empty()) { + ATH_MSG_ERROR("Provide at least one inference tool"); + return StatusCode::FAILURE; + } + ATH_CHECK(m_inferenceTools.retrieve()); + return StatusCode::SUCCESS; + } + StatusCode InferenceAlg::execute(const EventContext& ctx) const { + GraphRawData graphData{}; + for (const auto& infTool : m_inferenceTools) { + ATH_CHECK(infTool->runGraphInference(ctx, graphData)); + } + return StatusCode::SUCCESS; + } +} \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h new file mode 100644 index 000000000000..1826c20ce1c7 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h @@ -0,0 +1,25 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCETOOLS_INFERENCEALG_H +#define MUONINFERENCETOOLS_INFERENCEALG_H + +#include "AthenaBaseComps/AthReentrantAlgorithm.h" +#include "MuonInferenceInterfaces/IGraphInferenceTool.h" + +namespace MuonML{ + class InferenceAlg : public AthReentrantAlgorithm { + public: + using AthReentrantAlgorithm::AthReentrantAlgorithm; + + virtual StatusCode initialize() override final; + virtual StatusCode execute(const EventContext& ctx) const override final; + private: + ToolHandleArray<IGraphInferenceTool> m_inferenceTools{this, "InferenceTools", {}, + "List of machine learning inference tools to be processed with the same graph"}; + }; +} + + + +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx new file mode 100644 index 000000000000..b51eafe64be1 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx @@ -0,0 +1,8 @@ + +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ + +#include "../InferenceAlg.h" + +DECLARE_COMPONENT(MuonML::InferenceAlg) \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt new file mode 100644 index 000000000000..439004a9651a --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration + +################################################################################ +# Package: MuonInferenceInterfaces +################################################################################ + +# Declare the package name: +atlas_subdir( MuonInferenceInterfaces ) + +find_package( onnxruntime REQUIRED) + +atlas_add_library( MuonInferenceInterfaces + src/*.cxx + PUBLIC_HEADERS MuonInferenceInterfaces + LINK_LIBRARIES MuonReadoutGeometryR4 AthenaBaseComps MuonPatternHelpers + MuonSpacePoint MuonPatternEvent ${ONNXRUNTIME_LIBRARIES}) + diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h new file mode 100644 index 000000000000..585892f55ef8 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h @@ -0,0 +1,50 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCEINTERACES_GRAPHNODE_H +#define MUONINFERENCEINTERACES_GRAPHNODE_H + +#include <onnxruntime_cxx_api.h> +#include <vector> +#include <memory> + +namespace MuonML{ + class NodeFeatureList; + + /** @brief Helper struct containing all the information needed to process */ + struct InferenceGraph { + /** @brief Vector of the inference input tensors */ + std::vector<Ort::Value> dataTensor{}; + /** @brief Vector of the input names. The raw char data + * is neither owned or managed by the object */ + std::vector<const char*> nameTensor{}; + }; + + /** @brief Helper struct to ship the Graph from the space point buckets + * to ONNX */ + struct GraphRawData { + using FeatureVec_t = std::vector<float>; + using NodeConnectVec_t = std::vector<int64_t>; + using EdgeCounterVec_t = std::vector<int64_t>; + /** @brief Vector containing all features */ + FeatureVec_t featureLeaves{}; + /** @brief Vector encoding the source index of the */ + NodeConnectVec_t srcEdges{}; + /** @brief Vect */ + NodeConnectVec_t desEdges{}; + /** @brief Vector keeping track of how many space points are in each parsed bucket */ + EdgeCounterVec_t spacePointsInBucket{}; + /** @brief Pointer to the latest parsed NodeFeatureList */ + const NodeFeatureList* previousList{}; + /** @brief Pointer to the graph to be parsed to ONNX */ + std::unique_ptr<InferenceGraph> graph{}; + + /** @brief The following variables are needed to fill the consistently the raw data + * for the Graph Building*/ + std::vector<float>::iterator currLeave{featureLeaves.begin()}; + /** @brief Number of the already filled nodes */ + unsigned int nodeIndex{0}; + }; +} + +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h new file mode 100644 index 000000000000..fc5243780177 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h @@ -0,0 +1,22 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCEINTERFACES_IGRAPHINFERENCETOOL_H +#define MUONINFERENCEINTERFACES_IGRAPHINFERENCETOOL_H + +#include "GaudiKernel/IAlgTool.h" +#include "GaudiKernel/EventContext.h" +namespace MuonML { + struct GraphRawData; + class IGraphInferenceTool: virtual public IAlgTool { + public: + /** @brief Empty desctructor */ + virtual ~IGraphInferenceTool() = default; + /** @brief Declaration of the interface */ + DeclareInterfaceID(IGraphInferenceTool, 1, 0); + + virtual StatusCode runGraphInference(const EventContext& ctx, + GraphRawData& graph) const = 0; + }; +} +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h new file mode 100644 index 000000000000..bb083b0e3a5a --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h @@ -0,0 +1,35 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCEINTERACES_LAYERBUCKET_H +#define MUONINFERENCEINTERACES_LAYERBUCKET_H + +#include "MuonSpacePoint/SpacePointContainer.h" +namespace MuonML{ + /** @brief The LayerSpBucket is a space pointbucket where the points are internally + * sorted by their layer number as defined in the SpacePointLayerSorter. The + * bucket also provides the layer number tag for each space point & the number of total layers + */ + class LayerSpBucket : public std::vector<const MuonR4::SpacePoint*> { + public: + /** @brief Standard constructor taking the space point bucket */ + LayerSpBucket(const MuonR4::SpacePointBucket& bucket); + /** @brief Returns how many Mdt layers are inside the bucket */ + uint8_t nMdtLayers() const { + return m_nMdtLay; + } + /** @brief Returns how many Strip layers are inside the bucket */ + uint8_t nStripLayers() const { + return m_nStripLay; + } + /** @brief Returns the associated layer number of the i-the space point inside the bucket */ + uint8_t layerNum(const size_t i) const { + return m_layNum[i]; + } + private: + uint8_t m_nMdtLay{0}; + uint8_t m_nStripLay{0}; + std::vector<uint8_t> m_layNum{}; + }; +} +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h new file mode 100644 index 000000000000..2861aa4908ec --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h @@ -0,0 +1,50 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCEINTERACES_NODECONNECTOR_H +#define MUONINFERENCEINTERACES_NODECONNECTOR_H + +#include <string> +#include <functional> + +#include <MuonInferenceInterfaces/NodeFeature.h> +namespace MuonML{ + /** @brief The NodeConnector is indicating whether two space points inside a bucket, + * the graph nodes, shall have a connection in their graph neural net representation. + * In short terms, the node connector is std::function with the bucket & the two indices + * to the space point inside as input arguments and then returning a boolean decision + * whether the connection shall be built. + * In order, to verify that two instances of the neural net graph are identical, + * the node connector also has a name as attribute. */ + class NodeConnector { + public: + using Bucket_t = NodeFeature::Bucket_t; + /** @brief Function type to connect two space points in a bucket. The signature + * takes the reference to the bucket and then the two indices of the space + * points which shall be connected. The function needs to return true or false */ + using Evaluator_t = std::function<bool(const Bucket_t&, size_t, size_t)>; + /** @brief Standard constructor taking the name of the node connector & + * a connector function definition + * @brief cName: Name of the connector function + * @brief conFunc: Definition of the connector function */ + NodeConnector(const std::string& cName, const Evaluator_t conFunc): + m_name{cName}, m_func{conFunc} {} + /** @brief Returns the name of the node connector */ + const std::string& name() const { + return m_name; + } + /** @brief returns the decision of the connector function + * @param bucket: Reference to the space point bucket + * @param i: Index of the first space point to connect + * @param j: Index of the second space point to connect */ + bool connect(const Bucket_t& bucket, size_t i, size_t j) const { + return m_func(bucket, i, j); + } + + private: + std::string m_name{}; + Evaluator_t m_func{[](const Bucket_t, size_t, size_t) { return false; }}; + + }; +} +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h new file mode 100644 index 000000000000..9356dc94a4dd --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h @@ -0,0 +1,50 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCEINTERACES_GRAPHFEATURE_H +#define MUONINFERENCEINTERACES_GRAPHFEATURE_H + + +#include <functional> +#include <string> + +#include "MuonInferenceInterfaces/LayerBucket.h" + +namespace MuonML { + /** @brief The NodeFeature is the gluing instance to extract the information from the space point + * inside a MuonBucket and then to parse it to the ML inference framework. */ + class NodeFeature { + public: + /** @brief Abreviation of the Space point bucket type */ + using Bucket_t = LayerSpBucket; + + /** @brief Lambda function type to extract the feature from a bucket. + * @arg: Bucket_t of interest + * @arg: size_t Index of the space point to extract the bucket from */ + using Func_t = std::function<double(const Bucket_t&, size_t)>; + + /** @brief Standard constructor to build a feature + * @param featName: Name of the feature used to distinguish whether + * two feature lists are the same */ + NodeFeature(const std::string& featName, + const Func_t& extractFunc): + m_name{featName}, m_func{extractFunc}{} + /** @brief Standard move assignment & move constructor */ + NodeFeature(NodeFeature&& other) = default; + + /** @brief Returns the feature name */ + const std::string& name() const { + return m_name; + } + /** @brief Extract the feature from a space point inside the bucket + * @param bucket: Reference to the space point bucket of interest + * @param spIndex: Index of the space point to extract the feature from */ + double eval(const Bucket_t& bucket, size_t spIndex) const { + return m_func(bucket, spIndex); + } + private: + std::string m_name{}; + Func_t m_func{[](const Bucket_t& , size_t) { return 0.; }}; + }; +} +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h new file mode 100644 index 000000000000..42980dc4bc51 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h @@ -0,0 +1,26 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCEINTERACES_GRAPHFEATUREFACTORY_H +#define MUONINFERENCEINTERACES_GRAPHFEATUREFACTORY_H + +#include "MuonInferenceInterfaces/NodeFeatureList.h" + +class MsgStream; + +namespace MuonML { + namespace Factory { + /** @brief Factory function that builds a NodeFeature from a predefined list of features + * @param featName: Name of the feature inside the list + * @param log: Refetrence to the message object for logging */ + NodeFeatureList::Feature_t makeFeature(const std::string& featName, MsgStream& log); + /** @brief Factory function that builds a connector relation between two edges in the bucket. + * @param connName: Name of the connection function inside the predefined list + * @param log: Refetrence to the message object for logging */ + NodeFeatureList::Connector_t makeConnector(const std::string& connName, MsgStream& log); + + + } +} + +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h new file mode 100644 index 000000000000..96b4c08cb7b6 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h @@ -0,0 +1,63 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#ifndef MUONINFERENCEINTERACES_GRAPHFEATURELIST_H +#define MUONINFERENCEINTERACES_GRAPHFEATURELIST_H + +#include "MuonInferenceInterfaces/NodeFeature.h" +#include "MuonInferenceInterfaces/NodeConnector.h" + +class MsgStream; + +namespace MuonML { + struct GraphRawData; + class NodeFeatureList { + public: + + using Feature_t = std::shared_ptr<const NodeFeature>; + using Connector_t = std::shared_ptr<const NodeConnector>; + + using Bucket_t = NodeFeature::Bucket_t; + /** @brief Empty standard constructor */ + NodeFeatureList() = default; + /** @brief Returns true if the features have pairwise + * the same name */ + bool operator==(const NodeFeatureList& other) const; + /** @brief Returns whether the NodeFeatureList is complete, + * i.e. it must have at least one feature and the node + * connector */ + bool isValid() const; + /** @brief Returns the number of features in the list */ + size_t numFeatures() const; + /** @brief Returns the name of the features in the list */ + std::vector<std::string> featureNames() const; + /** @brief */ + void fillInData(const Bucket_t& bucket, + GraphRawData& graphData) const; + /** @brief Tries to add a new feature to the list using the predefined + * list of features in the GraphFeatureFactory + * @param featName: Name of the feature in the factory + * @param msg: Reference to the message stream object for logging. */ + bool addFeature(const std::string& featName, MsgStream& msg); + /** @brief Tries to add a particular feature to the list. + * @param featPtr: Pointer to the instantiated feature + * @param msg: Reference to the message stream object for logging. */ + bool addFeature(const Feature_t& featPtr, MsgStream& msg); + + /** @brief Tries to set the graph connector based on the connector name. + * @param conName: Name of the connector to extract from the factory + * @param msg: Reference to the message stream object for logging. */ + bool setConnector(const std::string& conName, MsgStream& msg); + /** @brief Sets the */ + void setConnector(const std::string& conName, NodeConnector::Evaluator_t evalFunc); + + + private: + std::vector<Feature_t> m_features{}; + Connector_t m_connector{}; + + }; +} + + +#endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx new file mode 100644 index 000000000000..f41b9a82084f --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx @@ -0,0 +1,32 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#include "MuonInferenceInterfaces/LayerBucket.h" + +#include "MuonSpacePoint/SpacePointPerLayerSorter.h" +#include "Acts/Utilities/Enumerate.hpp" +namespace MuonML{ + using HitVec = MuonR4::SpacePointPerLayerSorter::HitVec; + LayerSpBucket::LayerSpBucket(const MuonR4::SpacePointBucket& bucket) { + reserve(bucket.size()); + m_layNum.resize(bucket.size()); + const MuonR4::SpacePointPerLayerSorter sorter{bucket}; + uint8_t globLayer{0}; + m_nMdtLay = sorter.mdtHits().size(); + m_nStripLay = sorter.stripHits().size(); + for (const HitVec& spInLay : sorter.mdtHits()) { + for (const auto& [idx, spacePoint] : Acts::enumerate(spInLay)) { + push_back(spacePoint); + m_layNum[idx] = globLayer; + } + ++globLayer; + } + for (const HitVec& spInLay : sorter.stripHits()) { + for (const auto& [idx, spacePoint] : Acts::enumerate(spInLay)) { + push_back(spacePoint); + m_layNum[idx] = globLayer; + } + ++globLayer; + } + } +} \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx new file mode 100644 index 000000000000..daf78ffe3330 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx @@ -0,0 +1,91 @@ +/* +Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#include "MuonInferenceInterfaces/NodeFeatureFactory.h" +#include "AthenaBaseComps/AthMessaging.h" + +#include <set> + +namespace MuonML { + + using Feature_t = NodeFeatureList::Feature_t; + using Connector_t = NodeFeatureList::Connector_t; + using Bucket_t = NodeFeature::Bucket_t; + + bool operator<(const std::string& a, const Feature_t & b) { + return a < b->name(); + } + bool operator<( const Feature_t & a, const std::string& b) { + return a->name() < b; + } + bool operator<(const Feature_t& a, const Feature_t & b) { + return a->name() < b->name(); + } + + bool operator<(const std::string& a, const Connector_t & b) { + return a < b->name(); + } + bool operator<( const Connector_t & a, const std::string& b) { + return a->name() < b; + } + bool operator<(const Connector_t& a, const Connector_t & b) { + return a->name() < b->name(); + } + + namespace Factory { + Feature_t makeFeature(const std::string& featName, MsgStream& log) { + + /** Predefine the known features in the pool */ + static const std::set<Feature_t, std::less<>> featurePool{ + std::make_unique<NodeFeature>("driftR", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->driftRadius(); + }), + std::make_unique<NodeFeature>("localX", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->positionInChamber().x(); + }), + }; + const auto feat_itr = featurePool.find(featName); + if(feat_itr != featurePool.end()){ + if (log.level() <= MSG::DEBUG) { + log<<MSG::DEBUG<<"Found graph feature "<<featName<<"."<<endmsg; + } + return *feat_itr; + } + std::stringstream available{}; + for (const Feature_t& known : featurePool) { + available<<known->name()<<", "; + } + log<<MSG::ERROR<<"The feature "<<featName<<" is unknown to the feature factory. " + <<" Please check for typos w.r.t "<<available.str()<<". Otherwise augment " + <<__FILE__<<" with your desired feature "<<endmsg; + return nullptr; + } + NodeFeatureList::Connector_t makeConnector(const std::string& connName, MsgStream& log) { + + static const std::set<Connector_t, std::less<>> connectorPool{ + std::make_unique<NodeConnector>("fullyConnected", + [](const Bucket_t& , size_t , size_t ) { + return true; + }), + }; + const auto feat_itr = connectorPool.find(connName); + if(feat_itr != connectorPool.end()){ + if (log.level() <= MSG::DEBUG) { + log<<MSG::DEBUG<<"Found graph connector "<<connName<<"."<<endmsg; + } + return *feat_itr; + } + std::stringstream available{}; + for (const Connector_t& known : connectorPool) { + available<<known->name()<<", "; + } + log<<MSG::ERROR<<"The graph connector "<<connName<<" is unknown to the factory. " + <<" Please check for typos w.r.t "<<available.str()<<". Otherwise augment " + <<__FILE__<<" with your desired connection function. "<<endmsg; + return nullptr; + } + } + +} \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx new file mode 100644 index 000000000000..91bbb536d321 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx @@ -0,0 +1,94 @@ +/* + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration +*/ +#include "MuonInferenceInterfaces/NodeFeatureList.h" +#include "MuonInferenceInterfaces/NodeFeatureFactory.h" +#include "MuonInferenceInterfaces/GraphData.h" + + +#include "AthenaBaseComps/AthMessaging.h" +#include "Acts/Utilities/Enumerate.hpp" +#include "MuonPatternHelpers/MatrixUtils.h" + +using namespace MuonR4; +namespace MuonML { + bool NodeFeatureList::isValid() const { + return m_connector && numFeatures(); + } + bool NodeFeatureList::setConnector(const std::string& conName, MsgStream& msg) { + m_connector = Factory::makeConnector(conName , msg); + return m_connector != nullptr; + } + + void NodeFeatureList::setConnector(const std::string& conName, NodeConnector::Evaluator_t evalFunc) { + m_connector = std::make_unique<NodeConnector>(conName, evalFunc); + } + bool NodeFeatureList::operator==(const NodeFeatureList& other) const { + if (numFeatures() != other.numFeatures()) { + return false; + } + if (!m_connector || !other.m_connector || m_connector->name() != other.m_connector->name()) { + return false; + } + for (size_t f =0 ; f < numFeatures(); ++f) { + if (m_features[f]->name() != other.m_features[f]->name()) { + return false; + } + } + return true; + } + size_t NodeFeatureList::numFeatures() const { + return m_features.size(); + } + /** @brief Returns the name of the features in the list */ + std::vector<std::string> NodeFeatureList::featureNames() const { + std::vector<std::string> names{}; + std::ranges::transform(m_features,std::back_inserter(names), + [](const Feature_t& ft){ return ft->name();}); + return names; + } + bool NodeFeatureList::addFeature(const std::string& featName, MsgStream& msg) { + return addFeature(Factory::makeFeature(featName, msg), msg); + } + + bool NodeFeatureList::addFeature(const Feature_t& newFeat, MsgStream& msg) { + if (!newFeat) { + msg<<MSG::ERROR<<"No feature has been parsed. "<<endmsg; + return false; + } + if (std::ranges::find_if(m_features, [&newFeat](const Feature_t& known){ + return known == newFeat || known->name() == newFeat->name(); + }) != m_features.end()) { + msg<<MSG::ERROR<<" The feature "<<newFeat->name()<<" has already been added & " + <<" cannot be added again. "<<endmsg; + return false; + } + if (msg.level() <= MSG::DEBUG) { + msg<<MSG::DEBUG<<__FILE__<<":"<<__LINE__<<" - Add new feature "<< newFeat->name()<<endmsg; + } + return true; + } + void NodeFeatureList::fillInData(const Bucket_t& bucket, + GraphRawData& prepGraph) const { + for (size_t sp = 0 ; sp < bucket.size(); ++sp) { + /** @brief Fill the graph features */ + for(const Feature_t& feat : m_features) { + assert(prepGraph.currLeave != prepGraph.featureLeaves.end()); + (*prepGraph.currLeave++) = feat->eval(bucket, sp); + } + for (size_t ot = 0; ot < sp; ++ot) { + if (m_connector->connect(bucket, sp, ot)) { + /// Connection i->j + prepGraph.srcEdges.emplace_back(ot); + prepGraph.desEdges.emplace_back(sp); + ///Connection j-> i + prepGraph.srcEdges.emplace_back(sp); + prepGraph.desEdges.emplace_back(ot); + + } + } + } + prepGraph.nodeIndex+=bucket.size(); + } + +} \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt new file mode 100644 index 000000000000..df03d50d2f7d --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration + +# Declare the package name: +atlas_subdir( MuonSPId ) +find_package( onnxruntime ) + + +# Component(s) in the package: +atlas_add_component( MuonSPId + src/*.cxx + src/components/*.cxx + LINK_LIBRARIES AthenaBaseComps GeoPrimitives MuonIdHelpersLib + MuonSimEvent xAODMuonSimHit StoreGateLib MdtCalibSvcLib + xAODMuonPrepData MuonSpacePoint xAODMuon MuonTesterTreeLib + MuonPatternEvent MuonPatternHelpers MuonPatternHelpers ${ONNXRUNTIME_LIBRARIES} + ) + +atlas_install_python_modules( python/*.py) + +atlas_add_test( testMuonSPId + SCRIPT python -m MuonSPId.muonSPIdDump --nEvents 5 --noSTGC --noMM + PROPERTIES TIMEOUT 600 + PRIVATE_WORKING_DIRECTORY + POST_EXEC_SCRIPT noerror.sh) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py new file mode 100644 index 000000000000..953d77712313 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py @@ -0,0 +1,17 @@ +#Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration + +from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator +from AthenaConfiguration.ComponentFactory import CompFactory + +def MuonSPIdDumpCfg(flags, name="MuonSPIdMaker", **kwargs): + result = ComponentAccumulator() + from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg + result.merge(MuonSpacePointFormationCfg(flags)) + kwargs.setdefault("isMC", flags.Input.isMC) + #kwargs.setdefault("spIdValue", flags.spIdValue) + #from RngComps.RngCompsConfig import AthRNGSvcCfg + #kwargs.setdefault("RndmSvc", result.getPrimaryAndMerge(AthRNGSvcCfg(flags))) + the_alg = CompFactory.MuonR4.SPIdDumperAlg(name=name, **kwargs) + result.addEventAlgo(the_alg, primary = True) + return result + diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py new file mode 100644 index 000000000000..a07d728e8396 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py @@ -0,0 +1,43 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +if __name__=="__main__": + from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg, SetupArgParser, executeTest, setupHistSvcCfg + parser = SetupArgParser() + parser.set_defaults(nEvents = -1) + parser.set_defaults(outRootFile="MuonSPId_R3SimHits.root") + parser.set_defaults(inputFile=[ + "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/MuonGeomRTT/R3SimHits.pool.root" + ]) + parser.set_defaults(eventPrintoutLevel = 500) + args = parser.parse_args() + + from AthenaConfiguration.AllConfigFlags import initConfigFlags + flags = initConfigFlags() + flags.PerfMon.doFullMonMT = True + + flags, cfg = setupGeoR4TestCfg(args) + + cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile, outStream="MuonSPId")) + + from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg + cfg.merge(xAODUncalibMeasPrepCfg(flags)) + + from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg + cfg.merge(MuonSpacePointFormationCfg(flags)) + + from MuonPatternRecognitionAlgs.MuonHoughTransformAlgConfig import MuonPatternRecognitionCfg, MuonSegmentFittingAlgCfg + cfg.merge(MuonPatternRecognitionCfg(flags)) + + cfg.merge(MuonSegmentFittingAlgCfg(flags)) + + ## add the SPId flags here + #newparser = SetupArgParser() + #newparser.set_defaults(spIdValue=0) + #newparser.set_defaults(isMC=True) + #newargs = newparser.parse_args() + #flags.spIdValue = newargs.spIdValue + + from MuonSPId.MuonSPIdDumpConfig import MuonSPIdDumpCfg + cfg.merge(MuonSPIdDumpCfg(flags)) + + executeTest(cfg) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py new file mode 100644 index 000000000000..7e7f1be057f4 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py @@ -0,0 +1,38 @@ +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration + +if __name__=="__main__": + from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg, SetupArgParser, executeTest, setupHistSvcCfg + parser = SetupArgParser() + parser.set_defaults(nEvents = -1) + parser.set_defaults(outRootFile="MuonSPId_2022_13p6TeV_00431493.root") + parser.set_defaults(condTag="CONDBR2-BLKPA-2023-03") + parser.set_defaults(inputFile=[ + "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/Tier0ChainTests/TCT_Run3/data22_13p6TeV.00431493.physics_Main.daq.RAW._lb0525._SFO-16._0001.data" + ]) + parser.set_defaults(eventPrintoutLevel = 500) + args = parser.parse_args() + + from AthenaConfiguration.AllConfigFlags import initConfigFlags + flags = initConfigFlags() + flags.PerfMon.doFullMonMT = True + + flags, cfg = setupGeoR4TestCfg(args) + + cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile, + outStream="MuonSPId")) + + from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg + cfg.merge(xAODUncalibMeasPrepCfg(flags)) + + from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg + cfg.merge(MuonSpacePointFormationCfg(flags)) + + from MuonPatternRecognitionAlgs.MuonHoughTransformAlgConfig import MuonPatternRecognitionCfg, MuonSegmentFittingAlgCfg + cfg.merge(MuonPatternRecognitionCfg(flags)) + cfg.merge(MuonSegmentFittingAlgCfg(flags)) + + from MuonSPId.MuonSPIdDumpConfig import MuonSPIdDumpCfg + cfg.merge(MuonSPIdDumpCfg(flags)) + + executeTest(cfg) + diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx new file mode 100644 index 000000000000..379257fe8b5b --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx @@ -0,0 +1,399 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "SPIdDumperAlg.h" + +#include "Identifier/Identifier.h" +#include "MuonIdHelpers/IMuonIdHelperSvc.h" +#include "StoreGate/ReadHandle.h" +#include "MuonTesterTree/EventHashBranch.h" +#include "MuonSpacePoint/SpacePointPerLayerSorter.h" +#include "xAODMeasurementBase/MeasurementDefs.h" +#include "xAODMuonPrepData/UtilFunctions.h" +#include "xAODMuonPrepData/MdtDriftCircle.h" +#include <fstream> +#include <TString.h> +#include <AthenaKernel/RNGWrapper.h> +#include "CLHEP/Random/RandFlat.h" + +namespace { + union bucketId{ + int8_t fields[4]; + int hash; + }; + +} + +namespace MuonR4{ + StatusCode SPIdDumperAlg::initialize() { + ATH_CHECK(m_readKey.initialize()); + ATH_CHECK(m_idHelperSvc.retrieve()); + ATH_CHECK(m_inSegmentKey.initialize(!m_inSegmentKey.empty())); + m_tree.addBranch(std::make_shared<MuonVal::EventHashBranch>(m_tree.tree())); + ATH_CHECK(m_tree.init(this)); + ATH_CHECK(m_idHelperSvc.retrieve()); + + + std::string onnx_path = m_modelPath; + ATH_MSG_INFO("Model Path: " << onnx_path); + ATH_MSG_INFO("SPId Cut Value: " << m_SPId_cut); + + std::vector<std::string> feature_names; + try { + + m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ONNXFeatureExtraction"); + m_session_options = std::make_unique<Ort::SessionOptions>(); + m_session_options->SetIntraOpNumThreads(1); + m_session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + m_session = std::make_unique<Ort::Session>(*m_env, onnx_path.c_str(), *m_session_options); + + Ort::ModelMetadata metadata = m_session->GetModelMetadata(); + + Ort::AllocatorWithDefaultOptions allocator; + Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator); + + if (feature_json_ptr) { + std::string feature_json = feature_json_ptr.get(); + nlohmann::json json_obj = nlohmann::json::parse(feature_json); + + for (const auto& feature : json_obj) { + feature_names.push_back(feature.get<std::string>()); + } + } + } catch (const std::exception& e) { + ATH_MSG_ERROR("Failed to retrieve feature names from ONNX model: " << e.what()); + return StatusCode::FAILURE; + } + + if (feature_names.empty()) { + ATH_MSG_ERROR("No feature metadata found in the ONNX model!"); + return StatusCode::FAILURE; + } + + ATH_MSG_ALWAYS("Extracted Feature Names:"); + for (const auto& name : feature_names) { + ATH_MSG_ALWAYS("- " << name); + } + + ATH_MSG_DEBUG("Successfully initialized"); + return StatusCode::SUCCESS; + } + + StatusCode SPIdDumperAlg::finalize() { + ATH_CHECK(m_tree.write()); + return StatusCode::SUCCESS; + } + + StatusCode SPIdDumperAlg::execute(){ + const EventContext& ctx{Gaudi::Hive::currentContext()}; + + std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap; // MuonR4Segment + + SG::ReadHandle readSegment(m_inSegmentKey, ctx); + ATH_CHECK(readSegment.isPresent()); + for (const MuonR4::Segment* segment : *readSegment) { + segmentMap[segment->parent()->parentBucket()].push_back(segment); + } + + SG::ReadHandle<SpacePointContainer> readHandle{m_readKey, ctx}; + ATH_CHECK(readHandle.isPresent()); + + bool Sparse = false; + bool Labels = true; + + std::vector<float> features; + std::vector<int64_t> edge_src; + std::vector<int64_t> edge_dst; + std::vector<int64_t> node_offsets; + std::vector<int64_t> labels; + size_t total_nodes = 0; + + size_t bucket_index = 0; + + for(const SpacePointBucket* bucket : *readHandle) { + + unsigned int segIdx{0}; + std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment; + + if (Labels) { + auto match_itr = segmentMap.find(bucket); + if (match_itr != segmentMap.end()) { + for (const MuonR4::Segment* segment : match_itr->second) { + for (const auto& meas : segment->measurements()) { + spacePointToSegment[meas->spacePoint()].push_back(segIdx); + } + ++segIdx; + } + } + } + + SpacePointPerLayerSorter sorter{*bucket}; + + size_t num_points = bucket->size(); + + unsigned int layer{0}; + Identifier prevLayer = Identifier(); + + std::vector<u_int16_t> Neighbors(num_points, 0); + + size_t bucket_offset = total_nodes; + node_offsets.push_back(bucket_offset); + total_nodes += num_points; + + float min_x = 3000; + float max_x = -1; + + Identifier layId; + for (u_int16_t i = 0; i < num_points; i++) { + const auto sp = bucket->at(i); + const Identifier id = sp->identify(); + + if (sp->type() == xAOD::UncalibMeasType::MdtDriftCircleType) { + const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement()); + if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime){ + continue; + } + const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()}; + layId = idHelper.channelID(idHelper.stationName(id), 1, idHelper.stationPhi(id), idHelper.multilayer(id), idHelper.tubeLayer(id), 1); + if (layId != prevLayer) { + layer++; + prevLayer = layId; + } + } else { + layId = m_idHelperSvc->gasGapId(id); + + if (layId != prevLayer) { + layer++; + prevLayer = layId; + } + } + + if (Labels) { + const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()]; + if (segIdxs.size() > 0) { + m_spoint_label.push_back(1); + } else { + m_spoint_label.push_back(0); + } + } + + float x = sp->positionInChamber().x(); + float y = sp->positionInChamber().y(); + float z = sp->positionInChamber().z(); + + + + m_spoint_bucket.push_back(bucket_index); + m_spoint_x.push_back(x); + m_spoint_y.push_back(y); + m_spoint_z.push_back(z); + m_spoint_station.push_back(m_idHelperSvc->stationName(id)); + + m_spoint_driftR.push_back(sp->driftRadius()); + m_spoint_layer.push_back(layer); + + if (x < min_x) min_x = x; + if (x > max_x) max_x = x; + + + if (Sparse) { + for (u_int16_t j = 0; j < num_points; j++) { + if (i == j) continue; + const auto sp2 = bucket->at(j); + Identifier layId2; + const Identifier id2 = sp2->identify(); + if (sp2->type() == xAOD::UncalibMeasType::MdtDriftCircleType) { + const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()}; + layId2 = idHelper.channelID(idHelper.stationName(id2), 1, idHelper.stationPhi(id2), idHelper.multilayer(id2), idHelper.tubeLayer(id2), 1); + } else { + layId2 = m_idHelperSvc->gasGapId(id2); + } + float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2)); + if (h < 500) { + Neighbors[i]++; + } + if ( h < 2000 and layId != layId2) { + edge_src.push_back(i + bucket_offset); + edge_dst.push_back(j + bucket_offset); + m_spoint_edges.push_back(j + bucket_offset); + } + } + } else { + for (u_int16_t j = 0; j < num_points; j++) { + if (i == j) continue; + edge_src.push_back(i + bucket_offset); + edge_dst.push_back(j + bucket_offset); + m_spoint_edges.push_back(j + bucket_offset); + if (j < i ) continue; + const auto sp2 = bucket->at(j); + float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2)); + if (h < 500) { + Neighbors[i]++; + Neighbors[j]++; + } + } + } + + m_spoint_neighbors.push_back(Neighbors[i]); + } // end loop over space points + + m_bucket_layers.push_back(layer); + float density = 0; + float bucket_size = bucket->coveredMax() - bucket->coveredMin(); + if (bucket_size == 0 ){ + bucket_size = max_x - min_x; + if (bucket_size == 0) bucket_size = 1; + } + density = num_points / bucket_size; + m_bucket_density.push_back(density); + bucket_index++; + + } // end loop over buckets + + features.reserve(total_nodes * 8); + for (size_t i = 0; i < total_nodes; i++) { + features.push_back(m_spoint_x[i]); + features.push_back(m_spoint_y[i]); + features.push_back(m_spoint_z[i]); + features.push_back(m_spoint_station[i]); + features.push_back(m_spoint_driftR[i]); + features.push_back(float(m_spoint_layer[i]) / (m_bucket_layers[m_spoint_bucket[i]] > 0 ? m_bucket_layers[m_spoint_bucket[i]] : 1)); + features.push_back(m_spoint_neighbors[i]); + features.push_back(m_bucket_density[m_spoint_bucket[i]]); + if (Labels) { + labels.push_back(m_spoint_label[i]); + } + } + + if (features.empty()) { + ATH_MSG_WARNING("No valid feature data available for inference. Skipping event."); + return StatusCode::SUCCESS; + } + + std::vector<int64_t> edge_index; + edge_index.reserve(2 * edge_src.size()); + edge_index.insert(edge_index.end(), edge_src.begin(), edge_src.end()); + edge_index.insert(edge_index.end(), edge_dst.begin(), edge_dst.end()); + + ATH_MSG_DEBUG("Total nodes: " << total_nodes); + ATH_MSG_DEBUG("Features size: " << features.size()); + ATH_MSG_DEBUG("Expected feature shape: (" << total_nodes << ", 8)"); + + if (features.size() != total_nodes * 8) { + ATH_MSG_ERROR("Feature size mismatch! Expected " << total_nodes * 8 << " but got " << features.size()); + } + + ATH_MSG_DEBUG("Edge index size: " << edge_index.size()); + ATH_MSG_DEBUG("Expected edge_index shape: (2, " << edge_index.size() / 2 << ")"); + + if (edge_index.size() % 2 != 0) { + ATH_MSG_ERROR("Edge index format error! Size should be divisible by 2."); + } + + + //Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"); + //Ort::SessionOptions session_options; + //session_options.SetIntraOpNumThreads(1); + //session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + //Ort::Session session(env, "torch_GatFourier_fcg_quantized.onnx", session_options); + + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + std::vector<int64_t> feature_shape = {static_cast<int64_t>(total_nodes), 8}; // (N, 8) + std::vector<int64_t> edge_shape = {2, static_cast<int64_t>(edge_index.size() / 2)}; // (2, E) + + Ort::Value feature_tensor = Ort::Value::CreateTensor<float>( + memory_info, features.data(), features.size(), feature_shape.data(), feature_shape.size()); + + Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>( + memory_info, edge_index.data(), edge_index.size(), edge_shape.data(), edge_shape.size()); + + std::vector<const char*> input_names = {"features", "edge_index"}; + std::vector<const char*> output_names = {"output"}; + + std::vector<Ort::Value> input_tensors; + input_tensors.push_back(std::move(feature_tensor)); + input_tensors.push_back(std::move(edge_tensor)); + + //std::vector<Ort::Value> output_tensors = session.Run( + std::vector<Ort::Value> output_tensors = m_session->Run( + Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), + input_tensors.size(), output_names.data(), output_names.size()); + + float* output_data = output_tensors[0].GetTensorMutableData<float>(); + size_t output_size = output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount(); + std::vector<float> predictions(output_data, output_data + output_size); + + for (size_t i = 0; i < output_size; i++) { + if (!std::isfinite(predictions[i])) { + ATH_MSG_ERROR("Non-finite prediction detected! Setting to zero."); + predictions[i] = 0.0f; + } + } + + ATH_MSG_DEBUG("ONNX output size: " << output_size); + ATH_MSG_DEBUG("First 5 predictions:"); + for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) { + ATH_MSG_DEBUG("Prediction[" << i << "]: " << predictions[i]); + } + + std::vector<int> binary_predictions(output_size); + for (size_t i = 0; i < output_size; i++) { + binary_predictions[i] = (predictions[i] > m_SPId_cut) ? 1 : 0; // Why inference cut is different? -> this batch/graph building may be different from torch one? + } + ATH_MSG_DEBUG("First 5 Binary Predictions:"); + for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) { + ATH_MSG_DEBUG("Binary_Prediction[" << i << "]: " << binary_predictions[i]); + } + + if (Labels) { + float loss = 0; + float accuracy = 0; + float pred_1 = 0; + float lab_1 = 0; + float true_0 = 0; + //float true_1 = 0; + for (size_t i = 0; i < output_size; i++) { + if (binary_predictions[i] != labels[i] && labels[i] == 1) { + loss++; + } + if (binary_predictions[i] == labels[i]) { + accuracy++; + } + //if (binary_predictions[i] == labels[i] && labels[i] == 1) { + // true_1++; + //} + if (binary_predictions[i] == labels[i] && labels[i] == 0) { + true_0++; + } + if (binary_predictions[i] == 1) { + pred_1++; + } + if (labels[i] == 1) { + lab_1++; + } + } + //ATH_MSG_ALWAYS("Hits: " << output_size ); + ATH_MSG_ALWAYS("Accuracy: " << accuracy / output_size << " ( Correct: " << accuracy << " / Hits: " << output_size << ")" ); + ATH_MSG_ALWAYS("Signal Loss: " << loss / lab_1 << " ( Miss predicted: " << loss << " / Good Hits: " << lab_1 << ")" ); + float lab_0 = output_size - lab_1; + if (lab_0 != 0) { + ATH_MSG_ALWAYS("Real Rejection: " << true_0 / (output_size - lab_1) << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" ); + } else ATH_MSG_ALWAYS("Real Rejection: " << 0 << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" ); + ATH_MSG_ALWAYS("Purity (before): " << lab_1 / output_size << " ( Good Hits: " << lab_1 << " / Hits: " << output_size << ")" ); + ATH_MSG_ALWAYS("Purity (after): " << (lab_1 - loss) / pred_1 << " ( Good predicted: " << lab_1 - loss << " / Predicted: " << pred_1 << ")" ); + ATH_MSG_ALWAYS("-----------------------------"); + } + + for (size_t i = 0; i < output_size; i++) { + m_spoint_predictions.push_back(binary_predictions[i]); + } + + if (!m_tree.fill(ctx)) return StatusCode::FAILURE; + return StatusCode::SUCCESS; + + } + +} diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h new file mode 100644 index 000000000000..6a2f9b1c7d90 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h @@ -0,0 +1,95 @@ +#ifndef MUON_SP_ID_DUMPER_H +#define MUON_SP_ID_DUMPER_H +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include <AthenaBaseComps/AthAlgorithm.h> +#include "AthenaBaseComps/AthHistogramAlgorithm.h" + +#include <MuonIdHelpers/IMuonIdHelperSvc.h> +#include "StoreGate/ReadHandleKey.h" +#include "StoreGate/ReadDecorHandle.h" +#include "StoreGate/ReadHandleKeyArray.h" +#include "StoreGate/ReadCondHandleKey.h" + +#include <MuonPatternEvent/MuonPatternContainer.h> +#include <MuonSpacePoint/SpacePointContainer.h> +#include <ActsGeometryInterfaces/ActsGeometryContext.h> +#include <MuonReadoutGeometryR4/MuonDetectorManager.h> + +#include "MuonTesterTree/MuonTesterTree.h" +#include "MuonTesterTree/ThreeVectorBranch.h" +#include "MuonTesterTree/IdentifierBranch.h" + +#include "xAODMuonSimHit/MuonSimHitContainer.h" +#include "xAODMuon/MuonSegmentContainer.h" + +#include "AthenaKernel/IAthRNGSvc.h" +#include "CLHEP/Random/RandomEngine.h" + +#include "nlohmann/json.hpp" + + +// onnx runtime +#include <onnxruntime_cxx_api.h> + + +namespace MuonR4{ +class SPIdDumperAlg: public AthHistogramAlgorithm { + + public: + using AthHistogramAlgorithm::AthHistogramAlgorithm; + ~SPIdDumperAlg() = default; + + virtual StatusCode initialize() override final; + virtual StatusCode finalize() override final; + virtual StatusCode execute() override final; + + private: + + void fillChamberInfo(const MuonGMR4::Chamber* chamber); + + SG::ReadHandleKey<SpacePointContainer> m_readKey{this, "ReadKey", "MuonSpacePoints", "Key to the space point container"}; + ServiceHandle<Muon::IMuonIdHelperSvc> m_idHelperSvc{this, "MuonIdHelperSvc", "Muon::MuonIdHelperSvc/MuonIdHelperSvc"}; + + SG::ReadHandleKey<MuonR4::SegmentContainer> m_inSegmentKey{this, "SegmentKey", "R4MuonSegments"}; + + Gaudi::Property<bool> m_isMC{this, "isMC", true}; + Gaudi::Property<double> m_SPId_cut{this,"spIdValue", 0.0001}; + Gaudi::Property<std::string> m_modelPath{this, "ModelPath", "torch_GatFourier_fcg_quantized.onnx"}; + + MuonVal::MuonTesterTree m_tree{"MuonSPId","MuonSPId"}; + + MuonVal::VectorBranch<float>& m_bucket_density{m_tree.newVector<float>("bucket_density", 0)}; + MuonVal::VectorBranch<uint8_t>& m_bucket_layers{m_tree.newVector<uint8_t>("bucket_layers", 0)}; + + MuonVal::VectorBranch<uint8_t>& m_spoint_bucket{m_tree.newVector<uint8_t>("bucket_index")}; + + //MuonVal::ThreeVectorBranch m_spoint_localPosition{m_tree, "localPosition"}; + MuonVal::VectorBranch<float>& m_spoint_x{m_tree.newVector<float>("x")}; + MuonVal::VectorBranch<float>& m_spoint_y{m_tree.newVector<float>("y")}; + MuonVal::VectorBranch<float>& m_spoint_z{m_tree.newVector<float>("z")}; + + //MuonVal::MuonIdentifierBranch m_spoint_id{m_tree, "id"}; + MuonVal::VectorBranch<uint8_t>& m_spoint_station{m_tree.newVector<uint8_t>("stationIndex")}; + MuonVal::VectorBranch<uint8_t>& m_spoint_layer{m_tree.newVector<uint8_t>("layer")}; + + MuonVal::VectorBranch<float>& m_spoint_driftR{m_tree.newVector<float>("driftR")}; + MuonVal::VectorBranch<float>& m_spoint_neighbors{m_tree.newVector<float>("neighbors")}; + + MuonVal::VectorBranch<uint8_t>& m_spoint_label{m_tree.newVector<uint8_t>("label")}; + MuonVal::VectorBranch<uint8_t>& m_spoint_predictions{m_tree.newVector<uint8_t>("predictions")}; + MuonVal::VectorBranch<uint16_t>& m_spoint_edges{m_tree.newVector<uint16_t>("edges")}; + + size_t m_event{0}; + + // ONNX Runtime objects + std::unique_ptr<Ort::Env> m_env; + std::unique_ptr<Ort::SessionOptions> m_session_options; + std::unique_ptr<Ort::Session> m_session; + +}; +} +#endif + diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx new file mode 100644 index 000000000000..a8ed24f5b37d --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx @@ -0,0 +1,340 @@ +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include "SPIdentifierAlg.h" + +#include "Identifier/Identifier.h" +#include "MuonIdHelpers/IMuonIdHelperSvc.h" +#include "StoreGate/ReadHandle.h" +#include "MuonTesterTree/EventHashBranch.h" +#include "MuonSpacePoint/SpacePointPerLayerSorter.h" +#include "xAODMeasurementBase/MeasurementDefs.h" +#include "xAODMuonPrepData/UtilFunctions.h" +#include "xAODMuonPrepData/MdtDriftCircle.h" +#include <fstream> +#include <TString.h> +#include <AthenaKernel/RNGWrapper.h> +#include "CLHEP/Random/RandFlat.h" + +namespace { + union bucketId{ + int8_t fields[4]; + int hash; + }; + +} + +namespace MuonR4{ + + StatusCode SPIdentifierAlg::initialize() { + ATH_CHECK(m_readKey.initialize()); + ATH_CHECK(m_idHelperSvc.retrieve()); + ATH_CHECK(m_inSegmentKey.initialize(!m_inSegmentKey.empty())); + m_tree.addBranch(std::make_shared<MuonVal::EventHashBranch>(m_tree.tree())); + ATH_CHECK(m_tree.init(this)); + ATH_CHECK(m_idHelperSvc.retrieve()); + + ATH_MSG_DEBUG("Successfully initialized"); + + return StatusCode::SUCCESS; + } + + StatusCode SPIdentifierAlg::finalize() { + ATH_CHECK(m_tree.write()); + return StatusCode::SUCCESS; + } + + StatusCode SPIdentifierAlg::execute(){ + const EventContext& ctx{Gaudi::Hive::currentContext()}; + + std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap; + + SG::ReadHandle readSegment(m_inSegmentKey, ctx); + ATH_CHECK(readSegment.isPresent()); + for (const MuonR4::Segment* segment : *readSegment) { + segmentMap[segment->parent()->parentBucket()].push_back(segment); + } + + SG::ReadHandle<SpacePointContainer> readHandle{m_readKey, ctx}; + ATH_CHECK(readHandle.isPresent()); + + bool Sparse = false; + bool Labels = true; + + std::vector<float> features; + std::vector<int64_t> edge_src; + std::vector<int64_t> edge_dst; + std::vector<int64_t> node_offsets; + std::vector<int64_t> labels; + size_t total_nodes = 0; + size_t bucket_index = 0; + + for(const SpacePointBucket* bucket : *readHandle) { + + unsigned int segIdx{0}; + std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment; + + if (Labels) { + auto match_itr = segmentMap.find(bucket); + if (match_itr != segmentMap.end()) { + for (const MuonR4::Segment* segment : match_itr->second) { + for (const auto& meas : segment->measurements()) { + spacePointToSegment[meas->spacePoint()].push_back(segIdx); + } + ++segIdx; + } + } + } + + SpacePointPerLayerSorter sorter{*bucket}; + + size_t num_points = bucket->size(); + + unsigned int layer{0}; + Identifier prevLayer = Identifier(); + + std::vector<u_int16_t> Neighbors(num_points, 0); + + size_t bucket_offset = total_nodes; + node_offsets.push_back(bucket_offset); + total_nodes += num_points; + + float min_x = 3000; + float max_x = -1; + + Identifier layId; + for (u_int16_t i = 0; i < num_points; i++) { + const auto sp = bucket->at(i); + const Identifier id = sp->identify(); + + if (sp->type() == xAOD::UncalibMeasType::MdtDriftCircleType) { + const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement()); + if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime){ + continue; + } + const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()}; + layId = idHelper.channelID(idHelper.stationName(id), 1, idHelper.stationPhi(id), idHelper.multilayer(id), idHelper.tubeLayer(id), 1); + if (layId != prevLayer) { + layer++; + prevLayer = layId; + } + } else { + layId = m_idHelperSvc->gasGapId(id); + + if (layId != prevLayer) { + layer++; + prevLayer = layId; + } + } + + if (Labels) { + const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()]; + if (segIdxs.size() > 0) { + m_spoint_label.push_back(1); + } else { + m_spoint_label.push_back(0); + } + } + + float x = sp->positionInChamber().x(); + float y = sp->positionInChamber().y(); + float z = sp->positionInChamber().z(); + + + + m_spoint_bucket.push_back(bucket_index); + m_spoint_x.push_back(x); + m_spoint_y.push_back(y); + m_spoint_z.push_back(z); + m_spoint_station.push_back(m_idHelperSvc->stationName(id)); + + m_spoint_driftR.push_back(sp->driftRadius()); + m_spoint_layer.push_back(layer); + + if (x < min_x) min_x = x; + if (x > max_x) max_x = x; + + + if (Sparse) { + for (u_int16_t j = 0; j < num_points; j++) { + if (i == j) continue; + const auto sp2 = bucket->at(j); + Identifier layId2; + const Identifier id2 = sp2->identify(); + if (sp2->type() == xAOD::UncalibMeasType::MdtDriftCircleType) { + const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()}; + layId2 = idHelper.channelID(idHelper.stationName(id2), 1, idHelper.stationPhi(id2), idHelper.multilayer(id2), idHelper.tubeLayer(id2), 1); + } else { + layId2 = m_idHelperSvc->gasGapId(id2); + } + float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2)); + if (h < 500) { + Neighbors[i]++; + } + if ( h < 2000 and layId != layId2) { + edge_src.push_back(i + bucket_offset); + edge_dst.push_back(j + bucket_offset); + m_spoint_edges.push_back(j + bucket_offset); + } + } + } else { + for (u_int16_t j = 0; j < num_points; j++) { + if (i == j) continue; + edge_src.push_back(i + bucket_offset); + edge_dst.push_back(j + bucket_offset); + m_spoint_edges.push_back(j + bucket_offset); + if (j < i ) continue; + const auto sp2 = bucket->at(j); + float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2)); + if (h < 500) { + Neighbors[i]++; + Neighbors[j]++; + } + } + } + + m_spoint_neighbors.push_back(Neighbors[i]); + } // end loop over space points + + m_bucket_layers.push_back(layer); + float density = 0; + float bucket_size = bucket->coveredMax() - bucket->coveredMin(); + if (bucket_size == 0 ){ + bucket_size = max_x - min_x; + if (bucket_size == 0) bucket_size = 1; + } + density = num_points / bucket_size; + m_bucket_density.push_back(density); + bucket_index++; + + } // end loop over buckets + + features.reserve(total_nodes * 8); + for (size_t i = 0; i < total_nodes; i++) { + features.push_back(m_spoint_x[i]); + features.push_back(m_spoint_y[i]); + features.push_back(m_spoint_z[i]); + features.push_back(m_spoint_station[i]); + features.push_back(m_spoint_driftR[i]); + features.push_back(float(m_spoint_layer[i]) / (m_bucket_layers[m_spoint_bucket[i]] > 0 ? m_bucket_layers[m_spoint_bucket[i]] : 1)); + features.push_back(m_spoint_neighbors[i]); + features.push_back(m_bucket_density[m_spoint_bucket[i]]); + if (Labels) { + labels.push_back(m_spoint_label[i]); + } + } + + if (features.empty()) { + ATH_MSG_WARNING("No valid feature data available for inference. Skipping event."); + return StatusCode::SUCCESS; + } + + std::vector<int64_t> edge_index; + edge_index.reserve(2 * edge_src.size()); + edge_index.insert(edge_index.end(), edge_src.begin(), edge_src.end()); + edge_index.insert(edge_index.end(), edge_dst.begin(), edge_dst.end()); + + ATH_MSG_DEBUG("Total nodes: " << total_nodes); + ATH_MSG_DEBUG("Features size: " << features.size()); + ATH_MSG_DEBUG("Expected feature shape: (" << total_nodes << ", 8)"); + + if (features.size() != total_nodes * 8) { + ATH_MSG_ERROR("Feature size mismatch! Expected " << total_nodes * 8 << " but got " << features.size()); + } + + ATH_MSG_DEBUG("Edge index size: " << edge_index.size()); + ATH_MSG_DEBUG("Expected edge_index shape: (2, " << edge_index.size() / 2 << ")"); + + if (edge_index.size() % 2 != 0) { + ATH_MSG_ERROR("Edge index format error! Size should be divisible by 2."); + } + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"); + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + + Ort::Session session(env, "torch_GatFourier_fcg_quantized.onnx", session_options); + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + std::vector<int64_t> feature_shape = {static_cast<int64_t>(total_nodes), 8}; // (N, 8) + std::vector<int64_t> edge_shape = {2, static_cast<int64_t>(edge_index.size() / 2)}; // (2, E) + + Ort::Value feature_tensor = Ort::Value::CreateTensor<float>( + memory_info, features.data(), features.size(), feature_shape.data(), feature_shape.size()); + + Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>( + memory_info, edge_index.data(), edge_index.size(), edge_shape.data(), edge_shape.size()); + + std::vector<const char*> input_names = {"features", "edge_index"}; + std::vector<const char*> output_names = {"output"}; + + std::vector<Ort::Value> input_tensors; + input_tensors.push_back(std::move(feature_tensor)); + input_tensors.push_back(std::move(edge_tensor)); + + std::vector<Ort::Value> output_tensors = session.Run( + Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), + input_tensors.size(), output_names.data(), output_names.size()); + + float* output_data = output_tensors[0].GetTensorMutableData<float>(); + size_t output_size = output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount(); + std::vector<float> predictions(output_data, output_data + output_size); + + for (size_t i = 0; i < output_size; i++) { + if (!std::isfinite(predictions[i])) { + ATH_MSG_ERROR("Non-finite prediction detected! Setting to zero."); + predictions[i] = 0.0f; + } + } + + ATH_MSG_DEBUG("ONNX output size: " << output_size); + ATH_MSG_DEBUG("First 5 predictions:"); + for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) { + ATH_MSG_DEBUG("Prediction[" << i << "]: " << predictions[i]); + } + + std::vector<int> binary_predictions(output_size); + for (size_t i = 0; i < output_size; i++) { + binary_predictions[i] = (predictions[i] > 0.05) ? 1 : 0; + } + ATH_MSG_DEBUG("First 5 Binary Predictions:"); + for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) { + ATH_MSG_DEBUG("Binary_Prediction[" << i << "]: " << binary_predictions[i]); + } + + if (Labels) { + float loss = 0; + float accuracy = 0; + for (size_t i = 0; i < output_size; i++) { + if (binary_predictions[i] != labels[i] && labels[i] == 1) { + loss++; + } + if (binary_predictions[i] == labels[i]) { + accuracy++; + } + } + ATH_MSG_ALWAYS("Hits: " << output_size); + ATH_MSG_ALWAYS("Loss: " << loss / output_size); + ATH_MSG_ALWAYS("Accuracy: " << accuracy / output_size); + } + + for (size_t i = 0; i < output_size; i++) { + m_spoint_predictions.push_back(binary_predictions[i]); + } + + if (!m_tree.fill(ctx)) return StatusCode::FAILURE; + return StatusCode::SUCCESS; + + } + + CLHEP::HepRandomEngine* SPIdentifierAlg::getRandomEngine(const EventContext&ctx) const { + ATHRNG::RNGWrapper* rngWrapper = m_rndmSvc->getEngine(this, m_streamName); + std::string rngName = m_streamName; + rngWrapper->setSeed(rngName, ctx); + return rngWrapper->getEngine(ctx); + } + +} diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h new file mode 100644 index 000000000000..b135b4a45864 --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h @@ -0,0 +1,91 @@ +#ifndef MUON_SP_ID_DUMPER_H +#define MUON_SP_ID_DUMPER_H +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ + +#include <AthenaBaseComps/AthAlgorithm.h> +#include "AthenaBaseComps/AthHistogramAlgorithm.h" + +#include <MuonIdHelpers/IMuonIdHelperSvc.h> +#include "StoreGate/ReadHandleKey.h" +#include "StoreGate/ReadDecorHandle.h" +#include "StoreGate/ReadHandleKeyArray.h" +#include "StoreGate/ReadCondHandleKey.h" + +#include <MuonPatternEvent/MuonPatternContainer.h> +#include <MuonSpacePoint/SpacePointContainer.h> +#include <ActsGeometryInterfaces/ActsGeometryContext.h> +#include <MuonReadoutGeometryR4/MuonDetectorManager.h> + +#include "MuonTesterTree/MuonTesterTree.h" +#include "MuonTesterTree/ThreeVectorBranch.h" +#include "MuonTesterTree/IdentifierBranch.h" + +#include "xAODMuonSimHit/MuonSimHitContainer.h" +#include "xAODMuon/MuonSegmentContainer.h" + +#include "AthenaKernel/IAthRNGSvc.h" +#include "CLHEP/Random/RandomEngine.h" + +// onnx runtime +#include <onnxruntime_cxx_api.h> + + +namespace MuonR4{ +class SPIdentifierAlg: public AthHistogramAlgorithm { + + public: + using AthHistogramAlgorithm::AthHistogramAlgorithm; + ~SPIdentifierAlg() = default; + + virtual StatusCode initialize() override final; + virtual StatusCode finalize() override final; + virtual StatusCode execute() override final; + + private: + + void fillChamberInfo(const MuonGMR4::Chamber* chamber); + + SG::ReadHandleKey<SpacePointContainer> m_readKey{this, "ReadKey", "MuonSpacePoints", "Key to the space point container"}; + ServiceHandle<Muon::IMuonIdHelperSvc> m_idHelperSvc{this, "MuonIdHelperSvc", "Muon::MuonIdHelperSvc/MuonIdHelperSvc"}; + + SG::ReadHandleKey<MuonR4::SegmentContainer> m_inSegmentKey{this, "SegmentKey", "R4MuonSegments"}; + + //SG::ReadHandleKey<ActsGeometryContext> m_geoCtxKey{this, "AlignmentKey", "ActsAlignment", "cond handle key"}; + + Gaudi::Property<bool> m_isMC{this, "isMC", true}; + //Gaudi::Property<double> m_fracToKeep{this,"dataFracToKeep", 1}; // 0.055 to balanced dataset without MC + Gaudi::Property<std::string> m_streamName{this, "StreamName", ""}; + ServiceHandle<IAthRNGSvc> m_rndmSvc{this, "RndmSvc", "AthRNGSvc", ""}; + CLHEP::HepRandomEngine* getRandomEngine(const EventContext&ctx) const; + + MuonVal::MuonTesterTree m_tree{"MuonSPId","MuonSPId"}; + + MuonVal::VectorBranch<float>& m_bucket_density{m_tree.newVector<float>("bucket_density", 0)}; + MuonVal::VectorBranch<uint8_t>& m_bucket_layers{m_tree.newVector<uint8_t>("bucket_layers", 0)}; + + MuonVal::VectorBranch<uint8_t>& m_spoint_bucket{m_tree.newVector<uint8_t>("bucket_index")}; + + //MuonVal::ThreeVectorBranch m_spoint_localPosition{m_tree, "localPosition"}; + MuonVal::VectorBranch<float>& m_spoint_x{m_tree.newVector<float>("x")}; + MuonVal::VectorBranch<float>& m_spoint_y{m_tree.newVector<float>("y")}; + MuonVal::VectorBranch<float>& m_spoint_z{m_tree.newVector<float>("z")}; + + //MuonVal::MuonIdentifierBranch m_spoint_id{m_tree, "id"}; + MuonVal::VectorBranch<uint8_t>& m_spoint_station{m_tree.newVector<uint8_t>("stationIndex")}; + MuonVal::VectorBranch<uint8_t>& m_spoint_layer{m_tree.newVector<uint8_t>("layer")}; + + MuonVal::VectorBranch<float>& m_spoint_driftR{m_tree.newVector<float>("driftR")}; + MuonVal::VectorBranch<float>& m_spoint_neighbors{m_tree.newVector<float>("neighbors")}; + + MuonVal::VectorBranch<uint8_t>& m_spoint_label{m_tree.newVector<uint8_t>("label")}; + MuonVal::VectorBranch<uint8_t>& m_spoint_predictions{m_tree.newVector<uint8_t>("predictions")}; + MuonVal::VectorBranch<uint16_t>& m_spoint_edges{m_tree.newVector<uint16_t>("edges")}; + + size_t m_event{0}; + +}; +} +#endif + diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx new file mode 100644 index 000000000000..bd50be95430b --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx @@ -0,0 +1,7 @@ + +/* + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +*/ +#include "../SPIdDumperAlg.h" +DECLARE_COMPONENT(MuonR4::SPIdDumperAlg) + -- GitLab From 2ddd20266f23e5e784866927364f87a63dfa8aac Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Thu, 20 Feb 2025 15:28:01 +0100 Subject: [PATCH 2/9] Speeding up inference and fixing dumper --- .../MuonSPId/python/muonSPIdDump_data.py | 2 + .../MuonSPId/src/SPIdDumperAlg.cxx | 280 ++++++++++++------ .../MuonLearning/MuonSPId/src/SPIdDumperAlg.h | 23 +- 3 files changed, 200 insertions(+), 105 deletions(-) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py index 7e7f1be057f4..187225123d58 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py @@ -33,6 +33,8 @@ if __name__=="__main__": from MuonSPId.MuonSPIdDumpConfig import MuonSPIdDumpCfg cfg.merge(MuonSPIdDumpCfg(flags)) + + #cfg.getService("MessageSvc").setDebug=["MuonSPIdMaker"] executeTest(cfg) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx index 379257fe8b5b..c7e2a5bc7ba6 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx @@ -15,6 +15,7 @@ #include <fstream> #include <TString.h> #include <AthenaKernel/RNGWrapper.h> +#include <vector> #include "CLHEP/Random/RandFlat.h" namespace { @@ -44,7 +45,8 @@ namespace MuonR4{ m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ONNXFeatureExtraction"); m_session_options = std::make_unique<Ort::SessionOptions>(); - m_session_options->SetIntraOpNumThreads(1); + //m_session_options->SetIntraOpNumThreads(1); + m_session_options->SetInterOpNumThreads(std::thread::hardware_concurrency()); m_session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); m_session = std::make_unique<Ort::Session>(*m_env, onnx_path.c_str(), *m_session_options); @@ -71,9 +73,9 @@ namespace MuonR4{ return StatusCode::FAILURE; } - ATH_MSG_ALWAYS("Extracted Feature Names:"); + ATH_MSG_DEBUG("Extracted Feature Names:"); for (const auto& name : feature_names) { - ATH_MSG_ALWAYS("- " << name); + ATH_MSG_DEBUG("- " << name); } ATH_MSG_DEBUG("Successfully initialized"); @@ -102,15 +104,38 @@ namespace MuonR4{ bool Sparse = false; bool Labels = true; - std::vector<float> features; - std::vector<int64_t> edge_src; - std::vector<int64_t> edge_dst; - std::vector<int64_t> node_offsets; - std::vector<int64_t> labels; - size_t total_nodes = 0; - + std::vector<float> spoint_x, spoint_y, spoint_z, spoint_driftR, spoint_layer; + std::vector<uint8_t> spoint_station, spoint_neighbors, spoint_bucket; + std::vector<uint16_t> spoint_edges, bucket_layers; + std::vector<uint64_t> edge_src, edge_dst, node_offsets; + std::vector<bool> spoint_label; + + spoint_x.reserve(100000); + spoint_y.reserve(100000); + spoint_z.reserve(100000); + spoint_driftR.reserve(100000); + spoint_layer.reserve(100000); + spoint_station.reserve(100000); + spoint_bucket.reserve(100000); + spoint_neighbors.reserve(100000); + spoint_edges.reserve(100000); + + size_t total_nodes = 0; size_t bucket_index = 0; + std::vector<float> spoint_predictions; + + std::vector<float> bucket_density; + std::vector<uint8_t> bucket_points; + std::vector<float> features; + std::vector<bool> labels; + + + std::vector<Ort::Value> input_tensors; + std::vector<int64_t> edge_index; + std::vector<const char*> input_names = {"features", "edge_index"}; + std::vector<const char*> output_names = {"output"}; + for(const SpacePointBucket* bucket : *readHandle) { unsigned int segIdx{0}; @@ -139,13 +164,31 @@ namespace MuonR4{ size_t bucket_offset = total_nodes; node_offsets.push_back(bucket_offset); - total_nodes += num_points; + //std::unordered_map<u_int16_t, u_int16_t> index_mapping; // Maps original index to new valid index + std::vector<u_int16_t> index_mapping(num_points, UINT16_MAX); + u_int16_t valid_index = 0; float min_x = 3000; float max_x = -1; Identifier layId; + + for (u_int16_t i = 0; i < num_points; i++) { + const auto sp = bucket->at(i); + if (sp->type() == xAOD::UncalibMeasType::MdtDriftCircleType) { + const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement()); + if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime) { + continue; + } + } + index_mapping[i] = valid_index; + valid_index++; + } + total_nodes += valid_index; + + for (u_int16_t i = 0; i < num_points; i++) { + if (index_mapping[i] == UINT16_MAX) continue; const auto sp = bucket->at(i); const Identifier id = sp->identify(); @@ -169,12 +212,12 @@ namespace MuonR4{ } } - if (Labels) { - const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()]; + const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()]; + if (Labels) { if (segIdxs.size() > 0) { - m_spoint_label.push_back(1); + spoint_label.push_back(1); } else { - m_spoint_label.push_back(0); + spoint_label.push_back(0); } } @@ -182,24 +225,21 @@ namespace MuonR4{ float y = sp->positionInChamber().y(); float z = sp->positionInChamber().z(); - - - m_spoint_bucket.push_back(bucket_index); - m_spoint_x.push_back(x); - m_spoint_y.push_back(y); - m_spoint_z.push_back(z); - m_spoint_station.push_back(m_idHelperSvc->stationName(id)); - - m_spoint_driftR.push_back(sp->driftRadius()); - m_spoint_layer.push_back(layer); + spoint_x.push_back(x); + spoint_y.push_back(y); + spoint_z.push_back(z); + spoint_driftR.push_back(sp->driftRadius()); + spoint_station.push_back(m_idHelperSvc->stationName(id)); + spoint_layer.push_back(layer); + spoint_bucket.push_back(bucket_index); if (x < min_x) min_x = x; if (x > max_x) max_x = x; - if (Sparse) { for (u_int16_t j = 0; j < num_points; j++) { - if (i == j) continue; + if (i == j || (index_mapping[j] == UINT16_MAX) ) continue; + //if (i == j) continue; const auto sp2 = bucket->at(j); Identifier layId2; const Identifier id2 = sp2->identify(); @@ -209,36 +249,43 @@ namespace MuonR4{ } else { layId2 = m_idHelperSvc->gasGapId(id2); } - float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2)); - if (h < 500) { - Neighbors[i]++; + //float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2)); + float dx = sp2->positionInChamber().x() - x; + float dy = sp2->positionInChamber().y() - y; + float h_squared = dx * dx + dy * dy; + if (h_squared < 250000) { + Neighbors[index_mapping[i]]++; } - if ( h < 2000 and layId != layId2) { - edge_src.push_back(i + bucket_offset); - edge_dst.push_back(j + bucket_offset); - m_spoint_edges.push_back(j + bucket_offset); + if ( h_squared < 4000000 and layId != layId2) { + edge_src.push_back(index_mapping[i] + bucket_offset); + edge_dst.push_back(index_mapping[j] + bucket_offset); + spoint_edges.push_back(index_mapping[j] + bucket_offset); } } } else { for (u_int16_t j = 0; j < num_points; j++) { - if (i == j) continue; - edge_src.push_back(i + bucket_offset); - edge_dst.push_back(j + bucket_offset); - m_spoint_edges.push_back(j + bucket_offset); - if (j < i ) continue; + //if (i == j) continue; + if (i == j || (index_mapping[j] == UINT16_MAX) ) continue; // Skip removed points + edge_src.push_back(index_mapping[i] + bucket_offset); + edge_dst.push_back(index_mapping[j] + bucket_offset); + spoint_edges.push_back(index_mapping[j] + bucket_offset); + if ( j < i ) continue; const auto sp2 = bucket->at(j); - float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2)); - if (h < 500) { - Neighbors[i]++; - Neighbors[j]++; + float dx = sp2->positionInChamber().x() - x; + float dy = sp2->positionInChamber().y() - y; + float h_squared = dx * dx + dy * dy; + if (h_squared < 250000) { + Neighbors[index_mapping[i]]++; + Neighbors[index_mapping[j]]++; } } } + - m_spoint_neighbors.push_back(Neighbors[i]); + spoint_neighbors.push_back(Neighbors[index_mapping[i]]); } // end loop over space points - m_bucket_layers.push_back(layer); + bucket_layers.push_back(layer); float density = 0; float bucket_size = bucket->coveredMax() - bucket->coveredMin(); if (bucket_size == 0 ){ @@ -246,32 +293,49 @@ namespace MuonR4{ if (bucket_size == 0) bucket_size = 1; } density = num_points / bucket_size; - m_bucket_density.push_back(density); + bucket_density.push_back(density); + bucket_points.push_back(num_points); bucket_index++; - + } // end loop over buckets features.reserve(total_nodes * 8); + labels.reserve(total_nodes); + + for (size_t i = 0; i < total_nodes; i++) { - features.push_back(m_spoint_x[i]); - features.push_back(m_spoint_y[i]); - features.push_back(m_spoint_z[i]); - features.push_back(m_spoint_station[i]); - features.push_back(m_spoint_driftR[i]); - features.push_back(float(m_spoint_layer[i]) / (m_bucket_layers[m_spoint_bucket[i]] > 0 ? m_bucket_layers[m_spoint_bucket[i]] : 1)); - features.push_back(m_spoint_neighbors[i]); - features.push_back(m_bucket_density[m_spoint_bucket[i]]); + float relative_layer = float(spoint_layer[i]) / (bucket_layers[spoint_bucket[i]] > 0 ? bucket_layers[spoint_bucket[i]] : 1); + + //if (!std::isfinite(spoint_x[i])) ATH_MSG_WARNING("X is NaN: " << spoint_x[i] << " in sp " << i); + //if (!std::isfinite(spoint_y[i])) ATH_MSG_WARNING("Y is NaN: " << spoint_y[i] << " in sp " << i); + //if (!std::isfinite(spoint_z[i])) ATH_MSG_WARNING("Z is NaN: " << spoint_z[i] << " in sp " << i); + //if (!std::isfinite(spoint_driftR[i])) ATH_MSG_WARNING("DriftR is NaN: " << spoint_driftR[i] << " in sp " << i); + //if (!std::isfinite(spoint_neighbors[i])) ATH_MSG_WARNING("Neighbors is NaN: " << spoint_neighbors[i] << " in sp " << i); + //if (!std::isfinite(bucket_density[spoint_bucket[i]])) ATH_MSG_WARNING("Density is NaN: " << bucket_density[spoint_bucket[i]] << " in sp " << i); + //if (!std::isfinite(relative_layer)) ATH_MSG_WARNING("Relative Layer is NaN: " << relative_layer << " in sp " << i); + + features.insert(features.end(), { + spoint_x[i], + spoint_y[i], + spoint_z[i], + static_cast<float>(spoint_station[i]), + spoint_driftR[i], + relative_layer, + static_cast<float>(spoint_neighbors[i]), + static_cast<float>(bucket_density[spoint_bucket[i]]) + }); + if (Labels) { - labels.push_back(m_spoint_label[i]); + labels.push_back(spoint_label[i]); } } - if (features.empty()) { - ATH_MSG_WARNING("No valid feature data available for inference. Skipping event."); - return StatusCode::SUCCESS; + for (size_t i = 0; i < edge_src.size(); i++) { + if (edge_src[i] >= total_nodes || edge_dst[i] >= total_nodes) { + ATH_MSG_WARNING("Edge connection out of range: (" << edge_src[i] << ", " << edge_dst[i] << ") - Total nodes: " << total_nodes); + } } - std::vector<int64_t> edge_index; edge_index.reserve(2 * edge_src.size()); edge_index.insert(edge_index.end(), edge_src.begin(), edge_src.end()); edge_index.insert(edge_index.end(), edge_dst.begin(), edge_dst.end()); @@ -291,13 +355,17 @@ namespace MuonR4{ ATH_MSG_ERROR("Edge index format error! Size should be divisible by 2."); } + //for (size_t i = 0; i < features.size(); i++) { + // if (!std::isfinite(features[i])) { + // ATH_MSG_WARNING("Feature[" << i << "] contains NaN or Inf: " << features[i]); + // } + //} + //for (size_t i = 0; i < edge_index.size(); i++) { + // if (edge_index[i] < 0) { + // ATH_MSG_WARNING("Edge index[" << i << "] is negative: " << edge_index[i]); + // } + //} - //Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"); - //Ort::SessionOptions session_options; - //session_options.SetIntraOpNumThreads(1); - //session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); - //Ort::Session session(env, "torch_GatFourier_fcg_quantized.onnx", session_options); - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); @@ -310,21 +378,25 @@ namespace MuonR4{ Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>( memory_info, edge_index.data(), edge_index.size(), edge_shape.data(), edge_shape.size()); - std::vector<const char*> input_names = {"features", "edge_index"}; - std::vector<const char*> output_names = {"output"}; - std::vector<Ort::Value> input_tensors; input_tensors.push_back(std::move(feature_tensor)); input_tensors.push_back(std::move(edge_tensor)); - //std::vector<Ort::Value> output_tensors = session.Run( + Ort::RunOptions run_options; + run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING); std::vector<Ort::Value> output_tensors = m_session->Run( - Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), + run_options, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), output_names.size()); + //std::vector<Ort::Value> output_tensors = m_session->Run( + // Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), + // input_tensors.size(), output_names.data(), output_names.size()); + float* output_data = output_tensors[0].GetTensorMutableData<float>(); size_t output_size = output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount(); + std::vector<float> predictions(output_data, output_data + output_size); + ATH_MSG_DEBUG("Predictions shape: (" << predictions.size() << ")"); for (size_t i = 0; i < output_size; i++) { if (!std::isfinite(predictions[i])) { @@ -339,13 +411,14 @@ namespace MuonR4{ ATH_MSG_DEBUG("Prediction[" << i << "]: " << predictions[i]); } - std::vector<int> binary_predictions(output_size); + std::vector<u_int8_t> binary_predictions(output_size); + ATH_MSG_DEBUG("Binary shape: (" << binary_predictions.size() << ")"); + ATH_MSG_DEBUG("Cut Value: " << m_SPId_cut); for (size_t i = 0; i < output_size; i++) { - binary_predictions[i] = (predictions[i] > m_SPId_cut) ? 1 : 0; // Why inference cut is different? -> this batch/graph building may be different from torch one? - } - ATH_MSG_DEBUG("First 5 Binary Predictions:"); - for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) { - ATH_MSG_DEBUG("Binary_Prediction[" << i << "]: " << binary_predictions[i]); + binary_predictions[i] = (predictions[i] > m_SPId_cut) ? 1 : 0; + if (i < 5) { + ATH_MSG_DEBUG("Binary_Prediction[" << i << " of first 5]: " << static_cast<int>(binary_predictions[i])); + } } if (Labels) { @@ -354,7 +427,7 @@ namespace MuonR4{ float pred_1 = 0; float lab_1 = 0; float true_0 = 0; - //float true_1 = 0; + for (size_t i = 0; i < output_size; i++) { if (binary_predictions[i] != labels[i] && labels[i] == 1) { loss++; @@ -362,9 +435,6 @@ namespace MuonR4{ if (binary_predictions[i] == labels[i]) { accuracy++; } - //if (binary_predictions[i] == labels[i] && labels[i] == 1) { - // true_1++; - //} if (binary_predictions[i] == labels[i] && labels[i] == 0) { true_0++; } @@ -375,23 +445,51 @@ namespace MuonR4{ lab_1++; } } - //ATH_MSG_ALWAYS("Hits: " << output_size ); - ATH_MSG_ALWAYS("Accuracy: " << accuracy / output_size << " ( Correct: " << accuracy << " / Hits: " << output_size << ")" ); - ATH_MSG_ALWAYS("Signal Loss: " << loss / lab_1 << " ( Miss predicted: " << loss << " / Good Hits: " << lab_1 << ")" ); + ATH_MSG_DEBUG("Accuracy: " << accuracy / output_size << " ( Correct: " << accuracy << " / Hits: " << output_size << ")" ); + ATH_MSG_DEBUG("Signal Loss: " << loss / lab_1 << " ( Miss predicted: " << loss << " / Good Hits: " << lab_1 << ")" ); float lab_0 = output_size - lab_1; if (lab_0 != 0) { - ATH_MSG_ALWAYS("Real Rejection: " << true_0 / (output_size - lab_1) << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" ); - } else ATH_MSG_ALWAYS("Real Rejection: " << 0 << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" ); - ATH_MSG_ALWAYS("Purity (before): " << lab_1 / output_size << " ( Good Hits: " << lab_1 << " / Hits: " << output_size << ")" ); - ATH_MSG_ALWAYS("Purity (after): " << (lab_1 - loss) / pred_1 << " ( Good predicted: " << lab_1 - loss << " / Predicted: " << pred_1 << ")" ); - ATH_MSG_ALWAYS("-----------------------------"); + ATH_MSG_DEBUG("Real Rejection: " << true_0 / (output_size - lab_1) << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" ); + } else ATH_MSG_DEBUG("Real Rejection: " << 0 << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" ); + ATH_MSG_DEBUG("Purity (before): " << lab_1 / output_size << " ( Good Hits: " << lab_1 << " / Hits: " << output_size << ")" ); + ATH_MSG_DEBUG("Purity (after): " << (lab_1 - loss) / pred_1 << " ( Good predicted: " << lab_1 - loss << " / Predicted: " << pred_1 << ")" ); + ATH_MSG_DEBUG("-----------------------------"); } + spoint_predictions.reserve(output_size); for (size_t i = 0; i < output_size; i++) { - m_spoint_predictions.push_back(binary_predictions[i]); + spoint_predictions.push_back(predictions[i]); + } + + if (bucket_density.size() != bucket_layers.size() || bucket_density.size() != bucket_points.size()) { + ATH_MSG_ERROR("Bucket size mismatch! Density: " << bucket_density.size() << ", Layers: " << bucket_layers.size() << ", Points: " << bucket_points.size()); } + size_t last_point = 0; + for (size_t i = 0; i < bucket_points.size(); i++) { + m_bucket_layers = bucket_layers[i]; + m_bucket_points = bucket_points[i]; + m_bucket_density = bucket_density[i]; + + size_t end_point = last_point + bucket_points[i]; + for (size_t j = last_point; j < end_point; j++) { + m_spoint_bucket.push_back(spoint_bucket[j]); + m_spoint_x.push_back(spoint_x[j]); + m_spoint_y.push_back(spoint_y[j]); + m_spoint_z.push_back(spoint_z[j]); + m_spoint_station.push_back(spoint_station[j]); + m_spoint_layer.push_back(spoint_layer[j]); + m_spoint_driftR.push_back(spoint_driftR[j]); + m_spoint_neighbors.push_back(spoint_neighbors[j]); + m_spoint_label.push_back(spoint_label[j]); + m_spoint_predictions.push_back(spoint_predictions[j]); + m_spoint_edges.push_back(spoint_edges[j] - node_offsets[i]); + } + + last_point = end_point; - if (!m_tree.fill(ctx)) return StatusCode::FAILURE; + if (!m_tree.fill(ctx)) return StatusCode::FAILURE; + } + return StatusCode::SUCCESS; } diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h index 6a2f9b1c7d90..2467924774c2 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h @@ -55,35 +55,30 @@ class SPIdDumperAlg: public AthHistogramAlgorithm { SG::ReadHandleKey<MuonR4::SegmentContainer> m_inSegmentKey{this, "SegmentKey", "R4MuonSegments"}; - Gaudi::Property<bool> m_isMC{this, "isMC", true}; - Gaudi::Property<double> m_SPId_cut{this,"spIdValue", 0.0001}; - Gaudi::Property<std::string> m_modelPath{this, "ModelPath", "torch_GatFourier_fcg_quantized.onnx"}; + Gaudi::Property<bool> m_isMC{this, "isMC", true}; + Gaudi::Property<double> m_SPId_cut{this,"spIdValue", -7}; + Gaudi::Property<std::string> m_modelPath{this, "ModelPath", "EdgeGAT_FCG_8vars_quantized.onnx"}; MuonVal::MuonTesterTree m_tree{"MuonSPId","MuonSPId"}; - MuonVal::VectorBranch<float>& m_bucket_density{m_tree.newVector<float>("bucket_density", 0)}; - MuonVal::VectorBranch<uint8_t>& m_bucket_layers{m_tree.newVector<uint8_t>("bucket_layers", 0)}; + MuonVal::ScalarBranch<uint8_t>& m_bucket_layers{m_tree.newScalar<uint8_t>("bucket_layers", 0)}; + MuonVal::ScalarBranch<uint8_t>& m_bucket_points{m_tree.newScalar<uint8_t>("bucket_points", 0)}; + MuonVal::ScalarBranch<float>& m_bucket_density{m_tree.newScalar<float>("bucket_density", 0)}; MuonVal::VectorBranch<uint8_t>& m_spoint_bucket{m_tree.newVector<uint8_t>("bucket_index")}; - - //MuonVal::ThreeVectorBranch m_spoint_localPosition{m_tree, "localPosition"}; MuonVal::VectorBranch<float>& m_spoint_x{m_tree.newVector<float>("x")}; MuonVal::VectorBranch<float>& m_spoint_y{m_tree.newVector<float>("y")}; MuonVal::VectorBranch<float>& m_spoint_z{m_tree.newVector<float>("z")}; - //MuonVal::MuonIdentifierBranch m_spoint_id{m_tree, "id"}; MuonVal::VectorBranch<uint8_t>& m_spoint_station{m_tree.newVector<uint8_t>("stationIndex")}; - MuonVal::VectorBranch<uint8_t>& m_spoint_layer{m_tree.newVector<uint8_t>("layer")}; - + MuonVal::VectorBranch<float>& m_spoint_layer{m_tree.newVector<float>("layer")}; MuonVal::VectorBranch<float>& m_spoint_driftR{m_tree.newVector<float>("driftR")}; - MuonVal::VectorBranch<float>& m_spoint_neighbors{m_tree.newVector<float>("neighbors")}; + MuonVal::VectorBranch<uint8_t>& m_spoint_neighbors{m_tree.newVector<uint8_t>("neighbors")}; MuonVal::VectorBranch<uint8_t>& m_spoint_label{m_tree.newVector<uint8_t>("label")}; - MuonVal::VectorBranch<uint8_t>& m_spoint_predictions{m_tree.newVector<uint8_t>("predictions")}; + MuonVal::VectorBranch<float>& m_spoint_predictions{m_tree.newVector<float>("predictions")}; MuonVal::VectorBranch<uint16_t>& m_spoint_edges{m_tree.newVector<uint16_t>("edges")}; - size_t m_event{0}; - // ONNX Runtime objects std::unique_ptr<Ort::Env> m_env; std::unique_ptr<Ort::SessionOptions> m_session_options; -- GitLab From 5ab3b6f25ae67e0e00b3b291f87c34ad9aa13e43 Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Mon, 24 Feb 2025 11:39:13 +0100 Subject: [PATCH 3/9] Update in MuonSPId --- .../MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx | 2 +- .../MuonInferenceInterfaces/src/NodeFeatureFactory.cxx | 2 ++ .../MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py | 3 ++- .../MuonLearning/MuonSPId/python/muonSPIdDump_data.py | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx index 51e666263121..d55804c8e2b4 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx @@ -23,7 +23,7 @@ namespace MuonML{ session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); m_model = std::make_unique<Ort::Session>(env, modelPath.c_str(), session_options); - ATH_MSG_DEBUG("Successfully loaded infernce model from "<<modelPath); + ATH_MSG_DEBUG("Successfully loaded inference model from "<<modelPath); Ort::ModelMetadata metadata = m_model->GetModelMetadata(); diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx index daf78ffe3330..b9490d5d9ec1 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx @@ -45,6 +45,8 @@ namespace MuonML { [](const Bucket_t& bucket, size_t index) { return bucket[index]->positionInChamber().x(); }), + + // loop over the bucket and get Neighbors }; const auto feat_itr = featurePool.find(featName); if(feat_itr != featurePool.end()){ diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py index a07d728e8396..cff3b9f224dd 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py @@ -15,7 +15,8 @@ if __name__=="__main__": flags = initConfigFlags() flags.PerfMon.doFullMonMT = True - flags, cfg = setupGeoR4TestCfg(args) + #flags, cfg = setupGeoR4TestCfg(args) + flags, cfg = setupGeoR4TestCfg(args,flags) cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile, outStream="MuonSPId")) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py index 187225123d58..4a5d6dd23ee8 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py @@ -16,7 +16,7 @@ if __name__=="__main__": flags = initConfigFlags() flags.PerfMon.doFullMonMT = True - flags, cfg = setupGeoR4TestCfg(args) + flags, cfg = setupGeoR4TestCfg(args,flags) cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile, outStream="MuonSPId")) -- GitLab From 21b77bf38a22d4435c79f3776f8e19767ac4e5c5 Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Mon, 24 Feb 2025 13:26:31 +0100 Subject: [PATCH 4/9] Adding features --- .../src/NodeFeatureFactory.cxx | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx index b9490d5d9ec1..a6d4a7398fd0 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx @@ -37,16 +37,40 @@ namespace MuonML { /** Predefine the known features in the pool */ static const std::set<Feature_t, std::less<>> featurePool{ - std::make_unique<NodeFeature>("driftR", - [](const Bucket_t& bucket, size_t index) { - return bucket[index]->driftRadius(); - }), std::make_unique<NodeFeature>("localX", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->positionInChamber().x(); + }), + std::make_unique<NodeFeature>("localY", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->positionInChamber().y(); + }), + std::make_unique<NodeFeature>("localZ", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->positionInChamber().z(); + }), + std::make_unique<NodeFeature>("driftR", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->driftRadius(); + }), + std::make_unique<NodeFeature>("relativeLayer", + [](const Bucket_t& bucket, size_t index) { + return bucket.layerNum(index) / (bucket.nStripLayers() + bucket.nMdtLayers()); + }), + std::make_unique<NodeFeature>("neighbours", + [](const Bucket_t& bucket, size_t index) { + constexpr double radCut2 = (50 *Gaudi::Units::cm * 50.*Gaudi::Units::cm); + unsigned int n =0; + for (size_t other =0 ; other < bucket.size(); ++ other){ + n+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2; + } + return n; + }), + std::make_unique<NodeFeature>("density", [](const Bucket_t& bucket, size_t index) { - return bucket[index]->positionInChamber().x(); - }), - - // loop over the bucket and get Neighbors + return bucket.size() / (bucket.coveredMax() - bucket.coveredMin()); + }), + }; const auto feat_itr = featurePool.find(featName); if(feat_itr != featurePool.end()){ -- GitLab From 399b7e96554797e5429bce8d543cba0870c6c40f Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Mon, 24 Feb 2025 14:14:11 +0100 Subject: [PATCH 5/9] update of feature implementation --- .../src/NodeFeatureFactory.cxx | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx index a6d4a7398fd0..a262627223a2 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx @@ -34,8 +34,6 @@ namespace MuonML { namespace Factory { Feature_t makeFeature(const std::string& featName, MsgStream& log) { - - /** Predefine the known features in the pool */ static const std::set<Feature_t, std::less<>> featurePool{ std::make_unique<NodeFeature>("localX", [](const Bucket_t& bucket, size_t index) { @@ -44,11 +42,16 @@ namespace MuonML { std::make_unique<NodeFeature>("localY", [](const Bucket_t& bucket, size_t index) { return bucket[index]->positionInChamber().y(); + bucket[index]->positionInChamber().y(); }), std::make_unique<NodeFeature>("localZ", [](const Bucket_t& bucket, size_t index) { return bucket[index]->positionInChamber().z(); }), + std::make_unique<NodeFeature>("stationIndex", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]-> + }), std::make_unique<NodeFeature>("driftR", [](const Bucket_t& bucket, size_t index) { return bucket[index]->driftRadius(); @@ -57,9 +60,9 @@ namespace MuonML { [](const Bucket_t& bucket, size_t index) { return bucket.layerNum(index) / (bucket.nStripLayers() + bucket.nMdtLayers()); }), - std::make_unique<NodeFeature>("neighbours", + std::make_unique<NodeFeature>("neighbors", [](const Bucket_t& bucket, size_t index) { - constexpr double radCut2 = (50 *Gaudi::Units::cm * 50.*Gaudi::Units::cm); + constexpr double radCut2 = (50.*Gaudi::Units::cm * 50.*Gaudi::Units::cm); unsigned int n =0; for (size_t other =0 ; other < bucket.size(); ++ other){ n+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2; @@ -67,10 +70,20 @@ namespace MuonML { return n; }), std::make_unique<NodeFeature>("density", - [](const Bucket_t& bucket, size_t index) { - return bucket.size() / (bucket.coveredMax() - bucket.coveredMin()); - }), - + [](const Bucket_t& bucket, size_t index) { + return bucket.size() / (bucket.coveredMax() - bucket.coveredMin()); // missing max and min of bucket + }), + // can I add isolation already? + std::make_unique<NodeFeature>("isolation", + [](const Bucket_t& bucket, size_t index) { + unsigned int neighbors = 0; + constexpr double radCut2 = (50.*Gaudi::Units::cm * 50.*Gaudi::Units::cm); + for (size_t other =0 ; other < bucket.size(); ++ other){ + neighbors+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2; + } + float bucket_density = bucket.size() / (bucket.coveredMax() - bucket.coveredMin()); + return neighbors / bucket_density; + }), }; const auto feat_itr = featurePool.find(featName); if(feat_itr != featurePool.end()){ -- GitLab From 2530c6e07f9435e56ca285a1ad89c527790a30ff Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Mon, 24 Feb 2025 15:24:11 +0100 Subject: [PATCH 6/9] Adding station index --- .../MuonInferenceInterfaces/LayerBucket.h | 11 +++++++++++ .../src/NodeFeatureFactory.cxx | 6 +++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h index bb083b0e3a5a..7eee70727f88 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h @@ -5,6 +5,7 @@ #define MUONINFERENCEINTERACES_LAYERBUCKET_H #include "MuonSpacePoint/SpacePointContainer.h" +#include "GaudiKernel/SystemOfUnits.h" namespace MuonML{ /** @brief The LayerSpBucket is a space pointbucket where the points are internally * sorted by their layer number as defined in the SpacePointLayerSorter. The @@ -26,10 +27,20 @@ namespace MuonML{ uint8_t layerNum(const size_t i) const { return m_layNum[i]; } + /** @brief Returns the max covered position of the bucket */ + uint8_t coveredMax() const { + return m_max; + } + /** @brief Returns the min covered position of the bucket */ + uint8_t coveredMin() const { + return m_min; + } private: uint8_t m_nMdtLay{0}; uint8_t m_nStripLay{0}; std::vector<uint8_t> m_layNum{}; + double m_min{-20. *Gaudi::Units::m}; + double m_max{20. * Gaudi::Units::m}; }; } #endif \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx index a262627223a2..5ea583274abb 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx @@ -50,7 +50,11 @@ namespace MuonML { }), std::make_unique<NodeFeature>("stationIndex", [](const Bucket_t& bucket, size_t index) { - return bucket[index]-> + // use idHelperSvc to get stationName + auto idHelper = bucket[index]->msSector()->idHelperSvc(); + //auto idHelper = sector->idHelperSvc(); + auto stationIdx = idHelper->stationName(bucket[index]->identify()); + return stationIdx; }), std::make_unique<NodeFeature>("driftR", [](const Bucket_t& bucket, size_t index) { -- GitLab From a527a36844517b5235ce2127a74fd69786749eab Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Mon, 24 Feb 2025 16:06:32 +0100 Subject: [PATCH 7/9] Adding station index --- .../src/NodeFeatureFactory.cxx | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx index 5ea583274abb..44ca8839cd16 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx @@ -50,11 +50,15 @@ namespace MuonML { }), std::make_unique<NodeFeature>("stationIndex", [](const Bucket_t& bucket, size_t index) { - // use idHelperSvc to get stationName - auto idHelper = bucket[index]->msSector()->idHelperSvc(); - //auto idHelper = sector->idHelperSvc(); - auto stationIdx = idHelper->stationName(bucket[index]->identify()); - return stationIdx; + return bucket[index]->msSector()->idHelperSvc()->stationName(bucket[index]->identify()); + }), + std::make_unique<NodeFeature>("stationPhi", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->msSector()->idHelperSvc()->stationPhi(bucket[index]->identify()); + }), + std::make_unique<NodeFeature>("stationEta", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->msSector()->idHelperSvc()->stationEta(bucket[index]->identify()); }), std::make_unique<NodeFeature>("driftR", [](const Bucket_t& bucket, size_t index) { @@ -77,7 +81,6 @@ namespace MuonML { [](const Bucket_t& bucket, size_t index) { return bucket.size() / (bucket.coveredMax() - bucket.coveredMin()); // missing max and min of bucket }), - // can I add isolation already? std::make_unique<NodeFeature>("isolation", [](const Bucket_t& bucket, size_t index) { unsigned int neighbors = 0; @@ -85,9 +88,21 @@ namespace MuonML { for (size_t other =0 ; other < bucket.size(); ++ other){ neighbors+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2; } - float bucket_density = bucket.size() / (bucket.coveredMax() - bucket.coveredMin()); + float bucket_density = float(bucket.size()) / (bucket.coveredMax() - bucket.coveredMin()); return neighbors / bucket_density; }), + std::make_unique<NodeFeature>("covX", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->covariance()(Amg::x, Amg::x); + }), + std::make_unique<NodeFeature>("covY", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->covariance()(Amg::y, Amg::y); + }), + std::make_unique<NodeFeature>("covXY", + [](const Bucket_t& bucket, size_t index) { + return bucket[index]->covariance()(Amg::x, Amg::y); + }), }; const auto feat_itr = featurePool.find(featName); if(feat_itr != featurePool.end()){ -- GitLab From ed71f9ad26ea3d8ebc5aee392c4835a10df2e53f Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Tue, 25 Feb 2025 13:49:32 +0100 Subject: [PATCH 8/9] Fixing variables name --- .../MuonInferenceInterfaces/src/NodeFeatureFactory.cxx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx index 44ca8839cd16..f326370fa3c0 100644 --- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx @@ -64,7 +64,7 @@ namespace MuonML { [](const Bucket_t& bucket, size_t index) { return bucket[index]->driftRadius(); }), - std::make_unique<NodeFeature>("relativeLayer", + std::make_unique<NodeFeature>("relative_layer", [](const Bucket_t& bucket, size_t index) { return bucket.layerNum(index) / (bucket.nStripLayers() + bucket.nMdtLayers()); }), @@ -77,7 +77,7 @@ namespace MuonML { } return n; }), - std::make_unique<NodeFeature>("density", + std::make_unique<NodeFeature>("bucket_density", [](const Bucket_t& bucket, size_t index) { return bucket.size() / (bucket.coveredMax() - bucket.coveredMin()); // missing max and min of bucket }), -- GitLab From bfbfb3acb1ebbde07dfca4672954b923b7b7f36a Mon Sep 17 00:00:00 2001 From: Davide Di Croce <davide.di.croce@cern.ch> Date: Wed, 26 Feb 2025 15:39:56 +0100 Subject: [PATCH 9/9] aInference module --- .../src/GraphBucketFilterTool.cxx | 31 +++++++++++++++++++ .../MuonInference/src/GraphBucketFilterTool.h | 26 ++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx new file mode 100644 index 000000000000..b69522b9269f --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx @@ -0,0 +1,31 @@ + + +#include "GraphBucketFilterTool.h" + +#include "StoreGate/ReadHandle.h" +#include "StoreGate/WriteHandle.h" +namesspace MuonML{ + StatusCode GraphBucketFilterTool::runGraphInference(const EventContext& ctx, + GraphRawData& graph) const { + + ATH_CHECK(buildGraph(ctx, graph)); + SG::WriteHandle filteredSpacePoints{m_writeKey, ctx}; + ATH_CHECK(filteredSpacePoints.record(std::make_unique<MuonR4::SpacePointContainer>())); + //// Inference part + + SG::ReadHandle inputSpacePoints{m_readKey, ctx}; + ATH_CHECK(inputSpacePoints.isPresent()); + for (const MuonR4::SpacePointBucket* bucket : *inputSpacePoints) { + std::unique_ptr<MuonR4::SpacePointBucket> filteredBucket = std::make_unique<MuonR4::SpacePointBucket>(); + filteredSpacePoints->push_back(std::move(filteredBucket)); + } + return StatusCode::SUCCESS; + } + + + StatusCode GraphBucketFilterTool::initialize() { + ATH_CHECK(setupModel()); + ATH_CHECK(m_writeKey.initialize()); + return StatusCode::SUCCESS; + } +} \ No newline at end of file diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h new file mode 100644 index 000000000000..d03d97f57a5f --- /dev/null +++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h @@ -0,0 +1,26 @@ +#ifndef MUONINFERENCE_GRAPHBUCKETFILTERTOOL_H +#define MUONINFERENCE_GRAPHBUCKETFILTERTOOL_H + +#include "GraphInferenceToolBase.h" +#include "StoreGate/WriteHandleKey.h" + +#include "MuonSpacePoint/SpacePointContainer.h" +namespace MuonML{ + class GraphBucketFilterTool : public GraphInferenceToolBase { + public: + using GraphInferenceToolBase::GraphInferenceToolBase; + + virtual StatusCode runGraphInference(const EventContext& ctx, + GraphRawData& graph) const override final; + + + virtual StatusCode initialize() override final; + private: + + SG::WriteHandleKey<MuonR4::SpacePointContainer> m_writeKey{this, "WriteSpacePointKey", "FilteredMlSpacePoints"}; + + + } +} + +#endif \ No newline at end of file -- GitLab