From acdb169bb6bf321c753c6b167cc480e8e5e00519 Mon Sep 17 00:00:00 2001 From: Samuel Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 15 May 2024 17:43:04 +0200 Subject: [PATCH] FTAG GNN inference cleanup FTAG GNN inference cleanup --- .../ConstituentsLoader.h | 2 +- .../CustomGetterUtils.h | 94 +++--- .../FlavorTagDiscriminants/IParticlesLoader.h | 4 +- .../FlavorTagDiscriminants/OnnxUtil.h | 8 +- .../FlavorTagDiscriminants/TracksLoader.h | 14 +- .../Root/ConstituentsLoader.cxx | 8 +- .../Root/CustomGetterUtils.cxx | 284 ++++++++---------- .../Root/DataPrepUtilities.cxx | 11 +- .../FlavorTagDiscriminants/Root/GNN.cxx | 30 +- .../Root/IParticlesLoader.cxx | 8 +- .../FlavorTagDiscriminants/Root/OnnxUtil.cxx | 8 +- .../Root/TracksLoader.cxx | 53 ++-- 12 files changed, 244 insertions(+), 280 deletions(-) diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h index a493129eb6e3..04e53686681d 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h @@ -76,7 +76,7 @@ namespace FlavorTagDiscriminants { m_config = cfg; }; virtual ~IConstituentsLoader() = default; - virtual std::tuple<std::string, input_pair, std::vector<const xAOD::IParticle*>> getData( + virtual std::tuple<std::string, Inputs, std::vector<const xAOD::IParticle*>> getData( const xAOD::Jet& jet, [[maybe_unused]] const SG::AuxElement& btag) const = 0; virtual FTagDataDependencyNames getDependencies() const = 0; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/CustomGetterUtils.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/CustomGetterUtils.h index 33d3c4c8fbb2..8afda0481c79 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/CustomGetterUtils.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/CustomGetterUtils.h @@ -2,10 +2,25 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration */ -// The CustomGetterUtils file is a catch-all for various getter functinos -// that need to be hard coded for whatever reason. Some of these are -// accessing methods like `pt` which have no name in the EDM, others -// can't be stored in the edm directly for various reasons. +/// This file contains "getter" functions used for accessing tagger inputs +/// from the EDM. In the most basic case, inputs can be directly retrieved +/// from the EDM using accessors. In other cases, inputs require custom code +/// called "custom getters" to produce the desired values. +/// +/// - a basic "getter" directly retrieves decorated values from an object +/// - a "custom getter" executes custom code to produce inputs (e.g. particle.pt()). +/// Custom getters are used for inputs that cannot be directly retrieved from the +/// EDM using accessors or which require on-the-fly calculations, such as IP variables. +/// - a "sequence getter" is a wrapper around a getter that broadcasts it over a +/// vector of associated objects (for example over all the tracks in a jet). +/// +/// Inputs to tagging algorithms are configured when the algorithm is initialised +/// which means that the list of track and jet features that are used as inputs is +/// not known at compile time. Instead we build an array of "getter" functions, +/// each of which returns one feature for the tagger. +/// +/// NOTE: This file is for experts only, don't expect support. +/// // EDM includes @@ -15,7 +30,6 @@ #include "AthContainers/AuxElement.h" #include "FlavorTagDiscriminants/DataPrepUtilities.h" - #include <functional> #include <string> #include <set> @@ -25,47 +39,23 @@ namespace FlavorTagDiscriminants { - /// Utils to produce Constituent -> vector<double> functions - /// - /// DL2 configures the its inputs when the algorithm is initalized, - /// meaning that the list of track and jet properties that are used - /// as inputs won't be known at compile time. Instead we build an - /// array of "getter" functions, each of which returns one input for - /// the tagger. The function here returns those getter functions. - /// - /// Many of the getter functions are trivial: they will, for example, - /// read one double of auxdata off of the BTagging object. The - /// sequence input getters tend to be more complicated. Since we'd - /// like to avoid reimplementing the logic in these functions in - /// multiple places, they are exposed here. - /// - /// NOTE: This file is for experts only, don't expect support. - /// - namespace getter_utils { - - using IParticles = std::vector<const xAOD::IParticle*>; - using Tracks = std::vector<const xAOD::TrackParticle*>; - - using SequenceFromIParticles = std::function<std::vector<double>( - const xAOD::Jet&, - const IParticles&)>; - using SequenceFromTracks = std::function<std::vector<double>( - const xAOD::Jet&, - const Tracks&)>; + // ------------------------------------------------------- + // type aliases + template <typename T> + using Constituents = std::vector<const T*>; + + template <typename T> + using SequenceGetterFunc = std::function<std::vector<double>(const xAOD::Jet&, const Constituents<T>&)>; + // ------------------------------------------------------- std::function<std::pair<std::string, double>(const xAOD::Jet&)> - customGetterAndName(const std::string&); + namedCustomJetGetter(const std::string&); template <typename T> - std::pair< - std::function<std::vector<double>( - const xAOD::Jet&, - const std::vector<const T*>&)>, - std::set<std::string>> - customSequenceGetterWithDeps(const std::string& name, - const std::string& prefix); + std::pair<SequenceGetterFunc<T>, std::set<std::string>> + buildCustomSeqGetter(const std::string& name, const std::string& prefix); /** * @brief Template class to extract features from sequence of constituents @@ -77,29 +67,29 @@ namespace FlavorTagDiscriminants { * - xAOD::TrackParticle */ template <typename T> - class CustomSequenceGetter { + class SeqGetter { public: - using Constituents = std::vector<const T*>; - using NamedSequenceFromConstituents = std::function<std::pair<std::string, std::vector<double>>( + using Const = Constituents<T>; + using InputSequence = std::function<std::pair<std::string, std::vector<double>>( const xAOD::Jet&, - const Constituents&)>; - CustomSequenceGetter(std::vector<InputVariableConfig> inputs, - const FTagOptions& options); + const Const&)>; - std::pair<std::vector<float>, std::vector<int64_t>> getFeats(const xAOD::Jet& jet, const Constituents& constituents) const; - std::map<std::string, std::vector<double>> getDL2Feats(const xAOD::Jet& jet, const Constituents& constituents) const; + SeqGetter(std::vector<InputVariableConfig> inputs, const FTagOptions& options); + + std::pair<std::vector<float>, std::vector<int64_t>> getFeats(const xAOD::Jet& jet, const Const& constituents) const; + std::map<std::string, std::vector<double>> getDL2Feats(const xAOD::Jet& jet, const Const& constituents) const; std::set<std::string> getDependencies() const; std::set<std::string> getUsedRemap() const; - + private: - std::pair<NamedSequenceFromConstituents, std::set<std::string>> customNamedSeqGetterWithDeps( + std::pair<InputSequence, std::set<std::string>> getNamedCustomSeqGetter( const std::string& name, const std::string& prefix); - std::pair<NamedSequenceFromConstituents, std::set<std::string>> seqFromConsituents( + std::pair<InputSequence, std::set<std::string>> seqFromConsituents( const InputVariableConfig& cfg, const FTagOptions& options); - std::vector<NamedSequenceFromConstituents> m_sequencesFromConstituents; + std::vector<InputSequence> m_sequence_getters; std::set<std::string> m_deps; std::set<std::string> m_used_remap; }; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h index b0abdc0f4ad6..5a571c5b1cfd 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h @@ -34,7 +34,7 @@ namespace FlavorTagDiscriminants { class IParticlesLoader : public IConstituentsLoader { public: IParticlesLoader(ConstituentsInputConfig, const FTagOptions& options); - std::tuple<std::string, input_pair, std::vector<const xAOD::IParticle*>> getData( + std::tuple<std::string, Inputs, std::vector<const xAOD::IParticle*>> getData( const xAOD::Jet& jet, [[maybe_unused]] const SG::AuxElement& btag) const override ; FTagDataDependencyNames getDependencies() const override; @@ -65,7 +65,7 @@ namespace FlavorTagDiscriminants { std::vector<const xAOD::IParticle*> getIParticlesFromJet(const xAOD::Jet& jet) const; IParticleSortVar m_iparticleSortVar; - getter_utils::CustomSequenceGetter<xAOD::IParticle> m_customSequenceGetter; + getter_utils::SeqGetter<xAOD::IParticle> m_seqGetter; std::function<IPV(const Jet&)> m_associator; bool m_isCharged; }; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h index 383ace1b4c83..46074fc4ba6a 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/OnnxUtil.h @@ -24,14 +24,16 @@ namespace FlavorTagDiscriminants { - typedef std::pair<std::vector<float>, std::vector<int64_t>> input_pair; + // the first element is the input data, the second is the shape + using Inputs = std::pair<std::vector<float>, std::vector<int64_t>>; - enum class OnnxModelVersion{UNKNOWN, V0, V1}; + enum class OnnxModelVersion{UNKNOWN, V0, V1, V2}; NLOHMANN_JSON_SERIALIZE_ENUM( OnnxModelVersion , { { OnnxModelVersion::UNKNOWN, "" }, { OnnxModelVersion::V0, "v0" }, { OnnxModelVersion::V1, "v1" }, + { OnnxModelVersion::V2, "v2" }, }) // @@ -53,7 +55,7 @@ namespace FlavorTagDiscriminants { std::map<std::string, std::vector<float>> vecFloat; }; - InferenceOutput runInference(std::map<std::string, input_pair>& gnn_inputs) const; + InferenceOutput runInference(std::map<std::string, Inputs>& gnn_inputs) const; const lwt::GraphConfig getLwtConfig() const; const nlohmann::json& getMetadata() const; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h index dc222952889f..d7a938929790 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h @@ -33,6 +33,7 @@ #include <regex> namespace FlavorTagDiscriminants { + using Tracks = std::vector<const xAOD::TrackParticle*>; // tracksConfig ConstituentsInputConfig createTracksLoaderConfig( @@ -40,14 +41,12 @@ namespace FlavorTagDiscriminants { FlipTagConfig flip_config ); - // Subclass for Tracks loader inherited from abstract IConstituentsLoader class class TracksLoader : public IConstituentsLoader { public: - typedef std::vector<const xAOD::TrackParticle*> Tracks; TracksLoader(ConstituentsInputConfig, const FTagOptions& options); - std::tuple<std::string, input_pair, std::vector<const xAOD::IParticle*>> getData( + std::tuple<std::string, Inputs, std::vector<const xAOD::IParticle*>> getData( const xAOD::Jet& jet, [[maybe_unused]] const SG::AuxElement& btag) const override; std::tuple<char, std::map<std::string, std::vector<double>>> getDL2Data( @@ -75,21 +74,20 @@ namespace FlavorTagDiscriminants { using TPC = xAOD::TrackParticleContainer; using TrackLinks = std::vector<ElementLink<TPC>>; using PartLinks = std::vector<ElementLink<IPC>>; - using TPV = std::vector<const xAOD::TrackParticle*>; TrackSortVar trackSortVar(ConstituentsSortOrder, const FTagOptions&); std::pair<TrackFilter,std::set<std::string>> trackFilter( ConstituentsSelection, const FTagOptions&); - std::pair<TrackSequenceFilter,std::set<std::string>> flipFilter( + std::pair<TrackSequenceFilter,std::set<std::string>> trackFlipper( const FTagOptions&); Tracks getTracksFromJet(const Jet& jet, const AE& btag) const; TrackSortVar m_trackSortVar; TrackFilter m_trackFilter; - TrackSequenceFilter m_flipFilter; - std::function<TPV(const SG::AuxElement&)> m_associator; - getter_utils::CustomSequenceGetter<xAOD::TrackParticle> m_customSequenceGetter; + TrackSequenceFilter m_trackFlipper; + std::function<Tracks(const SG::AuxElement&)> m_associator; + getter_utils::SeqGetter<xAOD::TrackParticle> m_seqGetter; }; } diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/ConstituentsLoader.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/ConstituentsLoader.cxx index 6433f58e2121..bfd0f4d46c84 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/ConstituentsLoader.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/ConstituentsLoader.cxx @@ -143,18 +143,18 @@ namespace FlavorTagDiscriminants { trk_type_regexes, trk_sort_regexes, trk_select_regexes, flip_sequences, flip_config); config.type = ConstituentsType::TRACK; - config.output_name = "track_features"; + config.output_name = "tracks"; } - else if (name.find("flow") != std::string::npos){ + else if (name.find("flows") != std::string::npos){ config = get_iparticle_input_config( name, input_variables, iparticle_type_regexes); config.type = ConstituentsType::IPARTICLE; - config.output_name = "flow_features"; + config.output_name = "flows"; } else{ throw std::runtime_error( - "Unknown constituent type: " + name + ". Only tracks and neutrals are supported." + "Unknown constituent type: " + name + ". Only tracks and flows are supported." ); } return config; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/CustomGetterUtils.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/CustomGetterUtils.cxx index fda23bb83458..f9a648c2c950 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/CustomGetterUtils.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/CustomGetterUtils.cxx @@ -8,15 +8,10 @@ namespace { - using FlavorTagDiscriminants::getter_utils::SequenceFromTracks; - using FlavorTagDiscriminants::getter_utils::SequenceFromIParticles; + using FlavorTagDiscriminants::getter_utils::SequenceGetterFunc; // ______________________________________________________________________ - // Custom getters for jet-wise quantities - // - // this function is not at all optimized, but then it doesn't have - // to be since it should only be called in the initialization stage. - // - std::function<double(const xAOD::Jet&)> customGetter( + // Custom getters for jet input features + std::function<double(const xAOD::Jet&)> customJetGetter( const std::string& name) { if (name == "pt") { @@ -41,57 +36,54 @@ namespace { throw std::logic_error("no match for custom getter " + name); } - // _______________________________________________________________________ - // Custom getters for constituent variables (CJGetter -> Constituent and Jet Getter) - template <typename T> - class CJGetter - { - using F = std::function<double(const T&, const xAOD::Jet&)>; + // Custom getters for jet constituents + + // wraps non-custom getters into sequences, also adds a name + template <typename T, typename U> + class NamedSeqGetter{ private: - F m_getter; + SG::AuxElement::ConstAccessor<T> m_getter; + std::string m_name; public: - CJGetter(F getter): - m_getter(getter) + NamedSeqGetter(const std::string& name): + m_getter(name), + m_name(name) {} - std::vector<double> operator()( - const xAOD::Jet& jet, - const std::vector<const T*>& particles) const { + + std::pair<std::string, std::vector<double>> + operator()(const xAOD::Jet&, const std::vector<const U*>& constituents) const { std::vector<double> sequence; - sequence.reserve(particles.size()); - for (const auto* particle: particles) { - sequence.push_back(m_getter(*particle, jet)); + for (const U* el: constituents) { + sequence.push_back(m_getter(*el)); } - return sequence; + return {m_name, sequence}; } }; - - // The sequence getter takes in constituents and calculates arrays of - // values which are better suited for inputs to the NNs - template <typename T, typename U> - class SequenceGetter{ + // wraps custom getters into sequences, doesn't add a name + template <typename Const> + class CustomSeqGetter + { + using F = std::function<double(const Const&, const xAOD::Jet&)>; private: - SG::AuxElement::ConstAccessor<T> m_getter; - std::string m_name; + F m_getter; public: - SequenceGetter(const std::string& name): - m_getter(name), - m_name(name) - { - } - std::pair<std::string, std::vector<double>> operator()(const xAOD::Jet&, const std::vector<const U*>& consts) const { - std::vector<double> seq; - for (const U* el: consts) { - seq.push_back(m_getter(*el)); + CustomSeqGetter(F getter): m_getter(getter) {} + + std::vector<double> + operator()(const xAOD::Jet& jet, const std::vector<const Const*>& constituents) const { + std::vector<double> sequence; + sequence.reserve(constituents.size()); + for (const auto* constituent: constituents) { + sequence.push_back(m_getter(*constituent, jet)); } - return {m_name, seq}; + return sequence; } }; - // Getters from xAOD::TrackParticle with IP dependencies - std::optional<SequenceFromTracks> + std::optional<SequenceGetterFunc<xAOD::TrackParticle>> getterFromTracksWithIpDep( const std::string& name, const std::string& prefix) @@ -101,129 +93,130 @@ namespace { BTagTrackIpAccessor a(prefix); if (name == "IP3D_signed_d0_significance") { - return CJGetter<Tp>([a](const Tp& tp, const Jet& j){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet& j){ return a.getSignedIp(tp, j).ip3d_signed_d0_significance; }); } if (name == "IP3D_signed_z0_significance") { - return CJGetter<Tp>([a](const Tp& tp, const Jet& j){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet& j){ return a.getSignedIp(tp, j).ip3d_signed_z0_significance; }); } if (name == "IP2D_signed_d0") { - return CJGetter<Tp>([a](const Tp& tp, const Jet& j){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet& j){ return a.getSignedIp(tp, j).ip2d_signed_d0; }); } if (name == "IP3D_signed_d0") { - return CJGetter<Tp>([a](const Tp& tp, const Jet& j){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet& j){ return a.getSignedIp(tp, j).ip3d_signed_d0; }); } if (name == "IP3D_signed_z0") { - return CJGetter<Tp>([a](const Tp& tp, const Jet& j){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet& j){ return a.getSignedIp(tp, j).ip3d_signed_z0; }); } if (name == "d0" || name == "btagIp_d0") { - return CJGetter<Tp>([a](const Tp& tp, const Jet&){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet&){ return a.d0(tp); }); } if (name == "z0SinTheta" || name == "btagIp_z0SinTheta") { - return CJGetter<Tp>([a](const Tp& tp, const Jet&){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet&){ return a.z0SinTheta(tp); }); } if (name == "d0Uncertainty") { - return CJGetter<Tp>([a](const Tp& tp, const Jet&){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet&){ return a.d0Uncertainty(tp); }); } if (name == "z0SinThetaUncertainty") { - return CJGetter<Tp>([a](const Tp& tp, const Jet&){ + return CustomSeqGetter<Tp>([a](const Tp& tp, const Jet&){ return a.z0SinThetaUncertainty(tp); }); } return std::nullopt; } + // Getters from xAOD::TrackParticle without IP dependencies - std::optional<SequenceFromTracks> + std::optional<SequenceGetterFunc<xAOD::TrackParticle>> getterFromTracksNoIpDep(const std::string& name) { using Tp = xAOD::TrackParticle; using Jet = xAOD::Jet; if (name == "phiUncertainty") { - return CJGetter<Tp>([](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([](const Tp& tp, const Jet&) { return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(2)); }); } if (name == "thetaUncertainty") { - return CJGetter<Tp>([](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([](const Tp& tp, const Jet&) { return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(3)); }); } if (name == "qOverPUncertainty") { - return CJGetter<Tp>([](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([](const Tp& tp, const Jet&) { return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(4)); }); } if (name == "z0RelativeToBeamspot") { - return CJGetter<Tp>([](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([](const Tp& tp, const Jet&) { return tp.z0(); }); } if (name == "log_z0RelativeToBeamspotUncertainty") { - return CJGetter<Tp>([](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([](const Tp& tp, const Jet&) { return std::log(std::sqrt(tp.definingParametersCovMatrixDiagVec().at(1))); }); } if (name == "z0RelativeToBeamspotUncertainty") { - return CJGetter<Tp>([](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([](const Tp& tp, const Jet&) { return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(1)); }); } if (name == "numberOfPixelHitsInclDead") { SG::AuxElement::ConstAccessor<unsigned char> pix_hits("numberOfPixelHits"); SG::AuxElement::ConstAccessor<unsigned char> pix_dead("numberOfPixelDeadSensors"); - return CJGetter<Tp>([pix_hits, pix_dead](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([pix_hits, pix_dead](const Tp& tp, const Jet&) { return pix_hits(tp) + pix_dead(tp); }); } if (name == "numberOfSCTHitsInclDead") { SG::AuxElement::ConstAccessor<unsigned char> sct_hits("numberOfSCTHits"); SG::AuxElement::ConstAccessor<unsigned char> sct_dead("numberOfSCTDeadSensors"); - return CJGetter<Tp>([sct_hits, sct_dead](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([sct_hits, sct_dead](const Tp& tp, const Jet&) { return sct_hits(tp) + sct_dead(tp); }); } if (name == "numberOfInnermostPixelLayerHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerEndcapHits"); - return CJGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { return barrel_hits(tp) + endcap_hits(tp); }); } if (name == "numberOfNextToInnermostPixelLayerHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfNextToInnermostPixelLayerHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfNextToInnermostPixelLayerEndcapHits"); - return CJGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { return barrel_hits(tp) + endcap_hits(tp); }); } if (name == "numberOfInnermostPixelLayerSharedHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerSharedHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerSharedEndcapHits"); - return CJGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { return barrel_hits(tp) + endcap_hits(tp); }); } if (name == "numberOfInnermostPixelLayerSplitHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerSplitHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerSplitEndcapHits"); - return CJGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { + return CustomSeqGetter<Tp>([barrel_hits, endcap_hits](const Tp& tp, const Jet&) { return barrel_hits(tp) + endcap_hits(tp); }); } @@ -232,87 +225,77 @@ namespace { // Getters from general xAOD::IParticle and derived classes - template <typename T> - std::optional< - std::function<std::vector<double>( - const xAOD::Jet&, - const std::vector<const T*>&)> - > + template <typename T> std::optional<SequenceGetterFunc<T>> getterFromIParticles(const std::string& name) { using Jet = xAOD::Jet; - if (name == "pt") { - return CJGetter<T>([](const T& p, const Jet&) { + return CustomSeqGetter<T>([](const T& p, const Jet&) { return p.pt(); }); } if (name == "log_pt") { - return CJGetter<T>([](const T& p, const Jet&) { + return CustomSeqGetter<T>([](const T& p, const Jet&) { return std::log(p.pt()); }); } if (name == "ptfrac") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return p.pt() / j.pt(); }); } if (name == "log_ptfrac") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return std::log(p.pt() / j.pt()); }); } - if (name == "eta") { - return CJGetter<T>([](const T& p, const Jet&) { + return CustomSeqGetter<T>([](const T& p, const Jet&) { return p.eta(); }); } if (name == "deta") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return p.eta() - j.eta(); }); } if (name == "abs_deta") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return copysign(1.0, j.eta()) * (p.eta() - j.eta()); }); } - if (name == "phi") { - return CJGetter<T>([](const T& p, const Jet&) { + return CustomSeqGetter<T>([](const T& p, const Jet&) { return p.phi(); }); } if (name == "dphi") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return p.p4().DeltaPhi(j.p4()); }); } - if (name == "dr") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return p.p4().DeltaR(j.p4()); }); } if (name == "log_dr") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return std::log(p.p4().DeltaR(j.p4())); }); } if (name == "log_dr_nansafe") { - return CJGetter<T>([](const T& p, const Jet& j) { + return CustomSeqGetter<T>([](const T& p, const Jet& j) { return std::log(p.p4().DeltaR(j.p4()) + 1e-7); }); } - if (name == "mass") { - return CJGetter<T>([](const T& p, const Jet&) { + return CustomSeqGetter<T>([](const T& p, const Jet&) { return p.m(); }); } if (name == "energy") { - return CJGetter<T>([](const T& p, const Jet&) { + return CustomSeqGetter<T>([](const T& p, const Jet&) { return p.e(); }); } @@ -332,23 +315,18 @@ namespace { // // Case for jet variables std::function<std::pair<std::string, double>(const xAOD::Jet&)> - customGetterAndName(const std::string& name) { - auto getter = customGetter(name); + namedCustomJetGetter(const std::string& name) { + auto getter = customJetGetter(name); return [name, getter](const xAOD::Jet& j) { - return std::make_pair(name, getter(j)); - }; + return std::make_pair(name, getter(j)); + }; } // Case for constituent variables // Returns getter function with dependencies template <typename T> - std::pair< - std::function<std::vector<double>( - const xAOD::Jet&, - const std::vector<const T*>&)>, - std::set<std::string>> - customSequenceGetterWithDeps(const std::string& name, - const std::string& prefix) { + std::pair<SequenceGetterFunc<T>, std::set<std::string>> + buildCustomSeqGetter(const std::string& name, const std::string& prefix) { if constexpr (std::is_same_v<T, xAOD::TrackParticle>) { if (auto getter = getterFromTracksWithIpDep(name, prefix)) { @@ -369,13 +347,11 @@ namespace { // Class implementation // template <typename T> - std::pair<typename CustomSequenceGetter<T>::NamedSequenceFromConstituents, std::set<std::string>> - CustomSequenceGetter<T>::customNamedSeqGetterWithDeps(const std::string& name, - const std::string& prefix) { - auto [getter, deps] = customSequenceGetterWithDeps<T>(name, prefix); + std::pair<typename SeqGetter<T>::InputSequence, std::set<std::string>> + SeqGetter<T>::getNamedCustomSeqGetter(const std::string& name, const std::string& prefix) { + auto [getter, deps] = buildCustomSeqGetter<T>(name, prefix); return { - [n=name, g=getter](const xAOD::Jet& j, - const std::vector<const T*>& t) { + [n=name, g=getter](const xAOD::Jet& j, const std::vector<const T*>& t) { return std::make_pair(n, g(j, t)); }, deps @@ -383,26 +359,24 @@ namespace { } template <typename T> - std::pair<typename CustomSequenceGetter<T>::NamedSequenceFromConstituents, std::set<std::string>> - CustomSequenceGetter<T>::seqFromConsituents( - const InputVariableConfig& cfg, - const FTagOptions& options){ + std::pair<typename SeqGetter<T>::InputSequence, std::set<std::string>> + SeqGetter<T>::seqFromConsituents(const InputVariableConfig& cfg, const FTagOptions& options){ const std::string prefix = options.track_prefix; switch (cfg.type) { case ConstituentsEDMType::INT: return { - SequenceGetter<int, T>(cfg.name), {cfg.name} + NamedSeqGetter<int, T>(cfg.name), {cfg.name} }; case ConstituentsEDMType::FLOAT: return { - SequenceGetter<float, T>(cfg.name), {cfg.name} + NamedSeqGetter<float, T>(cfg.name), {cfg.name} }; case ConstituentsEDMType::CHAR: return { - SequenceGetter<char, T>(cfg.name), {cfg.name} + NamedSeqGetter<char, T>(cfg.name), {cfg.name} }; case ConstituentsEDMType::UCHAR: return { - SequenceGetter<unsigned char, T>(cfg.name), {cfg.name} + NamedSeqGetter<unsigned char, T>(cfg.name), {cfg.name} }; case ConstituentsEDMType::CUSTOM_GETTER: { - return customNamedSeqGetterWithDeps( + return getNamedCustomSeqGetter( cfg.name, options.track_prefix); } default: { @@ -412,54 +386,50 @@ namespace { } template <typename T> - CustomSequenceGetter<T>::CustomSequenceGetter( - std::vector<InputVariableConfig> inputs, - const FTagOptions& options) + SeqGetter<T>::SeqGetter(std::vector<InputVariableConfig> inputs, const FTagOptions& options) { - std::map<std::string, std::string> remap = options.remap_scalar; - for (const InputVariableConfig& input_cfg: inputs) { - auto [seqGetter, seq_deps] = seqFromConsituents( - input_cfg, options); - - if(input_cfg.flip_sign){ - auto seqGetter_flip=[g=seqGetter](const xAOD::Jet&jet, const Constituents& constituents){ - auto [n,v] = g(jet,constituents); - std::for_each(v.begin(), v.end(), [](double &n){ n=-1.0*n; }); - return std::make_pair(n,v); - }; - m_sequencesFromConstituents.push_back(seqGetter_flip); - } - else{ - m_sequencesFromConstituents.push_back(seqGetter); - } - m_deps.merge(seq_deps); - if (auto h = remap.extract(input_cfg.name)){ - m_used_remap.insert(h.key()); - } + std::map<std::string, std::string> remap = options.remap_scalar; + for (const InputVariableConfig& input_cfg: inputs) { + auto [seqGetter, seq_deps] = seqFromConsituents(input_cfg, options); + + if(input_cfg.flip_sign){ + auto seqGetter_flip=[g=seqGetter](const xAOD::Jet&jet, const Const& constituents){ + auto [n,v] = g(jet,constituents); + std::for_each(v.begin(), v.end(), [](double &n){ n=-1.0*n; }); + return std::make_pair(n,v); + }; + m_sequence_getters.push_back(seqGetter_flip); + } + else{ + m_sequence_getters.push_back(seqGetter); } + m_deps.merge(seq_deps); + if (auto h = remap.extract(input_cfg.name)){ + m_used_remap.insert(h.key()); + } + } } template <typename T> - std::pair<std::vector<float>, std::vector<int64_t>> CustomSequenceGetter<T>::getFeats( - const xAOD::Jet& jet, const Constituents& constituents) const + std::pair<std::vector<float>, std::vector<int64_t>> SeqGetter<T>::getFeats( + const xAOD::Jet& jet, const Const& constituents) const { std::vector<float> cnsts_feats; - int num_vars = m_sequencesFromConstituents.size(); + int num_vars = m_sequence_getters.size(); int num_cnsts = 0; int cnst_var_idx = 0; - for (const auto& seq_builder: m_sequencesFromConstituents){ - auto double_vec = seq_builder(jet, constituents).second; + for (const auto& seq_getter: m_sequence_getters){ + auto input_sequence = seq_getter(jet, constituents).second; if (cnst_var_idx==0){ - num_cnsts = static_cast<int>(double_vec.size()); - cnsts_feats.resize(num_cnsts * num_vars); + num_cnsts = static_cast<int>(input_sequence.size()); + cnsts_feats.resize(num_cnsts * num_vars); } // need to transpose + flatten - for (unsigned int cnst_idx=0; cnst_idx<double_vec.size(); cnst_idx++){ - cnsts_feats.at(cnst_idx*num_vars + cnst_var_idx) - = double_vec.at(cnst_idx); + for (unsigned int cnst_idx=0; cnst_idx<input_sequence.size(); cnst_idx++){ + cnsts_feats.at(cnst_idx*num_vars + cnst_var_idx) = input_sequence.at(cnst_idx); } cnst_var_idx++; } @@ -468,28 +438,28 @@ namespace { } template <typename T> - std::map<std::string, std::vector<double>> CustomSequenceGetter<T>::getDL2Feats( - const xAOD::Jet& jet, const Constituents& constituents) const + std::map<std::string, std::vector<double>> SeqGetter<T>::getDL2Feats( + const xAOD::Jet& jet, const Const& constituents) const { std::map<std::string, std::vector<double>> feats; - for (const auto& seq_builder: m_sequencesFromConstituents){ - feats.insert(seq_builder(jet, constituents)); + for (const auto& seq_getter: m_sequence_getters){ + feats.insert(seq_getter(jet, constituents)); } return feats; } template <typename T> - std::set<std::string> CustomSequenceGetter<T>::getDependencies() const { + std::set<std::string> SeqGetter<T>::getDependencies() const { return m_deps; } template <typename T> - std::set<std::string> CustomSequenceGetter<T>::getUsedRemap() const { + std::set<std::string> SeqGetter<T>::getUsedRemap() const { return m_used_remap; } // Explicit instantiations of supported types (IParticle, TrackParticle) - template class CustomSequenceGetter<xAOD::IParticle>; - template class CustomSequenceGetter<xAOD::TrackParticle>; + template class SeqGetter<xAOD::IParticle>; + template class SeqGetter<xAOD::TrackParticle>; } } \ No newline at end of file diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx index 59a878c634e5..60ede87dcb20 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx @@ -245,7 +245,7 @@ namespace FlavorTagDiscriminants { rewriteFlipConfig(config, flip_converters); } - // build the standard inputs + // build the jet inputs // type and default value-finding regexes are hardcoded for now TypeRegexes type_regexes = { @@ -285,8 +285,7 @@ namespace FlavorTagDiscriminants { std::vector<FTagInputConfig> input_config; for (auto& node: config.inputs){ // allow the user to remape some of the inputs - remap_inputs(node.variables, remap_scalar, - node.defaults); + remap_inputs(node.variables, remap_scalar, node.defaults); std::vector<std::string> input_names; for (const auto& var: node.variables) { @@ -299,12 +298,10 @@ namespace FlavorTagDiscriminants { throw std::logic_error( "We don't currently support multiple scalar input nodes"); } - input_config = get_input_config( - input_names, type_regexes, default_flag_regexes); + input_config = get_input_config(input_names, type_regexes, default_flag_regexes); } // build the constituents inputs - std::vector<std::pair<std::string, std::vector<std::string>>> constituent_names; for (auto& node: config.input_sequences) { remap_inputs(node.variables, remap_scalar, @@ -363,7 +360,7 @@ namespace FlavorTagDiscriminants { deps.bTagInputs.insert(input.name); varsFromBTag.push_back(filler); } else { - varsFromJet.push_back(getter_utils::customGetterAndName(input.name)); + varsFromJet.push_back(getter_utils::namedCustomJetGetter(input.name)); } if (input.default_flag.size() > 0) { deps.bTagInputs.insert(input.default_flag); diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx index 156f0b5e008f..040e7470bb7d 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx @@ -144,8 +144,9 @@ namespace FlavorTagDiscriminants { // prepare input // ------------- - std::map<std::string, input_pair> gnn_input; + std::map<std::string, Inputs> gnn_inputs; + // jet level inputs std::vector<float> jet_feat; for (const auto& getter: m_varsFromBTag) { jet_feat.push_back(getter(btag).second); @@ -154,18 +155,27 @@ namespace FlavorTagDiscriminants { jet_feat.push_back(getter(jet).second); } std::vector<int64_t> jet_feat_dim = {1, static_cast<int64_t>(jet_feat.size())}; + Inputs jet_info(jet_feat, jet_feat_dim); + if (m_onnxUtil->getOnnxModelVersion() == OnnxModelVersion::V2) { + gnn_inputs.insert({"jets", jet_info}); + } else { + gnn_inputs.insert({"jet_features", jet_info}); + } - input_pair jet_info (jet_feat, jet_feat_dim); - gnn_input.insert({"jet_features", jet_info}); - + // constituent level inputs Tracks input_tracks; - for (auto loader : m_constituentsLoaders){ - auto [sequence_name, sequence_data, sequence_constituents] = loader->getData(jet, btag); - gnn_input.insert({sequence_name, sequence_data}); - // collect tracks for decoration + auto [input_name, input_data, input_objects] = loader->getData(jet, btag); + if (m_onnxUtil->getOnnxModelVersion() != OnnxModelVersion::V2) { + input_name.pop_back(); + input_name.append("_features"); + } + gnn_inputs.insert({input_name, input_data}); + + // for now we only collect tracks for aux task decoration + // they have to be converted back from IParticle to TrackParticle first if (loader->getType() == ConstituentsType::TRACK){ - for (auto constituent : sequence_constituents){ + for (auto constituent : input_objects){ input_tracks.push_back(dynamic_cast<const xAOD::TrackParticle*>(constituent)); } } @@ -173,7 +183,7 @@ namespace FlavorTagDiscriminants { // run inference // ------------- - auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input); + auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_inputs); // decorate outputs // ---------------- diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx index 725e8b3c27f4..e48da03d3aeb 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx @@ -30,7 +30,7 @@ namespace FlavorTagDiscriminants { ): IConstituentsLoader(cfg), m_iparticleSortVar(IParticlesLoader::iparticleSortVar(cfg.order)), - m_customSequenceGetter(getter_utils::CustomSequenceGetter<xAOD::IParticle>( + m_seqGetter(getter_utils::SeqGetter<xAOD::IParticle>( cfg.inputs, options)) { SG::AuxElement::ConstAccessor<PartLinks> acc("constituentLinks"); @@ -50,7 +50,7 @@ namespace FlavorTagDiscriminants { } else { m_isCharged = false; } - m_used_remap = m_customSequenceGetter.getUsedRemap(); + m_used_remap = m_seqGetter.getUsedRemap(); m_name = cfg.name; } @@ -87,12 +87,12 @@ namespace FlavorTagDiscriminants { return only_particles; } - std::tuple<std::string, input_pair, std::vector<const xAOD::IParticle*>> IParticlesLoader::getData( + std::tuple<std::string, Inputs, std::vector<const xAOD::IParticle*>> IParticlesLoader::getData( const xAOD::Jet& jet, [[maybe_unused]] const SG::AuxElement& btag) const { IParticles sorted_particles = getIParticlesFromJet(jet); - return std::make_tuple(m_config.output_name, m_customSequenceGetter.getFeats(jet, sorted_particles), sorted_particles); + return std::make_tuple(m_config.output_name, m_seqGetter.getFeats(jet, sorted_particles), sorted_particles); } FTagDataDependencyNames IParticlesLoader::getDependencies() const { diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx index ebcf3703b50d..ea04a2039dda 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/OnnxUtil.cxx @@ -58,8 +58,8 @@ namespace FlavorTagDiscriminants { // iterate over input nodes and get their names for (size_t i = 0; i < m_num_inputs; i++) { - auto input_name = m_session->GetInputNameAllocated(i, allocator); - m_input_node_names.push_back(input_name.get()); + std::string input_name = m_session->GetInputNameAllocated(i, allocator).get(); + m_input_node_names.push_back(input_name); } // iterate over output nodes and get their configuration @@ -145,7 +145,7 @@ namespace FlavorTagDiscriminants { OnnxUtil::InferenceOutput OnnxUtil::runInference( - std::map<std::string, input_pair>& gnn_inputs) const { + std::map<std::string, Inputs>& gnn_inputs) const { std::vector<float> input_tensor_values; @@ -154,7 +154,7 @@ namespace FlavorTagDiscriminants { OrtArenaAllocator, OrtMemTypeDefault ); std::vector<Ort::Value> input_tensors; - for (auto const &node_name : m_input_node_names){ + for (auto& node_name : m_input_node_names) { input_tensors.push_back(Ort::Value::CreateTensor<float>( memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(), gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size()) diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx index 05a406b655af..87757eac7778 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx @@ -8,7 +8,6 @@ Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration #include "FlavorTagDiscriminants/StringUtils.h" namespace FlavorTagDiscriminants { - // factory for functions which return the sort variable we // use to order tracks TracksLoader::TrackSortVar TracksLoader::trackSortVar( @@ -212,12 +211,9 @@ namespace FlavorTagDiscriminants { // start by defining the raw functions, there's a factory // function below to convert the configuration enums to a // std::function - std::vector<const xAOD::TrackParticle*> negativeIpOnly( - BTagTrackIpAccessor& aug, - const std::vector<const xAOD::TrackParticle*>& tracks, - const xAOD::Jet& j) + Tracks negativeIpOnly(BTagTrackIpAccessor& aug, const Tracks& tracks, const xAOD::Jet& j) { - std::vector<const xAOD::TrackParticle*> filtered; + Tracks filtered; // we want to reverse the order of the tracks as part of the // flipping for (auto ti = tracks.crbegin(); ti != tracks.crend(); ti++) { @@ -229,7 +225,7 @@ namespace FlavorTagDiscriminants { } // factory function - std::pair<TracksLoader::TrackSequenceFilter,std::set<std::string>> TracksLoader::flipFilter( + std::pair<TracksLoader::TrackSequenceFilter,std::set<std::string>> TracksLoader::trackFlipper( const FTagOptions& options) { namespace ph = std::placeholders; // for _1, _2, _3 @@ -271,8 +267,8 @@ namespace FlavorTagDiscriminants { IConstituentsLoader(cfg), m_trackSortVar(TracksLoader::trackSortVar(cfg.order, options)), m_trackFilter(TracksLoader::trackFilter(cfg.selection, options).first), - m_flipFilter(TracksLoader::flipFilter(options).first), - m_customSequenceGetter(getter_utils::CustomSequenceGetter<xAOD::TrackParticle>( + m_trackFlipper(TracksLoader::trackFlipper(options).first), + m_seqGetter(getter_utils::SeqGetter<xAOD::TrackParticle>( cfg.inputs, options)) { // We have several ways to get tracks: either we retrieve an @@ -283,8 +279,8 @@ namespace FlavorTagDiscriminants { // if (options.track_link_type == TrackLinkType::IPARTICLE) { SG::AuxElement::ConstAccessor<PartLinks> acc(options.track_link_name); - m_associator = [acc](const SG::AuxElement& btag) -> TPV { - TPV tracks; + m_associator = [acc](const SG::AuxElement& btag) -> Tracks { + Tracks tracks; for (const ElementLink<IPC>& link: acc(btag)) { if (!link.isValid()) { throw std::logic_error("invalid particle link"); @@ -299,8 +295,8 @@ namespace FlavorTagDiscriminants { }; } else if (options.track_link_type == TrackLinkType::TRACK_PARTICLE){ SG::AuxElement::ConstAccessor<TrackLinks> acc(options.track_link_name); - m_associator = [acc](const SG::AuxElement& btag) -> TPV { - TPV tracks; + m_associator = [acc](const SG::AuxElement& btag) -> Tracks { + Tracks tracks; for (const ElementLink<TPC>& link: acc(btag)) { if (!link.isValid()) { throw std::logic_error("invalid track link"); @@ -313,15 +309,15 @@ namespace FlavorTagDiscriminants { throw std::logic_error("Unknown TrackLinkType"); } auto track_data_deps = trackFilter(cfg.selection, options).second; - track_data_deps.merge(flipFilter(options).second); - track_data_deps.merge(m_customSequenceGetter.getDependencies()); + track_data_deps.merge(trackFlipper(options).second); + track_data_deps.merge(m_seqGetter.getDependencies()); m_deps.trackInputs.merge(track_data_deps); m_deps.bTagInputs.insert(options.track_link_name); - m_used_remap = m_customSequenceGetter.getUsedRemap(); + m_used_remap = m_seqGetter.getUsedRemap(); m_name = cfg.name; } - std::vector<const xAOD::TrackParticle*> TracksLoader::getTracksFromJet( + Tracks TracksLoader::getTracksFromJet( const xAOD::Jet& jet, const SG::AuxElement& btag) const { @@ -340,20 +336,21 @@ namespace FlavorTagDiscriminants { return only_tracks; } - std::tuple<std::string, input_pair, std::vector<const xAOD::IParticle*>> TracksLoader::getData( - const xAOD::Jet& jet, - [[maybe_unused]] const SG::AuxElement& btag) const { - Tracks flipped_tracks; + std::tuple<std::string, Inputs, std::vector<const xAOD::IParticle*>> + TracksLoader::getData(const xAOD::Jet& jet, [[maybe_unused]] const SG::AuxElement& btag) const + { Tracks sorted_tracks = getTracksFromJet(jet, btag); - std::vector<const xAOD::IParticle*> flipped_tracks_ip; - - flipped_tracks = m_flipFilter(sorted_tracks, jet); + Tracks flipped_tracks = m_trackFlipper(sorted_tracks, jet); + // cast to IParticle for aux task decoration + // this could probably be templated since we cast back again later + std::vector<const xAOD::IParticle*> flipped_iparticles; for (const auto& trk: flipped_tracks) { - flipped_tracks_ip.push_back(trk); + flipped_iparticles.push_back(trk); } - return std::make_tuple(m_config.output_name, m_customSequenceGetter.getFeats(jet, flipped_tracks), flipped_tracks_ip); + Inputs features = m_seqGetter.getFeats(jet, flipped_tracks); + return std::make_tuple(m_config.output_name, features, flipped_iparticles); } std::tuple<char, std::map<std::string, std::vector<double>>> TracksLoader::getDL2Data( @@ -366,9 +363,9 @@ namespace FlavorTagDiscriminants { Tracks sorted_tracks = getTracksFromJet(jet, btag); if (ip_checker(sorted_tracks)) invalid = 1; - flipped_tracks = m_flipFilter(sorted_tracks, jet); + flipped_tracks = m_trackFlipper(sorted_tracks, jet); - auto feats = m_customSequenceGetter.getDL2Feats(jet, flipped_tracks); + auto feats = m_seqGetter.getDL2Feats(jet, flipped_tracks); return std::make_tuple(invalid, feats); }; -- GitLab