From 335f3bb4edb02f73de50e2ff3da17fedb8d570f5 Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Tue, 18 Feb 2025 11:30:43 +0100
Subject: [PATCH 1/9] First inference infrastructure

---
 .../MuonBucketGraph/FCGraphMaker.cpp          |  10 +
 .../MuonBucketGraph/FCGraphMaker.h            |  21 +
 .../MuonBucketGraph/GraphFramer.h             |  14 +
 .../MuonLearning/MuonInference/CMakeLists.txt |  16 +
 .../MuonInference/InferenceInterface.cpp      |  10 +
 .../MuonInference/InferenceInterface.h        |  17 +
 .../MuonInference/python/InferenceConfig.py   |  11 +
 .../src/GraphInferenceToolBase.cxx            | 123 ++++++
 .../src/GraphInferenceToolBase.h              |  55 +++
 .../MuonInference/src/InferenceAlg.cxx        |  24 ++
 .../MuonInference/src/InferenceAlg.h          |  25 ++
 .../src/components/MuonInference_entries.cxx  |   8 +
 .../MuonInferenceInterfaces/CMakeLists.txt    |  17 +
 .../MuonInferenceInterfaces/GraphData.h       |  50 +++
 .../IGraphInferenceTool.h                     |  22 +
 .../MuonInferenceInterfaces/LayerBucket.h     |  35 ++
 .../MuonInferenceInterfaces/NodeConnector.h   |  50 +++
 .../MuonInferenceInterfaces/NodeFeature.h     |  50 +++
 .../NodeFeatureFactory.h                      |  26 ++
 .../MuonInferenceInterfaces/NodeFeatureList.h |  63 +++
 .../src/LayerBucket.cxx                       |  32 ++
 .../src/NodeFeatureFactory.cxx                |  91 ++++
 .../src/NodeFeatureList.cxx                   |  94 +++++
 .../MuonLearning/MuonSPId/CMakeLists.txt      |  24 ++
 .../MuonSPId/python/MuonSPIdDumpConfig.py     |  17 +
 .../MuonSPId/python/muonSPIdDump.py           |  43 ++
 .../MuonSPId/python/muonSPIdDump_data.py      |  38 ++
 .../MuonSPId/src/SPIdDumperAlg.cxx            | 399 ++++++++++++++++++
 .../MuonLearning/MuonSPId/src/SPIdDumperAlg.h |  95 +++++
 .../MuonSPId/src/SPIdentifierAlg.cxx          | 340 +++++++++++++++
 .../MuonSPId/src/SPIdentifierAlg.h            |  91 ++++
 .../src/components/MuonSPId_entries.cxx       |   7 +
 32 files changed, 1918 insertions(+)
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp
new file mode 100644
index 000000000000..becdc324c032
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.cpp
@@ -0,0 +1,10 @@
+#include "FCGraphMaker.h"
+
+FCGraphMaker::FCGraphMaker(const SpacePointBucket* bucket, const GraphFramer& graphType, const IMuonIdHelperSvc* idHelperSvc)
+    : m_bucket(bucket), m_graphType(graphType), m_idHelperSvc(idHelperSvc) {}
+
+void FCGraphMaker::extractGraphData(std::vector<float>& features, std::vector<int64_t>& edge_src, std::vector<int64_t>& edge_dst) {
+	// implementation
+
+}
+
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h
new file mode 100644
index 000000000000..f3924b6876d3
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/FCGraphMaker.h
@@ -0,0 +1,21 @@
+#ifndef FC_GRAPH_MAKER_H
+#define FC_GRAPH_MAKER_H
+
+#include "GraphFramer.h"
+#include "MuonIdHelpers/IMuonIdHelperSvc.h"
+#include <vector>
+#include <cstdint>
+
+class SpacePointBucket;
+
+class FCGraphMaker {
+public:
+    FCGraphMaker(const SpacePointBucket* bucket, const GraphFramer& graphType, const IMuonIdHelperSvc* idHelperSvc);
+    void extractGraphData(std::vector<float>& features, std::vector<int64_t>& edge_src, std::vector<int64_t>& edge_dst);
+private:
+    const SpacePointBucket* m_bucket;
+    GraphFramer m_graphType;
+    const IMuonIdHelperSvc* m_idHelperSvc;
+};
+
+#endif // FC_GRAPH_MAKER_H
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h
new file mode 100644
index 000000000000..acf92d271a10
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonBucketGraph/GraphFramer.h
@@ -0,0 +1,14 @@
+#ifndef GRAPH_FRAMER_H
+#define GRAPH_FRAMER_H
+
+class GraphFramer {
+public:
+    GraphFramer(bool sparse, bool classification);
+    bool isSparse() const;
+    bool isClassification() const;
+private:
+    bool m_sparse;
+    bool m_classification;
+};
+
+#endif // GRAPH_FRAMER_H
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt
new file mode 100644
index 000000000000..c78436b8cd0b
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+
+################################################################################
+# Package: MuonInferenceInterfaces
+################################################################################
+
+# Declare the package name:
+atlas_subdir( MuonInference )
+
+
+find_package( onnxruntime REQUIRED)
+
+atlas_add_component( MuonInference
+                     src/components/*.cxx src/*.cxx 
+                     LINK_LIBRARIES AthenaKernel StoreGateLib  MuonInferenceInterfaces PathResolver
+                                    ${ONNXRUNTIME_LIBRARIES})
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp
new file mode 100644
index 000000000000..3115feae747a
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.cpp
@@ -0,0 +1,10 @@
+#include "InferenceInterface.h"
+
+InferenceInterface::InferenceInterface(const std::string& model_path) 
+    : m_env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference"), 
+      m_session(m_env, model_path.c_str(), Ort::SessionOptions{}) {}
+
+std::vector<float> InferenceInterface::runInference(const std::vector<float>& features, const std::vector<int64_t>& edge_index) {
+    // Implement ONNX model inference logic
+    return {}; // Placeholder return
+}
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h
new file mode 100644
index 000000000000..f86138bd8a7b
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/InferenceInterface.h
@@ -0,0 +1,17 @@
+#ifndef INFERENCE_INTERFACE_H
+#define INFERENCE_INTERFACE_H
+
+#include <vector>
+#include <string>
+#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
+
+class InferenceInterface {
+public:
+    InferenceInterface(const std::string& model_path);
+    std::vector<float> runInference(const std::vector<float>& features, const std::vector<int64_t>& edge_index);
+private:
+    Ort::Session m_session;
+    Ort::Env m_env;
+};
+
+#endif // INFERENCE_INTERFACE_H
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py
new file mode 100644
index 000000000000..bee02114be11
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/python/InferenceConfig.py
@@ -0,0 +1,11 @@
+# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+
+def GraphAssemblyToolCfg(flags, name = "GraphAssemblyTool", **kwargs):
+    result = ComponentAccumulator()
+    the_tool = CompFactory.MuonML.GraphAssemblyTool(name, **kwargs)
+    result.setPrivateTools(the_tool)
+    return result
+
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx
new file mode 100644
index 000000000000..51e666263121
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx
@@ -0,0 +1,123 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#include "GraphInferenceToolBase.h"
+
+#include "MuonInferenceInterfaces/GraphData.h"
+#include "MuonInferenceInterfaces/NodeFeatureList.h"
+#include "MuonPatternHelpers/MatrixUtils.h"
+
+
+#include "PathResolver/PathResolver.h"
+namespace MuonML{
+    StatusCode GraphInferenceToolBase::setupModel() {
+        const std::string  modelPath = PathResolver::FindCalibFile(m_modelPath);
+        if (modelPath.empty()) {
+            ATH_MSG_FATAL("No such file or directory "<<m_modelPath<<".");
+            return StatusCode::FAILURE;
+        }
+        try {
+            Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference");
+            Ort::SessionOptions session_options;
+            session_options.SetIntraOpNumThreads(1);
+            session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+
+            m_model = std::make_unique<Ort::Session>(env, modelPath.c_str(), session_options);
+            ATH_MSG_DEBUG("Successfully loaded infernce model from "<<modelPath);
+        
+            
+            Ort::ModelMetadata metadata = m_model->GetModelMetadata();
+            Ort::AllocatorWithDefaultOptions allocator;
+            Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
+            
+            if (feature_json_ptr) {
+                std::string feature_json = feature_json_ptr.get();
+                nlohmann::json json_obj = nlohmann::json::parse(feature_json);
+                for (const auto& feature : json_obj) {
+                    m_graphFeatures.addFeature(feature.get<std::string>(), msgStream());
+                }
+            }
+        } catch (const std::exception& e) {
+            ATH_MSG_ERROR("Failed to retrieve feature from ONNX model: " << e.what());
+            return StatusCode::FAILURE;
+        }
+
+        if (!m_graphFeatures.isValid()) {
+            ATH_MSG_FATAL("No graph features have been parsed. Please check the model");
+            return StatusCode::FAILURE;
+        }
+        ATH_CHECK(m_spacePointKey.initialize());
+        return StatusCode::SUCCESS;
+    }
+    StatusCode GraphInferenceToolBase::buildGraph(const EventContext& ctx,
+                                                  GraphRawData& graphData) const {
+
+        /** Check whether the graph needs a rebuild */
+        if ((*graphData.previousList) != m_graphFeatures) {
+            graphData.graph.reset();
+        }
+        /** Don't launch the rebuild of the graph */
+        if (graphData.graph) {
+            return StatusCode::SUCCESS;
+        }        
+        if (!m_graphFeatures.isValid()) {
+            ATH_MSG_ERROR("The feature list is in complete. Either it has no features or no node connector set");
+            return StatusCode::FAILURE;
+        }
+    
+        graphData.graph = std::make_unique<InferenceGraph>();
+    
+        SG::ReadHandle spacePoints{m_spacePointKey, ctx};
+        ATH_CHECK(spacePoints.isPresent());
+    
+        int64_t nNodes{0}, possConn{0};
+        graphData.spacePointsInBucket.clear();
+        graphData.spacePointsInBucket.reserve(spacePoints->size());
+
+        for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
+            nNodes += graphData.spacePointsInBucket.emplace_back(bucket->size());
+            possConn += MuonR4::sumUp(graphData.spacePointsInBucket.back());
+        }
+        graphData.nodeIndex = 0;
+        graphData.featureLeaves.resize(nNodes * m_graphFeatures.numFeatures());
+        graphData.currLeave = graphData.featureLeaves.begin();
+        
+
+        graphData.srcEdges.reserve(possConn);
+        graphData.desEdges.reserve(possConn);
+    
+    
+        /** Fill the graph edge features and all their respective connections */
+        for (const MuonR4::SpacePointBucket* bucket : *spacePoints) {
+            const LayerSpBucket mlBucket{*bucket};
+            m_graphFeatures.fillInData(mlBucket, graphData);
+        }
+
+        graphData.srcEdges.insert(graphData.srcEdges.end(), std::make_move_iterator(graphData.desEdges.begin()),
+                                 std::make_move_iterator(graphData.desEdges.end()));
+
+        std::vector<int64_t> featShape{nNodes, static_cast<int64_t>(m_graphFeatures.numFeatures())};  // (N, 8)
+        std::vector<int64_t> edgeShape{2, static_cast<int64_t>(graphData.srcEdges.size() / 2)};  // (2, E)
+
+        Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+
+        
+        graphData.graph->dataTensor.emplace_back(Ort::Value::CreateTensor<float>(memInfo, 
+                                                    graphData.featureLeaves.data(), graphData.featureLeaves.size(), 
+                                                    featShape.data(), featShape.size()));
+
+
+
+        Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(memInfo, graphData.srcEdges.data(), 
+                                                                   graphData.srcEdges.size(), edgeShape.data(), edgeShape.size());
+                                
+        
+        graphData.previousList = &m_graphFeatures;
+
+        graphData.srcEdges.clear();
+        graphData.desEdges.clear();
+        graphData.featureLeaves.clear();
+        return StatusCode::SUCCESS;
+    }
+    
+}
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h
new file mode 100644
index 000000000000..f27ed78ee67e
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.h
@@ -0,0 +1,55 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCETOOLS_GRAPHINFERENCETOOL_H
+#define MUONINFERENCETOOLS_GRAPHINFERENCETOOL_H
+
+#include "MuonInferenceInterfaces/IGraphInferenceTool.h"
+#include "MuonInferenceInterfaces/NodeFeatureList.h"
+#include "MuonInferenceInterfaces/GraphData.h"
+
+
+#include "MuonSpacePoint/SpacePointContainer.h"
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "StoreGate/ReadHandleKey.h"
+#include <onnxruntime_cxx_api.h> // is this somewhere else?
+#include "nlohmann/json.hpp"
+
+namespace MuonML{
+    /** @brief Baseline tool to handle the  */
+    class GraphInferenceToolBase : public extends<AthAlgTool, IGraphInferenceTool> {
+        public:
+            /** @brief Keep the constructor of the parent class */
+            using base_class::base_class;
+            /** @brief Fill up the GraphRawData and construct the graph for the ML inference with
+             *         ONNX. If the graph has been built by another inference tool and would be the 
+             *         same than this one the rebuild is skipped 
+             *  @param ctx: EventContext to access the space ponit container from StoreGate
+             *  @param graphData: Rerference to the data object to be filled. */
+            StatusCode buildGraph(const EventContext& ctx,
+                                  GraphRawData& graphData) const;
+        
+
+        protected:
+            StatusCode setupModel();
+            
+            const Ort::Session* model() const;
+        private:
+            
+            // input space points from SG
+            SG::ReadHandleKey<MuonR4::SpacePointContainer> m_spacePointKey{this, "SpacePointContainer", "MuonSpacePoints"};
+
+            /** @brief Location of the model file */
+            Gaudi::Property<std::string> m_modelPath{this, "ModelPath", ""};
+
+            /** @brief List of features to be used for the inference */
+            NodeFeatureList m_graphFeatures{};
+            /** @brief Pointer to the ONNX model session */
+            std::unique_ptr<Ort::Session> m_model{};
+
+    }; 
+
+}
+
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx
new file mode 100644
index 000000000000..63b9b1520f79
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.cxx
@@ -0,0 +1,24 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#include "InferenceAlg.h"
+
+#include "MuonInferenceInterfaces/GraphData.h"
+
+namespace MuonML{
+    StatusCode InferenceAlg::initialize() {
+        if (m_inferenceTools.empty()) {
+            ATH_MSG_ERROR("Provide at least one inference tool");
+            return StatusCode::FAILURE;
+        }
+        ATH_CHECK(m_inferenceTools.retrieve());
+        return StatusCode::SUCCESS;
+    }
+    StatusCode InferenceAlg::execute(const EventContext& ctx) const {
+        GraphRawData graphData{};
+        for (const auto& infTool : m_inferenceTools) {
+            ATH_CHECK(infTool->runGraphInference(ctx, graphData));
+        }
+        return StatusCode::SUCCESS;
+    }
+}
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h
new file mode 100644
index 000000000000..1826c20ce1c7
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/InferenceAlg.h
@@ -0,0 +1,25 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCETOOLS_INFERENCEALG_H
+#define MUONINFERENCETOOLS_INFERENCEALG_H
+
+#include "AthenaBaseComps/AthReentrantAlgorithm.h"
+#include "MuonInferenceInterfaces/IGraphInferenceTool.h"
+
+namespace MuonML{
+    class InferenceAlg : public AthReentrantAlgorithm {
+        public:
+            using AthReentrantAlgorithm::AthReentrantAlgorithm;
+
+            virtual StatusCode initialize() override final;
+            virtual StatusCode execute(const EventContext& ctx) const override final;
+        private:
+            ToolHandleArray<IGraphInferenceTool> m_inferenceTools{this, "InferenceTools", {},
+                                                                  "List of machine learning inference tools to be processed with the same graph"};
+    };
+}
+
+
+
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx
new file mode 100644
index 000000000000..b51eafe64be1
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/components/MuonInference_entries.cxx
@@ -0,0 +1,8 @@
+
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "../InferenceAlg.h"
+
+DECLARE_COMPONENT(MuonML::InferenceAlg)
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt
new file mode 100644
index 000000000000..439004a9651a
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+
+################################################################################
+# Package: MuonInferenceInterfaces
+################################################################################
+
+# Declare the package name:
+atlas_subdir( MuonInferenceInterfaces )
+
+find_package( onnxruntime REQUIRED)
+
+atlas_add_library( MuonInferenceInterfaces
+                   src/*.cxx
+                   PUBLIC_HEADERS MuonInferenceInterfaces
+                   LINK_LIBRARIES MuonReadoutGeometryR4 AthenaBaseComps MuonPatternHelpers
+                                  MuonSpacePoint MuonPatternEvent ${ONNXRUNTIME_LIBRARIES})
+
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h
new file mode 100644
index 000000000000..585892f55ef8
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/GraphData.h
@@ -0,0 +1,50 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCEINTERACES_GRAPHNODE_H
+#define MUONINFERENCEINTERACES_GRAPHNODE_H
+
+#include <onnxruntime_cxx_api.h>
+#include <vector>
+#include <memory>
+
+namespace MuonML{
+    class NodeFeatureList;
+        
+    /** @brief Helper struct containing all the information needed to process  */
+    struct InferenceGraph {
+        /** @brief Vector of the inference input tensors */
+        std::vector<Ort::Value> dataTensor{};
+        /** @brief Vector of the input names. The raw char data
+         *         is neither owned or managed by the object */
+        std::vector<const char*> nameTensor{};
+     };
+    
+    /** @brief Helper struct to ship the Graph from the space point buckets 
+     *         to ONNX */
+    struct GraphRawData {
+        using FeatureVec_t = std::vector<float>;
+        using NodeConnectVec_t = std::vector<int64_t>;
+        using EdgeCounterVec_t = std::vector<int64_t>;
+        /** @brief Vector containing all features */
+        FeatureVec_t featureLeaves{};
+        /** @brief Vector encoding the source index of the */
+        NodeConnectVec_t srcEdges{};
+        /** @brief Vect  */
+        NodeConnectVec_t desEdges{};
+        /** @brief  Vector keeping track of how many space points are in each parsed bucket */
+        EdgeCounterVec_t spacePointsInBucket{};
+        /** @brief Pointer to the latest parsed NodeFeatureList */
+        const NodeFeatureList* previousList{};
+        /** @brief Pointer to the graph to be parsed to ONNX */
+        std::unique_ptr<InferenceGraph> graph{};
+        
+        /** @brief The following variables are needed to fill the consistently the raw data 
+         *         for the Graph Building*/
+        std::vector<float>::iterator currLeave{featureLeaves.begin()};
+        /** @brief Number of the already filled nodes */
+        unsigned int nodeIndex{0};
+    };
+}
+
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h
new file mode 100644
index 000000000000..fc5243780177
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/IGraphInferenceTool.h
@@ -0,0 +1,22 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCEINTERFACES_IGRAPHINFERENCETOOL_H
+#define MUONINFERENCEINTERFACES_IGRAPHINFERENCETOOL_H
+
+#include "GaudiKernel/IAlgTool.h"
+#include "GaudiKernel/EventContext.h"
+namespace MuonML {
+    struct GraphRawData;
+    class IGraphInferenceTool: virtual public IAlgTool {
+        public:
+            /** @brief Empty desctructor */
+            virtual ~IGraphInferenceTool() = default;
+            /** @brief Declaration of the interface */
+            DeclareInterfaceID(IGraphInferenceTool, 1, 0);
+
+            virtual StatusCode runGraphInference(const EventContext& ctx,
+                                                 GraphRawData& graph) const = 0;
+    };
+}
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h
new file mode 100644
index 000000000000..bb083b0e3a5a
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h
@@ -0,0 +1,35 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCEINTERACES_LAYERBUCKET_H
+#define MUONINFERENCEINTERACES_LAYERBUCKET_H
+
+#include "MuonSpacePoint/SpacePointContainer.h"
+namespace MuonML{
+    /** @brief The LayerSpBucket is a space pointbucket where the points are internally 
+     *         sorted by their layer number as defined in the SpacePointLayerSorter. The
+     *         bucket also provides the layer number tag for each space point & the number of total layers
+      */
+    class LayerSpBucket : public std::vector<const MuonR4::SpacePoint*> {
+        public:
+          /** @brief Standard constructor taking the space point bucket */  
+          LayerSpBucket(const MuonR4::SpacePointBucket& bucket);
+          /** @brief Returns how many Mdt layers are inside the bucket */
+          uint8_t nMdtLayers() const {
+              return m_nMdtLay;
+          }
+          /** @brief Returns how many Strip layers are inside the bucket */
+          uint8_t nStripLayers() const {
+              return m_nStripLay;
+          }
+          /** @brief Returns the associated layer number of the i-the space point inside the bucket */
+          uint8_t layerNum(const size_t i) const {
+            return m_layNum[i];
+          }
+        private:
+            uint8_t m_nMdtLay{0};
+            uint8_t m_nStripLay{0};
+            std::vector<uint8_t> m_layNum{};
+    };
+}
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h
new file mode 100644
index 000000000000..2861aa4908ec
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeConnector.h
@@ -0,0 +1,50 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCEINTERACES_NODECONNECTOR_H
+#define MUONINFERENCEINTERACES_NODECONNECTOR_H
+
+#include <string>
+#include <functional>
+
+#include <MuonInferenceInterfaces/NodeFeature.h>
+namespace MuonML{
+    /** @brief The NodeConnector is indicating whether two space points inside a bucket,
+     *         the graph nodes, shall have a connection in their graph neural net representation.
+     *         In short terms, the node connector is std::function with the bucket & the two indices
+     *         to the space point inside as input arguments and then returning a boolean decision 
+     *         whether the connection shall be built.
+     *     In order, to verify that two instances of the neural net graph are identical,
+     *     the node connector also has a name as attribute. */
+    class NodeConnector {
+        public:
+            using Bucket_t = NodeFeature::Bucket_t;
+            /** @brief Function type to connect two space points in a bucket. The signature
+             *         takes the reference to the bucket and then the two indices of the space
+             *         points which shall be connected. The function needs to return true or false */
+            using Evaluator_t = std::function<bool(const Bucket_t&, size_t, size_t)>;
+            /** @brief Standard constructor taking the name  of the node connector & 
+             *         a connector function definition
+             *  @brief cName: Name of the connector function
+             *  @brief conFunc: Definition of the connector function */
+            NodeConnector(const std::string& cName, const Evaluator_t conFunc):
+                            m_name{cName}, m_func{conFunc} {} 
+            /** @brief Returns the name of the node connector */
+            const std::string& name() const {
+                return m_name;
+            }
+            /** @brief returns the decision of the connector function
+             *  @param bucket: Reference to the space point bucket
+             *  @param i: Index of the first space point to connect
+             *  @param j: Index of the second space point to connect */
+            bool connect(const Bucket_t& bucket, size_t i, size_t j) const {
+                return m_func(bucket, i, j);
+            }
+
+        private:
+            std::string m_name{};
+            Evaluator_t m_func{[](const Bucket_t, size_t, size_t) { return false; }};
+
+    };
+}
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h
new file mode 100644
index 000000000000..9356dc94a4dd
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeature.h
@@ -0,0 +1,50 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCEINTERACES_GRAPHFEATURE_H
+#define MUONINFERENCEINTERACES_GRAPHFEATURE_H
+
+
+#include <functional>
+#include <string>
+
+#include "MuonInferenceInterfaces/LayerBucket.h"
+
+namespace MuonML {
+    /** @brief The NodeFeature is the gluing instance to extract the information from the space point
+     *         inside a MuonBucket and then to parse it to the ML inference framework. */
+    class NodeFeature {
+        public:
+            /** @brief Abreviation of the Space point bucket type */
+            using Bucket_t = LayerSpBucket;
+
+            /** @brief Lambda function type to extract the feature from a bucket.
+             *  @arg: Bucket_t of interest
+             *  @arg: size_t Index of the space point to extract the bucket from */
+            using Func_t = std::function<double(const Bucket_t&, size_t)>;
+
+            /** @brief Standard constructor to build a feature
+             *  @param featName: Name of the feature used to distinguish whether 
+             *                   two feature lists are the same */
+            NodeFeature(const std::string& featName,
+                        const Func_t& extractFunc):
+                m_name{featName}, m_func{extractFunc}{}
+            /** @brief Standard move assignment & move constructor */
+            NodeFeature(NodeFeature&& other) = default;
+
+            /** @brief Returns the feature name */
+            const std::string& name() const {
+                return m_name;
+            }
+            /** @brief Extract the feature from a space point inside the bucket
+             *  @param bucket: Reference to the space point bucket of interest
+             *  @param spIndex: Index of the space point to extract the feature from */
+            double eval(const Bucket_t& bucket, size_t spIndex) const {
+                return m_func(bucket, spIndex);
+            }
+        private:
+            std::string m_name{};
+            Func_t m_func{[](const Bucket_t& , size_t) { return 0.; }};
+    };   
+}
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h
new file mode 100644
index 000000000000..42980dc4bc51
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureFactory.h
@@ -0,0 +1,26 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCEINTERACES_GRAPHFEATUREFACTORY_H
+#define MUONINFERENCEINTERACES_GRAPHFEATUREFACTORY_H
+
+#include "MuonInferenceInterfaces/NodeFeatureList.h"
+
+class MsgStream;
+
+namespace MuonML {
+    namespace Factory {
+      /** @brief Factory function that builds a NodeFeature from a predefined list of features
+       *  @param featName: Name of the feature inside the list
+       *  @param log: Refetrence to the message object for logging */
+      NodeFeatureList::Feature_t makeFeature(const std::string& featName, MsgStream& log);
+      /** @brief Factory function that builds a connector relation between two edges in the bucket.
+       *  @param connName: Name of the connection function inside the predefined list
+       *  @param log: Refetrence to the message object for logging */
+      NodeFeatureList::Connector_t makeConnector(const std::string& connName, MsgStream& log);
+
+
+    }
+}
+
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h
new file mode 100644
index 000000000000..96b4c08cb7b6
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/NodeFeatureList.h
@@ -0,0 +1,63 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#ifndef MUONINFERENCEINTERACES_GRAPHFEATURELIST_H
+#define MUONINFERENCEINTERACES_GRAPHFEATURELIST_H
+
+#include "MuonInferenceInterfaces/NodeFeature.h"
+#include "MuonInferenceInterfaces/NodeConnector.h"
+
+class MsgStream;
+
+namespace MuonML {
+    struct GraphRawData;
+    class NodeFeatureList {
+        public:
+            
+            using Feature_t = std::shared_ptr<const NodeFeature>;
+            using Connector_t = std::shared_ptr<const NodeConnector>;
+
+            using Bucket_t = NodeFeature::Bucket_t;
+            /** @brief Empty standard constructor */
+            NodeFeatureList() = default;
+            /** @brief Returns true if the features have pairwise
+             *         the same name */
+            bool operator==(const NodeFeatureList& other) const;
+            /** @brief Returns whether the NodeFeatureList is complete,
+             *         i.e. it must have at least one feature and the node
+             *         connector */
+            bool isValid() const;
+            /** @brief Returns the number of features in the list */
+            size_t numFeatures() const;
+            /** @brief Returns the name of the features in the list */
+            std::vector<std::string> featureNames() const;
+            /** @brief  */
+            void fillInData(const Bucket_t& bucket, 
+                            GraphRawData& graphData) const;
+            /** @brief Tries to add a new feature to the list using the predefined
+             *          list of features in the GraphFeatureFactory
+             *  @param featName: Name of the feature in the factory
+             *  @param msg: Reference to the message stream object for logging. */
+            bool addFeature(const std::string& featName, MsgStream& msg);
+            /** @brief Tries to add a particular feature to the list.
+             *  @param featPtr: Pointer to the instantiated feature
+             *  @param msg: Reference to the message stream object for logging. */
+            bool addFeature(const Feature_t& featPtr,  MsgStream& msg);
+
+            /** @brief Tries to set the graph connector based on the connector name.
+             *  @param conName: Name of the connector to extract from the factory
+             *  @param msg: Reference to the message stream object for logging. */
+            bool setConnector(const std::string& conName, MsgStream& msg);
+            /** @brief Sets the  */
+            void setConnector(const std::string& conName, NodeConnector::Evaluator_t evalFunc);
+
+            
+        private:
+            std::vector<Feature_t> m_features{};
+            Connector_t m_connector{};
+            
+    };
+}
+
+
+#endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx
new file mode 100644
index 000000000000..f41b9a82084f
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/LayerBucket.cxx
@@ -0,0 +1,32 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#include "MuonInferenceInterfaces/LayerBucket.h"
+
+#include "MuonSpacePoint/SpacePointPerLayerSorter.h"
+#include "Acts/Utilities/Enumerate.hpp"
+namespace MuonML{
+    using HitVec = MuonR4::SpacePointPerLayerSorter::HitVec;
+    LayerSpBucket::LayerSpBucket(const MuonR4::SpacePointBucket& bucket) {
+        reserve(bucket.size());
+        m_layNum.resize(bucket.size());
+        const MuonR4::SpacePointPerLayerSorter sorter{bucket};
+        uint8_t globLayer{0};
+        m_nMdtLay = sorter.mdtHits().size();
+        m_nStripLay = sorter.stripHits().size();
+        for (const HitVec& spInLay : sorter.mdtHits()) {
+            for (const auto& [idx, spacePoint] : Acts::enumerate(spInLay)) {
+                push_back(spacePoint);
+                m_layNum[idx] = globLayer;
+            }
+            ++globLayer;
+        }
+        for (const HitVec& spInLay : sorter.stripHits()) {
+            for (const auto& [idx, spacePoint] : Acts::enumerate(spInLay)) {
+                push_back(spacePoint);
+                m_layNum[idx] = globLayer;
+            }
+            ++globLayer;
+        }
+    }
+}
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
new file mode 100644
index 000000000000..daf78ffe3330
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
@@ -0,0 +1,91 @@
+/*
+Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#include "MuonInferenceInterfaces/NodeFeatureFactory.h"
+#include "AthenaBaseComps/AthMessaging.h"
+
+#include <set>
+
+namespace MuonML {
+
+    using Feature_t = NodeFeatureList::Feature_t;
+    using Connector_t = NodeFeatureList::Connector_t;
+    using Bucket_t = NodeFeature::Bucket_t;
+
+    bool operator<(const std::string& a, const Feature_t & b) {
+        return a < b->name();
+    }
+    bool operator<( const Feature_t & a, const std::string& b) {
+        return a->name() < b;
+    }
+    bool operator<(const Feature_t& a, const Feature_t & b) {
+        return a->name() < b->name();
+    }
+
+    bool operator<(const std::string& a, const Connector_t & b) {
+        return a < b->name();
+    }
+    bool operator<( const Connector_t & a, const std::string& b) {
+        return a->name() < b;
+    }
+    bool operator<(const Connector_t& a, const Connector_t & b) {
+        return a->name() < b->name();
+    }
+     
+    namespace Factory {
+        Feature_t makeFeature(const std::string& featName, MsgStream& log) {
+
+            /** Predefine the known features in the pool  */
+            static const std::set<Feature_t, std::less<>> featurePool{
+                std::make_unique<NodeFeature>("driftR", 
+                        [](const Bucket_t& bucket, size_t index) {
+                            return bucket[index]->driftRadius();
+                        }),
+                std::make_unique<NodeFeature>("localX", 
+                        [](const Bucket_t& bucket, size_t index) {
+                            return bucket[index]->positionInChamber().x();
+                        }),
+            }; 
+            const auto feat_itr = featurePool.find(featName);
+            if(feat_itr != featurePool.end()){
+                if (log.level() <= MSG::DEBUG) {
+                    log<<MSG::DEBUG<<"Found graph feature "<<featName<<"."<<endmsg;
+                }
+                return *feat_itr;
+            }
+            std::stringstream available{};
+            for (const Feature_t& known : featurePool) {
+                available<<known->name()<<", ";
+            }
+            log<<MSG::ERROR<<"The feature "<<featName<<" is unknown to the feature factory. "
+               <<" Please check for typos w.r.t "<<available.str()<<". Otherwise augment "
+               <<__FILE__<<" with your desired feature "<<endmsg;
+            return nullptr;
+        }
+        NodeFeatureList::Connector_t makeConnector(const std::string& connName, MsgStream& log) {
+
+            static const std::set<Connector_t, std::less<>> connectorPool{
+                std::make_unique<NodeConnector>("fullyConnected", 
+                        [](const Bucket_t& , size_t , size_t ) {
+                            return true;
+                        }),
+            };
+            const auto feat_itr = connectorPool.find(connName);
+            if(feat_itr != connectorPool.end()){
+                if (log.level() <= MSG::DEBUG) {
+                    log<<MSG::DEBUG<<"Found graph connector "<<connName<<"."<<endmsg;
+                }
+                return *feat_itr;
+            }
+            std::stringstream available{};
+            for (const Connector_t& known : connectorPool) {
+                available<<known->name()<<", ";
+            }
+            log<<MSG::ERROR<<"The graph connector "<<connName<<" is unknown to the factory. "
+            <<" Please check for typos w.r.t "<<available.str()<<". Otherwise augment "
+            <<__FILE__<<" with your desired connection function. "<<endmsg;
+            return nullptr;
+        }
+    }
+
+}
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx
new file mode 100644
index 000000000000..91bbb536d321
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureList.cxx
@@ -0,0 +1,94 @@
+/*
+  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+*/
+#include "MuonInferenceInterfaces/NodeFeatureList.h"
+#include "MuonInferenceInterfaces/NodeFeatureFactory.h"
+#include "MuonInferenceInterfaces/GraphData.h"
+
+
+#include "AthenaBaseComps/AthMessaging.h"
+#include "Acts/Utilities/Enumerate.hpp"
+#include "MuonPatternHelpers/MatrixUtils.h"
+
+using namespace MuonR4;
+namespace MuonML {
+   bool NodeFeatureList::isValid() const {
+      return m_connector && numFeatures();
+   }
+   bool NodeFeatureList::setConnector(const std::string& conName, MsgStream& msg) {
+      m_connector = Factory::makeConnector(conName , msg);
+      return m_connector != nullptr;
+   }
+
+   void NodeFeatureList::setConnector(const std::string& conName, NodeConnector::Evaluator_t evalFunc) {
+      m_connector = std::make_unique<NodeConnector>(conName, evalFunc);
+   }
+   bool NodeFeatureList::operator==(const NodeFeatureList& other) const {
+      if (numFeatures() != other.numFeatures()) {
+         return false;
+      }
+      if (!m_connector || !other.m_connector || m_connector->name() != other.m_connector->name()) {
+         return false;
+      }
+      for (size_t f =0 ; f < numFeatures(); ++f) {
+         if (m_features[f]->name() != other.m_features[f]->name()) {
+            return false;
+         }
+      }
+      return true;
+   }
+   size_t NodeFeatureList::numFeatures() const {
+        return m_features.size();
+     }
+     /** @brief Returns the name of the features in the list */
+     std::vector<std::string> NodeFeatureList::featureNames() const {
+        std::vector<std::string> names{};
+        std::ranges::transform(m_features,std::back_inserter(names),
+                               [](const Feature_t& ft){ return ft->name();});
+        return names;
+     }
+     bool NodeFeatureList::addFeature(const std::string& featName, MsgStream& msg) {
+        return addFeature(Factory::makeFeature(featName, msg), msg);
+     }
+  
+     bool NodeFeatureList::addFeature(const Feature_t& newFeat,  MsgStream& msg) {
+       if (!newFeat) {
+          msg<<MSG::ERROR<<"No feature has been parsed. "<<endmsg;
+          return false;
+       }
+       if (std::ranges::find_if(m_features, [&newFeat](const Feature_t& known){
+           return known == newFeat || known->name() == newFeat->name();
+          }) != m_features.end()) {
+           msg<<MSG::ERROR<<" The feature "<<newFeat->name()<<" has already been added & "
+              <<" cannot be added again. "<<endmsg;
+          return false;
+       }
+       if (msg.level() <= MSG::DEBUG) {
+          msg<<MSG::DEBUG<<__FILE__<<":"<<__LINE__<<" - Add new feature "<< newFeat->name()<<endmsg;
+       }
+       return true;
+     }
+     void NodeFeatureList::fillInData(const Bucket_t& bucket, 
+                                      GraphRawData& prepGraph) const {
+         for (size_t sp = 0 ; sp < bucket.size(); ++sp) {
+            /** @brief Fill the graph features  */
+            for(const Feature_t& feat : m_features) {
+               assert(prepGraph.currLeave != prepGraph.featureLeaves.end());
+               (*prepGraph.currLeave++) = feat->eval(bucket, sp);
+            }
+            for (size_t ot = 0; ot < sp; ++ot) {
+               if (m_connector->connect(bucket, sp, ot)) {
+                  /// Connection i->j
+                  prepGraph.srcEdges.emplace_back(ot);
+                  prepGraph.desEdges.emplace_back(sp);
+                  ///Connection j-> i
+                  prepGraph.srcEdges.emplace_back(sp);
+                  prepGraph.desEdges.emplace_back(ot);
+
+               }
+            }
+        }
+        prepGraph.nodeIndex+=bucket.size();
+     }
+     
+}
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt
new file mode 100644
index 000000000000..df03d50d2f7d
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+
+# Declare the package name:
+atlas_subdir( MuonSPId )
+find_package( onnxruntime )
+
+
+# Component(s) in the package:
+atlas_add_component( MuonSPId
+                     src/*.cxx
+                     src/components/*.cxx
+                     LINK_LIBRARIES  AthenaBaseComps GeoPrimitives MuonIdHelpersLib  
+                                     MuonSimEvent xAODMuonSimHit StoreGateLib MdtCalibSvcLib
+                                     xAODMuonPrepData MuonSpacePoint xAODMuon MuonTesterTreeLib
+                                     MuonPatternEvent MuonPatternHelpers MuonPatternHelpers ${ONNXRUNTIME_LIBRARIES}
+                                     )
+                                     
+atlas_install_python_modules( python/*.py)
+
+atlas_add_test( testMuonSPId
+	SCRIPT python -m MuonSPId.muonSPIdDump --nEvents 5 --noSTGC --noMM
+                PROPERTIES TIMEOUT 600
+                PRIVATE_WORKING_DIRECTORY
+                POST_EXEC_SCRIPT noerror.sh)
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py
new file mode 100644
index 000000000000..953d77712313
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/MuonSPIdDumpConfig.py
@@ -0,0 +1,17 @@
+#Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
+
+from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
+from AthenaConfiguration.ComponentFactory import CompFactory
+
+def MuonSPIdDumpCfg(flags, name="MuonSPIdMaker", **kwargs):
+    result = ComponentAccumulator()
+    from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg
+    result.merge(MuonSpacePointFormationCfg(flags))
+    kwargs.setdefault("isMC", flags.Input.isMC)
+    #kwargs.setdefault("spIdValue", flags.spIdValue)
+    #from RngComps.RngCompsConfig import AthRNGSvcCfg
+    #kwargs.setdefault("RndmSvc", result.getPrimaryAndMerge(AthRNGSvcCfg(flags)))
+    the_alg = CompFactory.MuonR4.SPIdDumperAlg(name=name, **kwargs)
+    result.addEventAlgo(the_alg, primary = True)
+    return result
+
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py
new file mode 100644
index 000000000000..a07d728e8396
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py
@@ -0,0 +1,43 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+if __name__=="__main__":
+    from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg, SetupArgParser, executeTest, setupHistSvcCfg
+    parser = SetupArgParser()
+    parser.set_defaults(nEvents = -1)
+    parser.set_defaults(outRootFile="MuonSPId_R3SimHits.root")
+    parser.set_defaults(inputFile=[
+                                   "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/MuonGeomRTT/R3SimHits.pool.root"
+                                    ])
+    parser.set_defaults(eventPrintoutLevel = 500)
+    args = parser.parse_args()
+
+    from AthenaConfiguration.AllConfigFlags import initConfigFlags
+    flags = initConfigFlags()
+    flags.PerfMon.doFullMonMT = True
+
+    flags, cfg = setupGeoR4TestCfg(args)
+
+    cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile, outStream="MuonSPId"))
+
+    from MuonConfig.MuonDataPrepConfig import xAODUncalibMeasPrepCfg
+    cfg.merge(xAODUncalibMeasPrepCfg(flags))
+
+    from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg 
+    cfg.merge(MuonSpacePointFormationCfg(flags))
+
+    from MuonPatternRecognitionAlgs.MuonHoughTransformAlgConfig import MuonPatternRecognitionCfg, MuonSegmentFittingAlgCfg
+    cfg.merge(MuonPatternRecognitionCfg(flags))
+
+    cfg.merge(MuonSegmentFittingAlgCfg(flags))
+    
+    ## add the SPId flags here
+    #newparser = SetupArgParser()
+    #newparser.set_defaults(spIdValue=0)
+    #newparser.set_defaults(isMC=True)
+    #newargs = newparser.parse_args()
+    #flags.spIdValue = newargs.spIdValue
+
+    from MuonSPId.MuonSPIdDumpConfig import MuonSPIdDumpCfg
+    cfg.merge(MuonSPIdDumpCfg(flags))
+
+    executeTest(cfg)
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
new file mode 100644
index 000000000000..7e7f1be057f4
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
@@ -0,0 +1,38 @@
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+
+if __name__=="__main__":
+    from MuonGeoModelTestR4.testGeoModel import setupGeoR4TestCfg, SetupArgParser, executeTest, setupHistSvcCfg
+    parser = SetupArgParser()
+    parser.set_defaults(nEvents = -1)
+    parser.set_defaults(outRootFile="MuonSPId_2022_13p6TeV_00431493.root")
+    parser.set_defaults(condTag="CONDBR2-BLKPA-2023-03")
+    parser.set_defaults(inputFile=[
+                                   "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/Tier0ChainTests/TCT_Run3/data22_13p6TeV.00431493.physics_Main.daq.RAW._lb0525._SFO-16._0001.data"
+                                    ])
+    parser.set_defaults(eventPrintoutLevel = 500)
+    args = parser.parse_args()
+
+    from AthenaConfiguration.AllConfigFlags import initConfigFlags
+    flags = initConfigFlags()
+    flags.PerfMon.doFullMonMT = True
+
+    flags, cfg = setupGeoR4TestCfg(args)
+
+    cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile,
+                                    outStream="MuonSPId"))
+
+    from MuonConfig.MuonDataPrepConfig  import xAODUncalibMeasPrepCfg
+    cfg.merge(xAODUncalibMeasPrepCfg(flags))
+   
+    from MuonSpacePointFormation.SpacePointFormationConfig import MuonSpacePointFormationCfg 
+    cfg.merge(MuonSpacePointFormationCfg(flags))
+
+    from MuonPatternRecognitionAlgs.MuonHoughTransformAlgConfig import MuonPatternRecognitionCfg, MuonSegmentFittingAlgCfg
+    cfg.merge(MuonPatternRecognitionCfg(flags))
+    cfg.merge(MuonSegmentFittingAlgCfg(flags))
+    
+    from MuonSPId.MuonSPIdDumpConfig import MuonSPIdDumpCfg
+    cfg.merge(MuonSPIdDumpCfg(flags))
+
+    executeTest(cfg)
+
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx
new file mode 100644
index 000000000000..379257fe8b5b
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx
@@ -0,0 +1,399 @@
+/*
+   Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "SPIdDumperAlg.h"
+
+#include "Identifier/Identifier.h"
+#include "MuonIdHelpers/IMuonIdHelperSvc.h"
+#include "StoreGate/ReadHandle.h"
+#include "MuonTesterTree/EventHashBranch.h"
+#include "MuonSpacePoint/SpacePointPerLayerSorter.h"
+#include "xAODMeasurementBase/MeasurementDefs.h"
+#include "xAODMuonPrepData/UtilFunctions.h"
+#include "xAODMuonPrepData/MdtDriftCircle.h"
+#include <fstream>
+#include <TString.h>
+#include <AthenaKernel/RNGWrapper.h>
+#include "CLHEP/Random/RandFlat.h"
+
+namespace {
+    union bucketId{
+        int8_t fields[4];
+        int hash;
+    };
+
+}
+
+namespace MuonR4{
+    StatusCode SPIdDumperAlg::initialize() {
+        ATH_CHECK(m_readKey.initialize());
+        ATH_CHECK(m_idHelperSvc.retrieve());
+        ATH_CHECK(m_inSegmentKey.initialize(!m_inSegmentKey.empty()));
+        m_tree.addBranch(std::make_shared<MuonVal::EventHashBranch>(m_tree.tree()));
+        ATH_CHECK(m_tree.init(this));
+        ATH_CHECK(m_idHelperSvc.retrieve());
+
+
+        std::string onnx_path = m_modelPath;
+        ATH_MSG_INFO("Model Path: " << onnx_path);
+        ATH_MSG_INFO("SPId Cut Value: " << m_SPId_cut);
+
+        std::vector<std::string> feature_names;
+        try {
+
+            m_env             = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ONNXFeatureExtraction");
+            m_session_options = std::make_unique<Ort::SessionOptions>();
+            m_session_options->SetIntraOpNumThreads(1);
+            m_session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+            m_session         = std::make_unique<Ort::Session>(*m_env, onnx_path.c_str(), *m_session_options);
+    
+            Ort::ModelMetadata metadata = m_session->GetModelMetadata();
+
+            Ort::AllocatorWithDefaultOptions allocator;
+            Ort::AllocatedStringPtr feature_json_ptr = metadata.LookupCustomMetadataMapAllocated("feature_names", allocator);
+            
+            if (feature_json_ptr) {
+                std::string feature_json = feature_json_ptr.get();
+                nlohmann::json json_obj = nlohmann::json::parse(feature_json);
+                
+                for (const auto& feature : json_obj) {
+                    feature_names.push_back(feature.get<std::string>());
+                }
+            }
+        } catch (const std::exception& e) {
+            ATH_MSG_ERROR("Failed to retrieve feature names from ONNX model: " << e.what());
+            return StatusCode::FAILURE;
+        }
+    
+        if (feature_names.empty()) {
+            ATH_MSG_ERROR("No feature metadata found in the ONNX model!");
+            return StatusCode::FAILURE;
+        }
+    
+        ATH_MSG_ALWAYS("Extracted Feature Names:");
+        for (const auto& name : feature_names) {
+            ATH_MSG_ALWAYS("- " << name);
+        }
+
+        ATH_MSG_DEBUG("Successfully initialized");
+        return StatusCode::SUCCESS;
+    }
+
+    StatusCode SPIdDumperAlg::finalize() {
+        ATH_CHECK(m_tree.write());
+        return StatusCode::SUCCESS;
+    }
+    
+    StatusCode SPIdDumperAlg::execute(){
+        const EventContext& ctx{Gaudi::Hive::currentContext()};
+
+        std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap;  // MuonR4Segment 
+        
+        SG::ReadHandle readSegment(m_inSegmentKey, ctx);
+        ATH_CHECK(readSegment.isPresent());
+        for (const MuonR4::Segment* segment : *readSegment) {
+            segmentMap[segment->parent()->parentBucket()].push_back(segment);
+        }
+
+        SG::ReadHandle<SpacePointContainer> readHandle{m_readKey, ctx};
+        ATH_CHECK(readHandle.isPresent());
+
+        bool Sparse = false;
+        bool Labels = true;
+
+        std::vector<float> features;
+        std::vector<int64_t> edge_src;
+        std::vector<int64_t> edge_dst;
+        std::vector<int64_t> node_offsets;
+        std::vector<int64_t> labels;
+        size_t total_nodes = 0;
+
+        size_t bucket_index = 0;
+
+        for(const SpacePointBucket* bucket : *readHandle) {          
+
+            unsigned int segIdx{0};
+            std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment;
+            
+            if (Labels) {
+                auto match_itr = segmentMap.find(bucket);
+                if (match_itr != segmentMap.end()) {
+                    for (const MuonR4::Segment* segment : match_itr->second) {
+                        for (const auto& meas : segment->measurements()) {
+                            spacePointToSegment[meas->spacePoint()].push_back(segIdx);
+                        }
+                        ++segIdx;
+                    }
+                }
+            }
+
+            SpacePointPerLayerSorter sorter{*bucket};
+
+            size_t num_points = bucket->size();
+
+            unsigned int layer{0};
+            Identifier prevLayer = Identifier();
+
+            std::vector<u_int16_t> Neighbors(num_points, 0);
+
+            size_t bucket_offset = total_nodes;
+            node_offsets.push_back(bucket_offset);
+            total_nodes += num_points;
+
+            float min_x = 3000;
+            float max_x = -1;
+
+            Identifier layId;
+            for (u_int16_t i = 0; i < num_points; i++) {
+                const auto sp = bucket->at(i);
+                const Identifier id = sp->identify();
+
+                if (sp->type() == xAOD::UncalibMeasType::MdtDriftCircleType) {
+                    const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement());
+                    if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime){
+                        continue;
+                    }
+                    const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()};
+                    layId = idHelper.channelID(idHelper.stationName(id), 1, idHelper.stationPhi(id), idHelper.multilayer(id), idHelper.tubeLayer(id), 1);
+                    if (layId != prevLayer) {
+                        layer++;
+                        prevLayer = layId;
+                    }
+                } else {
+                    layId = m_idHelperSvc->gasGapId(id);
+
+                    if (layId != prevLayer) {
+                        layer++;
+                        prevLayer = layId;
+                    }
+                }
+
+                if (Labels) {
+                    const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()];
+                    if (segIdxs.size() > 0) {
+                        m_spoint_label.push_back(1);
+                    } else {
+                        m_spoint_label.push_back(0);
+                    }
+                }
+                
+                float x = sp->positionInChamber().x();
+                float y = sp->positionInChamber().y();
+                float z = sp->positionInChamber().z();
+
+
+
+                m_spoint_bucket.push_back(bucket_index);
+                m_spoint_x.push_back(x);
+                m_spoint_y.push_back(y);
+                m_spoint_z.push_back(z);
+                m_spoint_station.push_back(m_idHelperSvc->stationName(id));
+
+                m_spoint_driftR.push_back(sp->driftRadius());
+                m_spoint_layer.push_back(layer);
+
+                if (x < min_x) min_x = x;
+                if (x > max_x) max_x = x;
+
+
+                if (Sparse) {
+                    for (u_int16_t j = 0; j < num_points; j++) {
+                        if (i == j) continue;
+                        const auto sp2 = bucket->at(j);
+                        Identifier layId2;
+                        const Identifier id2 = sp2->identify();
+                        if (sp2->type() == xAOD::UncalibMeasType::MdtDriftCircleType) {
+                            const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()};
+                            layId2 = idHelper.channelID(idHelper.stationName(id2), 1, idHelper.stationPhi(id2), idHelper.multilayer(id2), idHelper.tubeLayer(id2), 1);
+                        } else {
+                            layId2 = m_idHelperSvc->gasGapId(id2);
+                        }
+                        float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2));
+                        if (h < 500) {
+                            Neighbors[i]++;
+                        }
+                        if ( h < 2000 and layId != layId2) { 
+                            edge_src.push_back(i + bucket_offset);
+                            edge_dst.push_back(j + bucket_offset);
+                            m_spoint_edges.push_back(j + bucket_offset);
+                        }
+                    }
+                } else {
+                    for (u_int16_t j = 0; j < num_points; j++) {
+                        if (i == j) continue;    
+                        edge_src.push_back(i + bucket_offset);
+                        edge_dst.push_back(j + bucket_offset);
+                        m_spoint_edges.push_back(j + bucket_offset);
+                        if (j < i ) continue;
+                        const auto sp2 = bucket->at(j);
+                        float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2));
+                        if (h < 500) {
+                            Neighbors[i]++;
+                            Neighbors[j]++;
+                        }
+                    }
+                }
+
+                m_spoint_neighbors.push_back(Neighbors[i]);
+            } // end loop over space points
+
+            m_bucket_layers.push_back(layer);
+            float density = 0;
+            float bucket_size = bucket->coveredMax() - bucket->coveredMin();
+            if (bucket_size == 0 ){
+                bucket_size = max_x - min_x;
+                if (bucket_size == 0) bucket_size = 1;
+            }
+            density = num_points / bucket_size;
+            m_bucket_density.push_back(density);
+            bucket_index++;
+
+        } // end loop over buckets
+
+        features.reserve(total_nodes * 8);
+        for (size_t i = 0; i < total_nodes; i++) {
+            features.push_back(m_spoint_x[i]);
+            features.push_back(m_spoint_y[i]);
+            features.push_back(m_spoint_z[i]);
+            features.push_back(m_spoint_station[i]);
+            features.push_back(m_spoint_driftR[i]);
+            features.push_back(float(m_spoint_layer[i]) / (m_bucket_layers[m_spoint_bucket[i]] > 0 ? m_bucket_layers[m_spoint_bucket[i]] : 1));
+            features.push_back(m_spoint_neighbors[i]);
+            features.push_back(m_bucket_density[m_spoint_bucket[i]]);
+            if (Labels) {
+                labels.push_back(m_spoint_label[i]);
+            }
+        }
+
+        if (features.empty()) {
+            ATH_MSG_WARNING("No valid feature data available for inference. Skipping event.");
+            return StatusCode::SUCCESS;
+        }
+
+        std::vector<int64_t> edge_index;
+        edge_index.reserve(2 * edge_src.size());
+        edge_index.insert(edge_index.end(), edge_src.begin(), edge_src.end());
+        edge_index.insert(edge_index.end(), edge_dst.begin(), edge_dst.end());
+
+        ATH_MSG_DEBUG("Total nodes: " << total_nodes);
+        ATH_MSG_DEBUG("Features size: " << features.size());
+        ATH_MSG_DEBUG("Expected feature shape: (" << total_nodes << ", 8)");
+
+        if (features.size() != total_nodes * 8) {
+            ATH_MSG_ERROR("Feature size mismatch! Expected " << total_nodes * 8 << " but got " << features.size());
+        }
+
+        ATH_MSG_DEBUG("Edge index size: " << edge_index.size());
+        ATH_MSG_DEBUG("Expected edge_index shape: (2, " << edge_index.size() / 2 << ")");
+
+        if (edge_index.size() % 2 != 0) {
+            ATH_MSG_ERROR("Edge index format error! Size should be divisible by 2.");
+        }
+
+        
+        //Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference");
+        //Ort::SessionOptions session_options;
+        //session_options.SetIntraOpNumThreads(1);
+        //session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+        //Ort::Session session(env, "torch_GatFourier_fcg_quantized.onnx", session_options);
+
+
+        Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+
+        std::vector<int64_t> feature_shape = {static_cast<int64_t>(total_nodes), 8};  // (N, 8)
+        std::vector<int64_t> edge_shape    = {2, static_cast<int64_t>(edge_index.size() / 2)};  // (2, E)
+
+        Ort::Value feature_tensor = Ort::Value::CreateTensor<float>(
+            memory_info, features.data(), features.size(), feature_shape.data(), feature_shape.size());
+
+        Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(
+            memory_info, edge_index.data(), edge_index.size(), edge_shape.data(), edge_shape.size());
+        
+        std::vector<const char*> input_names = {"features", "edge_index"};
+        std::vector<const char*> output_names = {"output"};
+
+        std::vector<Ort::Value> input_tensors;
+        input_tensors.push_back(std::move(feature_tensor));
+        input_tensors.push_back(std::move(edge_tensor));
+
+        //std::vector<Ort::Value> output_tensors = session.Run(
+        std::vector<Ort::Value> output_tensors = m_session->Run(
+            Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(),
+            input_tensors.size(), output_names.data(), output_names.size());
+
+        float* output_data = output_tensors[0].GetTensorMutableData<float>();
+        size_t output_size = output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
+        std::vector<float> predictions(output_data, output_data + output_size);
+
+        for (size_t i = 0; i < output_size; i++) {
+            if (!std::isfinite(predictions[i])) {
+                ATH_MSG_ERROR("Non-finite prediction detected! Setting to zero.");
+                predictions[i] = 0.0f;
+            }
+        }
+
+        ATH_MSG_DEBUG("ONNX output size: " << output_size);
+        ATH_MSG_DEBUG("First 5 predictions:");
+        for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) {
+            ATH_MSG_DEBUG("Prediction[" << i << "]: " << predictions[i]);
+        }
+
+        std::vector<int> binary_predictions(output_size);
+        for (size_t i = 0; i < output_size; i++) {
+            binary_predictions[i] = (predictions[i] > m_SPId_cut) ? 1 : 0;   // Why inference cut is different? -> this batch/graph building may be different from torch one?
+        }
+        ATH_MSG_DEBUG("First 5 Binary Predictions:");
+        for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) {
+            ATH_MSG_DEBUG("Binary_Prediction[" << i << "]: " << binary_predictions[i]);
+        }
+
+        if (Labels) {
+            float loss     = 0;
+            float accuracy = 0;
+            float pred_1   = 0;
+            float lab_1    = 0;
+            float true_0   = 0;
+            //float true_1   = 0;
+            for (size_t i = 0; i < output_size; i++) {
+                if (binary_predictions[i] != labels[i] && labels[i] == 1) {
+                    loss++;
+                }
+                if (binary_predictions[i] == labels[i]) {
+                    accuracy++;
+                }
+                //if (binary_predictions[i] == labels[i] && labels[i] == 1) {
+                //    true_1++;
+                //}
+                if (binary_predictions[i] == labels[i] && labels[i] == 0) {
+                    true_0++;
+                }
+                if (binary_predictions[i] == 1) {
+                    pred_1++;
+                }
+                if (labels[i] == 1) {
+                    lab_1++;
+                }       
+            }
+            //ATH_MSG_ALWAYS("Hits:            " << output_size );
+            ATH_MSG_ALWAYS("Accuracy:        " << accuracy / output_size          << " ( Correct: " << accuracy << " / Hits: " << output_size << ")" );
+            ATH_MSG_ALWAYS("Signal Loss:     " << loss     / lab_1                << " ( Miss predicted: " << loss << " / Good Hits: " << lab_1 << ")" );
+            float lab_0 = output_size - lab_1;
+            if (lab_0 != 0) {
+                ATH_MSG_ALWAYS("Real Rejection:  " << true_0 / (output_size - lab_1)  << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" );
+            } else ATH_MSG_ALWAYS("Real Rejection:  " << 0  << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" );
+            ATH_MSG_ALWAYS("Purity (before): " << lab_1  / output_size            << " ( Good Hits: " << lab_1 << " / Hits: " << output_size << ")" );            
+            ATH_MSG_ALWAYS("Purity (after):  " << (lab_1 - loss) / pred_1         << " ( Good predicted: " << lab_1 - loss << " / Predicted: " << pred_1 << ")" );
+            ATH_MSG_ALWAYS("-----------------------------");
+        }
+
+        for (size_t i = 0; i < output_size; i++) {
+            m_spoint_predictions.push_back(binary_predictions[i]);
+        }
+
+        if (!m_tree.fill(ctx)) return StatusCode::FAILURE;
+        return StatusCode::SUCCESS;
+
+    }
+
+}
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h
new file mode 100644
index 000000000000..6a2f9b1c7d90
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h
@@ -0,0 +1,95 @@
+#ifndef MUON_SP_ID_DUMPER_H
+#define MUON_SP_ID_DUMPER_H
+/*
+   Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include <AthenaBaseComps/AthAlgorithm.h>
+#include "AthenaBaseComps/AthHistogramAlgorithm.h"
+
+#include <MuonIdHelpers/IMuonIdHelperSvc.h>
+#include "StoreGate/ReadHandleKey.h"
+#include "StoreGate/ReadDecorHandle.h"
+#include "StoreGate/ReadHandleKeyArray.h"
+#include "StoreGate/ReadCondHandleKey.h"
+
+#include <MuonPatternEvent/MuonPatternContainer.h>
+#include <MuonSpacePoint/SpacePointContainer.h>
+#include <ActsGeometryInterfaces/ActsGeometryContext.h>
+#include <MuonReadoutGeometryR4/MuonDetectorManager.h>
+
+#include "MuonTesterTree/MuonTesterTree.h"
+#include "MuonTesterTree/ThreeVectorBranch.h"
+#include "MuonTesterTree/IdentifierBranch.h"
+
+#include "xAODMuonSimHit/MuonSimHitContainer.h"
+#include "xAODMuon/MuonSegmentContainer.h"
+
+#include "AthenaKernel/IAthRNGSvc.h"
+#include "CLHEP/Random/RandomEngine.h"
+
+#include "nlohmann/json.hpp"
+
+
+// onnx runtime
+#include <onnxruntime_cxx_api.h>
+
+
+namespace MuonR4{
+class SPIdDumperAlg: public AthHistogramAlgorithm {
+
+   public:
+    using AthHistogramAlgorithm::AthHistogramAlgorithm;
+    ~SPIdDumperAlg() = default;
+
+    virtual StatusCode initialize() override final;
+    virtual StatusCode finalize() override final;
+    virtual StatusCode execute() override final;
+
+   private:
+
+    void fillChamberInfo(const MuonGMR4::Chamber* chamber); 
+
+    SG::ReadHandleKey<SpacePointContainer> m_readKey{this, "ReadKey", "MuonSpacePoints", "Key to the space point container"};
+    ServiceHandle<Muon::IMuonIdHelperSvc>  m_idHelperSvc{this, "MuonIdHelperSvc", "Muon::MuonIdHelperSvc/MuonIdHelperSvc"};
+
+    SG::ReadHandleKey<MuonR4::SegmentContainer> m_inSegmentKey{this, "SegmentKey", "R4MuonSegments"};
+    
+    Gaudi::Property<bool> m_isMC{this, "isMC", true};
+    Gaudi::Property<double> m_SPId_cut{this,"spIdValue", 0.0001};
+    Gaudi::Property<std::string> m_modelPath{this, "ModelPath", "torch_GatFourier_fcg_quantized.onnx"};
+
+    MuonVal::MuonTesterTree m_tree{"MuonSPId","MuonSPId"};
+
+    MuonVal::VectorBranch<float>&           m_bucket_density{m_tree.newVector<float>("bucket_density", 0)};
+    MuonVal::VectorBranch<uint8_t>&         m_bucket_layers{m_tree.newVector<uint8_t>("bucket_layers", 0)}; 
+
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_bucket{m_tree.newVector<uint8_t>("bucket_index")};
+
+    //MuonVal::ThreeVectorBranch              m_spoint_localPosition{m_tree, "localPosition"}; 
+    MuonVal::VectorBranch<float>&           m_spoint_x{m_tree.newVector<float>("x")};
+    MuonVal::VectorBranch<float>&           m_spoint_y{m_tree.newVector<float>("y")};
+    MuonVal::VectorBranch<float>&           m_spoint_z{m_tree.newVector<float>("z")};
+
+    //MuonVal::MuonIdentifierBranch           m_spoint_id{m_tree, "id"};
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_station{m_tree.newVector<uint8_t>("stationIndex")};
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_layer{m_tree.newVector<uint8_t>("layer")};
+
+    MuonVal::VectorBranch<float>&           m_spoint_driftR{m_tree.newVector<float>("driftR")};
+    MuonVal::VectorBranch<float>&           m_spoint_neighbors{m_tree.newVector<float>("neighbors")};
+
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_label{m_tree.newVector<uint8_t>("label")};
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_predictions{m_tree.newVector<uint8_t>("predictions")};
+    MuonVal::VectorBranch<uint16_t>&        m_spoint_edges{m_tree.newVector<uint16_t>("edges")};
+
+    size_t m_event{0};
+
+    // ONNX Runtime objects
+    std::unique_ptr<Ort::Env> m_env;
+    std::unique_ptr<Ort::SessionOptions> m_session_options;
+    std::unique_ptr<Ort::Session> m_session;
+
+};
+}
+#endif
+
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx
new file mode 100644
index 000000000000..a8ed24f5b37d
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.cxx
@@ -0,0 +1,340 @@
+/*
+   Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "SPIdentifierAlg.h"
+
+#include "Identifier/Identifier.h"
+#include "MuonIdHelpers/IMuonIdHelperSvc.h"
+#include "StoreGate/ReadHandle.h"
+#include "MuonTesterTree/EventHashBranch.h"
+#include "MuonSpacePoint/SpacePointPerLayerSorter.h"
+#include "xAODMeasurementBase/MeasurementDefs.h"
+#include "xAODMuonPrepData/UtilFunctions.h"
+#include "xAODMuonPrepData/MdtDriftCircle.h"
+#include <fstream>
+#include <TString.h>
+#include <AthenaKernel/RNGWrapper.h>
+#include "CLHEP/Random/RandFlat.h"
+
+namespace {
+    union bucketId{
+        int8_t fields[4];
+        int hash;
+    };
+
+}
+
+namespace MuonR4{
+
+    StatusCode SPIdentifierAlg::initialize() {
+        ATH_CHECK(m_readKey.initialize());
+        ATH_CHECK(m_idHelperSvc.retrieve());
+        ATH_CHECK(m_inSegmentKey.initialize(!m_inSegmentKey.empty()));
+        m_tree.addBranch(std::make_shared<MuonVal::EventHashBranch>(m_tree.tree()));
+        ATH_CHECK(m_tree.init(this));
+        ATH_CHECK(m_idHelperSvc.retrieve());
+
+        ATH_MSG_DEBUG("Successfully initialized");
+
+        return StatusCode::SUCCESS;
+    }
+
+    StatusCode SPIdentifierAlg::finalize() {
+        ATH_CHECK(m_tree.write());
+        return StatusCode::SUCCESS;
+    }
+    
+    StatusCode SPIdentifierAlg::execute(){
+        const EventContext& ctx{Gaudi::Hive::currentContext()};
+
+        std::unordered_map <const SpacePointBucket*, std::vector<const MuonR4::Segment*>> segmentMap;
+        
+        SG::ReadHandle readSegment(m_inSegmentKey, ctx);
+        ATH_CHECK(readSegment.isPresent());
+        for (const MuonR4::Segment* segment : *readSegment) {
+            segmentMap[segment->parent()->parentBucket()].push_back(segment);
+        }
+
+        SG::ReadHandle<SpacePointContainer> readHandle{m_readKey, ctx};
+        ATH_CHECK(readHandle.isPresent());
+
+        bool Sparse = false;
+        bool Labels = true;
+
+        std::vector<float> features;
+        std::vector<int64_t> edge_src;
+        std::vector<int64_t> edge_dst;
+        std::vector<int64_t> node_offsets;
+        std::vector<int64_t> labels;
+        size_t total_nodes = 0;
+        size_t bucket_index = 0;
+
+        for(const SpacePointBucket* bucket : *readHandle) {          
+
+            unsigned int segIdx{0};
+            std::unordered_map<const SpacePoint*, std::vector<int16_t>> spacePointToSegment;
+            
+            if (Labels) {
+                auto match_itr = segmentMap.find(bucket);
+                if (match_itr != segmentMap.end()) {
+                    for (const MuonR4::Segment* segment : match_itr->second) {
+                        for (const auto& meas : segment->measurements()) {
+                            spacePointToSegment[meas->spacePoint()].push_back(segIdx);
+                        }
+                        ++segIdx;
+                    }
+                }
+            }
+
+            SpacePointPerLayerSorter sorter{*bucket};
+
+            size_t num_points = bucket->size();
+
+            unsigned int layer{0};
+            Identifier prevLayer = Identifier();
+
+            std::vector<u_int16_t> Neighbors(num_points, 0);
+
+            size_t bucket_offset = total_nodes;
+            node_offsets.push_back(bucket_offset);
+            total_nodes += num_points;
+
+            float min_x = 3000;
+            float max_x = -1;
+
+            Identifier layId;
+            for (u_int16_t i = 0; i < num_points; i++) {
+                const auto sp = bucket->at(i);
+                const Identifier id = sp->identify();
+
+                if (sp->type() == xAOD::UncalibMeasType::MdtDriftCircleType) {
+                    const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement());
+                    if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime){
+                        continue;
+                    }
+                    const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()};
+                    layId = idHelper.channelID(idHelper.stationName(id), 1, idHelper.stationPhi(id), idHelper.multilayer(id), idHelper.tubeLayer(id), 1);
+                    if (layId != prevLayer) {
+                        layer++;
+                        prevLayer = layId;
+                    }
+                } else {
+                    layId = m_idHelperSvc->gasGapId(id);
+
+                    if (layId != prevLayer) {
+                        layer++;
+                        prevLayer = layId;
+                    }
+                }
+
+                if (Labels) {
+                    const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()];
+                    if (segIdxs.size() > 0) {
+                        m_spoint_label.push_back(1);
+                    } else {
+                        m_spoint_label.push_back(0);
+                    }
+                }
+                
+                float x = sp->positionInChamber().x();
+                float y = sp->positionInChamber().y();
+                float z = sp->positionInChamber().z();
+
+
+
+                m_spoint_bucket.push_back(bucket_index);
+                m_spoint_x.push_back(x);
+                m_spoint_y.push_back(y);
+                m_spoint_z.push_back(z);
+                m_spoint_station.push_back(m_idHelperSvc->stationName(id));
+
+                m_spoint_driftR.push_back(sp->driftRadius());
+                m_spoint_layer.push_back(layer);
+
+                if (x < min_x) min_x = x;
+                if (x > max_x) max_x = x;
+
+
+                if (Sparse) {
+                    for (u_int16_t j = 0; j < num_points; j++) {
+                        if (i == j) continue;
+                        const auto sp2 = bucket->at(j);
+                        Identifier layId2;
+                        const Identifier id2 = sp2->identify();
+                        if (sp2->type() == xAOD::UncalibMeasType::MdtDriftCircleType) {
+                            const MdtIdHelper& idHelper{m_idHelperSvc->mdtIdHelper()};
+                            layId2 = idHelper.channelID(idHelper.stationName(id2), 1, idHelper.stationPhi(id2), idHelper.multilayer(id2), idHelper.tubeLayer(id2), 1);
+                        } else {
+                            layId2 = m_idHelperSvc->gasGapId(id2);
+                        }
+                        float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2));
+                        if (h < 500) {
+                            Neighbors[i]++;
+                        }
+                        if ( h < 2000 and layId != layId2) { 
+                            edge_src.push_back(i + bucket_offset);
+                            edge_dst.push_back(j + bucket_offset);
+                            m_spoint_edges.push_back(j + bucket_offset);
+                        }
+                    }
+                } else {
+                    for (u_int16_t j = 0; j < num_points; j++) {
+                        if (i == j) continue;    
+                        edge_src.push_back(i + bucket_offset);
+                        edge_dst.push_back(j + bucket_offset);
+                        m_spoint_edges.push_back(j + bucket_offset);
+                        if (j < i ) continue;
+                        const auto sp2 = bucket->at(j);
+                        float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2));
+                        if (h < 500) {
+                            Neighbors[i]++;
+                            Neighbors[j]++;
+                        }
+                    }
+                }
+
+                m_spoint_neighbors.push_back(Neighbors[i]);
+            } // end loop over space points
+
+            m_bucket_layers.push_back(layer);
+            float density = 0;
+            float bucket_size = bucket->coveredMax() - bucket->coveredMin();
+            if (bucket_size == 0 ){
+                bucket_size = max_x - min_x;
+                if (bucket_size == 0) bucket_size = 1;
+            }
+            density = num_points / bucket_size;
+            m_bucket_density.push_back(density);
+            bucket_index++;
+
+        } // end loop over buckets
+
+        features.reserve(total_nodes * 8);
+        for (size_t i = 0; i < total_nodes; i++) {
+            features.push_back(m_spoint_x[i]);
+            features.push_back(m_spoint_y[i]);
+            features.push_back(m_spoint_z[i]);
+            features.push_back(m_spoint_station[i]);
+            features.push_back(m_spoint_driftR[i]);
+            features.push_back(float(m_spoint_layer[i]) / (m_bucket_layers[m_spoint_bucket[i]] > 0 ? m_bucket_layers[m_spoint_bucket[i]] : 1));
+            features.push_back(m_spoint_neighbors[i]);
+            features.push_back(m_bucket_density[m_spoint_bucket[i]]);
+            if (Labels) {
+                labels.push_back(m_spoint_label[i]);
+            }
+        }
+
+        if (features.empty()) {
+            ATH_MSG_WARNING("No valid feature data available for inference. Skipping event.");
+            return StatusCode::SUCCESS;
+        }
+
+        std::vector<int64_t> edge_index;
+        edge_index.reserve(2 * edge_src.size());
+        edge_index.insert(edge_index.end(), edge_src.begin(), edge_src.end());
+        edge_index.insert(edge_index.end(), edge_dst.begin(), edge_dst.end());
+
+        ATH_MSG_DEBUG("Total nodes: " << total_nodes);
+        ATH_MSG_DEBUG("Features size: " << features.size());
+        ATH_MSG_DEBUG("Expected feature shape: (" << total_nodes << ", 8)");
+
+        if (features.size() != total_nodes * 8) {
+            ATH_MSG_ERROR("Feature size mismatch! Expected " << total_nodes * 8 << " but got " << features.size());
+        }
+
+        ATH_MSG_DEBUG("Edge index size: " << edge_index.size());
+        ATH_MSG_DEBUG("Expected edge_index shape: (2, " << edge_index.size() / 2 << ")");
+
+        if (edge_index.size() % 2 != 0) {
+            ATH_MSG_ERROR("Edge index format error! Size should be divisible by 2.");
+        }
+
+        Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference");
+        Ort::SessionOptions session_options;
+        session_options.SetIntraOpNumThreads(1);
+        session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+
+        Ort::Session session(env, "torch_GatFourier_fcg_quantized.onnx", session_options);
+
+        Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+
+        std::vector<int64_t> feature_shape = {static_cast<int64_t>(total_nodes), 8};  // (N, 8)
+        std::vector<int64_t> edge_shape    = {2, static_cast<int64_t>(edge_index.size() / 2)};  // (2, E)
+
+        Ort::Value feature_tensor = Ort::Value::CreateTensor<float>(
+            memory_info, features.data(), features.size(), feature_shape.data(), feature_shape.size());
+
+        Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(
+            memory_info, edge_index.data(), edge_index.size(), edge_shape.data(), edge_shape.size());
+        
+        std::vector<const char*> input_names = {"features", "edge_index"};
+        std::vector<const char*> output_names = {"output"};
+
+        std::vector<Ort::Value> input_tensors;
+        input_tensors.push_back(std::move(feature_tensor));
+        input_tensors.push_back(std::move(edge_tensor));
+
+        std::vector<Ort::Value> output_tensors = session.Run(
+            Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(),
+            input_tensors.size(), output_names.data(), output_names.size());
+
+        float* output_data = output_tensors[0].GetTensorMutableData<float>();
+        size_t output_size = output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
+        std::vector<float> predictions(output_data, output_data + output_size);
+
+        for (size_t i = 0; i < output_size; i++) {
+            if (!std::isfinite(predictions[i])) {
+                ATH_MSG_ERROR("Non-finite prediction detected! Setting to zero.");
+                predictions[i] = 0.0f;
+            }
+        }
+
+        ATH_MSG_DEBUG("ONNX output size: " << output_size);
+        ATH_MSG_DEBUG("First 5 predictions:");
+        for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) {
+            ATH_MSG_DEBUG("Prediction[" << i << "]: " << predictions[i]);
+        }
+
+        std::vector<int> binary_predictions(output_size);
+        for (size_t i = 0; i < output_size; i++) {
+            binary_predictions[i] = (predictions[i] > 0.05) ? 1 : 0;
+        }
+        ATH_MSG_DEBUG("First 5 Binary Predictions:");
+        for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) {
+            ATH_MSG_DEBUG("Binary_Prediction[" << i << "]: " << binary_predictions[i]);
+        }
+
+        if (Labels) {
+            float loss = 0;
+            float accuracy = 0;
+            for (size_t i = 0; i < output_size; i++) {
+                if (binary_predictions[i] != labels[i] && labels[i] == 1) {
+                    loss++;
+                }
+                if (binary_predictions[i] == labels[i]) {
+                    accuracy++;
+                }
+            }
+            ATH_MSG_ALWAYS("Hits:     " << output_size);
+            ATH_MSG_ALWAYS("Loss:     " << loss / output_size);
+            ATH_MSG_ALWAYS("Accuracy: " << accuracy / output_size);
+        }
+
+        for (size_t i = 0; i < output_size; i++) {
+            m_spoint_predictions.push_back(binary_predictions[i]);
+        }
+
+        if (!m_tree.fill(ctx)) return StatusCode::FAILURE;
+        return StatusCode::SUCCESS;
+
+    }
+
+    CLHEP::HepRandomEngine* SPIdentifierAlg::getRandomEngine(const EventContext&ctx) const {
+        ATHRNG::RNGWrapper* rngWrapper = m_rndmSvc->getEngine(this, m_streamName);
+        std::string rngName = m_streamName;
+        rngWrapper->setSeed(rngName, ctx);
+        return rngWrapper->getEngine(ctx);
+    }
+
+}
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h
new file mode 100644
index 000000000000..b135b4a45864
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdentifierAlg.h
@@ -0,0 +1,91 @@
+#ifndef MUON_SP_ID_DUMPER_H
+#define MUON_SP_ID_DUMPER_H
+/*
+   Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include <AthenaBaseComps/AthAlgorithm.h>
+#include "AthenaBaseComps/AthHistogramAlgorithm.h"
+
+#include <MuonIdHelpers/IMuonIdHelperSvc.h>
+#include "StoreGate/ReadHandleKey.h"
+#include "StoreGate/ReadDecorHandle.h"
+#include "StoreGate/ReadHandleKeyArray.h"
+#include "StoreGate/ReadCondHandleKey.h"
+
+#include <MuonPatternEvent/MuonPatternContainer.h>
+#include <MuonSpacePoint/SpacePointContainer.h>
+#include <ActsGeometryInterfaces/ActsGeometryContext.h>
+#include <MuonReadoutGeometryR4/MuonDetectorManager.h>
+
+#include "MuonTesterTree/MuonTesterTree.h"
+#include "MuonTesterTree/ThreeVectorBranch.h"
+#include "MuonTesterTree/IdentifierBranch.h"
+
+#include "xAODMuonSimHit/MuonSimHitContainer.h"
+#include "xAODMuon/MuonSegmentContainer.h"
+
+#include "AthenaKernel/IAthRNGSvc.h"
+#include "CLHEP/Random/RandomEngine.h"
+
+// onnx runtime
+#include <onnxruntime_cxx_api.h>
+
+
+namespace MuonR4{
+class SPIdentifierAlg: public AthHistogramAlgorithm {
+
+   public:
+    using AthHistogramAlgorithm::AthHistogramAlgorithm;
+    ~SPIdentifierAlg() = default;
+
+    virtual StatusCode initialize() override final;
+    virtual StatusCode finalize() override final;
+    virtual StatusCode execute() override final;
+
+   private:
+
+    void fillChamberInfo(const MuonGMR4::Chamber* chamber); 
+
+    SG::ReadHandleKey<SpacePointContainer> m_readKey{this, "ReadKey", "MuonSpacePoints", "Key to the space point container"};
+    ServiceHandle<Muon::IMuonIdHelperSvc> m_idHelperSvc{this, "MuonIdHelperSvc", "Muon::MuonIdHelperSvc/MuonIdHelperSvc"};
+
+    SG::ReadHandleKey<MuonR4::SegmentContainer> m_inSegmentKey{this, "SegmentKey", "R4MuonSegments"};
+
+    //SG::ReadHandleKey<ActsGeometryContext> m_geoCtxKey{this, "AlignmentKey", "ActsAlignment", "cond handle key"};
+    
+    Gaudi::Property<bool> m_isMC{this, "isMC", true};
+    //Gaudi::Property<double> m_fracToKeep{this,"dataFracToKeep", 1}; // 0.055 to balanced dataset without MC
+    Gaudi::Property<std::string> m_streamName{this, "StreamName", ""};
+    ServiceHandle<IAthRNGSvc> m_rndmSvc{this, "RndmSvc", "AthRNGSvc", ""};
+    CLHEP::HepRandomEngine* getRandomEngine(const EventContext&ctx) const;
+
+    MuonVal::MuonTesterTree m_tree{"MuonSPId","MuonSPId"};
+
+    MuonVal::VectorBranch<float>&           m_bucket_density{m_tree.newVector<float>("bucket_density", 0)};
+    MuonVal::VectorBranch<uint8_t>&         m_bucket_layers{m_tree.newVector<uint8_t>("bucket_layers", 0)}; 
+
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_bucket{m_tree.newVector<uint8_t>("bucket_index")};
+
+    //MuonVal::ThreeVectorBranch              m_spoint_localPosition{m_tree, "localPosition"}; 
+    MuonVal::VectorBranch<float>&           m_spoint_x{m_tree.newVector<float>("x")};
+    MuonVal::VectorBranch<float>&           m_spoint_y{m_tree.newVector<float>("y")};
+    MuonVal::VectorBranch<float>&           m_spoint_z{m_tree.newVector<float>("z")};
+
+    //MuonVal::MuonIdentifierBranch           m_spoint_id{m_tree, "id"};
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_station{m_tree.newVector<uint8_t>("stationIndex")};
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_layer{m_tree.newVector<uint8_t>("layer")};
+
+    MuonVal::VectorBranch<float>&           m_spoint_driftR{m_tree.newVector<float>("driftR")};
+    MuonVal::VectorBranch<float>&           m_spoint_neighbors{m_tree.newVector<float>("neighbors")};
+
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_label{m_tree.newVector<uint8_t>("label")};
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_predictions{m_tree.newVector<uint8_t>("predictions")};
+    MuonVal::VectorBranch<uint16_t>&        m_spoint_edges{m_tree.newVector<uint16_t>("edges")};
+
+    size_t m_event{0};
+
+};
+}
+#endif
+
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx
new file mode 100644
index 000000000000..bd50be95430b
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/components/MuonSPId_entries.cxx
@@ -0,0 +1,7 @@
+
+/*
+   Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+#include "../SPIdDumperAlg.h"
+DECLARE_COMPONENT(MuonR4::SPIdDumperAlg)
+
-- 
GitLab


From 2ddd20266f23e5e784866927364f87a63dfa8aac Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Thu, 20 Feb 2025 15:28:01 +0100
Subject: [PATCH 2/9] Speeding up inference and fixing dumper

---
 .../MuonSPId/python/muonSPIdDump_data.py      |   2 +
 .../MuonSPId/src/SPIdDumperAlg.cxx            | 280 ++++++++++++------
 .../MuonLearning/MuonSPId/src/SPIdDumperAlg.h |  23 +-
 3 files changed, 200 insertions(+), 105 deletions(-)

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
index 7e7f1be057f4..187225123d58 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
@@ -33,6 +33,8 @@ if __name__=="__main__":
     
     from MuonSPId.MuonSPIdDumpConfig import MuonSPIdDumpCfg
     cfg.merge(MuonSPIdDumpCfg(flags))
+    
+    #cfg.getService("MessageSvc").setDebug=["MuonSPIdMaker"]
 
     executeTest(cfg)
 
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx
index 379257fe8b5b..c7e2a5bc7ba6 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.cxx
@@ -15,6 +15,7 @@
 #include <fstream>
 #include <TString.h>
 #include <AthenaKernel/RNGWrapper.h>
+#include <vector>
 #include "CLHEP/Random/RandFlat.h"
 
 namespace {
@@ -44,7 +45,8 @@ namespace MuonR4{
 
             m_env             = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "ONNXFeatureExtraction");
             m_session_options = std::make_unique<Ort::SessionOptions>();
-            m_session_options->SetIntraOpNumThreads(1);
+            //m_session_options->SetIntraOpNumThreads(1);
+            m_session_options->SetInterOpNumThreads(std::thread::hardware_concurrency());
             m_session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
             m_session         = std::make_unique<Ort::Session>(*m_env, onnx_path.c_str(), *m_session_options);
     
@@ -71,9 +73,9 @@ namespace MuonR4{
             return StatusCode::FAILURE;
         }
     
-        ATH_MSG_ALWAYS("Extracted Feature Names:");
+        ATH_MSG_DEBUG("Extracted Feature Names:");
         for (const auto& name : feature_names) {
-            ATH_MSG_ALWAYS("- " << name);
+            ATH_MSG_DEBUG("- " << name);
         }
 
         ATH_MSG_DEBUG("Successfully initialized");
@@ -102,15 +104,38 @@ namespace MuonR4{
         bool Sparse = false;
         bool Labels = true;
 
-        std::vector<float> features;
-        std::vector<int64_t> edge_src;
-        std::vector<int64_t> edge_dst;
-        std::vector<int64_t> node_offsets;
-        std::vector<int64_t> labels;
-        size_t total_nodes = 0;
-
+        std::vector<float>    spoint_x, spoint_y, spoint_z, spoint_driftR, spoint_layer;
+        std::vector<uint8_t>  spoint_station, spoint_neighbors, spoint_bucket;
+        std::vector<uint16_t> spoint_edges, bucket_layers;
+        std::vector<uint64_t> edge_src, edge_dst, node_offsets;
+        std::vector<bool> spoint_label;
+
+        spoint_x.reserve(100000);
+        spoint_y.reserve(100000);
+        spoint_z.reserve(100000);
+        spoint_driftR.reserve(100000);
+        spoint_layer.reserve(100000);
+        spoint_station.reserve(100000);
+        spoint_bucket.reserve(100000);
+        spoint_neighbors.reserve(100000);
+        spoint_edges.reserve(100000);
+
+        size_t total_nodes  = 0;
         size_t bucket_index = 0;
 
+        std::vector<float>    spoint_predictions;
+
+        std::vector<float>   bucket_density;
+        std::vector<uint8_t> bucket_points;
+        std::vector<float>   features;
+        std::vector<bool>    labels;
+        
+
+        std::vector<Ort::Value>  input_tensors;
+        std::vector<int64_t>     edge_index;
+        std::vector<const char*> input_names = {"features", "edge_index"};
+        std::vector<const char*> output_names = {"output"};
+
         for(const SpacePointBucket* bucket : *readHandle) {          
 
             unsigned int segIdx{0};
@@ -139,13 +164,31 @@ namespace MuonR4{
 
             size_t bucket_offset = total_nodes;
             node_offsets.push_back(bucket_offset);
-            total_nodes += num_points;
 
+            //std::unordered_map<u_int16_t, u_int16_t> index_mapping;  // Maps original index to new valid index
+            std::vector<u_int16_t> index_mapping(num_points, UINT16_MAX);
+            u_int16_t valid_index = 0;
             float min_x = 3000;
             float max_x = -1;
 
             Identifier layId;
+
+            for (u_int16_t i = 0; i < num_points; i++) {
+                const auto sp = bucket->at(i);
+                if (sp->type() == xAOD::UncalibMeasType::MdtDriftCircleType) {
+                    const auto* dc = static_cast<const xAOD::MdtDriftCircle*>(sp->primaryMeasurement());
+                    if (dc->status() != Muon::MdtDriftCircleStatus::MdtStatusDriftTime) {
+                        continue;
+                    }
+                }
+                index_mapping[i] = valid_index;
+                valid_index++;
+            }            
+            total_nodes += valid_index;  
+
+
             for (u_int16_t i = 0; i < num_points; i++) {
+                if (index_mapping[i] == UINT16_MAX) continue;
                 const auto sp = bucket->at(i);
                 const Identifier id = sp->identify();
 
@@ -169,12 +212,12 @@ namespace MuonR4{
                     }
                 }
 
-                if (Labels) {
-                    const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()];
+                const std::vector<int16_t>& segIdxs = spacePointToSegment[sp.get()];
+                if (Labels) {    
                     if (segIdxs.size() > 0) {
-                        m_spoint_label.push_back(1);
+                        spoint_label.push_back(1);
                     } else {
-                        m_spoint_label.push_back(0);
+                        spoint_label.push_back(0);
                     }
                 }
                 
@@ -182,24 +225,21 @@ namespace MuonR4{
                 float y = sp->positionInChamber().y();
                 float z = sp->positionInChamber().z();
 
-
-
-                m_spoint_bucket.push_back(bucket_index);
-                m_spoint_x.push_back(x);
-                m_spoint_y.push_back(y);
-                m_spoint_z.push_back(z);
-                m_spoint_station.push_back(m_idHelperSvc->stationName(id));
-
-                m_spoint_driftR.push_back(sp->driftRadius());
-                m_spoint_layer.push_back(layer);
+                spoint_x.push_back(x);
+                spoint_y.push_back(y);
+                spoint_z.push_back(z);
+                spoint_driftR.push_back(sp->driftRadius());
+                spoint_station.push_back(m_idHelperSvc->stationName(id));
+                spoint_layer.push_back(layer);
+                spoint_bucket.push_back(bucket_index);
 
                 if (x < min_x) min_x = x;
                 if (x > max_x) max_x = x;
 
-
                 if (Sparse) {
                     for (u_int16_t j = 0; j < num_points; j++) {
-                        if (i == j) continue;
+                        if (i == j || (index_mapping[j] == UINT16_MAX) ) continue; 
+                        //if (i == j) continue;
                         const auto sp2 = bucket->at(j);
                         Identifier layId2;
                         const Identifier id2 = sp2->identify();
@@ -209,36 +249,43 @@ namespace MuonR4{
                         } else {
                             layId2 = m_idHelperSvc->gasGapId(id2);
                         }
-                        float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2));
-                        if (h < 500) {
-                            Neighbors[i]++;
+                        //float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2));
+                        float dx = sp2->positionInChamber().x() - x;
+                        float dy = sp2->positionInChamber().y() - y;
+                        float h_squared = dx * dx + dy * dy;
+                        if (h_squared < 250000) {
+                            Neighbors[index_mapping[i]]++;
                         }
-                        if ( h < 2000 and layId != layId2) { 
-                            edge_src.push_back(i + bucket_offset);
-                            edge_dst.push_back(j + bucket_offset);
-                            m_spoint_edges.push_back(j + bucket_offset);
+                        if ( h_squared < 4000000 and layId != layId2) { 
+                            edge_src.push_back(index_mapping[i] + bucket_offset);
+                            edge_dst.push_back(index_mapping[j] + bucket_offset);
+                            spoint_edges.push_back(index_mapping[j] + bucket_offset);
                         }
                     }
                 } else {
                     for (u_int16_t j = 0; j < num_points; j++) {
-                        if (i == j) continue;    
-                        edge_src.push_back(i + bucket_offset);
-                        edge_dst.push_back(j + bucket_offset);
-                        m_spoint_edges.push_back(j + bucket_offset);
-                        if (j < i ) continue;
+                        //if (i == j) continue;    
+                        if (i == j || (index_mapping[j] == UINT16_MAX) ) continue; // Skip removed points
+                        edge_src.push_back(index_mapping[i] + bucket_offset);
+                        edge_dst.push_back(index_mapping[j] + bucket_offset);
+                        spoint_edges.push_back(index_mapping[j] + bucket_offset);
+                        if ( j < i ) continue;
                         const auto sp2 = bucket->at(j);
-                        float h = sqrt(pow(sp2->positionInChamber().x() - x, 2) + pow(sp2->positionInChamber().y() - y, 2));
-                        if (h < 500) {
-                            Neighbors[i]++;
-                            Neighbors[j]++;
+                        float dx = sp2->positionInChamber().x() - x;
+                        float dy = sp2->positionInChamber().y() - y;
+                        float h_squared = dx * dx + dy * dy;
+                        if (h_squared < 250000) {
+                            Neighbors[index_mapping[i]]++;
+                            Neighbors[index_mapping[j]]++;
                         }
                     }
                 }
+                
 
-                m_spoint_neighbors.push_back(Neighbors[i]);
+                spoint_neighbors.push_back(Neighbors[index_mapping[i]]);
             } // end loop over space points
 
-            m_bucket_layers.push_back(layer);
+            bucket_layers.push_back(layer);
             float density = 0;
             float bucket_size = bucket->coveredMax() - bucket->coveredMin();
             if (bucket_size == 0 ){
@@ -246,32 +293,49 @@ namespace MuonR4{
                 if (bucket_size == 0) bucket_size = 1;
             }
             density = num_points / bucket_size;
-            m_bucket_density.push_back(density);
+            bucket_density.push_back(density);
+            bucket_points.push_back(num_points);
             bucket_index++;
-
+            
         } // end loop over buckets
 
         features.reserve(total_nodes * 8);
+        labels.reserve(total_nodes);
+
+
         for (size_t i = 0; i < total_nodes; i++) {
-            features.push_back(m_spoint_x[i]);
-            features.push_back(m_spoint_y[i]);
-            features.push_back(m_spoint_z[i]);
-            features.push_back(m_spoint_station[i]);
-            features.push_back(m_spoint_driftR[i]);
-            features.push_back(float(m_spoint_layer[i]) / (m_bucket_layers[m_spoint_bucket[i]] > 0 ? m_bucket_layers[m_spoint_bucket[i]] : 1));
-            features.push_back(m_spoint_neighbors[i]);
-            features.push_back(m_bucket_density[m_spoint_bucket[i]]);
+            float relative_layer = float(spoint_layer[i]) / (bucket_layers[spoint_bucket[i]] > 0 ? bucket_layers[spoint_bucket[i]] : 1);
+
+            //if (!std::isfinite(spoint_x[i]))         ATH_MSG_WARNING("X is NaN: " << spoint_x[i] << " in sp " << i);
+            //if (!std::isfinite(spoint_y[i]))         ATH_MSG_WARNING("Y is NaN: " << spoint_y[i] << " in sp " << i);
+            //if (!std::isfinite(spoint_z[i]))         ATH_MSG_WARNING("Z is NaN: " << spoint_z[i] << " in sp " << i);
+            //if (!std::isfinite(spoint_driftR[i]))    ATH_MSG_WARNING("DriftR is NaN: " << spoint_driftR[i] << " in sp " << i);
+            //if (!std::isfinite(spoint_neighbors[i])) ATH_MSG_WARNING("Neighbors is NaN: " << spoint_neighbors[i] << " in sp " << i);
+            //if (!std::isfinite(bucket_density[spoint_bucket[i]])) ATH_MSG_WARNING("Density is NaN: " << bucket_density[spoint_bucket[i]] << " in sp " << i);
+            //if (!std::isfinite(relative_layer)) ATH_MSG_WARNING("Relative Layer is NaN: " << relative_layer << " in sp " << i);
+
+            features.insert(features.end(), {
+                spoint_x[i],
+                spoint_y[i],
+                spoint_z[i],
+                static_cast<float>(spoint_station[i]), 
+                spoint_driftR[i],
+                relative_layer,
+                static_cast<float>(spoint_neighbors[i]),
+                static_cast<float>(bucket_density[spoint_bucket[i]]) 
+            });
+
             if (Labels) {
-                labels.push_back(m_spoint_label[i]);
+                labels.push_back(spoint_label[i]);
             }
         }
 
-        if (features.empty()) {
-            ATH_MSG_WARNING("No valid feature data available for inference. Skipping event.");
-            return StatusCode::SUCCESS;
+        for (size_t i = 0; i < edge_src.size(); i++) {
+            if (edge_src[i] >= total_nodes || edge_dst[i] >= total_nodes) {
+                ATH_MSG_WARNING("Edge connection out of range: (" << edge_src[i] << ", " << edge_dst[i] << ") - Total nodes: " << total_nodes);
+            }
         }
 
-        std::vector<int64_t> edge_index;
         edge_index.reserve(2 * edge_src.size());
         edge_index.insert(edge_index.end(), edge_src.begin(), edge_src.end());
         edge_index.insert(edge_index.end(), edge_dst.begin(), edge_dst.end());
@@ -291,13 +355,17 @@ namespace MuonR4{
             ATH_MSG_ERROR("Edge index format error! Size should be divisible by 2.");
         }
 
+        //for (size_t i = 0; i < features.size(); i++) {
+        //    if (!std::isfinite(features[i])) {
+        //        ATH_MSG_WARNING("Feature[" << i << "] contains NaN or Inf: " << features[i]);
+        //    }
+        //}
+        //for (size_t i = 0; i < edge_index.size(); i++) {
+        //    if (edge_index[i] < 0) {
+        //        ATH_MSG_WARNING("Edge index[" << i << "] is negative: " << edge_index[i]);
+        //    }
+        //}
         
-        //Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXInference");
-        //Ort::SessionOptions session_options;
-        //session_options.SetIntraOpNumThreads(1);
-        //session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
-        //Ort::Session session(env, "torch_GatFourier_fcg_quantized.onnx", session_options);
-
 
         Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
 
@@ -310,21 +378,25 @@ namespace MuonR4{
         Ort::Value edge_tensor = Ort::Value::CreateTensor<int64_t>(
             memory_info, edge_index.data(), edge_index.size(), edge_shape.data(), edge_shape.size());
         
-        std::vector<const char*> input_names = {"features", "edge_index"};
-        std::vector<const char*> output_names = {"output"};
 
-        std::vector<Ort::Value> input_tensors;
         input_tensors.push_back(std::move(feature_tensor));
         input_tensors.push_back(std::move(edge_tensor));
 
-        //std::vector<Ort::Value> output_tensors = session.Run(
+        Ort::RunOptions run_options;
+        run_options.SetRunLogSeverityLevel(ORT_LOGGING_LEVEL_WARNING);
         std::vector<Ort::Value> output_tensors = m_session->Run(
-            Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(),
+            run_options, input_names.data(), input_tensors.data(),
             input_tensors.size(), output_names.data(), output_names.size());
 
+        //std::vector<Ort::Value> output_tensors = m_session->Run(
+        //    Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(),
+        //    input_tensors.size(), output_names.data(), output_names.size());
+
         float* output_data = output_tensors[0].GetTensorMutableData<float>();
         size_t output_size = output_tensors[0].GetTensorTypeAndShapeInfo().GetElementCount();
+
         std::vector<float> predictions(output_data, output_data + output_size);
+        ATH_MSG_DEBUG("Predictions shape: (" << predictions.size() << ")");
 
         for (size_t i = 0; i < output_size; i++) {
             if (!std::isfinite(predictions[i])) {
@@ -339,13 +411,14 @@ namespace MuonR4{
             ATH_MSG_DEBUG("Prediction[" << i << "]: " << predictions[i]);
         }
 
-        std::vector<int> binary_predictions(output_size);
+        std::vector<u_int8_t> binary_predictions(output_size);
+        ATH_MSG_DEBUG("Binary shape: (" << binary_predictions.size() << ")");
+        ATH_MSG_DEBUG("Cut Value: " << m_SPId_cut);
         for (size_t i = 0; i < output_size; i++) {
-            binary_predictions[i] = (predictions[i] > m_SPId_cut) ? 1 : 0;   // Why inference cut is different? -> this batch/graph building may be different from torch one?
-        }
-        ATH_MSG_DEBUG("First 5 Binary Predictions:");
-        for (size_t i = 0; i < std::min(output_size, size_t(5)); i++) {
-            ATH_MSG_DEBUG("Binary_Prediction[" << i << "]: " << binary_predictions[i]);
+            binary_predictions[i] = (predictions[i] > m_SPId_cut) ? 1 : 0; 
+            if (i < 5) {
+                ATH_MSG_DEBUG("Binary_Prediction[" << i << " of first 5]: " << static_cast<int>(binary_predictions[i]));
+            }
         }
 
         if (Labels) {
@@ -354,7 +427,7 @@ namespace MuonR4{
             float pred_1   = 0;
             float lab_1    = 0;
             float true_0   = 0;
-            //float true_1   = 0;
+            
             for (size_t i = 0; i < output_size; i++) {
                 if (binary_predictions[i] != labels[i] && labels[i] == 1) {
                     loss++;
@@ -362,9 +435,6 @@ namespace MuonR4{
                 if (binary_predictions[i] == labels[i]) {
                     accuracy++;
                 }
-                //if (binary_predictions[i] == labels[i] && labels[i] == 1) {
-                //    true_1++;
-                //}
                 if (binary_predictions[i] == labels[i] && labels[i] == 0) {
                     true_0++;
                 }
@@ -375,23 +445,51 @@ namespace MuonR4{
                     lab_1++;
                 }       
             }
-            //ATH_MSG_ALWAYS("Hits:            " << output_size );
-            ATH_MSG_ALWAYS("Accuracy:        " << accuracy / output_size          << " ( Correct: " << accuracy << " / Hits: " << output_size << ")" );
-            ATH_MSG_ALWAYS("Signal Loss:     " << loss     / lab_1                << " ( Miss predicted: " << loss << " / Good Hits: " << lab_1 << ")" );
+            ATH_MSG_DEBUG("Accuracy:        " << accuracy / output_size          << " ( Correct: " << accuracy << " / Hits: " << output_size << ")" );
+            ATH_MSG_DEBUG("Signal Loss:     " << loss     / lab_1                << " ( Miss predicted: " << loss << " / Good Hits: " << lab_1 << ")" );
             float lab_0 = output_size - lab_1;
             if (lab_0 != 0) {
-                ATH_MSG_ALWAYS("Real Rejection:  " << true_0 / (output_size - lab_1)  << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" );
-            } else ATH_MSG_ALWAYS("Real Rejection:  " << 0  << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" );
-            ATH_MSG_ALWAYS("Purity (before): " << lab_1  / output_size            << " ( Good Hits: " << lab_1 << " / Hits: " << output_size << ")" );            
-            ATH_MSG_ALWAYS("Purity (after):  " << (lab_1 - loss) / pred_1         << " ( Good predicted: " << lab_1 - loss << " / Predicted: " << pred_1 << ")" );
-            ATH_MSG_ALWAYS("-----------------------------");
+                ATH_MSG_DEBUG("Real Rejection:  " << true_0 / (output_size - lab_1)  << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" );
+            } else ATH_MSG_DEBUG("Real Rejection:  " << 0  << " ( Correct Rejection: " << true_0 << " / Bad Hits: " << output_size - lab_1 << ")" );
+            ATH_MSG_DEBUG("Purity (before): " << lab_1  / output_size            << " ( Good Hits: " << lab_1 << " / Hits: " << output_size << ")" );            
+            ATH_MSG_DEBUG("Purity (after):  " << (lab_1 - loss) / pred_1         << " ( Good predicted: " << lab_1 - loss << " / Predicted: " << pred_1 << ")" );
+            ATH_MSG_DEBUG("-----------------------------");
         }
 
+        spoint_predictions.reserve(output_size);
         for (size_t i = 0; i < output_size; i++) {
-            m_spoint_predictions.push_back(binary_predictions[i]);
+            spoint_predictions.push_back(predictions[i]);
+        }
+
+        if (bucket_density.size() != bucket_layers.size() || bucket_density.size() != bucket_points.size()) {
+            ATH_MSG_ERROR("Bucket size mismatch! Density: " << bucket_density.size() << ", Layers: " << bucket_layers.size() << ", Points: " << bucket_points.size());
         }
+        size_t last_point = 0;
+        for (size_t i = 0; i < bucket_points.size(); i++) {
+            m_bucket_layers  = bucket_layers[i];
+            m_bucket_points  = bucket_points[i];
+            m_bucket_density = bucket_density[i];
+
+            size_t end_point = last_point + bucket_points[i];
+            for (size_t j = last_point; j < end_point; j++) {
+                m_spoint_bucket.push_back(spoint_bucket[j]);
+                m_spoint_x.push_back(spoint_x[j]);
+                m_spoint_y.push_back(spoint_y[j]);
+                m_spoint_z.push_back(spoint_z[j]);
+                m_spoint_station.push_back(spoint_station[j]);
+                m_spoint_layer.push_back(spoint_layer[j]);
+                m_spoint_driftR.push_back(spoint_driftR[j]);
+                m_spoint_neighbors.push_back(spoint_neighbors[j]);
+                m_spoint_label.push_back(spoint_label[j]);
+                m_spoint_predictions.push_back(spoint_predictions[j]);
+                m_spoint_edges.push_back(spoint_edges[j] - node_offsets[i]);
+            }
+        
+            last_point = end_point;
 
-        if (!m_tree.fill(ctx)) return StatusCode::FAILURE;
+            if (!m_tree.fill(ctx)) return StatusCode::FAILURE;
+        }
+        
         return StatusCode::SUCCESS;
 
     }
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h
index 6a2f9b1c7d90..2467924774c2 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/src/SPIdDumperAlg.h
@@ -55,35 +55,30 @@ class SPIdDumperAlg: public AthHistogramAlgorithm {
 
     SG::ReadHandleKey<MuonR4::SegmentContainer> m_inSegmentKey{this, "SegmentKey", "R4MuonSegments"};
     
-    Gaudi::Property<bool> m_isMC{this, "isMC", true};
-    Gaudi::Property<double> m_SPId_cut{this,"spIdValue", 0.0001};
-    Gaudi::Property<std::string> m_modelPath{this, "ModelPath", "torch_GatFourier_fcg_quantized.onnx"};
+    Gaudi::Property<bool>         m_isMC{this, "isMC", true};
+    Gaudi::Property<double>       m_SPId_cut{this,"spIdValue", -7};
+    Gaudi::Property<std::string>  m_modelPath{this, "ModelPath", "EdgeGAT_FCG_8vars_quantized.onnx"};
 
     MuonVal::MuonTesterTree m_tree{"MuonSPId","MuonSPId"};
 
-    MuonVal::VectorBranch<float>&           m_bucket_density{m_tree.newVector<float>("bucket_density", 0)};
-    MuonVal::VectorBranch<uint8_t>&         m_bucket_layers{m_tree.newVector<uint8_t>("bucket_layers", 0)}; 
+    MuonVal::ScalarBranch<uint8_t>&         m_bucket_layers{m_tree.newScalar<uint8_t>("bucket_layers", 0)}; 
+    MuonVal::ScalarBranch<uint8_t>&         m_bucket_points{m_tree.newScalar<uint8_t>("bucket_points", 0)};
+    MuonVal::ScalarBranch<float>&           m_bucket_density{m_tree.newScalar<float>("bucket_density", 0)};
 
     MuonVal::VectorBranch<uint8_t>&         m_spoint_bucket{m_tree.newVector<uint8_t>("bucket_index")};
-
-    //MuonVal::ThreeVectorBranch              m_spoint_localPosition{m_tree, "localPosition"}; 
     MuonVal::VectorBranch<float>&           m_spoint_x{m_tree.newVector<float>("x")};
     MuonVal::VectorBranch<float>&           m_spoint_y{m_tree.newVector<float>("y")};
     MuonVal::VectorBranch<float>&           m_spoint_z{m_tree.newVector<float>("z")};
 
-    //MuonVal::MuonIdentifierBranch           m_spoint_id{m_tree, "id"};
     MuonVal::VectorBranch<uint8_t>&         m_spoint_station{m_tree.newVector<uint8_t>("stationIndex")};
-    MuonVal::VectorBranch<uint8_t>&         m_spoint_layer{m_tree.newVector<uint8_t>("layer")};
-
+    MuonVal::VectorBranch<float>&           m_spoint_layer{m_tree.newVector<float>("layer")};
     MuonVal::VectorBranch<float>&           m_spoint_driftR{m_tree.newVector<float>("driftR")};
-    MuonVal::VectorBranch<float>&           m_spoint_neighbors{m_tree.newVector<float>("neighbors")};
+    MuonVal::VectorBranch<uint8_t>&         m_spoint_neighbors{m_tree.newVector<uint8_t>("neighbors")};
 
     MuonVal::VectorBranch<uint8_t>&         m_spoint_label{m_tree.newVector<uint8_t>("label")};
-    MuonVal::VectorBranch<uint8_t>&         m_spoint_predictions{m_tree.newVector<uint8_t>("predictions")};
+    MuonVal::VectorBranch<float>&           m_spoint_predictions{m_tree.newVector<float>("predictions")};
     MuonVal::VectorBranch<uint16_t>&        m_spoint_edges{m_tree.newVector<uint16_t>("edges")};
 
-    size_t m_event{0};
-
     // ONNX Runtime objects
     std::unique_ptr<Ort::Env> m_env;
     std::unique_ptr<Ort::SessionOptions> m_session_options;
-- 
GitLab


From 5ab3b6f25ae67e0e00b3b291f87c34ad9aa13e43 Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Mon, 24 Feb 2025 11:39:13 +0100
Subject: [PATCH 3/9] Update in MuonSPId

---
 .../MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx  | 2 +-
 .../MuonInferenceInterfaces/src/NodeFeatureFactory.cxx         | 2 ++
 .../MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py   | 3 ++-
 .../MuonLearning/MuonSPId/python/muonSPIdDump_data.py          | 2 +-
 4 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx
index 51e666263121..d55804c8e2b4 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphInferenceToolBase.cxx
@@ -23,7 +23,7 @@ namespace MuonML{
             session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
 
             m_model = std::make_unique<Ort::Session>(env, modelPath.c_str(), session_options);
-            ATH_MSG_DEBUG("Successfully loaded infernce model from "<<modelPath);
+            ATH_MSG_DEBUG("Successfully loaded inference model from "<<modelPath);
         
             
             Ort::ModelMetadata metadata = m_model->GetModelMetadata();
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
index daf78ffe3330..b9490d5d9ec1 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
@@ -45,6 +45,8 @@ namespace MuonML {
                         [](const Bucket_t& bucket, size_t index) {
                             return bucket[index]->positionInChamber().x();
                         }),
+
+                        // loop over the bucket and get Neighbors
             }; 
             const auto feat_itr = featurePool.find(featName);
             if(feat_itr != featurePool.end()){
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py
index a07d728e8396..cff3b9f224dd 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump.py
@@ -15,7 +15,8 @@ if __name__=="__main__":
     flags = initConfigFlags()
     flags.PerfMon.doFullMonMT = True
 
-    flags, cfg = setupGeoR4TestCfg(args)
+    #flags, cfg = setupGeoR4TestCfg(args)
+    flags, cfg = setupGeoR4TestCfg(args,flags)
 
     cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile, outStream="MuonSPId"))
 
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
index 187225123d58..4a5d6dd23ee8 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonSPId/python/muonSPIdDump_data.py
@@ -16,7 +16,7 @@ if __name__=="__main__":
     flags = initConfigFlags()
     flags.PerfMon.doFullMonMT = True
 
-    flags, cfg = setupGeoR4TestCfg(args)
+    flags, cfg = setupGeoR4TestCfg(args,flags)
 
     cfg.merge(setupHistSvcCfg(flags,outFile=args.outRootFile,
                                     outStream="MuonSPId"))
-- 
GitLab


From 21b77bf38a22d4435c79f3776f8e19767ac4e5c5 Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Mon, 24 Feb 2025 13:26:31 +0100
Subject: [PATCH 4/9] Adding features

---
 .../src/NodeFeatureFactory.cxx                | 40 +++++++++++++++----
 1 file changed, 32 insertions(+), 8 deletions(-)

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
index b9490d5d9ec1..a6d4a7398fd0 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
@@ -37,16 +37,40 @@ namespace MuonML {
 
             /** Predefine the known features in the pool  */
             static const std::set<Feature_t, std::less<>> featurePool{
-                std::make_unique<NodeFeature>("driftR", 
-                        [](const Bucket_t& bucket, size_t index) {
-                            return bucket[index]->driftRadius();
-                        }),
                 std::make_unique<NodeFeature>("localX", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->positionInChamber().x();
+                    }),
+                std::make_unique<NodeFeature>("localY", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->positionInChamber().y();
+                    }),
+                std::make_unique<NodeFeature>("localZ", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->positionInChamber().z();
+                    }),
+                std::make_unique<NodeFeature>("driftR", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->driftRadius();
+                    }),
+                std::make_unique<NodeFeature>("relativeLayer", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket.layerNum(index) / (bucket.nStripLayers() + bucket.nMdtLayers());
+                    }),
+                std::make_unique<NodeFeature>("neighbours", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        constexpr double radCut2 =  (50 *Gaudi::Units::cm * 50.*Gaudi::Units::cm);
+                        unsigned int n =0;
+                        for (size_t other =0 ; other < bucket.size(); ++ other){
+                            n+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2;
+                        }
+                        return n;
+                    }),
+                std::make_unique<NodeFeature>("density", 
                         [](const Bucket_t& bucket, size_t index) {
-                            return bucket[index]->positionInChamber().x();
-                        }),
-
-                        // loop over the bucket and get Neighbors
+                            return bucket.size() / (bucket.coveredMax() - bucket.coveredMin());
+                    }),
+                
             }; 
             const auto feat_itr = featurePool.find(featName);
             if(feat_itr != featurePool.end()){
-- 
GitLab


From 399b7e96554797e5429bce8d543cba0870c6c40f Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Mon, 24 Feb 2025 14:14:11 +0100
Subject: [PATCH 5/9] update of feature implementation

---
 .../src/NodeFeatureFactory.cxx                | 29 ++++++++++++++-----
 1 file changed, 21 insertions(+), 8 deletions(-)

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
index a6d4a7398fd0..a262627223a2 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
@@ -34,8 +34,6 @@ namespace MuonML {
      
     namespace Factory {
         Feature_t makeFeature(const std::string& featName, MsgStream& log) {
-
-            /** Predefine the known features in the pool  */
             static const std::set<Feature_t, std::less<>> featurePool{
                 std::make_unique<NodeFeature>("localX", 
                     [](const Bucket_t& bucket, size_t index) {
@@ -44,11 +42,16 @@ namespace MuonML {
                 std::make_unique<NodeFeature>("localY", 
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket[index]->positionInChamber().y();
+                        bucket[index]->positionInChamber().y();
                     }),
                 std::make_unique<NodeFeature>("localZ", 
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket[index]->positionInChamber().z();
                     }),
+                std::make_unique<NodeFeature>("stationIndex", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->
+                    }),
                 std::make_unique<NodeFeature>("driftR", 
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket[index]->driftRadius();
@@ -57,9 +60,9 @@ namespace MuonML {
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket.layerNum(index) / (bucket.nStripLayers() + bucket.nMdtLayers());
                     }),
-                std::make_unique<NodeFeature>("neighbours", 
+                std::make_unique<NodeFeature>("neighbors", 
                     [](const Bucket_t& bucket, size_t index) {
-                        constexpr double radCut2 =  (50 *Gaudi::Units::cm * 50.*Gaudi::Units::cm);
+                        constexpr double radCut2 =  (50.*Gaudi::Units::cm * 50.*Gaudi::Units::cm);
                         unsigned int n =0;
                         for (size_t other =0 ; other < bucket.size(); ++ other){
                             n+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2;
@@ -67,10 +70,20 @@ namespace MuonML {
                         return n;
                     }),
                 std::make_unique<NodeFeature>("density", 
-                        [](const Bucket_t& bucket, size_t index) {
-                            return bucket.size() / (bucket.coveredMax() - bucket.coveredMin());
-                    }),
-                
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket.size() / (bucket.coveredMax() - bucket.coveredMin());    // missing max and min of bucket
+                    }), 
+                // can I add isolation already?
+                std::make_unique<NodeFeature>("isolation", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        unsigned int  neighbors = 0;
+                        constexpr double radCut2 =  (50.*Gaudi::Units::cm * 50.*Gaudi::Units::cm);
+                        for (size_t other =0 ; other < bucket.size(); ++ other){
+                            neighbors+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2;
+                        }
+                        float bucket_density = bucket.size() / (bucket.coveredMax() - bucket.coveredMin());
+                        return neighbors / bucket_density;
+                    }), 
             }; 
             const auto feat_itr = featurePool.find(featName);
             if(feat_itr != featurePool.end()){
-- 
GitLab


From 2530c6e07f9435e56ca285a1ad89c527790a30ff Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Mon, 24 Feb 2025 15:24:11 +0100
Subject: [PATCH 6/9] Adding station index

---
 .../MuonInferenceInterfaces/LayerBucket.h             | 11 +++++++++++
 .../src/NodeFeatureFactory.cxx                        |  6 +++++-
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h
index bb083b0e3a5a..7eee70727f88 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/MuonInferenceInterfaces/LayerBucket.h
@@ -5,6 +5,7 @@
 #define MUONINFERENCEINTERACES_LAYERBUCKET_H
 
 #include "MuonSpacePoint/SpacePointContainer.h"
+#include "GaudiKernel/SystemOfUnits.h"
 namespace MuonML{
     /** @brief The LayerSpBucket is a space pointbucket where the points are internally 
      *         sorted by their layer number as defined in the SpacePointLayerSorter. The
@@ -26,10 +27,20 @@ namespace MuonML{
           uint8_t layerNum(const size_t i) const {
             return m_layNum[i];
           }
+          /** @brief Returns the max covered position of the bucket */
+          uint8_t coveredMax() const {
+              return m_max;
+          }
+          /** @brief Returns the min covered position of the bucket */
+          uint8_t coveredMin() const {
+              return m_min;
+          }
         private:
             uint8_t m_nMdtLay{0};
             uint8_t m_nStripLay{0};
             std::vector<uint8_t> m_layNum{};
+            double m_min{-20. *Gaudi::Units::m};
+            double m_max{20. * Gaudi::Units::m};
     };
 }
 #endif
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
index a262627223a2..5ea583274abb 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
@@ -50,7 +50,11 @@ namespace MuonML {
                     }),
                 std::make_unique<NodeFeature>("stationIndex", 
                     [](const Bucket_t& bucket, size_t index) {
-                        return bucket[index]->
+                        // use idHelperSvc to get stationName
+                        auto idHelper     = bucket[index]->msSector()->idHelperSvc();
+                        //auto idHelper   = sector->idHelperSvc();
+                        auto stationIdx = idHelper->stationName(bucket[index]->identify());
+                        return stationIdx;
                     }),
                 std::make_unique<NodeFeature>("driftR", 
                     [](const Bucket_t& bucket, size_t index) {
-- 
GitLab


From a527a36844517b5235ce2127a74fd69786749eab Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Mon, 24 Feb 2025 16:06:32 +0100
Subject: [PATCH 7/9] Adding station index

---
 .../src/NodeFeatureFactory.cxx                | 29 ++++++++++++++-----
 1 file changed, 22 insertions(+), 7 deletions(-)

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
index 5ea583274abb..44ca8839cd16 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
@@ -50,11 +50,15 @@ namespace MuonML {
                     }),
                 std::make_unique<NodeFeature>("stationIndex", 
                     [](const Bucket_t& bucket, size_t index) {
-                        // use idHelperSvc to get stationName
-                        auto idHelper     = bucket[index]->msSector()->idHelperSvc();
-                        //auto idHelper   = sector->idHelperSvc();
-                        auto stationIdx = idHelper->stationName(bucket[index]->identify());
-                        return stationIdx;
+                        return bucket[index]->msSector()->idHelperSvc()->stationName(bucket[index]->identify());
+                    }),
+                std::make_unique<NodeFeature>("stationPhi", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->msSector()->idHelperSvc()->stationPhi(bucket[index]->identify());
+                    }),
+                std::make_unique<NodeFeature>("stationEta", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->msSector()->idHelperSvc()->stationEta(bucket[index]->identify());
                     }),
                 std::make_unique<NodeFeature>("driftR", 
                     [](const Bucket_t& bucket, size_t index) {
@@ -77,7 +81,6 @@ namespace MuonML {
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket.size() / (bucket.coveredMax() - bucket.coveredMin());    // missing max and min of bucket
                     }), 
-                // can I add isolation already?
                 std::make_unique<NodeFeature>("isolation", 
                     [](const Bucket_t& bucket, size_t index) {
                         unsigned int  neighbors = 0;
@@ -85,9 +88,21 @@ namespace MuonML {
                         for (size_t other =0 ; other < bucket.size(); ++ other){
                             neighbors+= index != other && (bucket[index]->positionInChamber() - bucket[other]->positionInChamber()).perp2() < radCut2;
                         }
-                        float bucket_density = bucket.size() / (bucket.coveredMax() - bucket.coveredMin());
+                        float bucket_density = float(bucket.size()) / (bucket.coveredMax() - bucket.coveredMin());
                         return neighbors / bucket_density;
                     }), 
+                std::make_unique<NodeFeature>("covX", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->covariance()(Amg::x, Amg::x);
+                    }),
+                std::make_unique<NodeFeature>("covY", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->covariance()(Amg::y, Amg::y);
+                    }),
+                std::make_unique<NodeFeature>("covXY", 
+                    [](const Bucket_t& bucket, size_t index) {
+                        return bucket[index]->covariance()(Amg::x, Amg::y);
+                    }),
             }; 
             const auto feat_itr = featurePool.find(featName);
             if(feat_itr != featurePool.end()){
-- 
GitLab


From ed71f9ad26ea3d8ebc5aee392c4835a10df2e53f Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Tue, 25 Feb 2025 13:49:32 +0100
Subject: [PATCH 8/9] Fixing variables name

---
 .../MuonInferenceInterfaces/src/NodeFeatureFactory.cxx        | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
index 44ca8839cd16..f326370fa3c0 100644
--- a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInferenceInterfaces/src/NodeFeatureFactory.cxx
@@ -64,7 +64,7 @@ namespace MuonML {
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket[index]->driftRadius();
                     }),
-                std::make_unique<NodeFeature>("relativeLayer", 
+                std::make_unique<NodeFeature>("relative_layer", 
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket.layerNum(index) / (bucket.nStripLayers() + bucket.nMdtLayers());
                     }),
@@ -77,7 +77,7 @@ namespace MuonML {
                         }
                         return n;
                     }),
