From f49d537cdadc0ad06bdfa7f2507c4386f88f0953 Mon Sep 17 00:00:00 2001
From: Dmitrii Kobylianskii <dmitrii.kobylianskii@cern.ch>
Date: Sun, 4 Feb 2024 17:32:55 +0100
Subject: [PATCH] - Clean code - Add customSequenceGetter - Make work with
 neutrals+tracks

---
 .../ConstituentsLoader.h                      |  30 +-
 .../FlavorTagDiscriminants/GNN.h              |   5 +-
 .../FlavorTagDiscriminants/IParticlesLoader.h |   8 +-
 .../FlavorTagDiscriminants/SequenceGetter.h   |  64 +++-
 .../FlavorTagDiscriminants/TracksLoader.h     |  14 +-
 .../Root/DataPrepUtilities.cxx                |   6 +-
 .../FlavorTagDiscriminants/Root/GNN.cxx       |  90 ++---
 .../Root/IParticlesLoader.cxx                 |  81 +----
 .../Root/SequenceGetter.cxx                   | 338 ++++++++++++------
 .../Root/TracksLoader.cxx                     | 123 ++-----
 10 files changed, 351 insertions(+), 408 deletions(-)

diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h
index de72ae86ec0f..185bdc664250 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/ConstituentsLoader.h
@@ -8,9 +8,8 @@
 // local includes
 #include "FlavorTagDiscriminants/FlipTagEnums.h"
 #include "FlavorTagDiscriminants/AssociationEnums.h"
-#include "FlavorTagDiscriminants/FTagDataDependencyNames.h"
-#include "FlavorTagDiscriminants/GNNConfig.h"
 #include "FlavorTagDiscriminants/OnnxUtil.h"
+#include "FlavorTagDiscriminants/FTagDataDependencyNames.h"
 
 // EDM includes
 #include "xAODJet/Jet.h"
@@ -54,28 +53,6 @@ namespace FlavorTagDiscriminants {
         std::vector<FTagConstituentsInputConfig> inputs;
     };
 
