From 04afffcb4e7cae4e4b7679849fd67cfd8ed06d62 Mon Sep 17 00:00:00 2001 From: Romain Bouquet <romain.bouquet04@gmail.com> Date: Tue, 18 Feb 2025 17:53:42 +0000 Subject: [PATCH] FTAG OnnxUtil and OnnxOutput class renaming to SaltModel and SaltModelOutput FTAG OnnxUtil and OnnxOutput class renaming to SaltModel and SaltModelOutput --- .../ConstituentsLoader.h | 2 +- .../DataPrepUtilities.h | 2 +- .../InDetGNNHardScatterSelection/GNN.h | 12 +++--- .../InDetGNNHardScatterSelection/Root/GNN.cxx | 12 +++--- .../LeptonTaggers/DecoratePLIT.h | 6 +-- .../LeptonTaggers/src/DecoratePLIT.cxx | 14 +++---- .../FlavorTagDiscriminants/CMakeLists.txt | 4 +- .../ConstituentsLoader.h | 2 +- .../DataPrepUtilities.h | 2 +- .../FlavorTagDiscriminants/GNN.h | 14 +++---- .../{OnnxUtil.h => SaltModel.h} | 32 +++++++------- .../{OnnxOutput.h => SaltModelOutput.h} | 18 ++++---- .../FlavorTagDiscriminants/Root/GNN.cxx | 40 +++++++++--------- .../Root/{OnnxUtil.cxx => SaltModel.cxx} | 42 +++++++++---------- .../{OnnxOutput.cxx => SaltModelOutput.cxx} | 14 +++---- .../Root/BTaggingEfficiencyTool.cxx | 10 ++--- .../xAODBTaggingEfficiency/Root/OnnxUtil.cxx | 10 ++--- .../BTaggingEfficiencyTool.h | 4 +- .../{OnnxUtil.h => SaltModel.h} | 14 +++---- Reconstruction/tauRecTools/Root/TauGNN.cxx | 14 +++---- .../tauRecTools/tauRecTools/TauGNN.h | 10 ++--- 21 files changed, 139 insertions(+), 139 deletions(-) rename PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/{OnnxUtil.h => SaltModel.h} (73%) rename PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/{OnnxOutput.h => SaltModelOutput.h} (67%) rename PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/{OnnxUtil.cxx => SaltModel.cxx} (85%) rename PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/{OnnxOutput.cxx => SaltModelOutput.cxx} (74%) rename PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/{OnnxUtil.h => SaltModel.h} (78%) diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h index 82f1ed164fa5..aa2f42aedc8f 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h @@ -10,7 +10,7 @@ #define INDET_CONSTITUENTS_LOADER_H // local includes -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // EDM includes #include "xAODTracking/Vertex.h" diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h index 0ef8be167baf..a1949ead20da 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h @@ -7,7 +7,7 @@ // local includes #include "InDetGNNHardScatterSelection/ConstituentsLoader.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // EDM includes #include "xAODTracking/Vertex.h" diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h index 0eaa0ac39be2..05353c48158c 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h @@ -1,10 +1,10 @@ /* Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration - This class is used in conjunction with OnnxUtil to run inference on a GNN model. - Whereas OnnxUtil handles the interfacing with the ONNX runtime, this class handles + This class is used in conjunction with SaltModel to run inference on a GNN model. + Whereas SaltModel handles the interfacing with the ONNX runtime, this class handles the interfacing with the ATLAS EDM. It is responsible for collecting all the inputs - needed for inference, running inference (via OnnxUtil), and decorating the results + needed for inference, running inference (via SaltModel), and decorating the results back to ATLAS EDM. */ @@ -14,7 +14,7 @@ // Tool includes #include "InDetGNNHardScatterSelection/DataPrepUtilities.h" #include "InDetGNNHardScatterSelection/IParticlesLoader.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // EDM includes #include "xAODTracking/VertexFwd.h" @@ -49,7 +49,7 @@ namespace InDetGNNHardScatterSelection { virtual void decorate(const xAOD::Vertex& verrtex) const; - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel; private: // type definitions for ONNX output decorators using TPC = xAOD::TrackParticleContainer; @@ -66,7 +66,7 @@ namespace InDetGNNHardScatterSelection { }; /* create all decorators */ - std::set<std::string> createDecorators(const FlavorTagDiscriminants::OnnxUtil::OutputConfig& outConfig); + std::set<std::string> createDecorators(const FlavorTagDiscriminants::SaltModel::OutputConfig& outConfig); std::string m_input_node_name; std::vector<internal::VarFromVertex> m_varsFromVertex; diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx index 44dca96d8068..081a5b8d982c 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx @@ -3,7 +3,7 @@ */ #include "InDetGNNHardScatterSelection/GNN.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "PathResolver/PathResolver.h" @@ -19,15 +19,15 @@ namespace InDetGNNHardScatterSelection { GNN::GNN(const std::string& nn_file): - m_onnxUtil(nullptr) + m_saltModel(nullptr) { // Load and initialize the neural network model from the given file path. std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file); - m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile); + m_saltModel = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile); // Extract metadata from the ONNX file, primarily about the model's inputs. - auto lwt_config = m_onnxUtil->getLwtConfig(); + auto lwt_config = m_saltModel->getLwtConfig(); // Create configuration objects for data preprocessing. auto [inputs, constituents_configs] = dataprep::createGetterConfig(lwt_config); @@ -58,7 +58,7 @@ namespace InDetGNNHardScatterSelection { m_varsFromVertex = dataprep::createVertexVarGetters(inputs); // Retrieve the configuration for the model outputs. - FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + FlavorTagDiscriminants::SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig(); for (const auto& outNode : gnn_output_config) { // the node's output name will be used to define the decoration name @@ -95,7 +95,7 @@ namespace InDetGNNHardScatterSelection { // run inference // ------------- - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); // decorate outputs // ---------------- diff --git a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h index fb7b79948464..af2227128675 100644 --- a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h +++ b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h @@ -12,7 +12,7 @@ // Tools #include "PathResolver/PathResolver.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // Athena #include "AsgDataHandles/WriteDecorHandle.h" @@ -39,8 +39,8 @@ namespace Prompt { virtual StatusCode execute (const EventContext&) const override; private: - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil{}; - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil_endcap{}; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel{}; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel_endcap{}; int m_num_lepton_features{}; int m_num_track_features{}; diff --git a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx index 3d0d3f25d520..2c1df8a5f15f 100644 --- a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx +++ b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx @@ -38,10 +38,10 @@ namespace Prompt { // Load and initialize the neural network model from the given file path. if(m_leptonsName == "Electrons") { std::string fullPathToOnnxFile = PathResolverFindCalibFile(m_configPath.value() + m_configFileVersion.value()); - m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile); + m_saltModel = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile); std::string fullPathToOnnxFile_endcap = PathResolverFindCalibFile(m_configPath.value() + m_configFileVersion_endcap.value()); - m_onnxUtil_endcap = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile_endcap); + m_saltModel_endcap = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile_endcap); m_num_lepton_features = 15; m_num_track_features = 19; @@ -56,7 +56,7 @@ namespace Prompt { std::vector<int64_t> track_feat_dim = {1, m_num_track_features}; FlavorTagDiscriminants::Inputs track_info(track_feat, track_feat_dim); gnn_input.insert({"track_features", track_info}); - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); std::vector<std::string> output_names; for (auto& singlefloat : out_f){ @@ -67,7 +67,7 @@ namespace Prompt { } else if (m_leptonsName == "Muons") { std::string fullPathToOnnxFile = PathResolverFindCalibFile(m_configPath.value() + m_configFileVersion.value()); - m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile); + m_saltModel = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile); m_num_lepton_features = 10; m_num_track_features = 18; @@ -82,7 +82,7 @@ namespace Prompt { std::vector<int64_t> track_feat_dim = {1, m_num_track_features}; FlavorTagDiscriminants::Inputs track_info(track_feat, track_feat_dim); gnn_input.insert({"track_features", track_info}); - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); std::vector<std::string> output_names; for (auto& singlefloat : out_f){ @@ -361,7 +361,7 @@ namespace Prompt { // run inference // ------------- - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); if (msgLvl(MSG::VERBOSE)) { ATH_MSG_VERBOSE("runInference done."); @@ -653,7 +653,7 @@ namespace Prompt { // run inference // ------------- // use different model for endcap electrons - auto [out_f, out_vc, out_vf] = (std::abs(elec_eta) < 1.37) ? m_onnxUtil->runInference(gnn_input) : m_onnxUtil_endcap->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = (std::abs(elec_eta) < 1.37) ? m_saltModel->runInference(gnn_input) : m_saltModel_endcap->runInference(gnn_input); if (msgLvl(MSG::VERBOSE)) { ATH_MSG_VERBOSE("runInference done."); diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt index 2c4224ea9adf..b4ae96b5b3ff 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt @@ -27,8 +27,8 @@ set(FTDSource Root/DL2HighLevel.cxx Root/DL2Tool.cxx Root/DataPrepUtilities.cxx - Root/OnnxUtil.cxx - Root/OnnxOutput.cxx + Root/SaltModel.cxx + Root/SaltModelOutput.cxx Root/GNN.cxx Root/GNNOptions.cxx Root/GNNTool.cxx diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h index f117b63082a3..4b4e1c2d977f 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h @@ -12,7 +12,7 @@ // local includes #include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/AssociationEnums.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "FlavorTagDiscriminants/FTagDataDependencyNames.h" #include "FlavorTagDiscriminants/StringUtils.h" diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h index 9946fb22ea35..cecd25cdb9ca 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h @@ -9,7 +9,7 @@ #include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/AssociationEnums.h" #include "FlavorTagDiscriminants/FTagDataDependencyNames.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "FlavorTagDiscriminants/ConstituentsLoader.h" // EDM includes diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h index 385e371f4824..8e5924786558 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h @@ -1,10 +1,10 @@ /* Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration - This class is used in conjunction with OnnxUtil to run inference on a GNN model. - Whereas OnnxUtil handles the interfacing with the ONNX runtime, this class handles + This class is used in conjunction with SaltModel to run inference on a GNN model. + Whereas SaltModel handles the interfacing with the ONNX runtime, this class handles the interfacing with the ATLAS EDM. It is responsible for collecting all the inputs - needed for inference, running inference (via OnnxUtil), and decorating the results + needed for inference, running inference (via SaltModel), and decorating the results back to ATLAS EDM. */ @@ -30,7 +30,7 @@ namespace FlavorTagDiscriminants { struct GNNOptions; - class OnnxUtil; + class SaltModel; // // Tool to to flavor tag jet/btagging object // using GNN based taggers @@ -60,10 +60,10 @@ namespace FlavorTagDiscriminants { virtual std::set<std::string> getAuxInputKeys() const; virtual std::set<std::string> getConstituentAuxInputKeys() const; - std::shared_ptr<const OnnxUtil> m_onnxUtil; + std::shared_ptr<const SaltModel> m_saltModel; private: // private constructor, delegate of the above public ones - GNN(std::shared_ptr<const OnnxUtil>, const GNNOptions& opts); + GNN(std::shared_ptr<const SaltModel>, const GNNOptions& opts); // type definitions for ONNX output decorators using TPC = xAOD::TrackParticleContainer; using TrackLinks = std::vector<ElementLink<TPC>>; @@ -85,7 +85,7 @@ namespace FlavorTagDiscriminants { /* create all decorators */ std::tuple<FTagDataDependencyNames, std::set<std::string>> - createDecorators(const OnnxUtil::OutputConfig& outConfig, const FTagOptions& options); + createDecorators(const SaltModel::OutputConfig& outConfig, const FTagOptions& options); SG::AuxElement::ConstAccessor<ElementLink<xAOD::JetContainer>> m_jetLink; std::string m_input_node_name; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModel.h similarity index 73% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModel.h index f327aa273ffd..405b67e370ef 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModel.h @@ -7,15 +7,15 @@ handles the interaction with the ATLAS EDM. */ -#ifndef FLAVORTAGDISCRIMINANTS_ONNXUTIL_H -#define FLAVORTAGDISCRIMINANTS_ONNXUTIL_H +#ifndef FLAVORTAGDISCRIMINANTS_SALTMODEL_H +#define FLAVORTAGDISCRIMINANTS_SALTMODEL_H #include <onnxruntime_cxx_api.h> #include "nlohmann/json.hpp" #include "lwtnn/parse_json.hh" -#include "FlavorTagDiscriminants/OnnxOutput.h" +#include "FlavorTagDiscriminants/SaltModelOutput.h" #include <map> //also has std::pair #include <vector> @@ -27,25 +27,25 @@ namespace FlavorTagDiscriminants { // the first element is the input data, the second is the shape using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>; - enum class OnnxModelVersion{UNKNOWN, V0, V1, V2}; + enum class SaltModelVersion{UNKNOWN, V0, V1, V2}; - NLOHMANN_JSON_SERIALIZE_ENUM( OnnxModelVersion , { - { OnnxModelVersion::UNKNOWN, "" }, - { OnnxModelVersion::V0, "v0" }, - { OnnxModelVersion::V1, "v1" }, - { OnnxModelVersion::V2, "v2" }, + NLOHMANN_JSON_SERIALIZE_ENUM( SaltModelVersion , { + { SaltModelVersion::UNKNOWN, "" }, + { SaltModelVersion::V0, "v0" }, + { SaltModelVersion::V1, "v1" }, + { SaltModelVersion::V2, "v2" }, }) // // Utility class that loads the onnx model from the given path // and runs inference based on the user given inputs - class OnnxUtil final{ + class SaltModel final{ public: - using OutputConfig = std::vector<OnnxOutput>; + using OutputConfig = std::vector<SaltModelOutput>; - OnnxUtil(const std::string& path_to_onnx); + SaltModel(const std::string& path_to_onnx); void initialize(); @@ -60,7 +60,7 @@ namespace FlavorTagDiscriminants { const lwt::GraphConfig getLwtConfig() const; const nlohmann::json& getMetadata() const; const OutputConfig& getOutputConfig() const; - OnnxModelVersion getOnnxModelVersion() const; + SaltModelVersion getSaltModelVersion() const; const std::string& getModelName() const; private: @@ -79,8 +79,8 @@ namespace FlavorTagDiscriminants { std::vector<std::string> m_input_node_names; OutputConfig m_output_nodes; - OnnxModelVersion m_onnx_model_version = OnnxModelVersion::UNKNOWN; + SaltModelVersion m_onnx_model_version = SaltModelVersion::UNKNOWN; - }; // Class OnnxUtil + }; // Class SaltModel } // end of FlavorTagDiscriminants namespace -#endif //FLAVORTAGDISCRIMINANTS_ONNXUTIL_H +#endif //FLAVORTAGDISCRIMINANTS_SALTMODEL_H diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxOutput.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModelOutput.h similarity index 67% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxOutput.h rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModelOutput.h index b34d745e3319..00b839e45bb3 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxOutput.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModelOutput.h @@ -4,8 +4,8 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration This class is used to store the configuration for a ONNX output node. */ -#ifndef OUTPUTNODE_H -#define OUTPUTNODE_H +#ifndef FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H +#define FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H #include <onnxruntime_cxx_api.h> #include "nlohmann/json.hpp" @@ -13,18 +13,18 @@ This class is used to store the configuration for a ONNX output node. namespace FlavorTagDiscriminants { -class OnnxOutput { +class SaltModelOutput { public: enum class OutputType {UNKNOWN, FLOAT, VECCHAR, VECFLOAT}; - /* constructor for OnnxModelVersion::V1 and higher */ - OnnxOutput(const std::string& name, + /* constructor for SaltModelVersion::V1 and higher */ + SaltModelOutput(const std::string& name, ONNXTensorElementDataType type, int rank); - /* constructor for OnnxModelVersion::V0 */ - OnnxOutput(const std::string& name, + /* constructor for SaltModelVersion::V0 */ + SaltModelOutput(const std::string& name, ONNXTensorElementDataType type, const std::string& name_in_model); @@ -36,8 +36,8 @@ class OnnxOutput { OutputType getOutputType(ONNXTensorElementDataType type, int rank) const; const std::string getName(const std::string& name, const std::string& model_name) const; -}; // class OnnxOutput +}; // class SaltModelOutput } // namespace FlavorTagDiscriminants -#endif // OUTPUTNODE_H +#endif // FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx index 20c61980ab25..fe46cea3d700 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx @@ -4,7 +4,7 @@ #include "FlavorTagDiscriminants/GNN.h" #include "FlavorTagDiscriminants/BTagTrackIpAccessor.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "FlavorTagDiscriminants/GNNOptions.h" #include "FlavorTagDiscriminants/StringUtils.h" @@ -22,10 +22,10 @@ namespace { const std::string jetLinkName = "jetLink"; - auto getOnnxUtil(const std::string& nn_file) { + auto getSaltModel(const std::string& nn_file) { using namespace FlavorTagDiscriminants; std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file); - return std::make_shared<const OnnxUtil>(fullPathToOnnxFile); + return std::make_shared<const SaltModel>(fullPathToOnnxFile); } template <typename T> @@ -44,22 +44,22 @@ namespace { namespace FlavorTagDiscriminants { GNN::GNN(const std::string& nn_file, const GNNOptions& o): - GNN(getOnnxUtil(nn_file), o) + GNN(getSaltModel(nn_file), o) { } GNN::GNN(const GNN& old, const GNNOptions& o): - GNN(old.m_onnxUtil, o) + GNN(old.m_saltModel, o) { } - GNN::GNN(std::shared_ptr<const OnnxUtil> util, const GNNOptions& o): - m_onnxUtil(util), + GNN::GNN(std::shared_ptr<const SaltModel> util, const GNNOptions& o): + m_saltModel(util), m_jetLink(jetLinkName) { // Extract metadata from the ONNX file, primarily about the model's inputs. - auto lwt_config = m_onnxUtil->getLwtConfig(); + auto lwt_config = m_saltModel->getLwtConfig(); // Create configuration objects for data preprocessing. auto [inputs, constituents_configs, options] = dataprep::createGetterConfig( @@ -91,7 +91,7 @@ namespace FlavorTagDiscriminants { m_dataDependencyNames = ds; // Retrieve the configuration for the model outputs. - OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig(); // Create the output decorators. auto [dd, rd] = createDecorators(gnn_output_config, options); @@ -154,7 +154,7 @@ namespace FlavorTagDiscriminants { dec(jet) = v; } // for some networks we need to set a lot of empty vectors as well - if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) { + if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V1) { // vector outputs, e.g. track predictions for (const auto& dec: m_decorators.jetVecChar) { dec.second(jet) = {}; @@ -186,7 +186,7 @@ namespace FlavorTagDiscriminants { } std::vector<int64_t> jet_feat_dim = {1, static_cast<int64_t>(jet_feat.size())}; Inputs jet_info(jet_feat, jet_feat_dim); - if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V2) { + if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V2) { gnn_inputs.insert({"jets", jet_info}); } else { gnn_inputs.insert({"jet_features", jet_info}); @@ -197,7 +197,7 @@ namespace FlavorTagDiscriminants { int64_t num_inputs = 0; for (const auto& loader : m_constituentsLoaders){ auto [input_name, input_data, input_objects] = loader->getData(jet, btag); - if (m_onnxUtil->getOnnxModelVersion() != OnnxModelVersion::V2) { + if (m_saltModel->getSaltModelVersion() != SaltModelVersion::V2) { input_name.pop_back(); input_name.append("_features"); } @@ -219,13 +219,13 @@ namespace FlavorTagDiscriminants { this->decorateWithDefaults(btag); return; } - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_inputs); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_inputs); // decorate outputs // ---------------- // with old metadata, doesn't support writing aux tasks - if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V0) { + if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V0) { for (const auto& dec: m_decorators.jetFloat) { if (out_vf.at(dec.first).size() != 1){ throw std::logic_error("expected vectors of length 1 for float decorators"); @@ -234,7 +234,7 @@ namespace FlavorTagDiscriminants { } } // the new metadata format supports writing aux tasks - else if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) { + else if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V1) { // float outputs, e.g. jet probabilities for (const auto& dec: m_decorators.jetFloat) { dec.second(btag) = out_f.at(dec.first); @@ -277,7 +277,7 @@ namespace FlavorTagDiscriminants { } std::tuple<FTagDataDependencyNames, std::set<std::string>> - GNN::createDecorators(const OnnxUtil::OutputConfig& outConfig, const FTagOptions& options) { + GNN::createDecorators(const SaltModel::OutputConfig& outConfig, const FTagOptions& options) { FTagDataDependencyNames deps; Decorators decs; @@ -305,13 +305,13 @@ namespace FlavorTagDiscriminants { // Create decorators based on output type and target switch (outNode.type) { - case OnnxOutput::OutputType::FLOAT: + case SaltModelOutput::OutputType::FLOAT: m_decorators.jetFloat.emplace_back(outNode.name, Dec<float>(dec_name)); break; - case OnnxOutput::OutputType::VECCHAR: + case SaltModelOutput::OutputType::VECCHAR: m_decorators.jetVecChar.emplace_back(outNode.name, Dec<std::vector<char>>(dec_name)); break; - case OnnxOutput::OutputType::VECFLOAT: + case SaltModelOutput::OutputType::VECFLOAT: m_decorators.jetVecFloat.emplace_back(outNode.name, Dec<std::vector<float>>(dec_name)); break; default: @@ -321,7 +321,7 @@ namespace FlavorTagDiscriminants { // Create decorators for links to the input tracks if (!m_decorators.jetVecChar.empty() || !m_decorators.jetVecFloat.empty()) { - std::string name = m_onnxUtil->getModelName() + "_TrackLinks"; + std::string name = m_saltModel->getModelName() + "_TrackLinks"; // modify the deco name if we're using flip taggers if (options.flip != FlipTagConfig::STANDARD) { diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModel.cxx similarity index 85% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModel.cxx index ea04a2039dda..f707072cc539 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModel.cxx @@ -3,7 +3,7 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "CxxUtils/checker_macros.h" #include "lwtnn/parse_json.hh" @@ -13,7 +13,7 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration namespace FlavorTagDiscriminants { - OnnxUtil::OnnxUtil(const std::string& path_to_onnx) + SaltModel::SaltModel(const std::string& path_to_onnx) //load the onnx model to memory using the path m_path_to_onnx : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, "")) { @@ -41,13 +41,13 @@ namespace FlavorTagDiscriminants { // get the onnx model version if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set - m_onnx_model_version = m_metadata["onnx_model_version"].get<OnnxModelVersion>(); - if (m_onnx_model_version == OnnxModelVersion::UNKNOWN){ + m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>(); + if (m_onnx_model_version == SaltModelVersion::UNKNOWN){ throw std::runtime_error("Unknown Onnx model version!"); } } else { // metadata version is not set, infer from the presence of "outputs" key if (m_metadata.contains("outputs")){ - m_onnx_model_version = OnnxModelVersion::V0; + m_onnx_model_version = SaltModelVersion::V0; } else { throw std::runtime_error("Onnx model version not found in metadata"); } @@ -67,26 +67,26 @@ namespace FlavorTagDiscriminants { const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get()); const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType(); const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size(); - if (m_onnx_model_version == OnnxModelVersion::V0) { - const OnnxOutput onnxOutput(name, type, m_model_name); - m_output_nodes.push_back(onnxOutput); + if (m_onnx_model_version == SaltModelVersion::V0) { + const SaltModelOutput saltModelOutput(name, type, m_model_name); + m_output_nodes.push_back(saltModelOutput); } else { - const OnnxOutput onnxOutput(name, type, rank); - m_output_nodes.push_back(onnxOutput); + const SaltModelOutput saltModelOutput(name, type, rank); + m_output_nodes.push_back(saltModelOutput); } } } - const nlohmann::json OnnxUtil::loadMetadata(const std::string& key) const { + const nlohmann::json SaltModel::loadMetadata(const std::string& key) const { Ort::AllocatorWithDefaultOptions allocator; Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata(); std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get()); return nlohmann::json::parse(metadataString); } - const std::string OnnxUtil::determineModelName() const { + const std::string SaltModel::determineModelName() const { Ort::AllocatorWithDefaultOptions allocator; - if (m_onnx_model_version == OnnxModelVersion::V0) { + if (m_onnx_model_version == SaltModelVersion::V0) { // get the model name directly from the metadata return std::string(m_metadata["outputs"].begin().key()); } else { @@ -104,14 +104,14 @@ namespace FlavorTagDiscriminants { } } if (model_names.size() != 1) { - throw std::runtime_error("OnnxUtil: model names are not consistent between outputs"); + throw std::runtime_error("SaltModel: model names are not consistent between outputs"); } return *model_names.begin(); } } - const lwt::GraphConfig OnnxUtil::getLwtConfig() const { + const lwt::GraphConfig SaltModel::getLwtConfig() const { /* for the new metadata format (>V0), the outputs are inferred directly from the model graph, rather than being configured as json metadata. however we still need to add an empty "outputs" key to the config so that @@ -119,7 +119,7 @@ namespace FlavorTagDiscriminants { // deep copy the metadata by round tripping through a string stream nlohmann::json metadataCopy = nlohmann::json::parse(m_metadata.dump()); - if (getOnnxModelVersion() != OnnxModelVersion::V0){ + if (getSaltModelVersion() != SaltModelVersion::V0){ metadataCopy["outputs"] = nlohmann::json::object(); } std::stringstream metadataStream; @@ -127,24 +127,24 @@ namespace FlavorTagDiscriminants { return lwt::parse_json_graph(metadataStream); } - const nlohmann::json& OnnxUtil::getMetadata() const { + const nlohmann::json& SaltModel::getMetadata() const { return m_metadata; } - const OnnxUtil::OutputConfig& OnnxUtil::getOutputConfig() const { + const SaltModel::OutputConfig& SaltModel::getOutputConfig() const { return m_output_nodes; } - OnnxModelVersion OnnxUtil::getOnnxModelVersion() const { + SaltModelVersion SaltModel::getSaltModelVersion() const { return m_onnx_model_version; } - const std::string& OnnxUtil::getModelName() const { + const std::string& SaltModel::getModelName() const { return m_model_name; } - OnnxUtil::InferenceOutput OnnxUtil::runInference( + SaltModel::InferenceOutput SaltModel::runInference( std::map<std::string, Inputs>& gnn_inputs) const { std::vector<float> input_tensor_values; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxOutput.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModelOutput.cxx similarity index 74% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxOutput.cxx rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModelOutput.cxx index eb7634570d0e..d4f12c1c8731 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxOutput.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModelOutput.cxx @@ -4,27 +4,27 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration This class is used to store the configuration for a ONNX output node. */ -#include "FlavorTagDiscriminants/OnnxOutput.h" +#include "FlavorTagDiscriminants/SaltModelOutput.h" namespace FlavorTagDiscriminants { -/* constructor for OnnxModelVersion::V1 and higher */ -OnnxOutput::OnnxOutput(const std::string& name, +/* constructor for SaltModelVersion::V1 and higher */ +SaltModelOutput::SaltModelOutput(const std::string& name, const ONNXTensorElementDataType type, int rank) : name(name), name_in_model(name), type(getOutputType(type, rank)){} -/* constructor for OnnxModelVersion::V0 */ -OnnxOutput::OnnxOutput(const std::string& name, +/* constructor for SaltModelVersion::V0 */ +SaltModelOutput::SaltModelOutput(const std::string& name, const ONNXTensorElementDataType type, const std::string& model_name) : name(getName(name, model_name)), name_in_model(name), type(getOutputType(type, 0)){} -const std::string OnnxOutput::getName(const std::string& name, const std::string& model_name) const { +const std::string SaltModelOutput::getName(const std::string& name, const std::string& model_name) const { // unfortunately, this is block is needed to support some taggers that we schedule that don't have // a well defined model name and rely on output remapping. if (model_name == "UnknownModelName") { @@ -33,7 +33,7 @@ const std::string OnnxOutput::getName(const std::string& name, const std::string return model_name + "_" + name; } -OnnxOutput::OutputType OnnxOutput::getOutputType(ONNXTensorElementDataType type, int rank) const { +SaltModelOutput::OutputType SaltModelOutput::getOutputType(ONNXTensorElementDataType type, int rank) const { // Determine the output node type based on the type and shape of the output tensor. using ORT = ONNXTensorElementDataType; if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx index 5cc606accf3f..d99c66b83c40 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx @@ -13,7 +13,7 @@ #include "xAODBTaggingEfficiency/ToolDefaults.h" // for the onnxtool -#include "xAODBTaggingEfficiency/OnnxUtil.h" +#include "xAODBTaggingEfficiency/SaltModel.h" #include "PATInterfaces/SystematicRegistry.h" #include "PathResolver/PathResolver.h" @@ -716,8 +716,8 @@ StatusCode BTaggingEfficiencyTool::initialize() { ATH_MSG_ERROR("ONNX error: Model file doesn't exist! Please set the property 'pathToONNX' to a valid ONNX file"); return StatusCode::FAILURE; } - m_onnxUtil = std::make_unique<OnnxUtil> (m_pathToONNX); - m_onnxUtil->initialize(); + m_saltModel = std::make_unique<SaltModel> (m_pathToONNX); + m_saltModel->initialize(); } m_initialised = true; @@ -1104,7 +1104,7 @@ BTaggingEfficiencyTool::getMCEfficiency( int flavour, const Analysis::Calibratio CorrectionCode BTaggingEfficiencyTool::getMCEfficiencyONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& effAllJet) { - m_onnxUtil->runInference(node_feat, effAllJet); + m_saltModel->runInference(node_feat, effAllJet); return CorrectionCode::Ok; } @@ -1112,7 +1112,7 @@ BTaggingEfficiencyTool::getMCEfficiencyONNX( const std::vector<std::vector<float CorrectionCode BTaggingEfficiencyTool::getMCEfficiencyONNX( const std::vector<std::vector<float>>& node_feat, std::vector<std::vector<float>>& effAllJetAllWp) { - m_onnxUtil->runInference(node_feat, effAllJetAllWp); + m_saltModel->runInference(node_feat, effAllJetAllWp); return CorrectionCode::Ok; } diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx index c8183d67db95..f35600f22395 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx @@ -2,19 +2,19 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration */ -#include "xAODBTaggingEfficiency/OnnxUtil.h" +#include "xAODBTaggingEfficiency/SaltModel.h" #include "CxxUtils/checker_macros.h" #include "PathResolver/PathResolver.h" //for PathResolverFindCalibFile #include <cstdint> //for int64_t // Constructor -OnnxUtil::OnnxUtil(const std::string& name) +SaltModel::SaltModel(const std::string& name) : m_path_to_onnx (name) { } -void OnnxUtil::initialize(){ +void SaltModel::initialize(){ std::string fullPathToFile = PathResolverFindCalibFile(m_path_to_onnx); @@ -61,7 +61,7 @@ void OnnxUtil::initialize(){ // for fixed cut wp -void OnnxUtil::runInference( +void SaltModel::runInference( const std::vector<std::vector<float>> & node_feat, std::vector<float>& effAllJet) const { @@ -105,7 +105,7 @@ void OnnxUtil::runInference( // for continuous wp -void OnnxUtil::runInference( +void SaltModel::runInference( const std::vector<std::vector<float>> & node_feat, std::vector<std::vector<float>> & effAllJetAllWp) const{ diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h index c2a8240908cc..259672107d4a 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h @@ -16,7 +16,7 @@ #include "CalibrationDataInterface/CalibrationDataInterfaceROOT.h" // for the onnxtool -#include "xAODBTaggingEfficiency/OnnxUtil.h" +#include "xAODBTaggingEfficiency/SaltModel.h" // #include <fstream> #include <string> @@ -293,7 +293,7 @@ private: //Analysis::CalibrationDataInterfaceROOT* m_CDI = nullptr; std::shared_ptr<Analysis::CalibrationDataInterfaceROOT> m_CDI; /// pointer to the onnx tool - std::unique_ptr<OnnxUtil> m_onnxUtil; + std::unique_ptr<SaltModel> m_saltModel; /// @name core configuration properties (set at initalization time and not modified afterwards) /// @{ diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h similarity index 78% rename from PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h rename to PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h index 6ec1c22b07b0..146249dc184e 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h @@ -2,8 +2,8 @@ Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration */ -#ifndef XAODBTAGGINGEFFICIENCY_ONNXUTIL_H -#define XAODBTAGGINGEFFICIENCY_ONNXUTIL_H +#ifndef XAODBTAGGINGEFFICIENCY_SALTMODEL_H +#define XAODBTAGGINGEFFICIENCY_SALTMODEL_H #include <onnxruntime_cxx_api.h> #include <string> @@ -11,13 +11,13 @@ Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration #include <memory> -class OnnxUtil final{ +class SaltModel final{ public: // Constructor/destructor/init - OnnxUtil(const std::string& name); - ~OnnxUtil() = default; + SaltModel(const std::string& name); + ~SaltModel() = default; void initialize(); @@ -45,7 +45,7 @@ class OnnxUtil final{ // num_wp=1 for fixed cut; int m_num_wp{}; -}; // Class OnnxUtil +}; // Class SaltModel -#endif //XAODBTAGGINGEFFICIENCY_ONNXUTIL_H +#endif //XAODBTAGGINGEFFICIENCY_SALTMODEL_H diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx index 98eb8c0e0673..d51bcbb9ad6f 100644 --- a/Reconstruction/tauRecTools/Root/TauGNN.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -13,7 +13,7 @@ TauGNN::TauGNN(const std::string &nnFile, const Config &config): asg::AsgMessaging("TauGNN"), - m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)), + m_saltModel(std::make_shared<FlavorTagDiscriminants::SaltModel>(nnFile)), m_config{config} { //==================================================// @@ -21,17 +21,17 @@ TauGNN::TauGNN(const std::string &nnFile, const Config &config): //==================================================// // get the configuration of the model outputs - FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + FlavorTagDiscriminants::SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig(); //Let's see the output! for (const auto& out_node: gnn_output_config) { - if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name); - if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name); - if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::SaltModelOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::SaltModelOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::SaltModelOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name); } //Get model config (for inputs) - auto lwtnn_config = m_onnxUtil->getLwtConfig(); + auto lwtnn_config = m_saltModel->getLwtConfig(); //===================================================// // This part is ported from tauRecTools TauJetRNN.cxx// @@ -158,7 +158,7 @@ TauGNN::compute(const xAOD::TauJet &tau, //RUN THE INFERENCE!!! ATH_MSG_DEBUG("Prepared inputs, running inference..."); - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); ATH_MSG_DEBUG("Finished compute!"); return std::make_tuple(out_f, out_vc, out_vf); } diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h index dcc25f4c9a98..2bd76e3eeecc 100644 --- a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h @@ -10,7 +10,7 @@ #include "AsgMessaging/AsgMessaging.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include <memory> #include <string> @@ -21,7 +21,7 @@ namespace TauGNNUtils { } /** - * @brief Wrapper around ONNXUtil to compute the output score of a model + * @brief Wrapper around SaltModel to compute the output score of a model * * Configures the network and computes the network outputs given the input * objects. Retrieval of input variables is handled internally. @@ -39,12 +39,12 @@ public: std::string output_node_tau; std::string output_node_jet; }; - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel; public: TauGNN(const std::string &nnFile, const Config &config); ~TauGNN(); - // Output the OnnxUtil tuple + // Output the SaltModel tuple std::tuple< std::map<std::string, float>, std::map<std::string, std::vector<char>>, @@ -66,7 +66,7 @@ public: } //Make the output config transparent to external tools - FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config; + FlavorTagDiscriminants::SaltModel::OutputConfig gnn_output_config; private: using Inputs = FlavorTagDiscriminants::Inputs; -- GitLab