-                std::make_unique<NodeFeature>("density", 
+                std::make_unique<NodeFeature>("bucket_density", 
                     [](const Bucket_t& bucket, size_t index) {
                         return bucket.size() / (bucket.coveredMax() - bucket.coveredMin());    // missing max and min of bucket
                     }), 
-- 
GitLab


From bfbfb3acb1ebbde07dfca4672954b923b7b7f36a Mon Sep 17 00:00:00 2001
From: Davide Di Croce <davide.di.croce@cern.ch>
Date: Wed, 26 Feb 2025 15:39:56 +0100
Subject: [PATCH 9/9] aInference module

---
 .../src/GraphBucketFilterTool.cxx             | 31 +++++++++++++++++++
 .../MuonInference/src/GraphBucketFilterTool.h | 26 ++++++++++++++++
 2 files changed, 57 insertions(+)
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx
 create mode 100644 MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h

diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx
new file mode 100644
index 000000000000..b69522b9269f
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.cxx
@@ -0,0 +1,31 @@
+
+
+#include "GraphBucketFilterTool.h"
+
+#include "StoreGate/ReadHandle.h"
+#include "StoreGate/WriteHandle.h"
+namesspace MuonML{
+    StatusCode GraphBucketFilterTool::runGraphInference(const EventContext& ctx,
+                                                        GraphRawData& graph) const {
+
+        ATH_CHECK(buildGraph(ctx, graph));
+        SG::WriteHandle filteredSpacePoints{m_writeKey, ctx};
+        ATH_CHECK(filteredSpacePoints.record(std::make_unique<MuonR4::SpacePointContainer>()));
+        //// Inference part
+    
+        SG::ReadHandle inputSpacePoints{m_readKey, ctx};
+        ATH_CHECK(inputSpacePoints.isPresent());
+        for (const MuonR4::SpacePointBucket* bucket : *inputSpacePoints) {
+            std::unique_ptr<MuonR4::SpacePointBucket> filteredBucket = std::make_unique<MuonR4::SpacePointBucket>();
+            filteredSpacePoints->push_back(std::move(filteredBucket));
+        }
+        return StatusCode::SUCCESS;
+    }
+
+
+    StatusCode GraphBucketFilterTool::initialize() {
+        ATH_CHECK(setupModel());
+        ATH_CHECK(m_writeKey.initialize());
+        return StatusCode::SUCCESS;
+    }
+}
\ No newline at end of file
diff --git a/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h
new file mode 100644
index 000000000000..d03d97f57a5f
--- /dev/null
+++ b/MuonSpectrometer/MuonPhaseII/MuonLearning/MuonInference/src/GraphBucketFilterTool.h
@@ -0,0 +1,26 @@
+#ifndef MUONINFERENCE_GRAPHBUCKETFILTERTOOL_H
+#define MUONINFERENCE_GRAPHBUCKETFILTERTOOL_H
+
+#include "GraphInferenceToolBase.h"
+#include "StoreGate/WriteHandleKey.h"
+
+#include "MuonSpacePoint/SpacePointContainer.h"
+namespace MuonML{
+    class GraphBucketFilterTool : public GraphInferenceToolBase {
+        public:
+            using GraphInferenceToolBase::GraphInferenceToolBase;
+
+            virtual StatusCode runGraphInference(const EventContext& ctx,
+                                                 GraphRawData& graph) const override final;
+
+
+            virtual StatusCode initialize() override final;
+        private:
+            
+        SG::WriteHandleKey<MuonR4::SpacePointContainer> m_writeKey{this, "WriteSpacePointKey", "FilteredMlSpacePoints"};
+
+
+    }
+}
+
+#endif
\ No newline at end of file
-- 
GitLab