-    // The sequence getter takes in constituents and calculates arrays of
-    // values which are better suited for inputs to the NNs
-    template <typename T, typename U>
-    class SequenceGetter{
-      private:
-        SG::AuxElement::ConstAccessor<T> m_getter;
-        std::string m_name;
-      public:
-        SequenceGetter(const std::string& name):
-          m_getter(name),
-          m_name(name)
-          {
-          }
-        std::pair<std::string, std::vector<double>> operator()(const xAOD::Jet&, const std::vector<const U*>& consts) const {
-          std::vector<double> seq;
-          for (const U* el: consts) {
-            seq.push_back(m_getter(*el));
-          }
-          return {m_name, seq};
-        }
-    };
-
     // Virtual class to represent loader of any type of constituents
     class ConstituentsLoader {
         public:
@@ -85,10 +62,13 @@ namespace FlavorTagDiscriminants {
             virtual ~ConstituentsLoader() {
             };
             virtual std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const = 0;
+            virtual FTagDataDependencyNames getDependencies() const = 0;
+            virtual std::set<std::string> getUsedRemap() const = 0;
 
+        protected:
             FTagDataDependencyNames deps;
-        private:
             FTagConstituentsSequenceConfig config;
+            std::set<std::string> used_remap;
     };
 }
 
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h
index 9c5588935423..f50bb77aaa6d 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/GNN.h
@@ -60,7 +60,6 @@ namespace FlavorTagDiscriminants {
     virtual std::set<std::string> getConstituentAuxInputKeys() const;
 
     std::shared_ptr<const OnnxUtil> m_onnxUtil;
-
   private:
     // type definitions for ONNX output decorators
     using TPC = xAOD::TrackParticleContainer;
@@ -90,9 +89,7 @@ namespace FlavorTagDiscriminants {
     std::vector<internal::VarFromBTag> m_varsFromBTag;
     std::vector<internal::VarFromJet> m_varsFromJet;
     std::vector<internal::TrackSequenceBuilder> m_trackSequenceBuilders;
-    // std::vector<std::shared_ptr<ConstituentsLoader>> m_constituentsLoaders;
-    // std::shared_ptr<const TracksLoader> m_trackLoader;
-    std::shared_ptr<const IParticlesLoader> m_flowLoader;
+    std::vector<std::shared_ptr<ConstituentsLoader>> m_constituentsLoaders;
 
     Decorators m_decorators;
     float m_defaultValue;
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h
index 5d2225da3452..994acfdc9b04 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/IParticlesLoader.h
@@ -8,6 +8,7 @@
 // local includes
 #include "FlavorTagDiscriminants/ConstituentsLoader.h"
 #include "FlavorTagDiscriminants/DataPrepUtilities.h"
+#include "FlavorTagDiscriminants/SequenceGetter.h"
 
 // EDM includes
 #include "xAODJet/Jet.h"
@@ -32,6 +33,8 @@ namespace FlavorTagDiscriminants {
         // TracksLoader();
         IParticlesLoader(FTagConstituentsSequenceConfig, const FTagOptions& options);
         std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const override ;
+        FTagDataDependencyNames getDependencies() const override;
+        std::set<std::string> getUsedRemap() const override;
       protected:
         // typedefs
         typedef xAOD::Jet Jet;
@@ -54,13 +57,14 @@ namespace FlavorTagDiscriminants {
         IParticleSortVar iparticleSortVar(ConstituentsSortOrder, const FTagOptions&);
         
         std::vector<const xAOD::IParticle*> getIParticlesFromJet(const xAOD::Jet& jet) const;
-        std::pair<SeqFromIParticles,std::set<std::string>> seqFromIParticles(
-          const FTagConstituentsInputConfig&, const FTagOptions&);
 
         std::vector<SeqFromIParticles> m_sequencesFromIParticles;
+        sequence_getter::CustomSequenceGetter m_customSequenceGetter;
         IParticleSortVar m_iparticleSortVar;
         std::function<IPV(const Jet&)> m_associator;
         bool m_isCharged;
+
+        // FTagDataDependencyNames deps;
     };
 }
 
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h
index c7c2c5676765..0402192725c7 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/SequenceGetter.h
@@ -13,6 +13,9 @@
 #include "xAODBase/IParticle.h"
 #include "xAODTracking/TrackParticleFwd.h"
 #include "AthContainers/AuxElement.h"
+#include "FlavorTagDiscriminants/ConstituentsLoader.h"
+#include "FlavorTagDiscriminants/DataPrepUtilities.h"
+
 
 #include <functional>
 #include <string>
@@ -23,7 +26,7 @@
 
 namespace FlavorTagDiscriminants {
 
-  /// Factory function to produce TrackParticle -> vector<double> functions
+  /// Class to produce IParticle -> vector<double> functions
   ///
   /// DL2 configures the its inputs when the algorithm is initalized,
   /// meaning that the list of track and jet properties that are used
@@ -44,26 +47,51 @@ namespace FlavorTagDiscriminants {
   ///
 
   namespace sequence_getter {
-
-  using SequenceFromIParticles = std::function<std::vector<double>(
-    const xAOD::Jet&,
-    const std::vector<const xAOD::IParticle*>&)>;
-
-  std::pair<SequenceFromIParticles, std::set<std::string>>
-  customSequenceGetterWithDeps(
-    const std::string& name,   // name of the getter
-    const std::string& prefix  // prefix for track accessor
-    );
-  
     std::function<std::pair<std::string, double>(const xAOD::Jet&)>
     customGetterAndName(const std::string&);
 
-    std::pair<std::function<std::pair<std::string, std::vector<double>>(
-      const xAOD::Jet&,
-      const std::vector<const xAOD::IParticle*>&)>,
-      std::set<std::string>>
-    customNamedSeqGetterWithDeps(const std::string&, const std::string&);
-  }
+    class CustomSequenceGetter {
+        public:
+          using IParticles = std::vector<const xAOD::IParticle*>;
+          using SequenceFromConstituents = std::function<std::vector<double>(
+            const xAOD::Jet&,
+            const IParticles&)>;
+          
+          using NamedSequenceFromConstituents = std::function<std::pair<std::string, std::vector<double>>(
+            const xAOD::Jet&,
+            const IParticles&)>;
+
+          CustomSequenceGetter(std::vector<FTagConstituentsInputConfig> inputs,
+                              const FTagOptions& options);
+
+          std::pair<std::vector<float>, std::vector<int64_t>> getFeats(const xAOD::Jet& jet, const IParticles& constituents) const;
+
+          std::set<std::string> getDependencies() const;
+          std::set<std::string> getUsedRemap() const;
+          
+        private:
+
+          std::pair<NamedSequenceFromConstituents, std::set<std::string>>
+          customNamedSeqGetterWithDeps(const std::string& name, const std::string& prefix);
+
+          std::pair<SequenceFromConstituents, std::set<std::string>>
+          customSequenceGetterWithDeps(const std::string& name,
+                                      const std::string& prefix);
+
+          std::optional<SequenceFromConstituents> sequenceNoIpDep(const std::string& name);
+
+          std::optional<SequenceFromConstituents> sequenceWithIpDep(const std::string& name,
+                                                                  const std::string& prefix);
+
+          std::pair<NamedSequenceFromConstituents, std::set<std::string>> seqFromConsituents(
+          const FTagConstituentsInputConfig& cfg, 
+          const FTagOptions& options);
+
+          std::vector<NamedSequenceFromConstituents> sequencesFromConstituents;
+          std::set<std::string> deps;
+          std::set<std::string> used_remap;        
+        };
+    }
 }
 
 #endif
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h
index caefe3d21dec..e0efa6ec915c 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/FlavorTagDiscriminants/TracksLoader.h
@@ -8,11 +8,12 @@
 // local includes
 #include "FlavorTagDiscriminants/FlipTagEnums.h"
 #include "FlavorTagDiscriminants/AssociationEnums.h"
-#include "FlavorTagDiscriminants/FTagDataDependencyNames.h"
+// #include "FlavorTagDiscriminants/FTagDataDependencyNames.h"
 
 #include "FlavorTagDiscriminants/ConstituentsLoader.h"
 #include "FlavorTagDiscriminants/DataPrepUtilities.h"
 #include "FlavorTagDiscriminants/BTagTrackIpAccessor.h"
+#include "FlavorTagDiscriminants/SequenceGetter.h"
 
 // EDM includes
 #include "xAODJet/Jet.h"
@@ -50,7 +51,9 @@ namespace FlavorTagDiscriminants {
       public:
         // TracksLoader();
         TracksLoader(FTagConstituentsSequenceConfig, const FTagOptions& options);
-        std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const override ;
+        std::pair<std::string, input_pair> getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const override;
+        FTagDataDependencyNames getDependencies() const override;
+        std::set<std::string> getUsedRemap() const override;
       private:
         // typedefs
         typedef std::pair<std::string, double> NamedVar;
@@ -65,9 +68,6 @@ namespace FlavorTagDiscriminants {
         typedef std::function<Tracks(const Tracks&,
                                     const Jet&)> TrackSequenceFilter;
 
-        // getter function
-        typedef std::function<NamedSeq(const Jet&, const Tracks&)> SeqFromTracks;
-
         // usings for track getter
         using AE = SG::AuxElement;
         using IPC = xAOD::IParticleContainer;
@@ -79,8 +79,6 @@ namespace FlavorTagDiscriminants {
         TrackSortVar trackSortVar(ConstituentsSortOrder, const FTagOptions&);
         std::pair<TrackFilter,std::set<std::string>> trackFilter(
           ConstituentsSelection, const FTagOptions&);
-        std::pair<SeqFromTracks,std::set<std::string>> seqFromTracks(
-          const FTagConstituentsInputConfig&, const FTagOptions&);
         std::pair<TrackSequenceFilter,std::set<std::string>> flipFilter(
           const FTagOptions&);
         
@@ -89,8 +87,8 @@ namespace FlavorTagDiscriminants {
         TrackSortVar m_trackSortVar;
         TrackFilter m_trackFilter;
         TrackSequenceFilter m_flipFilter;
-        std::vector<SeqFromTracks> m_sequencesFromTracks;
         std::function<TPV(const SG::AuxElement&)> m_associator;
+        sequence_getter::CustomSequenceGetter m_customSequenceGetter;
     };
 }
 
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx
index db9a8432f689..335115a1cc41 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx
@@ -11,6 +11,8 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
 
 #include "xAODBTagging/BTaggingUtilities.h"
 
+#include <iostream>
+
 namespace {
   using namespace FlavorTagDiscriminants;
 
@@ -646,10 +648,6 @@ namespace FlavorTagDiscriminants {
       // we rewrite the inputs if we're using flip taggers
       StringRegexes flip_converters = getFlipConverters(flip_config);
 
-      // some sequences also need to be sign-flipped. We apply this by
-      // changing the input scaling and normalizations
-      std::regex flip_sequences(".*signed_[dz]0.*");
-
       if (flip_config != FlipTagConfig::STANDARD) {
         rewriteFlipConfig(config, flip_converters);
       }
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx
index 3e29cf48c63c..9b5bd7efc2ee 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/GNN.cxx
@@ -28,9 +28,7 @@ namespace FlavorTagDiscriminants {
     m_onnxUtil(nullptr),
     m_jetLink(jetLinkName),
     m_defaultValue(o.default_output_value),
-    m_decorate_tracks(o.decorate_tracks),
-    // m_trackLoader(nullptr),
-    m_flowLoader(nullptr)
+    m_decorate_tracks(o.decorate_tracks)
   {
     // track decoration is allowed only for non-production builds
     if (m_decorate_tracks) {
@@ -55,22 +53,20 @@ namespace FlavorTagDiscriminants {
     // Create configuration objects for data preprocessing.
     auto [inputs, constituents_configs, options] = dataprep::createGetterConfigNew(
         lwt_config, o.flip_config, o.variable_remapping, o.track_link_type);
-    std::cout << "TEST 1 " << std::endl;
-    // auto [tracksLoaderConfig, tmp_options] = createTracksLoaderConfig(
-    //   config, flip_config, variableRemapping, trackLinkType
-    // );
-    std::cout << constituents_configs.size() << std::endl;
     std::vector<FTagTrackSequenceConfig> track_sequences;
-    if (constituents_configs.size() > 0){
-      auto tracksLoaderConfig = constituents_configs[0];
-      std::cout << "TEST 2 " << std::endl; 
-      // m_trackLoader = std::make_shared<TracksLoader>(tracksLoaderConfig, options);
-      track_sequences = convertTracksConfigBack(constituents_configs[0]);
-      std::cout << "TEST 2F " << std::endl; 
-      m_flowLoader = std::make_shared<IParticlesLoader>(tracksLoaderConfig, options);
-      // m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(tracksLoaderConfig, options));
+    for (auto config : constituents_configs){
+      std::cout << "Config name: " << config.name << std::endl;
+      if (config.name.find("tracks") != std::string::npos){
+        m_constituentsLoaders.push_back(std::make_shared<TracksLoader>(config, options));
+        track_sequences = convertTracksConfigBack(config);
+      }
+      else if (config.name.find("flow") != std::string::npos){
+        m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(config, options));
+      }
+      else {
+        throw std::runtime_error("Unknown constituent type. Only tracks and neutrals are supported.");
+      }
     }
-    std::cout << "TEST 3 " << std::endl;
     // Initialize jet and b-tagging input getters.
     auto [vb, vj, ds] = dataprep::createBvarGetters(inputs);
     m_varsFromBTag = vb;
@@ -89,8 +85,11 @@ namespace FlavorTagDiscriminants {
     auto [dd, rd] = createDecorators(gnn_output_config, options);
     m_dataDependencyNames += dd;
 
-    // Check that all remaps have been used.
-    rd.merge(rt);
+    // Update dependencies and used remap from the constituents loaders.
+    for (auto loader : m_constituentsLoaders){
+      m_dataDependencyNames += loader->getDependencies();
+      rd.merge(loader->getUsedRemap());
+    }
     dataprep::checkForUnusedRemaps(options.remap_scalar, rd);
   }
 
@@ -157,57 +156,12 @@ namespace FlavorTagDiscriminants {
     }
     Tracks input_tracks;
     for (const auto& builder: m_trackSequenceBuilders) {
-      std::vector<float> track_feat; // (#tracks, #feats).flatten
-      int num_track_vars = static_cast<int>(builder.sequencesFromTracks.size());
-      int num_tracks = 0;
-
       Tracks sorted_tracks = builder.tracksFromJet(jet, btag);
       input_tracks = builder.flipFilter(sorted_tracks, jet);
-
-      int track_var_idx=0;
-      for (const auto& seq_builder: builder.sequencesFromTracks) {
-        auto double_vec = seq_builder(jet, input_tracks).second;
-
-        if (track_var_idx==0){
-          num_tracks = static_cast<int>(double_vec.size());
-          track_feat.resize(num_tracks * num_track_vars);
-        }
-
-        // need to transpose + flatten
-        for (unsigned int track_idx=0; track_idx<double_vec.size(); track_idx++){
-          track_feat.at(track_idx*num_track_vars + track_var_idx)
-            = double_vec.at(track_idx);
-        }
-        track_var_idx++;
-      }
-      std::vector<int64_t> track_feat_dim = {num_tracks, num_track_vars};
-
-      input_pair track_info (track_feat, track_feat_dim);
-      if (m_flowLoader){
-        auto loader_out = m_flowLoader->getData(jet, btag);
-        auto loader_track_feat = loader_out.second.first;
-        std::cout << "Vector size: TRACK " << track_feat.size() << " FLOW " << loader_track_feat.size() << std::endl;
-        // for (uint64_t i = 0; i < track_feat.size(); i++){
-        //   if (std::fabs(loader_track_feat.at(i) - track_feat.at(i)) > 0.001) {
-        //       std::cout << "DIFFERENCE " << i << std::endl;
-        //   }
-        // }
-        gnn_input.insert({"flow_features", loader_out.second});
-      }
-      // if (m_trackLoader){
-      //   auto loader_out = m_trackLoader->getData(jet, btag);
-      //   auto loader_track_feat = loader_out.second.first;
-      //   std::cout << "Vector size: OLD " << track_feat.size() << " NEW " << loader_track_feat.size() << std::endl;
-      //   for (uint64_t i = 0; i < track_feat.size(); i++){
-      //     if (std::fabs(loader_track_feat.at(i) - track_feat.at(i)) > 0.001) {
-      //         std::cout << "DIFFERENCE " << i << std::endl;
-      //     }
-      //   }
-      // }
-      // if (m_constituentsLoaders.size() > 0){
-      //   auto flow_out = m_constituentsLoaders[0]->getData(jet, btag);
-      // }
-      // gnn_input.insert({"flow_features", track_info});
+    }
+    for (auto loader : m_constituentsLoaders){
+      auto loader_out = loader->getData(jet, btag);
+      gnn_input.insert({loader_out.first, loader_out.second});
     }
 
     // run inference
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx
index e2e6e9ad8d72..73f0a93d1d57 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/IParticlesLoader.cxx
@@ -3,10 +3,9 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
 */
 
 #include "FlavorTagDiscriminants/IParticlesLoader.h"
-#include "FlavorTagDiscriminants/FTagDataDependencyNames.h"
+// #include "FlavorTagDiscriminants/FTagDataDependencyNames.h"
 #include "xAODPFlow/FlowElement.h"
 
-#include "FlavorTagDiscriminants/SequenceGetter.h"
 #include <iostream>
 
 namespace {
@@ -132,42 +131,14 @@ namespace FlavorTagDiscriminants {
       }
     } // end of iparticle sort getter
 
-    // factory for functions that build std::vector objects from
-    // iparticle sequences
-    std::pair<IParticlesLoader::SeqFromIParticles,std::set<std::string>> IParticlesLoader::seqFromIParticles(
-        const FTagConstituentsInputConfig& cfg, 
-        const FTagOptions& options)
-    {
-        const std::string prefix = options.track_prefix;
-        switch (cfg.type) {
-          case ConstituentsEDMType::INT: return {
-              SequenceGetter<int, xAOD::IParticle>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::FLOAT: return {
-              SequenceGetter<float, xAOD::IParticle>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::CHAR: return {
-              SequenceGetter<char, xAOD::IParticle>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::UCHAR: return {
-              SequenceGetter<unsigned char, xAOD::IParticle>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::CUSTOM_GETTER: {
-            return sequence_getter::customNamedSeqGetterWithDeps(
-              cfg.name, options.track_prefix);
-          }
-          default: {
-            throw std::logic_error("Unknown EDM type for iparticles");
-          }
-        }
-    }
-
     IParticlesLoader::IParticlesLoader(
         FTagConstituentsSequenceConfig cfg,
         const FTagOptions& options
     ):
         ConstituentsLoader(cfg),
-        m_iparticleSortVar(IParticlesLoader::iparticleSortVar(cfg.order, options))
+        m_iparticleSortVar(IParticlesLoader::iparticleSortVar(cfg.order, options)),
+        m_customSequenceGetter(sequence_getter::CustomSequenceGetter(
+          cfg.inputs, options))
     {
         SG::AuxElement::ConstAccessor<PartLinks> acc("constituentLinks");
         m_associator = [acc](const xAOD::Jet& jet) -> IPV {
@@ -187,22 +158,7 @@ namespace FlavorTagDiscriminants {
         } else {
             m_isCharged = false;
         }
-
-        std::map<std::string, std::string> remap = options.remap_scalar;
-        std::set<std::string> used_remap;
-        std::set<std::string> iparticle_data_deps;
-
-        for (const FTagConstituentsInputConfig& input_cfg: cfg.inputs) {
-            std::cout << input_cfg.name << std::endl;
-            auto [seqGetter, deps] = seqFromIParticles(
-            input_cfg, options);
-
-            m_sequencesFromIParticles.push_back(seqGetter);
-            iparticle_data_deps.merge(deps);
-            if (auto h = remap.extract(input_cfg.name)){
-              used_remap.insert(h.key());
-            }
-        }
+        used_remap = m_customSequenceGetter.getUsedRemap();
         std::cout << "TEST: IParticlesLoader loaded " << std::endl;
     }
 
@@ -235,31 +191,20 @@ namespace FlavorTagDiscriminants {
           }
           only_particles.push_back(obj);
         }
-        std::cout << "TEST: SIZE OF only_particles: " << only_particles.size() << std::endl;
         return only_particles;
     }
 
     std::pair<std::string, input_pair> IParticlesLoader::getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const {
-        std::vector<float> particle_feat;
-        int num_iparticle_vars = static_cast<int>(m_sequencesFromIParticles.size());
-        int num_iparticles = 0;
-
         IParticles sorted_particles = getIParticlesFromJet(jet);
-        int iparticle_var_idx=0;
-        for (const auto& seq_builder: m_sequencesFromIParticles) {
-            auto double_vec = seq_builder(jet, sorted_particles).second;
-            if (iparticle_var_idx == 0){
-              num_iparticles = static_cast<int>(double_vec.size());
-              particle_feat.resize(num_iparticles * num_iparticle_vars);
-            }
-            for (unsigned int particle_idx=0; particle_idx < double_vec.size(); particle_idx++){
-                particle_feat[iparticle_var_idx * num_iparticles + particle_idx] = double_vec[particle_idx];
-            }
-            iparticle_var_idx++;
-        }
 
-        std::vector<int64_t> particle_feat_dim = {num_iparticles, num_iparticle_vars};
+        return std::make_pair("flow_features", m_customSequenceGetter.getFeats(jet, sorted_particles));
+    }
 
-        return std::make_pair("flow_features", std::make_pair(particle_feat, particle_feat_dim));
+    FTagDataDependencyNames IParticlesLoader::getDependencies() const {
+        return deps;
     }
+    std::set<std::string> IParticlesLoader::getUsedRemap() const {
+        return used_remap;
+    }
+
 }
\ No newline at end of file
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx
index e0d1789b4740..b34ceeaad1ac 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/SequenceGetter.cxx
@@ -8,6 +8,7 @@
 #include "xAODTracking/TrackParticle.h"
 
 #include <optional>
+#include <iostream>
 
 namespace {
   // ______________________________________________________________________
@@ -51,249 +52,275 @@ namespace {
   private:
     T m_getter;
   public:
-    TJGetter(T getter):
+      TJGetter(T getter):
       m_getter(getter)
       {}
     std::vector<double> operator()(
       const xAOD::Jet& jet,
-      const std::vector<const xAOD::IParticle*>& tracks) const {
+      const std::vector<const xAOD::IParticle*>& particles) const {
       std::vector<double> sequence;
-      sequence.reserve(tracks.size());
-      for (const auto* track: tracks) {
-        sequence.push_back(m_getter(track, jet));
+      sequence.reserve(particles.size());
+      for (const auto* particle: particles) {
+        sequence.push_back(m_getter(*particle, jet));
       }
       return sequence;
     }
   };
 
-  std::optional<FlavorTagDiscriminants::sequence_getter::SequenceFromIParticles>
-  sequenceWithIpDep(
+
+  // The sequence getter takes in constituents and calculates arrays of
+  // values which are better suited for inputs to the NNs
+  template <typename T, typename U>
+  class SequenceGetter{
+    private:
+      SG::AuxElement::ConstAccessor<T> m_getter;
+      std::string m_name;
+    public:
+      SequenceGetter(const std::string& name):
+        m_getter(name),
+        m_name(name)
+        {
+        }
+      std::pair<std::string, std::vector<double>> operator()(const xAOD::Jet&, const std::vector<const U*>& consts) const {
+        std::vector<double> seq;
+        for (const U* el: consts) {
+          seq.push_back(m_getter(*el));
+        }
+        return {m_name, seq};
+      }
+  };
+}
+  namespace FlavorTagDiscriminants {
+  namespace sequence_getter {
+
+  
+  std::optional<typename CustomSequenceGetter::SequenceFromConstituents>
+  CustomSequenceGetter::sequenceWithIpDep(
     const std::string& name,
     const std::string& prefix)
   {
-
     using Ip = xAOD::IParticle;
     using Tp = xAOD::TrackParticle;
     using Jet = xAOD::Jet;
 
     BTagTrackIpAccessor a(prefix);
     if (name == "IP3D_signed_d0_significance") {
-      return TJGetter([a](const Ip* t, const Jet& j){
-          auto tp = (Tp*)t;
-          return a.getSignedIp(*tp, j).ip3d_signed_d0_significance;
+      return TJGetter([a](const Ip& p, const Jet& j){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.getSignedIp(tp, j).ip3d_signed_d0_significance;
       });
     }
     if (name == "IP3D_signed_z0_significance") {
-      return TJGetter([a](const Ip* t, const Jet& j){
-        auto tp = (Tp*)t;
-        return a.getSignedIp(*tp, j).ip3d_signed_z0_significance;
+      return TJGetter([a](const Ip& p, const Jet& j){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.getSignedIp(tp, j).ip3d_signed_z0_significance;
       });
     }
     if (name == "IP2D_signed_d0") {
-      return TJGetter([a](const Ip* t, const Jet& j){
-        auto tp = (Tp*)t;
-        return a.getSignedIp(*tp, j).ip2d_signed_d0;
+      return TJGetter([a](const Ip& p, const Jet& j){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.getSignedIp(tp, j).ip2d_signed_d0;
       });
     }
     if (name == "IP3D_signed_d0") {
-      return TJGetter([a](const Ip* t, const Jet& j){
-        auto tp = (Tp*)t;
-        return a.getSignedIp(*tp, j).ip3d_signed_d0;
+      return TJGetter([a](const Ip& p, const Jet& j){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.getSignedIp(tp, j).ip3d_signed_d0;
       });
     }
     if (name == "IP3D_signed_z0") {
-      return TJGetter([a](const Ip* t, const Jet& j){
-        auto tp = (Tp*)t;
-        return a.getSignedIp(*tp, j).ip3d_signed_z0;
+      return TJGetter([a](const Ip& p, const Jet& j){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.getSignedIp(tp, j).ip3d_signed_z0;
       });
     }
     if (name == "d0" || name == "btagIp_d0") {
-      return TJGetter([a](const Ip* t, const Jet&){
-        auto tp = (Tp*)t;
-        return a.d0(*tp);
+      return TJGetter([a](const Ip& p, const Jet&){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.d0(tp);
       });
     }
     if (name == "z0SinTheta" || name == "btagIp_z0SinTheta") {
-      return TJGetter([a](const Ip* t, const Jet&){
-        auto tp = (Tp*)t;
-        return a.z0SinTheta(*tp);
+      return TJGetter([a](const Ip& p, const Jet&){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.z0SinTheta(tp);
       });
     }
     if (name == "d0Uncertainty") {
-      return TJGetter([a](const Ip* t, const Jet&){
-        auto tp = (Tp*)t;
-        return a.d0Uncertainty(*tp);
+      return TJGetter([a](const Ip& p, const Jet&){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.d0Uncertainty(tp);
       });
     }
     if (name == "z0SinThetaUncertainty") {
-      return TJGetter([a](const Ip* t, const Jet&){
-        auto tp = (Tp*)t;
-        return a.z0SinThetaUncertainty(*tp);
+      return TJGetter([a](const Ip& p, const Jet&){
+        auto tp = dynamic_cast<const Tp&>(p);
+        return a.z0SinThetaUncertainty(tp);
       });
     }
     return std::nullopt;
   }
 
-  std::optional<FlavorTagDiscriminants::sequence_getter::SequenceFromIParticles>
-  sequenceNoIpDep(const std::string& name)
+  
+  std::optional<typename CustomSequenceGetter::SequenceFromConstituents>
+  CustomSequenceGetter::sequenceNoIpDep(const std::string& name)
   {
-
     using Ip = xAOD::IParticle;
     using Tp = xAOD::TrackParticle;
     using Jet = xAOD::Jet;
 
     if (name == "pt") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        return t->pt();
+      return TJGetter([](const Ip& p, const Jet&) {
+        return p.pt();
       });
     }
     if (name == "log_pt") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        return std::log(t->pt());
+      return TJGetter([](const Ip& p, const Jet&) {
+        return std::log(p.pt());
       });
     }
     if (name == "ptfrac") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return t->pt() / j.pt();
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return p.pt() / j.pt();
       });
     }
     if (name == "log_ptfrac") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return std::log(t->pt() / j.pt());
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return std::log(p.pt() / j.pt());
       });
     }
 
     if (name == "eta") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        return t->eta();
+      return TJGetter([](const Ip& p, const Jet&) {
+        return p.eta();
       });
     }
     if (name == "deta") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return t->eta() - j.eta();
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return p.eta() - j.eta();
       });
     }
     if (name == "abs_deta") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return copysign(1.0, j.eta()) * (t->eta() - j.eta());
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return copysign(1.0, j.eta()) * (p.eta() - j.eta());
       });
     }
 
     if (name == "phi") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        return t->phi();
+      return TJGetter([](const Ip& p, const Jet&) {
+        return p.phi();
       });
     }
     if (name == "dphi") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return t->p4().DeltaPhi(j.p4());
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return p.p4().DeltaPhi(j.p4());
       });
     }
 
     if (name == "dr") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return t->p4().DeltaR(j.p4());
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return p.p4().DeltaR(j.p4());
       });
     }
     if (name == "log_dr") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return std::log(t->p4().DeltaR(j.p4()));
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return std::log(p.p4().DeltaR(j.p4()));
       });
     }
     if (name == "log_dr_nansafe") {
-      return TJGetter([](const Ip* t, const Jet& j) {
-        return std::log(t->p4().DeltaR(j.p4()) + 1e-7);
+      return TJGetter([](const Ip& p, const Jet& j) {
+        return std::log(p.p4().DeltaR(j.p4()) + 1e-7);
       });
     }
 
     if (name == "mass") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        return t->m();
+      return TJGetter([](const Ip& p, const Jet&) {
+        return p.m();
       });
     }
     if (name == "energy") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        return t->e();
+      return TJGetter([](const Ip& p, const Jet&) {
+        return p.e();
       });
     }
 
     if (name == "phiUncertainty") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        auto tp = (Tp*)t;
-        return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(2));
+      return TJGetter([](const Ip& p, const Jet&) {
+          auto tp = dynamic_cast<const Tp&>(p);
+          return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(2));
       });
     }
     if (name == "thetaUncertainty") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        auto tp = (Tp*)t;
-        return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(3));
+      return TJGetter([](const Ip& p, const Jet&) {
+          auto tp = dynamic_cast<const Tp&>(p);
+          return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(3));
       });
     }
     if (name == "qOverPUncertainty") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        auto tp = (Tp*)t;
-        return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(4));
+      return TJGetter([](const Ip& p, const Jet&) {
+          auto tp = dynamic_cast<const Tp&>(p);
+          return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(4));
       });
     }
     if (name == "z0RelativeToBeamspot") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        auto tp = (Tp*)t;
-        return tp->z0();
+      return TJGetter([](const Ip& p, const Jet&) {
+          auto tp = dynamic_cast<const Tp&>(p);
+          return tp.z0();
       });
     }
     if (name == "log_z0RelativeToBeamspotUncertainty") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        auto tp = (Tp*)t;
