diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h index 82f1ed164fa5dda6f874b5283ed58e782dfbd7af..aa2f42aedc8ff18a9ff624c54881bb2c71283285 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/ConstituentsLoader.h @@ -10,7 +10,7 @@ #define INDET_CONSTITUENTS_LOADER_H // local includes -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // EDM includes #include "xAODTracking/Vertex.h" diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h index 0ef8be167bafdd351e1e9a56e5211f32e3ab35ef..a1949ead20da5eecaa9b053772936eb2b44607bf 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/DataPrepUtilities.h @@ -7,7 +7,7 @@ // local includes #include "InDetGNNHardScatterSelection/ConstituentsLoader.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // EDM includes #include "xAODTracking/Vertex.h" diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h index 0eaa0ac39be253f3c8016be0e4aca9659f556f2e..05353c48158c21a8f2b821934f4a0832e577d797 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h @@ -1,10 +1,10 @@ /* Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration - This class is used in conjunction with OnnxUtil to run inference on a GNN model. - Whereas OnnxUtil handles the interfacing with the ONNX runtime, this class handles + This class is used in conjunction with SaltModel to run inference on a GNN model. + Whereas SaltModel handles the interfacing with the ONNX runtime, this class handles the interfacing with the ATLAS EDM. It is responsible for collecting all the inputs - needed for inference, running inference (via OnnxUtil), and decorating the results + needed for inference, running inference (via SaltModel), and decorating the results back to ATLAS EDM. */ @@ -14,7 +14,7 @@ // Tool includes #include "InDetGNNHardScatterSelection/DataPrepUtilities.h" #include "InDetGNNHardScatterSelection/IParticlesLoader.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // EDM includes #include "xAODTracking/VertexFwd.h" @@ -49,7 +49,7 @@ namespace InDetGNNHardScatterSelection { virtual void decorate(const xAOD::Vertex& verrtex) const; - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel; private: // type definitions for ONNX output decorators using TPC = xAOD::TrackParticleContainer; @@ -66,7 +66,7 @@ namespace InDetGNNHardScatterSelection { }; /* create all decorators */ - std::set<std::string> createDecorators(const FlavorTagDiscriminants::OnnxUtil::OutputConfig& outConfig); + std::set<std::string> createDecorators(const FlavorTagDiscriminants::SaltModel::OutputConfig& outConfig); std::string m_input_node_name; std::vector<internal::VarFromVertex> m_varsFromVertex; diff --git a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx index 44dca96d8068144e7d5a960534d081524a2a9ef2..081a5b8d982cd7a1f8c348580bab939726a57c8f 100644 --- a/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx +++ b/InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx @@ -3,7 +3,7 @@ */ #include "InDetGNNHardScatterSelection/GNN.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "PathResolver/PathResolver.h" @@ -19,15 +19,15 @@ namespace InDetGNNHardScatterSelection { GNN::GNN(const std::string& nn_file): - m_onnxUtil(nullptr) + m_saltModel(nullptr) { // Load and initialize the neural network model from the given file path. std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file); - m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile); + m_saltModel = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile); // Extract metadata from the ONNX file, primarily about the model's inputs. - auto lwt_config = m_onnxUtil->getLwtConfig(); + auto lwt_config = m_saltModel->getLwtConfig(); // Create configuration objects for data preprocessing. auto [inputs, constituents_configs] = dataprep::createGetterConfig(lwt_config); @@ -58,7 +58,7 @@ namespace InDetGNNHardScatterSelection { m_varsFromVertex = dataprep::createVertexVarGetters(inputs); // Retrieve the configuration for the model outputs. - FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + FlavorTagDiscriminants::SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig(); for (const auto& outNode : gnn_output_config) { // the node's output name will be used to define the decoration name @@ -95,7 +95,7 @@ namespace InDetGNNHardScatterSelection { // run inference // ------------- - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); // decorate outputs // ---------------- diff --git a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h index fb7b799484644704f41f8d5cc5e50ee8093ca002..af22271286757b47449669d0f6a27489c4b1e46b 100644 --- a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h +++ b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/LeptonTaggers/DecoratePLIT.h @@ -12,7 +12,7 @@ // Tools #include "PathResolver/PathResolver.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" // Athena #include "AsgDataHandles/WriteDecorHandle.h" @@ -39,8 +39,8 @@ namespace Prompt { virtual StatusCode execute (const EventContext&) const override; private: - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil{}; - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil_endcap{}; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel{}; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel_endcap{}; int m_num_lepton_features{}; int m_num_track_features{}; diff --git a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx index 3d0d3f25d52084a25c8cbe0930e828719f4e495e..2c1df8a5f15f3044dd128c0d25a7afff4c76a793 100644 --- a/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx +++ b/PhysicsAnalysis/AnalysisCommon/LeptonTaggers/src/DecoratePLIT.cxx @@ -38,10 +38,10 @@ namespace Prompt { // Load and initialize the neural network model from the given file path. if(m_leptonsName == "Electrons") { std::string fullPathToOnnxFile = PathResolverFindCalibFile(m_configPath.value() + m_configFileVersion.value()); - m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile); + m_saltModel = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile); std::string fullPathToOnnxFile_endcap = PathResolverFindCalibFile(m_configPath.value() + m_configFileVersion_endcap.value()); - m_onnxUtil_endcap = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile_endcap); + m_saltModel_endcap = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile_endcap); m_num_lepton_features = 15; m_num_track_features = 19; @@ -56,7 +56,7 @@ namespace Prompt { std::vector<int64_t> track_feat_dim = {1, m_num_track_features}; FlavorTagDiscriminants::Inputs track_info(track_feat, track_feat_dim); gnn_input.insert({"track_features", track_info}); - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); std::vector<std::string> output_names; for (auto& singlefloat : out_f){ @@ -67,7 +67,7 @@ namespace Prompt { } else if (m_leptonsName == "Muons") { std::string fullPathToOnnxFile = PathResolverFindCalibFile(m_configPath.value() + m_configFileVersion.value()); - m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(fullPathToOnnxFile); + m_saltModel = std::make_shared<FlavorTagDiscriminants::SaltModel>(fullPathToOnnxFile); m_num_lepton_features = 10; m_num_track_features = 18; @@ -82,7 +82,7 @@ namespace Prompt { std::vector<int64_t> track_feat_dim = {1, m_num_track_features}; FlavorTagDiscriminants::Inputs track_info(track_feat, track_feat_dim); gnn_input.insert({"track_features", track_info}); - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); std::vector<std::string> output_names; for (auto& singlefloat : out_f){ @@ -361,7 +361,7 @@ namespace Prompt { // run inference // ------------- - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); if (msgLvl(MSG::VERBOSE)) { ATH_MSG_VERBOSE("runInference done."); @@ -653,7 +653,7 @@ namespace Prompt { // run inference // ------------- // use different model for endcap electrons - auto [out_f, out_vc, out_vf] = (std::abs(elec_eta) < 1.37) ? m_onnxUtil->runInference(gnn_input) : m_onnxUtil_endcap->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = (std::abs(elec_eta) < 1.37) ? m_saltModel->runInference(gnn_input) : m_saltModel_endcap->runInference(gnn_input); if (msgLvl(MSG::VERBOSE)) { ATH_MSG_VERBOSE("runInference done."); diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt index 2c4224ea9adf862b79e8fe205321490e956324f8..b4ae96b5b3ffa3a8b76b944d83b394a6d943da94 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/CMakeLists.txt @@ -27,8 +27,8 @@ set(FTDSource Root/DL2HighLevel.cxx Root/DL2Tool.cxx Root/DataPrepUtilities.cxx - Root/OnnxUtil.cxx - Root/OnnxOutput.cxx + Root/SaltModel.cxx + Root/SaltModelOutput.cxx Root/GNN.cxx Root/GNNOptions.cxx Root/GNNTool.cxx diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h index f117b63082a3bb6e41a32206156fbe152d41013a..4b4e1c2d977f7ef027f45129656b0625a65dcf9c 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h @@ -12,7 +12,7 @@ // local includes #include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/AssociationEnums.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "FlavorTagDiscriminants/FTagDataDependencyNames.h" #include "FlavorTagDiscriminants/StringUtils.h" diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h index 9946fb22ea352bebfff36237de41b7cd86843ec7..cecd25cdb9caeb2a2daa2994be2a02a90db8c111 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/DataPrepUtilities.h @@ -9,7 +9,7 @@ #include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/AssociationEnums.h" #include "FlavorTagDiscriminants/FTagDataDependencyNames.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "FlavorTagDiscriminants/ConstituentsLoader.h" // EDM includes diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h index 385e371f48245713547acf4abaf3af823d5943f2..8e592478655817f6ccba913059ef831925712d45 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h @@ -1,10 +1,10 @@ /* Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration - This class is used in conjunction with OnnxUtil to run inference on a GNN model. - Whereas OnnxUtil handles the interfacing with the ONNX runtime, this class handles + This class is used in conjunction with SaltModel to run inference on a GNN model. + Whereas SaltModel handles the interfacing with the ONNX runtime, this class handles the interfacing with the ATLAS EDM. It is responsible for collecting all the inputs - needed for inference, running inference (via OnnxUtil), and decorating the results + needed for inference, running inference (via SaltModel), and decorating the results back to ATLAS EDM. */ @@ -30,7 +30,7 @@ namespace FlavorTagDiscriminants { struct GNNOptions; - class OnnxUtil; + class SaltModel; // // Tool to to flavor tag jet/btagging object // using GNN based taggers @@ -60,10 +60,10 @@ namespace FlavorTagDiscriminants { virtual std::set<std::string> getAuxInputKeys() const; virtual std::set<std::string> getConstituentAuxInputKeys() const; - std::shared_ptr<const OnnxUtil> m_onnxUtil; + std::shared_ptr<const SaltModel> m_saltModel; private: // private constructor, delegate of the above public ones - GNN(std::shared_ptr<const OnnxUtil>, const GNNOptions& opts); + GNN(std::shared_ptr<const SaltModel>, const GNNOptions& opts); // type definitions for ONNX output decorators using TPC = xAOD::TrackParticleContainer; using TrackLinks = std::vector<ElementLink<TPC>>; @@ -85,7 +85,7 @@ namespace FlavorTagDiscriminants { /* create all decorators */ std::tuple<FTagDataDependencyNames, std::set<std::string>> - createDecorators(const OnnxUtil::OutputConfig& outConfig, const FTagOptions& options); + createDecorators(const SaltModel::OutputConfig& outConfig, const FTagOptions& options); SG::AuxElement::ConstAccessor<ElementLink<xAOD::JetContainer>> m_jetLink; std::string m_input_node_name; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModel.h similarity index 73% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModel.h index f327aa273ffdd195785b8a590521db0b874aafd0..405b67e370ef6f63fd74479f38fdaf0e8366b639 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModel.h @@ -7,15 +7,15 @@ handles the interaction with the ATLAS EDM. */ -#ifndef FLAVORTAGDISCRIMINANTS_ONNXUTIL_H -#define FLAVORTAGDISCRIMINANTS_ONNXUTIL_H +#ifndef FLAVORTAGDISCRIMINANTS_SALTMODEL_H +#define FLAVORTAGDISCRIMINANTS_SALTMODEL_H #include <onnxruntime_cxx_api.h> #include "nlohmann/json.hpp" #include "lwtnn/parse_json.hh" -#include "FlavorTagDiscriminants/OnnxOutput.h" +#include "FlavorTagDiscriminants/SaltModelOutput.h" #include <map> //also has std::pair #include <vector> @@ -27,25 +27,25 @@ namespace FlavorTagDiscriminants { // the first element is the input data, the second is the shape using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>; - enum class OnnxModelVersion{UNKNOWN, V0, V1, V2}; + enum class SaltModelVersion{UNKNOWN, V0, V1, V2}; - NLOHMANN_JSON_SERIALIZE_ENUM( OnnxModelVersion , { - { OnnxModelVersion::UNKNOWN, "" }, - { OnnxModelVersion::V0, "v0" }, - { OnnxModelVersion::V1, "v1" }, - { OnnxModelVersion::V2, "v2" }, + NLOHMANN_JSON_SERIALIZE_ENUM( SaltModelVersion , { + { SaltModelVersion::UNKNOWN, "" }, + { SaltModelVersion::V0, "v0" }, + { SaltModelVersion::V1, "v1" }, + { SaltModelVersion::V2, "v2" }, }) // // Utility class that loads the onnx model from the given path // and runs inference based on the user given inputs - class OnnxUtil final{ + class SaltModel final{ public: - using OutputConfig = std::vector<OnnxOutput>; + using OutputConfig = std::vector<SaltModelOutput>; - OnnxUtil(const std::string& path_to_onnx); + SaltModel(const std::string& path_to_onnx); void initialize(); @@ -60,7 +60,7 @@ namespace FlavorTagDiscriminants { const lwt::GraphConfig getLwtConfig() const; const nlohmann::json& getMetadata() const; const OutputConfig& getOutputConfig() const; - OnnxModelVersion getOnnxModelVersion() const; + SaltModelVersion getSaltModelVersion() const; const std::string& getModelName() const; private: @@ -79,8 +79,8 @@ namespace FlavorTagDiscriminants { std::vector<std::string> m_input_node_names; OutputConfig m_output_nodes; - OnnxModelVersion m_onnx_model_version = OnnxModelVersion::UNKNOWN; + SaltModelVersion m_onnx_model_version = SaltModelVersion::UNKNOWN; - }; // Class OnnxUtil + }; // Class SaltModel } // end of FlavorTagDiscriminants namespace -#endif //FLAVORTAGDISCRIMINANTS_ONNXUTIL_H +#endif //FLAVORTAGDISCRIMINANTS_SALTMODEL_H diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxOutput.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModelOutput.h similarity index 67% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxOutput.h rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModelOutput.h index b34d745e3319206f6f8f3d6fb4c3355b397f795e..00b839e45bb38860e36e3c259a76cabeb865e042 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxOutput.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SaltModelOutput.h @@ -4,8 +4,8 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration This class is used to store the configuration for a ONNX output node. */ -#ifndef OUTPUTNODE_H -#define OUTPUTNODE_H +#ifndef FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H +#define FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H #include <onnxruntime_cxx_api.h> #include "nlohmann/json.hpp" @@ -13,18 +13,18 @@ This class is used to store the configuration for a ONNX output node. namespace FlavorTagDiscriminants { -class OnnxOutput { +class SaltModelOutput { public: enum class OutputType {UNKNOWN, FLOAT, VECCHAR, VECFLOAT}; - /* constructor for OnnxModelVersion::V1 and higher */ - OnnxOutput(const std::string& name, + /* constructor for SaltModelVersion::V1 and higher */ + SaltModelOutput(const std::string& name, ONNXTensorElementDataType type, int rank); - /* constructor for OnnxModelVersion::V0 */ - OnnxOutput(const std::string& name, + /* constructor for SaltModelVersion::V0 */ + SaltModelOutput(const std::string& name, ONNXTensorElementDataType type, const std::string& name_in_model); @@ -36,8 +36,8 @@ class OnnxOutput { OutputType getOutputType(ONNXTensorElementDataType type, int rank) const; const std::string getName(const std::string& name, const std::string& model_name) const; -}; // class OnnxOutput +}; // class SaltModelOutput } // namespace FlavorTagDiscriminants -#endif // OUTPUTNODE_H +#endif // FLAVORTAGDISCRIMINANTS_SALTMODELOUTPUT_H diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx index 20c61980ab25263b55f100f0ebf7b6a864afe2a3..fe46cea3d700d6a43f40d814f9f159e018a496c3 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx @@ -4,7 +4,7 @@ #include "FlavorTagDiscriminants/GNN.h" #include "FlavorTagDiscriminants/BTagTrackIpAccessor.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "FlavorTagDiscriminants/GNNOptions.h" #include "FlavorTagDiscriminants/StringUtils.h" @@ -22,10 +22,10 @@ namespace { const std::string jetLinkName = "jetLink"; - auto getOnnxUtil(const std::string& nn_file) { + auto getSaltModel(const std::string& nn_file) { using namespace FlavorTagDiscriminants; std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file); - return std::make_shared<const OnnxUtil>(fullPathToOnnxFile); + return std::make_shared<const SaltModel>(fullPathToOnnxFile); } template <typename T> @@ -44,22 +44,22 @@ namespace { namespace FlavorTagDiscriminants { GNN::GNN(const std::string& nn_file, const GNNOptions& o): - GNN(getOnnxUtil(nn_file), o) + GNN(getSaltModel(nn_file), o) { } GNN::GNN(const GNN& old, const GNNOptions& o): - GNN(old.m_onnxUtil, o) + GNN(old.m_saltModel, o) { } - GNN::GNN(std::shared_ptr<const OnnxUtil> util, const GNNOptions& o): - m_onnxUtil(util), + GNN::GNN(std::shared_ptr<const SaltModel> util, const GNNOptions& o): + m_saltModel(util), m_jetLink(jetLinkName) { // Extract metadata from the ONNX file, primarily about the model's inputs. - auto lwt_config = m_onnxUtil->getLwtConfig(); + auto lwt_config = m_saltModel->getLwtConfig(); // Create configuration objects for data preprocessing. auto [inputs, constituents_configs, options] = dataprep::createGetterConfig( @@ -91,7 +91,7 @@ namespace FlavorTagDiscriminants { m_dataDependencyNames = ds; // Retrieve the configuration for the model outputs. - OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig(); // Create the output decorators. auto [dd, rd] = createDecorators(gnn_output_config, options); @@ -154,7 +154,7 @@ namespace FlavorTagDiscriminants { dec(jet) = v; } // for some networks we need to set a lot of empty vectors as well - if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) { + if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V1) { // vector outputs, e.g. track predictions for (const auto& dec: m_decorators.jetVecChar) { dec.second(jet) = {}; @@ -186,7 +186,7 @@ namespace FlavorTagDiscriminants { } std::vector<int64_t> jet_feat_dim = {1, static_cast<int64_t>(jet_feat.size())}; Inputs jet_info(jet_feat, jet_feat_dim); - if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V2) { + if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V2) { gnn_inputs.insert({"jets", jet_info}); } else { gnn_inputs.insert({"jet_features", jet_info}); @@ -197,7 +197,7 @@ namespace FlavorTagDiscriminants { int64_t num_inputs = 0; for (const auto& loader : m_constituentsLoaders){ auto [input_name, input_data, input_objects] = loader->getData(jet, btag); - if (m_onnxUtil->getOnnxModelVersion() != OnnxModelVersion::V2) { + if (m_saltModel->getSaltModelVersion() != SaltModelVersion::V2) { input_name.pop_back(); input_name.append("_features"); } @@ -219,13 +219,13 @@ namespace FlavorTagDiscriminants { this->decorateWithDefaults(btag); return; } - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_inputs); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_inputs); // decorate outputs // ---------------- // with old metadata, doesn't support writing aux tasks - if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V0) { + if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V0) { for (const auto& dec: m_decorators.jetFloat) { if (out_vf.at(dec.first).size() != 1){ throw std::logic_error("expected vectors of length 1 for float decorators"); @@ -234,7 +234,7 @@ namespace FlavorTagDiscriminants { } } // the new metadata format supports writing aux tasks - else if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V1) { + else if (m_saltModel->getSaltModelVersion() == SaltModelVersion::V1) { // float outputs, e.g. jet probabilities for (const auto& dec: m_decorators.jetFloat) { dec.second(btag) = out_f.at(dec.first); @@ -277,7 +277,7 @@ namespace FlavorTagDiscriminants { } std::tuple<FTagDataDependencyNames, std::set<std::string>> - GNN::createDecorators(const OnnxUtil::OutputConfig& outConfig, const FTagOptions& options) { + GNN::createDecorators(const SaltModel::OutputConfig& outConfig, const FTagOptions& options) { FTagDataDependencyNames deps; Decorators decs; @@ -305,13 +305,13 @@ namespace FlavorTagDiscriminants { // Create decorators based on output type and target switch (outNode.type) { - case OnnxOutput::OutputType::FLOAT: + case SaltModelOutput::OutputType::FLOAT: m_decorators.jetFloat.emplace_back(outNode.name, Dec<float>(dec_name)); break; - case OnnxOutput::OutputType::VECCHAR: + case SaltModelOutput::OutputType::VECCHAR: m_decorators.jetVecChar.emplace_back(outNode.name, Dec<std::vector<char>>(dec_name)); break; - case OnnxOutput::OutputType::VECFLOAT: + case SaltModelOutput::OutputType::VECFLOAT: m_decorators.jetVecFloat.emplace_back(outNode.name, Dec<std::vector<float>>(dec_name)); break; default: @@ -321,7 +321,7 @@ namespace FlavorTagDiscriminants { // Create decorators for links to the input tracks if (!m_decorators.jetVecChar.empty() || !m_decorators.jetVecFloat.empty()) { - std::string name = m_onnxUtil->getModelName() + "_TrackLinks"; + std::string name = m_saltModel->getModelName() + "_TrackLinks"; // modify the deco name if we're using flip taggers if (options.flip != FlipTagConfig::STANDARD) { diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModel.cxx similarity index 85% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModel.cxx index ea04a2039dda28ad3cb7d1ecae1d2ebd35319d81..f707072cc539b722b070cd436f636080fadc2e02 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModel.cxx @@ -3,7 +3,7 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include "CxxUtils/checker_macros.h" #include "lwtnn/parse_json.hh" @@ -13,7 +13,7 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration namespace FlavorTagDiscriminants { - OnnxUtil::OnnxUtil(const std::string& path_to_onnx) + SaltModel::SaltModel(const std::string& path_to_onnx) //load the onnx model to memory using the path m_path_to_onnx : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, "")) { @@ -41,13 +41,13 @@ namespace FlavorTagDiscriminants { // get the onnx model version if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set - m_onnx_model_version = m_metadata["onnx_model_version"].get<OnnxModelVersion>(); - if (m_onnx_model_version == OnnxModelVersion::UNKNOWN){ + m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>(); + if (m_onnx_model_version == SaltModelVersion::UNKNOWN){ throw std::runtime_error("Unknown Onnx model version!"); } } else { // metadata version is not set, infer from the presence of "outputs" key if (m_metadata.contains("outputs")){ - m_onnx_model_version = OnnxModelVersion::V0; + m_onnx_model_version = SaltModelVersion::V0; } else { throw std::runtime_error("Onnx model version not found in metadata"); } @@ -67,26 +67,26 @@ namespace FlavorTagDiscriminants { const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get()); const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType(); const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size(); - if (m_onnx_model_version == OnnxModelVersion::V0) { - const OnnxOutput onnxOutput(name, type, m_model_name); - m_output_nodes.push_back(onnxOutput); + if (m_onnx_model_version == SaltModelVersion::V0) { + const SaltModelOutput saltModelOutput(name, type, m_model_name); + m_output_nodes.push_back(saltModelOutput); } else { - const OnnxOutput onnxOutput(name, type, rank); - m_output_nodes.push_back(onnxOutput); + const SaltModelOutput saltModelOutput(name, type, rank); + m_output_nodes.push_back(saltModelOutput); } } } - const nlohmann::json OnnxUtil::loadMetadata(const std::string& key) const { + const nlohmann::json SaltModel::loadMetadata(const std::string& key) const { Ort::AllocatorWithDefaultOptions allocator; Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata(); std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get()); return nlohmann::json::parse(metadataString); } - const std::string OnnxUtil::determineModelName() const { + const std::string SaltModel::determineModelName() const { Ort::AllocatorWithDefaultOptions allocator; - if (m_onnx_model_version == OnnxModelVersion::V0) { + if (m_onnx_model_version == SaltModelVersion::V0) { // get the model name directly from the metadata return std::string(m_metadata["outputs"].begin().key()); } else { @@ -104,14 +104,14 @@ namespace FlavorTagDiscriminants { } } if (model_names.size() != 1) { - throw std::runtime_error("OnnxUtil: model names are not consistent between outputs"); + throw std::runtime_error("SaltModel: model names are not consistent between outputs"); } return *model_names.begin(); } } - const lwt::GraphConfig OnnxUtil::getLwtConfig() const { + const lwt::GraphConfig SaltModel::getLwtConfig() const { /* for the new metadata format (>V0), the outputs are inferred directly from the model graph, rather than being configured as json metadata. however we still need to add an empty "outputs" key to the config so that @@ -119,7 +119,7 @@ namespace FlavorTagDiscriminants { // deep copy the metadata by round tripping through a string stream nlohmann::json metadataCopy = nlohmann::json::parse(m_metadata.dump()); - if (getOnnxModelVersion() != OnnxModelVersion::V0){ + if (getSaltModelVersion() != SaltModelVersion::V0){ metadataCopy["outputs"] = nlohmann::json::object(); } std::stringstream metadataStream; @@ -127,24 +127,24 @@ namespace FlavorTagDiscriminants { return lwt::parse_json_graph(metadataStream); } - const nlohmann::json& OnnxUtil::getMetadata() const { + const nlohmann::json& SaltModel::getMetadata() const { return m_metadata; } - const OnnxUtil::OutputConfig& OnnxUtil::getOutputConfig() const { + const SaltModel::OutputConfig& SaltModel::getOutputConfig() const { return m_output_nodes; } - OnnxModelVersion OnnxUtil::getOnnxModelVersion() const { + SaltModelVersion SaltModel::getSaltModelVersion() const { return m_onnx_model_version; } - const std::string& OnnxUtil::getModelName() const { + const std::string& SaltModel::getModelName() const { return m_model_name; } - OnnxUtil::InferenceOutput OnnxUtil::runInference( + SaltModel::InferenceOutput SaltModel::runInference( std::map<std::string, Inputs>& gnn_inputs) const { std::vector<float> input_tensor_values; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxOutput.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModelOutput.cxx similarity index 74% rename from PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxOutput.cxx rename to PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModelOutput.cxx index eb7634570d0ea9221a5d217061e8dd9af032c743..d4f12c1c8731d6ec189424d92bbb2571381c5fa1 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxOutput.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SaltModelOutput.cxx @@ -4,27 +4,27 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration This class is used to store the configuration for a ONNX output node. */ -#include "FlavorTagDiscriminants/OnnxOutput.h" +#include "FlavorTagDiscriminants/SaltModelOutput.h" namespace FlavorTagDiscriminants { -/* constructor for OnnxModelVersion::V1 and higher */ -OnnxOutput::OnnxOutput(const std::string& name, +/* constructor for SaltModelVersion::V1 and higher */ +SaltModelOutput::SaltModelOutput(const std::string& name, const ONNXTensorElementDataType type, int rank) : name(name), name_in_model(name), type(getOutputType(type, rank)){} -/* constructor for OnnxModelVersion::V0 */ -OnnxOutput::OnnxOutput(const std::string& name, +/* constructor for SaltModelVersion::V0 */ +SaltModelOutput::SaltModelOutput(const std::string& name, const ONNXTensorElementDataType type, const std::string& model_name) : name(getName(name, model_name)), name_in_model(name), type(getOutputType(type, 0)){} -const std::string OnnxOutput::getName(const std::string& name, const std::string& model_name) const { +const std::string SaltModelOutput::getName(const std::string& name, const std::string& model_name) const { // unfortunately, this is block is needed to support some taggers that we schedule that don't have // a well defined model name and rely on output remapping. if (model_name == "UnknownModelName") { @@ -33,7 +33,7 @@ const std::string OnnxOutput::getName(const std::string& name, const std::string return model_name + "_" + name; } -OnnxOutput::OutputType OnnxOutput::getOutputType(ONNXTensorElementDataType type, int rank) const { +SaltModelOutput::OutputType SaltModelOutput::getOutputType(ONNXTensorElementDataType type, int rank) const { // Determine the output node type based on the type and shape of the output tensor. using ORT = ONNXTensorElementDataType; if (type == ORT::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx index 5cc606accf3f4bbc23d572f2d721193ef59af37a..d99c66b83c4065f9864fcffb6a9e611d03fa0768 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingEfficiencyTool.cxx @@ -13,7 +13,7 @@ #include "xAODBTaggingEfficiency/ToolDefaults.h" // for the onnxtool -#include "xAODBTaggingEfficiency/OnnxUtil.h" +#include "xAODBTaggingEfficiency/SaltModel.h" #include "PATInterfaces/SystematicRegistry.h" #include "PathResolver/PathResolver.h" @@ -716,8 +716,8 @@ StatusCode BTaggingEfficiencyTool::initialize() { ATH_MSG_ERROR("ONNX error: Model file doesn't exist! Please set the property 'pathToONNX' to a valid ONNX file"); return StatusCode::FAILURE; } - m_onnxUtil = std::make_unique<OnnxUtil> (m_pathToONNX); - m_onnxUtil->initialize(); + m_saltModel = std::make_unique<SaltModel> (m_pathToONNX); + m_saltModel->initialize(); } m_initialised = true; @@ -1104,7 +1104,7 @@ BTaggingEfficiencyTool::getMCEfficiency( int flavour, const Analysis::Calibratio CorrectionCode BTaggingEfficiencyTool::getMCEfficiencyONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& effAllJet) { - m_onnxUtil->runInference(node_feat, effAllJet); + m_saltModel->runInference(node_feat, effAllJet); return CorrectionCode::Ok; } @@ -1112,7 +1112,7 @@ BTaggingEfficiencyTool::getMCEfficiencyONNX( const std::vector<std::vector<float CorrectionCode BTaggingEfficiencyTool::getMCEfficiencyONNX( const std::vector<std::vector<float>>& node_feat, std::vector<std::vector<float>>& effAllJetAllWp) { - m_onnxUtil->runInference(node_feat, effAllJetAllWp); + m_saltModel->runInference(node_feat, effAllJetAllWp); return CorrectionCode::Ok; } diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx index c8183d67db9593b0c09fdb8bdd027244423d9f67..f35600f223956475faa509ae5b208a1a954f107c 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/OnnxUtil.cxx @@ -2,19 +2,19 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration */ -#include "xAODBTaggingEfficiency/OnnxUtil.h" +#include "xAODBTaggingEfficiency/SaltModel.h" #include "CxxUtils/checker_macros.h" #include "PathResolver/PathResolver.h" //for PathResolverFindCalibFile #include <cstdint> //for int64_t // Constructor -OnnxUtil::OnnxUtil(const std::string& name) +SaltModel::SaltModel(const std::string& name) : m_path_to_onnx (name) { } -void OnnxUtil::initialize(){ +void SaltModel::initialize(){ std::string fullPathToFile = PathResolverFindCalibFile(m_path_to_onnx); @@ -61,7 +61,7 @@ void OnnxUtil::initialize(){ // for fixed cut wp -void OnnxUtil::runInference( +void SaltModel::runInference( const std::vector<std::vector<float>> & node_feat, std::vector<float>& effAllJet) const { @@ -105,7 +105,7 @@ void OnnxUtil::runInference( // for continuous wp -void OnnxUtil::runInference( +void SaltModel::runInference( const std::vector<std::vector<float>> & node_feat, std::vector<std::vector<float>> & effAllJetAllWp) const{ diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h index c2a8240908cccab7f40880221589104ea1e1c9f6..259672107d4a08675a2c6884695584b1ca10ef9b 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingEfficiencyTool.h @@ -16,7 +16,7 @@ #include "CalibrationDataInterface/CalibrationDataInterfaceROOT.h" // for the onnxtool -#include "xAODBTaggingEfficiency/OnnxUtil.h" +#include "xAODBTaggingEfficiency/SaltModel.h" // #include <fstream> #include <string> @@ -293,7 +293,7 @@ private: //Analysis::CalibrationDataInterfaceROOT* m_CDI = nullptr; std::shared_ptr<Analysis::CalibrationDataInterfaceROOT> m_CDI; /// pointer to the onnx tool - std::unique_ptr<OnnxUtil> m_onnxUtil; + std::unique_ptr<SaltModel> m_saltModel; /// @name core configuration properties (set at initalization time and not modified afterwards) /// @{ diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h similarity index 78% rename from PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h rename to PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h index 6ec1c22b07b044b7bcc4b9e16608238d45760c03..146249dc184e1c20057c7cae50f2850aa4faba0f 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/SaltModel.h @@ -2,8 +2,8 @@ Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration */ -#ifndef XAODBTAGGINGEFFICIENCY_ONNXUTIL_H -#define XAODBTAGGINGEFFICIENCY_ONNXUTIL_H +#ifndef XAODBTAGGINGEFFICIENCY_SALTMODEL_H +#define XAODBTAGGINGEFFICIENCY_SALTMODEL_H #include <onnxruntime_cxx_api.h> #include <string> @@ -11,13 +11,13 @@ Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration #include <memory> -class OnnxUtil final{ +class SaltModel final{ public: // Constructor/destructor/init - OnnxUtil(const std::string& name); - ~OnnxUtil() = default; + SaltModel(const std::string& name); + ~SaltModel() = default; void initialize(); @@ -45,7 +45,7 @@ class OnnxUtil final{ // num_wp=1 for fixed cut; int m_num_wp{}; -}; // Class OnnxUtil +}; // Class SaltModel -#endif //XAODBTAGGINGEFFICIENCY_ONNXUTIL_H +#endif //XAODBTAGGINGEFFICIENCY_SALTMODEL_H diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx index 98eb8c0e06732bf3199ff537ae5dee456f38ce81..d51bcbb9ad6f5a87ad180037a23c0c24de3f56fa 100644 --- a/Reconstruction/tauRecTools/Root/TauGNN.cxx +++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx @@ -13,7 +13,7 @@ TauGNN::TauGNN(const std::string &nnFile, const Config &config): asg::AsgMessaging("TauGNN"), - m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)), + m_saltModel(std::make_shared<FlavorTagDiscriminants::SaltModel>(nnFile)), m_config{config} { //==================================================// @@ -21,17 +21,17 @@ TauGNN::TauGNN(const std::string &nnFile, const Config &config): //==================================================// // get the configuration of the model outputs - FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig(); + FlavorTagDiscriminants::SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig(); //Let's see the output! for (const auto& out_node: gnn_output_config) { - if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name); - if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name); - if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::SaltModelOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::SaltModelOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name); + if(out_node.type==FlavorTagDiscriminants::SaltModelOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name); } //Get model config (for inputs) - auto lwtnn_config = m_onnxUtil->getLwtConfig(); + auto lwtnn_config = m_saltModel->getLwtConfig(); //===================================================// // This part is ported from tauRecTools TauJetRNN.cxx// @@ -158,7 +158,7 @@ TauGNN::compute(const xAOD::TauJet &tau, //RUN THE INFERENCE!!! ATH_MSG_DEBUG("Prepared inputs, running inference..."); - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); ATH_MSG_DEBUG("Finished compute!"); return std::make_tuple(out_f, out_vc, out_vf); } diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h index dcc25f4c9a9810f19955551a4190bff5a57270b2..2bd76e3eeecc2fd863692f5475e1360548807553 100644 --- a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h +++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h @@ -10,7 +10,7 @@ #include "AsgMessaging/AsgMessaging.h" -#include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/SaltModel.h" #include <memory> #include <string> @@ -21,7 +21,7 @@ namespace TauGNNUtils { } /** - * @brief Wrapper around ONNXUtil to compute the output score of a model + * @brief Wrapper around SaltModel to compute the output score of a model * * Configures the network and computes the network outputs given the input * objects. Retrieval of input variables is handled internally. @@ -39,12 +39,12 @@ public: std::string output_node_tau; std::string output_node_jet; }; - std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil; + std::shared_ptr<const FlavorTagDiscriminants::SaltModel> m_saltModel; public: TauGNN(const std::string &nnFile, const Config &config); ~TauGNN(); - // Output the OnnxUtil tuple + // Output the SaltModel tuple std::tuple< std::map<std::string, float>, std::map<std::string, std::vector<char>>, @@ -66,7 +66,7 @@ public: } //Make the output config transparent to external tools - FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config; + FlavorTagDiscriminants::SaltModel::OutputConfig gnn_output_config; private: using Inputs = FlavorTagDiscriminants::Inputs;