From f49d537cdadc0ad06bdfa7f2507c4386f88f0953 Mon Sep 17 00:00:00 2001 From: Dmitrii Kobylianskii <dmitrii.kobylianskii@cern.ch> Date: Sun, 4 Feb 2024 17:32:55 +0100 Subject: [PATCH] - Clean code - Add customSequenceGetter - Make work with neutrals+tracks --- .../ConstituentsLoader.h | 30 +- .../FlavorTagDiscriminants/GNN.h | 5 +- .../FlavorTagDiscriminants/IParticlesLoader.h | 8 +- .../FlavorTagDiscriminants/SequenceGetter.h | 64 +++- .../FlavorTagDiscriminants/TracksLoader.h | 14 +- .../Root/DataPrepUtilities.cxx | 6 +- .../FlavorTagDiscriminants/Root/GNN.cxx | 90 ++--- .../Root/IParticlesLoader.cxx | 81 +---- .../Root/SequenceGetter.cxx | 338 ++++++++++++------ .../Root/TracksLoader.cxx | 123 ++----- 10 files changed, 351 insertions(+), 408 deletions(-) diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h index de72ae86ec0f..185bdc664250 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h @@ -8,9 +8,8 @@ // local includes #include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/AssociationEnums.h" -#include "FlavorTagDiscriminants/FTagDataDependencyNames.h" -#include "FlavorTagDiscriminants/GNNConfig.h" #include "FlavorTagDiscriminants/OnnxUtil.h" +#include "FlavorTagDiscriminants/FTagDataDependencyNames.h" // EDM includes #include "xAODJet/Jet.h" @@ -54,28 +53,6 @@ namespace FlavorTagDiscriminants { std::vector<FTagConstituentsInputConfig> inputs; }; - // The sequence getter takes in constituents and calculates arrays of - // values which are better suited for inputs to the NNs - template <typename T, typename U> - class SequenceGetter{ - private: - SG::AuxElement::ConstAccessor<T> m_getter; - std::string m_name; - public: - SequenceGetter(const std::string& name): - m_getter(name), - m_name(name) - { - } - std::pair<std::string, std::vector<double>> operator()(const xAOD::Jet&, const std::vector<const U*>& consts) const { - std::vector<double> seq; - for (const U* el: consts) { - seq.push_back(m_getter(*el)); - } - return {m_name, seq}; - } - }; - // Virtual class to represent loader of any type of constituents class ConstituentsLoader { public: @@ -85,10 +62,13 @@ namespace FlavorTagDiscriminants { virtual ~ConstituentsLoader() { }; virtual std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const = 0; + virtual FTagDataDependencyNames getDependencies() const = 0; + virtual std::set<std::string> getUsedRemap() const = 0; + protected: FTagDataDependencyNames deps; - private: FTagConstituentsSequenceConfig config; + std::set<std::string> used_remap; }; } diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h index 9c5588935423..f50bb77aaa6d 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h @@ -60,7 +60,6 @@ namespace FlavorTagDiscriminants { virtual std::set<std::string> getConstituentAuxInputKeys() const; std::shared_ptr<const OnnxUtil> m_onnxUtil; - private: // type definitions for ONNX output decorators using TPC = xAOD::TrackParticleContainer; @@ -90,9 +89,7 @@ namespace FlavorTagDiscriminants { std::vector<internal::VarFromBTag> m_varsFromBTag; std::vector<internal::VarFromJet> m_varsFromJet; std::vector<internal::TrackSequenceBuilder> m_trackSequenceBuilders; - // std::vector<std::shared_ptr<ConstituentsLoader>> m_constituentsLoaders; - // std::shared_ptr<const TracksLoader> m_trackLoader; - std::shared_ptr<const IParticlesLoader> m_flowLoader; + std::vector<std::shared_ptr<ConstituentsLoader>> m_constituentsLoaders; Decorators m_decorators; float m_defaultValue; diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h index 5d2225da3452..994acfdc9b04 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h @@ -8,6 +8,7 @@ // local includes #include "FlavorTagDiscriminants/ConstituentsLoader.h" #include "FlavorTagDiscriminants/DataPrepUtilities.h" +#include "FlavorTagDiscriminants/SequenceGetter.h" // EDM includes #include "xAODJet/Jet.h" @@ -32,6 +33,8 @@ namespace FlavorTagDiscriminants { // TracksLoader(); IParticlesLoader(FTagConstituentsSequenceConfig, const FTagOptions& options); std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const override ; + FTagDataDependencyNames getDependencies() const override; + std::set<std::string> getUsedRemap() const override; protected: // typedefs typedef xAOD::Jet Jet; @@ -54,13 +57,14 @@ namespace FlavorTagDiscriminants { IParticleSortVar iparticleSortVar(ConstituentsSortOrder, const FTagOptions&); std::vector<const xAOD::IParticle*> getIParticlesFromJet(const xAOD::Jet& jet) const; - std::pair<SeqFromIParticles,std::set<std::string>> seqFromIParticles( - const FTagConstituentsInputConfig&, const FTagOptions&); std::vector<SeqFromIParticles> m_sequencesFromIParticles; + sequence_getter::CustomSequenceGetter m_customSequenceGetter; IParticleSortVar m_iparticleSortVar; std::function<IPV(const Jet&)> m_associator; bool m_isCharged; + + // FTagDataDependencyNames deps; }; } diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h index c7c2c5676765..0402192725c7 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h @@ -13,6 +13,9 @@ #include "xAODBase/IParticle.h" #include "xAODTracking/TrackParticleFwd.h" #include "AthContainers/AuxElement.h" +#include "FlavorTagDiscriminants/ConstituentsLoader.h" +#include "FlavorTagDiscriminants/DataPrepUtilities.h" + #include <functional> #include <string> @@ -23,7 +26,7 @@ namespace FlavorTagDiscriminants { - /// Factory function to produce TrackParticle -> vector<double> functions + /// Class to produce IParticle -> vector<double> functions /// /// DL2 configures the its inputs when the algorithm is initalized, /// meaning that the list of track and jet properties that are used @@ -44,26 +47,51 @@ namespace FlavorTagDiscriminants { /// namespace sequence_getter { - - using SequenceFromIParticles = std::function<std::vector<double>( - const xAOD::Jet&, - const std::vector<const xAOD::IParticle*>&)>; - - std::pair<SequenceFromIParticles, std::set<std::string>> - customSequenceGetterWithDeps( - const std::string& name, // name of the getter - const std::string& prefix // prefix for track accessor - ); - std::function<std::pair<std::string, double>(const xAOD::Jet&)> customGetterAndName(const std::string&); - std::pair<std::function<std::pair<std::string, std::vector<double>>( - const xAOD::Jet&, - const std::vector<const xAOD::IParticle*>&)>, - std::set<std::string>> - customNamedSeqGetterWithDeps(const std::string&, const std::string&); - } + class CustomSequenceGetter { + public: + using IParticles = std::vector<const xAOD::IParticle*>; + using SequenceFromConstituents = std::function<std::vector<double>( + const xAOD::Jet&, + const IParticles&)>; + + using NamedSequenceFromConstituents = std::function<std::pair<std::string, std::vector<double>>( + const xAOD::Jet&, + const IParticles&)>; + + CustomSequenceGetter(std::vector<FTagConstituentsInputConfig> inputs, + const FTagOptions& options); + + std::pair<std::vector<float>, std::vector<int64_t>> getFeats(const xAOD::Jet& jet, const IParticles& constituents) const; + + std::set<std::string> getDependencies() const; + std::set<std::string> getUsedRemap() const; + + private: + + std::pair<NamedSequenceFromConstituents, std::set<std::string>> + customNamedSeqGetterWithDeps(const std::string& name, const std::string& prefix); + + std::pair<SequenceFromConstituents, std::set<std::string>> + customSequenceGetterWithDeps(const std::string& name, + const std::string& prefix); + + std::optional<SequenceFromConstituents> sequenceNoIpDep(const std::string& name); + + std::optional<SequenceFromConstituents> sequenceWithIpDep(const std::string& name, + const std::string& prefix); + + std::pair<NamedSequenceFromConstituents, std::set<std::string>> seqFromConsituents( + const FTagConstituentsInputConfig& cfg, + const FTagOptions& options); + + std::vector<NamedSequenceFromConstituents> sequencesFromConstituents; + std::set<std::string> deps; + std::set<std::string> used_remap; + }; + } } #endif diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h index caefe3d21dec..e0efa6ec915c 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h @@ -8,11 +8,12 @@ // local includes #include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/AssociationEnums.h" -#include "FlavorTagDiscriminants/FTagDataDependencyNames.h" +// #include "FlavorTagDiscriminants/FTagDataDependencyNames.h" #include "FlavorTagDiscriminants/ConstituentsLoader.h" #include "FlavorTagDiscriminants/DataPrepUtilities.h" #include "FlavorTagDiscriminants/BTagTrackIpAccessor.h" +#include "FlavorTagDiscriminants/SequenceGetter.h" // EDM includes #include "xAODJet/Jet.h" @@ -50,7 +51,9 @@ namespace FlavorTagDiscriminants { public: // TracksLoader(); TracksLoader(FTagConstituentsSequenceConfig, const FTagOptions& options); - std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const override ; + std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const override; + FTagDataDependencyNames getDependencies() const override; + std::set<std::string> getUsedRemap() const override; private: // typedefs typedef std::pair<std::string, double> NamedVar; @@ -65,9 +68,6 @@ namespace FlavorTagDiscriminants { typedef std::function<Tracks(const Tracks&, const Jet&)> TrackSequenceFilter; - // getter function - typedef std::function<NamedSeq(const Jet&, const Tracks&)> SeqFromTracks; - // usings for track getter using AE = SG::AuxElement; using IPC = xAOD::IParticleContainer; @@ -79,8 +79,6 @@ namespace FlavorTagDiscriminants { TrackSortVar trackSortVar(ConstituentsSortOrder, const FTagOptions&); std::pair<TrackFilter,std::set<std::string>> trackFilter( ConstituentsSelection, const FTagOptions&); - std::pair<SeqFromTracks,std::set<std::string>> seqFromTracks( - const FTagConstituentsInputConfig&, const FTagOptions&); std::pair<TrackSequenceFilter,std::set<std::string>> flipFilter( const FTagOptions&); @@ -89,8 +87,8 @@ namespace FlavorTagDiscriminants { TrackSortVar m_trackSortVar; TrackFilter m_trackFilter; TrackSequenceFilter m_flipFilter; - std::vector<SeqFromTracks> m_sequencesFromTracks; std::function<TPV(const SG::AuxElement&)> m_associator; + sequence_getter::CustomSequenceGetter m_customSequenceGetter; }; } diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx index db9a8432f689..335115a1cc41 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx @@ -11,6 +11,8 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration #include "xAODBTagging/BTaggingUtilities.h" +#include <iostream> + namespace { using namespace FlavorTagDiscriminants; @@ -646,10 +648,6 @@ namespace FlavorTagDiscriminants { // we rewrite the inputs if we're using flip taggers StringRegexes flip_converters = getFlipConverters(flip_config); - // some sequences also need to be sign-flipped. We apply this by - // changing the input scaling and normalizations - std::regex flip_sequences(".*signed_[dz]0.*"); - if (flip_config != FlipTagConfig::STANDARD) { rewriteFlipConfig(config, flip_converters); } diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx index 3e29cf48c63c..9b5bd7efc2ee 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx @@ -28,9 +28,7 @@ namespace FlavorTagDiscriminants { m_onnxUtil(nullptr), m_jetLink(jetLinkName), m_defaultValue(o.default_output_value), - m_decorate_tracks(o.decorate_tracks), - // m_trackLoader(nullptr), - m_flowLoader(nullptr) + m_decorate_tracks(o.decorate_tracks) { // track decoration is allowed only for non-production builds if (m_decorate_tracks) { @@ -55,22 +53,20 @@ namespace FlavorTagDiscriminants { // Create configuration objects for data preprocessing. auto [inputs, constituents_configs, options] = dataprep::createGetterConfigNew( lwt_config, o.flip_config, o.variable_remapping, o.track_link_type); - std::cout << "TEST 1 " << std::endl; - // auto [tracksLoaderConfig, tmp_options] = createTracksLoaderConfig( - // config, flip_config, variableRemapping, trackLinkType - // ); - std::cout << constituents_configs.size() << std::endl; std::vector<FTagTrackSequenceConfig> track_sequences; - if (constituents_configs.size() > 0){ - auto tracksLoaderConfig = constituents_configs[0]; - std::cout << "TEST 2 " << std::endl; - // m_trackLoader = std::make_shared<TracksLoader>(tracksLoaderConfig, options); - track_sequences = convertTracksConfigBack(constituents_configs[0]); - std::cout << "TEST 2F " << std::endl; - m_flowLoader = std::make_shared<IParticlesLoader>(tracksLoaderConfig, options); - // m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(tracksLoaderConfig, options)); + for (auto config : constituents_configs){ + std::cout << "Config name: " << config.name << std::endl; + if (config.name.find("tracks") != std::string::npos){ + m_constituentsLoaders.push_back(std::make_shared<TracksLoader>(config, options)); + track_sequences = convertTracksConfigBack(config); + } + else if (config.name.find("flow") != std::string::npos){ + m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(config, options)); + } + else { + throw std::runtime_error("Unknown constituent type. Only tracks and neutrals are supported."); + } } - std::cout << "TEST 3 " << std::endl; // Initialize jet and b-tagging input getters. auto [vb, vj, ds] = dataprep::createBvarGetters(inputs); m_varsFromBTag = vb; @@ -89,8 +85,11 @@ namespace FlavorTagDiscriminants { auto [dd, rd] = createDecorators(gnn_output_config, options); m_dataDependencyNames += dd; - // Check that all remaps have been used. - rd.merge(rt); + // Update dependencies and used remap from the constituents loaders. + for (auto loader : m_constituentsLoaders){ + m_dataDependencyNames += loader->getDependencies(); + rd.merge(loader->getUsedRemap()); + } dataprep::checkForUnusedRemaps(options.remap_scalar, rd); } @@ -157,57 +156,12 @@ namespace FlavorTagDiscriminants { } Tracks input_tracks; for (const auto& builder: m_trackSequenceBuilders) { - std::vector<float> track_feat; // (#tracks, #feats).flatten - int num_track_vars = static_cast<int>(builder.sequencesFromTracks.size()); - int num_tracks = 0; - Tracks sorted_tracks = builder.tracksFromJet(jet, btag); input_tracks = builder.flipFilter(sorted_tracks, jet); - - int track_var_idx=0; - for (const auto& seq_builder: builder.sequencesFromTracks) { - auto double_vec = seq_builder(jet, input_tracks).second; - - if (track_var_idx==0){ - num_tracks = static_cast<int>(double_vec.size()); - track_feat.resize(num_tracks * num_track_vars); - } - - // need to transpose + flatten - for (unsigned int track_idx=0; track_idx<double_vec.size(); track_idx++){ - track_feat.at(track_idx*num_track_vars + track_var_idx) - = double_vec.at(track_idx); - } - track_var_idx++; - } - std::vector<int64_t> track_feat_dim = {num_tracks, num_track_vars}; - - input_pair track_info (track_feat, track_feat_dim); - if (m_flowLoader){ - auto loader_out = m_flowLoader->getData(jet, btag); - auto loader_track_feat = loader_out.second.first; - std::cout << "Vector size: TRACK " << track_feat.size() << " FLOW " << loader_track_feat.size() << std::endl; - // for (uint64_t i = 0; i < track_feat.size(); i++){ - // if (std::fabs(loader_track_feat.at(i) - track_feat.at(i)) > 0.001) { - // std::cout << "DIFFERENCE " << i << std::endl; - // } - // } - gnn_input.insert({"flow_features", loader_out.second}); - } - // if (m_trackLoader){ - // auto loader_out = m_trackLoader->getData(jet, btag); - // auto loader_track_feat = loader_out.second.first; - // std::cout << "Vector size: OLD " << track_feat.size() << " NEW " << loader_track_feat.size() << std::endl; - // for (uint64_t i = 0; i < track_feat.size(); i++){ - // if (std::fabs(loader_track_feat.at(i) - track_feat.at(i)) > 0.001) { - // std::cout << "DIFFERENCE " << i << std::endl; - // } - // } - // } - // if (m_constituentsLoaders.size() > 0){ - // auto flow_out = m_constituentsLoaders[0]->getData(jet, btag); - // } - // gnn_input.insert({"flow_features", track_info}); + } + for (auto loader : m_constituentsLoaders){ + auto loader_out = loader->getData(jet, btag); + gnn_input.insert({loader_out.first, loader_out.second}); } // run inference diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx index e2e6e9ad8d72..73f0a93d1d57 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx @@ -3,10 +3,9 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration */ #include "FlavorTagDiscriminants/IParticlesLoader.h" -#include "FlavorTagDiscriminants/FTagDataDependencyNames.h" +// #include "FlavorTagDiscriminants/FTagDataDependencyNames.h" #include "xAODPFlow/FlowElement.h" -#include "FlavorTagDiscriminants/SequenceGetter.h" #include <iostream> namespace { @@ -132,42 +131,14 @@ namespace FlavorTagDiscriminants { } } // end of iparticle sort getter - // factory for functions that build std::vector objects from - // iparticle sequences - std::pair<IParticlesLoader::SeqFromIParticles,std::set<std::string>> IParticlesLoader::seqFromIParticles( - const FTagConstituentsInputConfig& cfg, - const FTagOptions& options) - { - const std::string prefix = options.track_prefix; - switch (cfg.type) { - case ConstituentsEDMType::INT: return { - SequenceGetter<int, xAOD::IParticle>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::FLOAT: return { - SequenceGetter<float, xAOD::IParticle>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::CHAR: return { - SequenceGetter<char, xAOD::IParticle>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::UCHAR: return { - SequenceGetter<unsigned char, xAOD::IParticle>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::CUSTOM_GETTER: { - return sequence_getter::customNamedSeqGetterWithDeps( - cfg.name, options.track_prefix); - } - default: { - throw std::logic_error("Unknown EDM type for iparticles"); - } - } - } - IParticlesLoader::IParticlesLoader( FTagConstituentsSequenceConfig cfg, const FTagOptions& options ): ConstituentsLoader(cfg), - m_iparticleSortVar(IParticlesLoader::iparticleSortVar(cfg.order, options)) + m_iparticleSortVar(IParticlesLoader::iparticleSortVar(cfg.order, options)), + m_customSequenceGetter(sequence_getter::CustomSequenceGetter( + cfg.inputs, options)) { SG::AuxElement::ConstAccessor<PartLinks> acc("constituentLinks"); m_associator = [acc](const xAOD::Jet& jet) -> IPV { @@ -187,22 +158,7 @@ namespace FlavorTagDiscriminants { } else { m_isCharged = false; } - - std::map<std::string, std::string> remap = options.remap_scalar; - std::set<std::string> used_remap; - std::set<std::string> iparticle_data_deps; - - for (const FTagConstituentsInputConfig& input_cfg: cfg.inputs) { - std::cout << input_cfg.name << std::endl; - auto [seqGetter, deps] = seqFromIParticles( - input_cfg, options); - - m_sequencesFromIParticles.push_back(seqGetter); - iparticle_data_deps.merge(deps); - if (auto h = remap.extract(input_cfg.name)){ - used_remap.insert(h.key()); - } - } + used_remap = m_customSequenceGetter.getUsedRemap(); std::cout << "TEST: IParticlesLoader loaded " << std::endl; } @@ -235,31 +191,20 @@ namespace FlavorTagDiscriminants { } only_particles.push_back(obj); } - std::cout << "TEST: SIZE OF only_particles: " << only_particles.size() << std::endl; return only_particles; } std::pair<std::string, input_pair> IParticlesLoader::getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const { - std::vector<float> particle_feat; - int num_iparticle_vars = static_cast<int>(m_sequencesFromIParticles.size()); - int num_iparticles = 0; - IParticles sorted_particles = getIParticlesFromJet(jet); - int iparticle_var_idx=0; - for (const auto& seq_builder: m_sequencesFromIParticles) { - auto double_vec = seq_builder(jet, sorted_particles).second; - if (iparticle_var_idx == 0){ - num_iparticles = static_cast<int>(double_vec.size()); - particle_feat.resize(num_iparticles * num_iparticle_vars); - } - for (unsigned int particle_idx=0; particle_idx < double_vec.size(); particle_idx++){ - particle_feat[iparticle_var_idx * num_iparticles + particle_idx] = double_vec[particle_idx]; - } - iparticle_var_idx++; - } - std::vector<int64_t> particle_feat_dim = {num_iparticles, num_iparticle_vars}; + return std::make_pair("flow_features", m_customSequenceGetter.getFeats(jet, sorted_particles)); + } - return std::make_pair("flow_features", std::make_pair(particle_feat, particle_feat_dim)); + FTagDataDependencyNames IParticlesLoader::getDependencies() const { + return deps; } + std::set<std::string> IParticlesLoader::getUsedRemap() const { + return used_remap; + } + } \ No newline at end of file diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx index e0d1789b4740..b34ceeaad1ac 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx @@ -8,6 +8,7 @@ #include "xAODTracking/TrackParticle.h" #include <optional> +#include <iostream> namespace { // ______________________________________________________________________ @@ -51,249 +52,275 @@ namespace { private: T m_getter; public: - TJGetter(T getter): + TJGetter(T getter): m_getter(getter) {} std::vector<double> operator()( const xAOD::Jet& jet, - const std::vector<const xAOD::IParticle*>& tracks) const { + const std::vector<const xAOD::IParticle*>& particles) const { std::vector<double> sequence; - sequence.reserve(tracks.size()); - for (const auto* track: tracks) { - sequence.push_back(m_getter(track, jet)); + sequence.reserve(particles.size()); + for (const auto* particle: particles) { + sequence.push_back(m_getter(*particle, jet)); } return sequence; } }; - std::optional<FlavorTagDiscriminants::sequence_getter::SequenceFromIParticles> - sequenceWithIpDep( + + // The sequence getter takes in constituents and calculates arrays of + // values which are better suited for inputs to the NNs + template <typename T, typename U> + class SequenceGetter{ + private: + SG::AuxElement::ConstAccessor<T> m_getter; + std::string m_name; + public: + SequenceGetter(const std::string& name): + m_getter(name), + m_name(name) + { + } + std::pair<std::string, std::vector<double>> operator()(const xAOD::Jet&, const std::vector<const U*>& consts) const { + std::vector<double> seq; + for (const U* el: consts) { + seq.push_back(m_getter(*el)); + } + return {m_name, seq}; + } + }; +} + namespace FlavorTagDiscriminants { + namespace sequence_getter { + + + std::optional<typename CustomSequenceGetter::SequenceFromConstituents> + CustomSequenceGetter::sequenceWithIpDep( const std::string& name, const std::string& prefix) { - using Ip = xAOD::IParticle; using Tp = xAOD::TrackParticle; using Jet = xAOD::Jet; BTagTrackIpAccessor a(prefix); if (name == "IP3D_signed_d0_significance") { - return TJGetter([a](const Ip* t, const Jet& j){ - auto tp = (Tp*)t; - return a.getSignedIp(*tp, j).ip3d_signed_d0_significance; + return TJGetter([a](const Ip& p, const Jet& j){ + auto tp = dynamic_cast<const Tp&>(p); + return a.getSignedIp(tp, j).ip3d_signed_d0_significance; }); } if (name == "IP3D_signed_z0_significance") { - return TJGetter([a](const Ip* t, const Jet& j){ - auto tp = (Tp*)t; - return a.getSignedIp(*tp, j).ip3d_signed_z0_significance; + return TJGetter([a](const Ip& p, const Jet& j){ + auto tp = dynamic_cast<const Tp&>(p); + return a.getSignedIp(tp, j).ip3d_signed_z0_significance; }); } if (name == "IP2D_signed_d0") { - return TJGetter([a](const Ip* t, const Jet& j){ - auto tp = (Tp*)t; - return a.getSignedIp(*tp, j).ip2d_signed_d0; + return TJGetter([a](const Ip& p, const Jet& j){ + auto tp = dynamic_cast<const Tp&>(p); + return a.getSignedIp(tp, j).ip2d_signed_d0; }); } if (name == "IP3D_signed_d0") { - return TJGetter([a](const Ip* t, const Jet& j){ - auto tp = (Tp*)t; - return a.getSignedIp(*tp, j).ip3d_signed_d0; + return TJGetter([a](const Ip& p, const Jet& j){ + auto tp = dynamic_cast<const Tp&>(p); + return a.getSignedIp(tp, j).ip3d_signed_d0; }); } if (name == "IP3D_signed_z0") { - return TJGetter([a](const Ip* t, const Jet& j){ - auto tp = (Tp*)t; - return a.getSignedIp(*tp, j).ip3d_signed_z0; + return TJGetter([a](const Ip& p, const Jet& j){ + auto tp = dynamic_cast<const Tp&>(p); + return a.getSignedIp(tp, j).ip3d_signed_z0; }); } if (name == "d0" || name == "btagIp_d0") { - return TJGetter([a](const Ip* t, const Jet&){ - auto tp = (Tp*)t; - return a.d0(*tp); + return TJGetter([a](const Ip& p, const Jet&){ + auto tp = dynamic_cast<const Tp&>(p); + return a.d0(tp); }); } if (name == "z0SinTheta" || name == "btagIp_z0SinTheta") { - return TJGetter([a](const Ip* t, const Jet&){ - auto tp = (Tp*)t; - return a.z0SinTheta(*tp); + return TJGetter([a](const Ip& p, const Jet&){ + auto tp = dynamic_cast<const Tp&>(p); + return a.z0SinTheta(tp); }); } if (name == "d0Uncertainty") { - return TJGetter([a](const Ip* t, const Jet&){ - auto tp = (Tp*)t; - return a.d0Uncertainty(*tp); + return TJGetter([a](const Ip& p, const Jet&){ + auto tp = dynamic_cast<const Tp&>(p); + return a.d0Uncertainty(tp); }); } if (name == "z0SinThetaUncertainty") { - return TJGetter([a](const Ip* t, const Jet&){ - auto tp = (Tp*)t; - return a.z0SinThetaUncertainty(*tp); + return TJGetter([a](const Ip& p, const Jet&){ + auto tp = dynamic_cast<const Tp&>(p); + return a.z0SinThetaUncertainty(tp); }); } return std::nullopt; } - std::optional<FlavorTagDiscriminants::sequence_getter::SequenceFromIParticles> - sequenceNoIpDep(const std::string& name) + + std::optional<typename CustomSequenceGetter::SequenceFromConstituents> + CustomSequenceGetter::sequenceNoIpDep(const std::string& name) { - using Ip = xAOD::IParticle; using Tp = xAOD::TrackParticle; using Jet = xAOD::Jet; if (name == "pt") { - return TJGetter([](const Ip* t, const Jet&) { - return t->pt(); + return TJGetter([](const Ip& p, const Jet&) { + return p.pt(); }); } if (name == "log_pt") { - return TJGetter([](const Ip* t, const Jet&) { - return std::log(t->pt()); + return TJGetter([](const Ip& p, const Jet&) { + return std::log(p.pt()); }); } if (name == "ptfrac") { - return TJGetter([](const Ip* t, const Jet& j) { - return t->pt() / j.pt(); + return TJGetter([](const Ip& p, const Jet& j) { + return p.pt() / j.pt(); }); } if (name == "log_ptfrac") { - return TJGetter([](const Ip* t, const Jet& j) { - return std::log(t->pt() / j.pt()); + return TJGetter([](const Ip& p, const Jet& j) { + return std::log(p.pt() / j.pt()); }); } if (name == "eta") { - return TJGetter([](const Ip* t, const Jet&) { - return t->eta(); + return TJGetter([](const Ip& p, const Jet&) { + return p.eta(); }); } if (name == "deta") { - return TJGetter([](const Ip* t, const Jet& j) { - return t->eta() - j.eta(); + return TJGetter([](const Ip& p, const Jet& j) { + return p.eta() - j.eta(); }); } if (name == "abs_deta") { - return TJGetter([](const Ip* t, const Jet& j) { - return copysign(1.0, j.eta()) * (t->eta() - j.eta()); + return TJGetter([](const Ip& p, const Jet& j) { + return copysign(1.0, j.eta()) * (p.eta() - j.eta()); }); } if (name == "phi") { - return TJGetter([](const Ip* t, const Jet&) { - return t->phi(); + return TJGetter([](const Ip& p, const Jet&) { + return p.phi(); }); } if (name == "dphi") { - return TJGetter([](const Ip* t, const Jet& j) { - return t->p4().DeltaPhi(j.p4()); + return TJGetter([](const Ip& p, const Jet& j) { + return p.p4().DeltaPhi(j.p4()); }); } if (name == "dr") { - return TJGetter([](const Ip* t, const Jet& j) { - return t->p4().DeltaR(j.p4()); + return TJGetter([](const Ip& p, const Jet& j) { + return p.p4().DeltaR(j.p4()); }); } if (name == "log_dr") { - return TJGetter([](const Ip* t, const Jet& j) { - return std::log(t->p4().DeltaR(j.p4())); + return TJGetter([](const Ip& p, const Jet& j) { + return std::log(p.p4().DeltaR(j.p4())); }); } if (name == "log_dr_nansafe") { - return TJGetter([](const Ip* t, const Jet& j) { - return std::log(t->p4().DeltaR(j.p4()) + 1e-7); + return TJGetter([](const Ip& p, const Jet& j) { + return std::log(p.p4().DeltaR(j.p4()) + 1e-7); }); } if (name == "mass") { - return TJGetter([](const Ip* t, const Jet&) { - return t->m(); + return TJGetter([](const Ip& p, const Jet&) { + return p.m(); }); } if (name == "energy") { - return TJGetter([](const Ip* t, const Jet&) { - return t->e(); + return TJGetter([](const Ip& p, const Jet&) { + return p.e(); }); } if (name == "phiUncertainty") { - return TJGetter([](const Ip* t, const Jet&) { - auto tp = (Tp*)t; - return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(2)); + return TJGetter([](const Ip& p, const Jet&) { + auto tp = dynamic_cast<const Tp&>(p); + return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(2)); }); } if (name == "thetaUncertainty") { - return TJGetter([](const Ip* t, const Jet&) { - auto tp = (Tp*)t; - return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(3)); + return TJGetter([](const Ip& p, const Jet&) { + auto tp = dynamic_cast<const Tp&>(p); + return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(3)); }); } if (name == "qOverPUncertainty") { - return TJGetter([](const Ip* t, const Jet&) { - auto tp = (Tp*)t; - return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(4)); + return TJGetter([](const Ip& p, const Jet&) { + auto tp = dynamic_cast<const Tp&>(p); + return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(4)); }); } if (name == "z0RelativeToBeamspot") { - return TJGetter([](const Ip* t, const Jet&) { - auto tp = (Tp*)t; - return tp->z0(); + return TJGetter([](const Ip& p, const Jet&) { + auto tp = dynamic_cast<const Tp&>(p); + return tp.z0(); }); } if (name == "log_z0RelativeToBeamspotUncertainty") { - return TJGetter([](const Ip* t, const Jet&) { - auto tp = (Tp*)t; - return std::log(std::sqrt(tp->definingParametersCovMatrixDiagVec().at(1))); + return TJGetter([](const Ip& p, const Jet&) { + auto tp = dynamic_cast<const Tp&>(p); + return std::log(std::sqrt(tp.definingParametersCovMatrixDiagVec().at(1))); }); } if (name == "z0RelativeToBeamspotUncertainty") { - return TJGetter([](const Ip* t, const Jet&) { - auto tp = (Tp*)t; - return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(1)); + return TJGetter([](const Ip& p, const Jet&) { + auto tp = dynamic_cast<const Tp&>(p); + return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(1)); }); } if (name == "numberOfPixelHitsInclDead") { SG::AuxElement::ConstAccessor<unsigned char> pix_hits("numberOfPixelHits"); SG::AuxElement::ConstAccessor<unsigned char> pix_dead("numberOfPixelDeadSensors"); - return TJGetter([pix_hits, pix_dead](const Ip* t, const Jet&) { - return pix_hits(*t) + pix_dead(*t); + return TJGetter([pix_hits, pix_dead](const Ip& p, const Jet&) { + return pix_hits(p) + pix_dead(p); }); } if (name == "numberOfSCTHitsInclDead") { SG::AuxElement::ConstAccessor<unsigned char> sct_hits("numberOfSCTHits"); SG::AuxElement::ConstAccessor<unsigned char> sct_dead("numberOfSCTDeadSensors"); - return TJGetter([sct_hits, sct_dead](const Ip* t, const Jet&) { - return sct_hits(*t) + sct_dead(*t); + return TJGetter([sct_hits, sct_dead](const Ip& p, const Jet&) { + return sct_hits(p) + sct_dead(p); }); } if (name == "numberOfInnermostPixelLayerHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerEndcapHits"); - return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) { - return barrel_hits(*t) + endcap_hits(*t); + return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) { + return barrel_hits(p) + endcap_hits(p); }); } if (name == "numberOfNextToInnermostPixelLayerHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfNextToInnermostPixelLayerHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfNextToInnermostPixelLayerEndcapHits"); - return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) { - return barrel_hits(*t) + endcap_hits(*t); + return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) { + return barrel_hits(p) + endcap_hits(p); }); } if (name == "numberOfInnermostPixelLayerSharedHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerSharedHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerSharedEndcapHits"); - return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) { - return barrel_hits(*t) + endcap_hits(*t); + return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) { + return barrel_hits(p) + endcap_hits(p); }); } if (name == "numberOfInnermostPixelLayerSplitHits21p9") { SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerSplitHits"); SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerSplitEndcapHits"); - return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) { - return barrel_hits(*t) + endcap_hits(*t); + return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) { + return barrel_hits(p) + endcap_hits(p); }); } @@ -301,11 +328,6 @@ namespace { return std::nullopt; } -} - -namespace FlavorTagDiscriminants { - namespace sequence_getter { - // ________________________________________________________________ // Interface functions // @@ -323,14 +345,11 @@ namespace FlavorTagDiscriminants { }; } - // Case for track variables - std::pair<std::function<std::pair<std::string, std::vector<double>>( - const xAOD::Jet&, - const std::vector<const xAOD::IParticle*>&)>, - std::set<std::string>> - customNamedSeqGetterWithDeps(const std::string& name, + // Case for constituents variables + std::pair<typename CustomSequenceGetter::NamedSequenceFromConstituents, std::set<std::string>> + CustomSequenceGetter::customNamedSeqGetterWithDeps(const std::string& name, const std::string& prefix) { - auto [getter, deps] = customSequenceGetterWithDeps(name, prefix); + auto [getter, deps] = CustomSequenceGetter::customSequenceGetterWithDeps(name, prefix); return { [n=name, g=getter](const xAOD::Jet& j, const std::vector<const xAOD::IParticle*>& t) { @@ -345,20 +364,113 @@ namespace FlavorTagDiscriminants { // These functions are wrapped by the customNamedSeqGetter function // below to become the ones that are actually used in DL2. // - std::pair<SequenceFromIParticles, std::set<std::string>> - customSequenceGetterWithDeps(const std::string& name, + std::pair<typename CustomSequenceGetter::SequenceFromConstituents, std::set<std::string>> + CustomSequenceGetter::customSequenceGetterWithDeps(const std::string& name, const std::string& prefix) { - if (auto getter = sequenceWithIpDep(name, prefix)) { + if (auto getter = CustomSequenceGetter::sequenceWithIpDep(name, prefix)) { auto deps = BTagTrackIpAccessor(prefix).getTrackIpDataDependencyNames(); return {*getter, deps}; } - if (auto getter = sequenceNoIpDep(name)) { + if (auto getter = CustomSequenceGetter::sequenceNoIpDep(name)) { return {*getter, {}}; } throw std::logic_error("no match for custom getter " + name); } + + // ________________________________________________________________________ + // Class implementation + // + std::pair<typename CustomSequenceGetter::NamedSequenceFromConstituents, std::set<std::string>> + CustomSequenceGetter::seqFromConsituents( + const FTagConstituentsInputConfig& cfg, + const FTagOptions& options){ + const std::string prefix = options.track_prefix; + switch (cfg.type) { + case ConstituentsEDMType::INT: return { + SequenceGetter<int, xAOD::IParticle>(cfg.name), {cfg.name} + }; + case ConstituentsEDMType::FLOAT: return { + SequenceGetter<float, xAOD::IParticle>(cfg.name), {cfg.name} + }; + case ConstituentsEDMType::CHAR: return { + SequenceGetter<char, xAOD::IParticle>(cfg.name), {cfg.name} + }; + case ConstituentsEDMType::UCHAR: return { + SequenceGetter<unsigned char, xAOD::IParticle>(cfg.name), {cfg.name} + }; + case ConstituentsEDMType::CUSTOM_GETTER: { + return CustomSequenceGetter::customNamedSeqGetterWithDeps( + cfg.name, options.track_prefix); + } + default: { + throw std::logic_error("Unknown EDM type for constituent."); + } + } } -} + + CustomSequenceGetter::CustomSequenceGetter( + std::vector<FTagConstituentsInputConfig> inputs, + const FTagOptions& options) + { + std::map<std::string, std::string> remap = options.remap_scalar; + for (const FTagConstituentsInputConfig& input_cfg: inputs) { + auto [seqGetter, seq_deps] = seqFromConsituents( + input_cfg, options); + + if(input_cfg.flip_sign){ + auto seqGetter_flip=[g=seqGetter](const xAOD::Jet&jet, const IParticles& constituents){ + auto [n,v] = g(jet,constituents); + std::for_each(v.begin(), v.end(), [](double &n){ n=-1.0*n; }); + return std::make_pair(n,v); + }; + sequencesFromConstituents.push_back(seqGetter_flip); + } + else{ + sequencesFromConstituents.push_back(seqGetter); + } + deps.merge(seq_deps); + if (auto h = remap.extract(input_cfg.name)){ + used_remap.insert(h.key()); + } + } + } + + + std::pair<std::vector<float>, std::vector<int64_t>> CustomSequenceGetter::getFeats( + const xAOD::Jet& jet, const IParticles& constituents) const + { + std::vector<float> cnsts_feats; + int num_vars = static_cast<int>(sequencesFromConstituents.size()); + int num_cnsts = 0; + + int cnst_var_idx = 0; + for (const auto& seq_builder: sequencesFromConstituents){ + auto double_vec = seq_builder(jet, constituents).second; + + if (cnst_var_idx==0){ + num_cnsts = static_cast<int>(double_vec.size()); + cnsts_feats.resize(num_cnsts * num_vars); + } + + // need to transpose + flatten + for (unsigned int cnst_idx=0; cnst_idx<double_vec.size(); cnst_idx++){ + cnsts_feats.at(cnst_idx*num_vars + cnst_var_idx) + = double_vec.at(cnst_idx); + } + cnst_var_idx++; + } + std::vector<int64_t> cnsts_feat_dim = {num_cnsts, num_vars}; + return {cnsts_feats, cnsts_feat_dim}; + } + + std::set<std::string> CustomSequenceGetter::getDependencies() const { + return deps; + } + std::set<std::string> CustomSequenceGetter::getUsedRemap() const { + return used_remap; + } + } +} \ No newline at end of file diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx index b930ef6d0bae..d97f356917c7 100644 --- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx +++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx @@ -5,7 +5,6 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration #include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/AssociationEnums.h" -#include "FlavorTagDiscriminants/customGetter.h" #include "FlavorTagDiscriminants/TracksLoader.h" #include <iostream> @@ -33,20 +32,6 @@ namespace { typedef std::vector<std::pair<std::regex, ConstituentsSortOrder> > SortRegexes; typedef std::vector<std::pair<std::regex, ConstituentsSelection> > TrkSelRegexes; - // Function to map the regex + list of inputs to variable config, - // this time for sequence inputs. - std::vector<FTagConstituentsSequenceConfig> get_track_input_config( - const std::vector<std::pair<std::string, std::vector<std::string>>>& names, - const TypeRegexes& type_regexes, - const SortRegexes& sort_regexes, - const TrkSelRegexes& select_regexes, - const std::regex& re, - const FlipTagConfig& flip_config); - - - //_______________________________________________________________________ - // Implementation of the above functions - // template <typename T> T match_first(const std::vector<std::pair<std::regex, T> >& regexes, @@ -137,7 +122,13 @@ namespace FlavorTagDiscriminants { ){ // some sequences also need to be sign-flipped. We apply this by // changing the input scaling and normalizations - std::regex flip_sequences(".*signed_[dz]0.*"); + std::regex flip_sequences; + if (flip_config == FlipTagConfig::FLIP_SIGN || flip_config == FlipTagConfig::NEGATIVE_IP_ONLY){ + flip_sequences=std::regex(".*signed_[dz]0.*"); + } + if (flip_config == FlipTagConfig::SIMPLE_FLIP){ + flip_sequences=std::regex("(.*signed_[dz]0.*)|d0|z0SinTheta"); + } // build the track inputs TypeRegexes trk_type_regexes { @@ -383,37 +374,6 @@ namespace FlavorTagDiscriminants { } } - // factory for functions that build std::vector objects from - // track sequences - std::pair<TracksLoader::SeqFromTracks,std::set<std::string>> TracksLoader::seqFromTracks( - const FTagConstituentsInputConfig& cfg, - const FTagOptions& options) - { - const std::string prefix = options.track_prefix; - switch (cfg.type) { - case ConstituentsEDMType::INT: return { - SequenceGetter<int, Track>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::FLOAT: return { - SequenceGetter<float, Track>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::CHAR: return { - SequenceGetter<char, Track>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::UCHAR: return { - SequenceGetter<unsigned char, Track>(cfg.name), {cfg.name} - }; - case ConstituentsEDMType::CUSTOM_GETTER: { - return internal::customNamedSeqGetterWithDeps( - cfg.name, options.track_prefix); - - } - default: { - throw std::logic_error("Unknown EDM type for tracks"); - } - } - } - // here we define filters for the "flip" taggers // // start by defining the raw functions, there's a factory @@ -463,7 +423,6 @@ namespace FlavorTagDiscriminants { } } - // TracksLoader::TracksLoader() : ConstituentsLoader() {}; TracksLoader::TracksLoader( FTagConstituentsSequenceConfig cfg, const FTagOptions& options @@ -471,7 +430,9 @@ namespace FlavorTagDiscriminants { ConstituentsLoader(cfg), m_trackSortVar(TracksLoader::trackSortVar(cfg.order, options)), m_trackFilter(TracksLoader::trackFilter(cfg.selection, options).first), - m_flipFilter(TracksLoader::flipFilter(options).first) + m_flipFilter(TracksLoader::flipFilter(options).first), + m_customSequenceGetter(sequence_getter::CustomSequenceGetter( + cfg.inputs, options)) { // We have several ways to get tracks: either we retrieve an // IParticleContainer and cast the pointers to TrackParticle, or @@ -479,7 +440,6 @@ namespace FlavorTagDiscriminants { // the way tracks are stored isn't consistent across the EDM, so // we allow configuration for both setups. // - std::cout << "TEST TRACK 1 " << std::endl; if (options.track_link_type == TrackLinkType::IPARTICLE) { SG::AuxElement::ConstAccessor<PartLinks> acc(options.track_link_name); m_associator = [acc](const SG::AuxElement& btag) -> TPV { @@ -511,35 +471,12 @@ namespace FlavorTagDiscriminants { } else { throw std::logic_error("Unknown TrackLinkType"); } - std::cout << "TEST TRACK 2 " << std::endl; - std::map<std::string, std::string> remap = options.remap_scalar; - std::set<std::string> used_remap; - auto track_data_deps = trackFilter(cfg.selection, options).second; track_data_deps.merge(flipFilter(options).second); - for (const FTagConstituentsInputConfig& input_cfg: cfg.inputs) { - auto [seqGetter, deps] = seqFromTracks( - input_cfg, options); - - if(input_cfg.flip_sign){ - auto seqGetter_flip=[g=seqGetter](const xAOD::Jet&jet, const internal::Tracks& trks){ - auto [n,v] = g(jet,trks); - std::for_each(v.begin(), v.end(), [](double &n){ n=-1.0*n; }); - return std::make_pair(n,v); - }; - m_sequencesFromTracks.push_back(seqGetter_flip); - } - else{ - m_sequencesFromTracks.push_back(seqGetter); - } - track_data_deps.merge(deps); - if (auto h = remap.extract(input_cfg.name)){ - used_remap.insert(h.key()); - } - } - std::cout << "TEST TRACK 3 " << std::endl; + track_data_deps.merge(m_customSequenceGetter.getDependencies()); deps.trackInputs.merge(track_data_deps); deps.bTagInputs.insert(options.track_link_name); + used_remap = m_customSequenceGetter.getUsedRemap(); } std::vector<const xAOD::TrackParticle*> TracksLoader::getTracksFromJet( @@ -563,32 +500,22 @@ namespace FlavorTagDiscriminants { std::pair<std::string, input_pair> TracksLoader::getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const { Tracks flipped_tracks; - std::vector<float> track_feat; // (#tracks, #feats).flatten - - int num_track_vars = static_cast<int>(m_sequencesFromTracks.size()); - int num_tracks = 0; - Tracks sorted_tracks = getTracksFromJet(jet, btag); - flipped_tracks = m_flipFilter(sorted_tracks, jet); + std::vector<const xAOD::IParticle*> flipped_tracks_ip; - int track_var_idx=0; - for (const auto& seq_builder: m_sequencesFromTracks) { - auto double_vec = seq_builder(jet, flipped_tracks).second; - - if (track_var_idx==0){ - num_tracks = static_cast<int>(double_vec.size()); - track_feat.resize(num_tracks * num_track_vars); - } - - // need to transpose + flatten - for (unsigned int track_idx=0; track_idx<double_vec.size(); track_idx++){ - track_feat.at(track_idx*num_track_vars + track_var_idx) - = double_vec.at(track_idx); - } - track_var_idx++; + flipped_tracks = m_flipFilter(sorted_tracks, jet); + + for (const auto& trk: flipped_tracks) { + flipped_tracks_ip.push_back(dynamic_cast<const xAOD::IParticle*>(trk)); } - std::vector<int64_t> track_feat_dim = {num_tracks, num_track_vars}; - return std::make_pair("track_features", std::make_pair(track_feat, track_feat_dim)); + return std::make_pair("track_features", m_customSequenceGetter.getFeats(jet, flipped_tracks_ip)); + } + + FTagDataDependencyNames TracksLoader::getDependencies() const { + return deps; + } + std::set<std::string> TracksLoader::getUsedRemap() const { + return used_remap; } } \ No newline at end of file -- GitLab