-        return std::log(std::sqrt(tp->definingParametersCovMatrixDiagVec().at(1)));
+      return TJGetter([](const Ip& p, const Jet&) {
+          auto tp = dynamic_cast<const Tp&>(p);
+          return std::log(std::sqrt(tp.definingParametersCovMatrixDiagVec().at(1)));
       });
     }
     if (name == "z0RelativeToBeamspotUncertainty") {
-      return TJGetter([](const Ip* t, const Jet&) {
-        auto tp = (Tp*)t;
-        return std::sqrt(tp->definingParametersCovMatrixDiagVec().at(1));
+      return TJGetter([](const Ip& p, const Jet&) {
+          auto tp = dynamic_cast<const Tp&>(p);
+          return std::sqrt(tp.definingParametersCovMatrixDiagVec().at(1));
       });
     }
 
     if (name == "numberOfPixelHitsInclDead") {
       SG::AuxElement::ConstAccessor<unsigned char> pix_hits("numberOfPixelHits");
       SG::AuxElement::ConstAccessor<unsigned char> pix_dead("numberOfPixelDeadSensors");
-      return TJGetter([pix_hits, pix_dead](const Ip* t, const Jet&) {
-        return pix_hits(*t) + pix_dead(*t);
+      return TJGetter([pix_hits, pix_dead](const Ip& p, const Jet&) {
+        return pix_hits(p) + pix_dead(p);
       });
     }
     if (name == "numberOfSCTHitsInclDead") {
       SG::AuxElement::ConstAccessor<unsigned char> sct_hits("numberOfSCTHits");
       SG::AuxElement::ConstAccessor<unsigned char> sct_dead("numberOfSCTDeadSensors");
-      return TJGetter([sct_hits, sct_dead](const Ip* t, const Jet&) {
-        return sct_hits(*t) + sct_dead(*t);
+      return TJGetter([sct_hits, sct_dead](const Ip& p, const Jet&) {
+        return sct_hits(p) + sct_dead(p);
       });
     }
     if (name == "numberOfInnermostPixelLayerHits21p9") {
       SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerHits");
       SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerEndcapHits");
-      return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) {
-        return barrel_hits(*t) + endcap_hits(*t);
+      return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) {
+        return barrel_hits(p) + endcap_hits(p);
       });
     }
     if (name == "numberOfNextToInnermostPixelLayerHits21p9") {
       SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfNextToInnermostPixelLayerHits");
       SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfNextToInnermostPixelLayerEndcapHits");
-      return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) {
-        return barrel_hits(*t) + endcap_hits(*t);
+      return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) {
+        return barrel_hits(p) + endcap_hits(p);
       });
     }
     if (name == "numberOfInnermostPixelLayerSharedHits21p9") {
       SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerSharedHits");
       SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerSharedEndcapHits");
-      return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) {
-        return barrel_hits(*t) + endcap_hits(*t);
+      return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) {
+        return barrel_hits(p) + endcap_hits(p);
       });
     }
     if (name == "numberOfInnermostPixelLayerSplitHits21p9") {
       SG::AuxElement::ConstAccessor<unsigned char> barrel_hits("numberOfInnermostPixelLayerSplitHits");
       SG::AuxElement::ConstAccessor<unsigned char> endcap_hits("numberOfInnermostPixelLayerSplitEndcapHits");
-      return TJGetter([barrel_hits, endcap_hits](const Ip* t, const Jet&) {
-        return barrel_hits(*t) + endcap_hits(*t);
+      return TJGetter([barrel_hits, endcap_hits](const Ip& p, const Jet&) {
+        return barrel_hits(p) + endcap_hits(p);
       });
     }
 
@@ -301,11 +328,6 @@ namespace {
     return std::nullopt;
   }
 
-}
-
-namespace FlavorTagDiscriminants {
-  namespace sequence_getter {
-
     // ________________________________________________________________
     // Interface functions
     //
@@ -323,14 +345,11 @@ namespace FlavorTagDiscriminants {
              };
     }
 
-    // Case for track variables
-    std::pair<std::function<std::pair<std::string, std::vector<double>>(
-      const xAOD::Jet&,
-      const std::vector<const xAOD::IParticle*>&)>,
-      std::set<std::string>>
-    customNamedSeqGetterWithDeps(const std::string& name,
+    // Case for constituents variables
+    std::pair<typename CustomSequenceGetter::NamedSequenceFromConstituents, std::set<std::string>> 
+    CustomSequenceGetter::customNamedSeqGetterWithDeps(const std::string& name,
                                  const std::string& prefix) {
-      auto [getter, deps] = customSequenceGetterWithDeps(name, prefix);
+      auto [getter, deps] = CustomSequenceGetter::customSequenceGetterWithDeps(name, prefix);
       return {
         [n=name, g=getter](const xAOD::Jet& j,
                        const std::vector<const xAOD::IParticle*>& t) {
@@ -345,20 +364,113 @@ namespace FlavorTagDiscriminants {
   // These functions are wrapped by the customNamedSeqGetter function
   // below to become the ones that are actually used in DL2.
   //
-  std::pair<SequenceFromIParticles, std::set<std::string>>
-  customSequenceGetterWithDeps(const std::string& name,
+  std::pair<typename CustomSequenceGetter::SequenceFromConstituents, std::set<std::string>>
+  CustomSequenceGetter::customSequenceGetterWithDeps(const std::string& name,
                                const std::string& prefix) {
 
-    if (auto getter = sequenceWithIpDep(name, prefix)) {
+    if (auto getter = CustomSequenceGetter::sequenceWithIpDep(name, prefix)) {
       auto deps = BTagTrackIpAccessor(prefix).getTrackIpDataDependencyNames();
       return {*getter, deps};
     }
 
-    if (auto getter = sequenceNoIpDep(name)) {
+    if (auto getter = CustomSequenceGetter::sequenceNoIpDep(name)) {
       return {*getter, {}};
     }
     throw std::logic_error("no match for custom getter " + name);
   }
+  
+  // ________________________________________________________________________
+  // Class implementation
+  //
+  std::pair<typename CustomSequenceGetter::NamedSequenceFromConstituents, std::set<std::string>> 
+  CustomSequenceGetter::seqFromConsituents(
+      const FTagConstituentsInputConfig& cfg, 
+      const FTagOptions& options){
+    const std::string prefix = options.track_prefix;
+    switch (cfg.type) {
+      case ConstituentsEDMType::INT: return {
+          SequenceGetter<int, xAOD::IParticle>(cfg.name), {cfg.name}
+        };
+      case ConstituentsEDMType::FLOAT: return {
+          SequenceGetter<float, xAOD::IParticle>(cfg.name), {cfg.name}
+        };
+      case ConstituentsEDMType::CHAR: return {
+          SequenceGetter<char, xAOD::IParticle>(cfg.name), {cfg.name}
+        };
+      case ConstituentsEDMType::UCHAR: return {
+          SequenceGetter<unsigned char, xAOD::IParticle>(cfg.name), {cfg.name}
+        };
+      case ConstituentsEDMType::CUSTOM_GETTER: {
+        return CustomSequenceGetter::customNamedSeqGetterWithDeps(
+          cfg.name, options.track_prefix);
+      }
+      default: {
+        throw std::logic_error("Unknown EDM type for constituent.");
+      }
+    }
   }
-}
 
+  
+  CustomSequenceGetter::CustomSequenceGetter(
+    std::vector<FTagConstituentsInputConfig> inputs,
+    const FTagOptions& options)
+  {
+      std::map<std::string, std::string> remap = options.remap_scalar;
+      for (const FTagConstituentsInputConfig& input_cfg: inputs) {
+          auto [seqGetter, seq_deps] = seqFromConsituents(
+          input_cfg, options);
+
+          if(input_cfg.flip_sign){
+          auto seqGetter_flip=[g=seqGetter](const xAOD::Jet&jet, const IParticles& constituents){
+              auto [n,v] = g(jet,constituents);
+              std::for_each(v.begin(), v.end(), [](double &n){ n=-1.0*n; });
+              return std::make_pair(n,v);
+          };
+              sequencesFromConstituents.push_back(seqGetter_flip);
+          }
+          else{
+              sequencesFromConstituents.push_back(seqGetter);
+          }
+              deps.merge(seq_deps);
+              if (auto h = remap.extract(input_cfg.name)){
+              used_remap.insert(h.key());
+          }
+      }
+  }
+
+  
+  std::pair<std::vector<float>, std::vector<int64_t>> CustomSequenceGetter::getFeats(
+    const xAOD::Jet& jet, const IParticles& constituents) const
+  {
+    std::vector<float> cnsts_feats;
+    int num_vars = static_cast<int>(sequencesFromConstituents.size());
+    int num_cnsts = 0;
+
+    int cnst_var_idx = 0;
+    for (const auto& seq_builder: sequencesFromConstituents){
+      auto double_vec = seq_builder(jet, constituents).second;
+
+      if (cnst_var_idx==0){
+          num_cnsts = static_cast<int>(double_vec.size());
+          cnsts_feats.resize(num_cnsts * num_vars);
+      }
+
+      // need to transpose + flatten
+      for (unsigned int cnst_idx=0; cnst_idx<double_vec.size(); cnst_idx++){
+        cnsts_feats.at(cnst_idx*num_vars + cnst_var_idx)
+            = double_vec.at(cnst_idx);
+      }
+      cnst_var_idx++;
+    }
+    std::vector<int64_t> cnsts_feat_dim = {num_cnsts, num_vars};
+    return {cnsts_feats, cnsts_feat_dim};
+  }
+
+  std::set<std::string> CustomSequenceGetter::getDependencies() const {
+    return deps;
+  }
+  std::set<std::string> CustomSequenceGetter::getUsedRemap() const {
+    return used_remap;
+  }
+  }
+}
\ No newline at end of file
diff --git a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx
index b930ef6d0bae..d97f356917c7 100644
--- a/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx
+++ b/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/TracksLoader.cxx
@@ -5,7 +5,6 @@ Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
 #include "FlavorTagDiscriminants/FlipTagEnums.h"
 #include "FlavorTagDiscriminants/AssociationEnums.h"
 
-#include "FlavorTagDiscriminants/customGetter.h"
 #include "FlavorTagDiscriminants/TracksLoader.h"
 
 #include <iostream>
@@ -33,20 +32,6 @@ namespace {
   typedef std::vector<std::pair<std::regex, ConstituentsSortOrder> > SortRegexes;
   typedef std::vector<std::pair<std::regex, ConstituentsSelection> > TrkSelRegexes;
 
-  // Function to map the regex + list of inputs to variable config,
-  // this time for sequence inputs.
-  std::vector<FTagConstituentsSequenceConfig> get_track_input_config(
-    const std::vector<std::pair<std::string, std::vector<std::string>>>& names,
-    const TypeRegexes& type_regexes,
-    const SortRegexes& sort_regexes,
-    const TrkSelRegexes& select_regexes,
-    const std::regex& re,
-    const FlipTagConfig& flip_config);
-
-
-  //_______________________________________________________________________
-  // Implementation of the above functions
-  //
 
   template <typename T>
   T match_first(const std::vector<std::pair<std::regex, T> >& regexes,
@@ -137,7 +122,13 @@ namespace FlavorTagDiscriminants {
     ){
         // some sequences also need to be sign-flipped. We apply this by
         // changing the input scaling and normalizations
-        std::regex flip_sequences(".*signed_[dz]0.*");
+        std::regex flip_sequences;
+        if (flip_config == FlipTagConfig::FLIP_SIGN || flip_config == FlipTagConfig::NEGATIVE_IP_ONLY){
+          flip_sequences=std::regex(".*signed_[dz]0.*");
+        }
+        if (flip_config == FlipTagConfig::SIMPLE_FLIP){
+          flip_sequences=std::regex("(.*signed_[dz]0.*)|d0|z0SinTheta");
+        }
 
         // build the track inputs
         TypeRegexes trk_type_regexes {
@@ -383,37 +374,6 @@ namespace FlavorTagDiscriminants {
         }
     }
 
-    // factory for functions that build std::vector objects from
-    // track sequences
-    std::pair<TracksLoader::SeqFromTracks,std::set<std::string>> TracksLoader::seqFromTracks(
-        const FTagConstituentsInputConfig& cfg, 
-        const FTagOptions& options)
-    {
-        const std::string prefix = options.track_prefix;
-        switch (cfg.type) {
-          case ConstituentsEDMType::INT: return {
-              SequenceGetter<int, Track>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::FLOAT: return {
-              SequenceGetter<float, Track>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::CHAR: return {
-              SequenceGetter<char, Track>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::UCHAR: return {
-              SequenceGetter<unsigned char, Track>(cfg.name), {cfg.name}
-            };
-          case ConstituentsEDMType::CUSTOM_GETTER: {
-            return internal::customNamedSeqGetterWithDeps(
-              cfg.name, options.track_prefix);
-
-          }
-          default: {
-            throw std::logic_error("Unknown EDM type for tracks");
-          }
-        }
-    }
-
     // here we define filters for the "flip" taggers
     //
     // start by defining the raw functions, there's a factory
@@ -463,7 +423,6 @@ namespace FlavorTagDiscriminants {
         }
     }
 
-    // TracksLoader::TracksLoader() : ConstituentsLoader() {};
     TracksLoader::TracksLoader(
         FTagConstituentsSequenceConfig cfg,
         const FTagOptions& options
@@ -471,7 +430,9 @@ namespace FlavorTagDiscriminants {
         ConstituentsLoader(cfg),
         m_trackSortVar(TracksLoader::trackSortVar(cfg.order, options)),
         m_trackFilter(TracksLoader::trackFilter(cfg.selection, options).first),
-        m_flipFilter(TracksLoader::flipFilter(options).first)
+        m_flipFilter(TracksLoader::flipFilter(options).first),
+        m_customSequenceGetter(sequence_getter::CustomSequenceGetter(
+          cfg.inputs, options))
     {
         // We have several ways to get tracks: either we retrieve an
         // IParticleContainer and cast the pointers to TrackParticle, or
@@ -479,7 +440,6 @@ namespace FlavorTagDiscriminants {
         // the way tracks are stored isn't consistent across the EDM, so
         // we allow configuration for both setups.
         //
-        std::cout << "TEST TRACK 1 " << std::endl;
         if (options.track_link_type == TrackLinkType::IPARTICLE) {
             SG::AuxElement::ConstAccessor<PartLinks> acc(options.track_link_name);
             m_associator = [acc](const SG::AuxElement& btag) -> TPV {
@@ -511,35 +471,12 @@ namespace FlavorTagDiscriminants {
         } else {
             throw std::logic_error("Unknown TrackLinkType");
         }
-        std::cout << "TEST TRACK 2 " << std::endl;
-        std::map<std::string, std::string> remap = options.remap_scalar;
-        std::set<std::string> used_remap;
-
         auto track_data_deps = trackFilter(cfg.selection, options).second;
         track_data_deps.merge(flipFilter(options).second);
-        for (const FTagConstituentsInputConfig& input_cfg: cfg.inputs) {
-            auto [seqGetter, deps] = seqFromTracks(
-            input_cfg, options);
-
-            if(input_cfg.flip_sign){
-            auto seqGetter_flip=[g=seqGetter](const xAOD::Jet&jet, const internal::Tracks& trks){
-                auto [n,v] = g(jet,trks);
-                std::for_each(v.begin(), v.end(), [](double &n){ n=-1.0*n; });
-                return std::make_pair(n,v);
-            };
-                m_sequencesFromTracks.push_back(seqGetter_flip);
-            }
-            else{
-                m_sequencesFromTracks.push_back(seqGetter);
-            }
-                track_data_deps.merge(deps);
-                if (auto h = remap.extract(input_cfg.name)){
-                used_remap.insert(h.key());
-            }
-        }
-        std::cout << "TEST TRACK 3 " << std::endl;
+        track_data_deps.merge(m_customSequenceGetter.getDependencies());
         deps.trackInputs.merge(track_data_deps);
         deps.bTagInputs.insert(options.track_link_name);
+        used_remap = m_customSequenceGetter.getUsedRemap();
     }
 
     std::vector<const xAOD::TrackParticle*> TracksLoader::getTracksFromJet(
@@ -563,32 +500,22 @@ namespace FlavorTagDiscriminants {
 
     std::pair<std::string, input_pair> TracksLoader::getData(const xAOD::Jet& jet, const SG::AuxElement& btag) const {
         Tracks flipped_tracks;
-        std::vector<float> track_feat; // (#tracks, #feats).flatten
-
-        int num_track_vars = static_cast<int>(m_sequencesFromTracks.size());
-        int num_tracks = 0;
-
         Tracks sorted_tracks = getTracksFromJet(jet, btag);
-        flipped_tracks = m_flipFilter(sorted_tracks, jet);
+        std::vector<const xAOD::IParticle*> flipped_tracks_ip;
 
-        int track_var_idx=0;
-        for (const auto& seq_builder: m_sequencesFromTracks) {
-            auto double_vec = seq_builder(jet, flipped_tracks).second;
-
-            if (track_var_idx==0){
-                num_tracks = static_cast<int>(double_vec.size());
-                track_feat.resize(num_tracks * num_track_vars);
-            }
-
-            // need to transpose + flatten
-            for (unsigned int track_idx=0; track_idx<double_vec.size(); track_idx++){
-            track_feat.at(track_idx*num_track_vars + track_var_idx)
-                = double_vec.at(track_idx);
-            }
-            track_var_idx++;
+        flipped_tracks = m_flipFilter(sorted_tracks, jet);
+        
+        for (const auto& trk: flipped_tracks) {
+            flipped_tracks_ip.push_back(dynamic_cast<const xAOD::IParticle*>(trk));
         }
-        std::vector<int64_t> track_feat_dim = {num_tracks, num_track_vars};
 
-        return std::make_pair("track_features", std::make_pair(track_feat, track_feat_dim));
+        return std::make_pair("track_features", m_customSequenceGetter.getFeats(jet, flipped_tracks_ip));
+    }
+
+    FTagDataDependencyNames TracksLoader::getDependencies() const {
+        return deps;
+    }
+    std::set<std::string> TracksLoader::getUsedRemap() const {
+        return used_remap;
     }
 }
\ No newline at end of file
-- 
GitLab