DL2.h 7.83 KB
Newer Older
1
/*
2
  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3
4
5
6
7
8
9
10
*/

#ifndef DL2_H
#define DL2_H

// local includes
#include "FlavorTagDiscriminants/customGetter.h"
#include "FlavorTagDiscriminants/FlipTagEnums.h"
11
#include "FlavorTagDiscriminants/AssociationEnums.h"
12
#include "FlavorTagDiscriminants/DL2DataDependencyNames.h"
13
#include "xAODBTagging/ftagfloat_t.h"
14
15
16

// EDM includes
#include "xAODJet/Jet.h"
17
#include "xAODBTagging/BTagging.h"
18
19
20
21
22
23
24
25
26

// external libraries
#include "lwtnn/lightweight_network_config.hh"

// STL includes
#include <string>
#include <vector>
#include <functional>
#include <exception>
Dan Guest's avatar
Dan Guest committed
27
#include <type_traits>
28
29
30
31
32
33
34
35
36
37
38
39

// forward declarations
namespace lwt {
  class NanReplacer;
  class LightweightGraph;
}

namespace FlavorTagDiscriminants {

  enum class EDMType {UCHAR, INT, FLOAT, DOUBLE, CUSTOM_GETTER};
  enum class SortOrder {
    ABS_D0_SIGNIFICANCE_DESCENDING, D0_SIGNIFICANCE_DESCENDING, PT_DESCENDING};
40
  enum class TrackSelection {ALL, IP3D_2018, DIPS_LOOSE_202102};
Dan Guest's avatar
Dan Guest committed
41
42


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  // Structures to define DL2 input.
  //
  struct DL2InputConfig
  {
    std::string name;
    EDMType type;
    std::string default_flag;
  };
  struct DL2TrackInputConfig
  {
    std::string name;
    EDMType type;
  };
  struct DL2TrackSequenceConfig
  {
    std::string name;
    SortOrder order;
    TrackSelection selection;
    std::vector<DL2TrackInputConfig> inputs;
  };

64
65
66
67
68
69
70
  // other DL2 options
  struct DL2Options {
    DL2Options();
    std::string track_prefix;
    FlipTagConfig flip;
    std::string track_link_name;
    std::map<std::string,std::string> remap_scalar;
71
    TrackLinkType track_link_type;
72
73
74
  };


75
76
77
78
79
80
81
82
  // _____________________________________________________________________
  // Internal code

  namespace internal {
    // typedefs
    typedef std::pair<std::string, double> NamedVar;
    typedef std::pair<std::string, std::vector<double> > NamedSeq;
    typedef xAOD::Jet Jet;
83
    typedef xAOD::BTagging BTagging;
84
85
86
87
88
89
90
91
    typedef std::vector<const xAOD::TrackParticle*> Tracks;
    typedef std::function<double(const xAOD::TrackParticle*,
                                 const xAOD::Jet&)> TrackSortVar;
    typedef std::function<bool(const xAOD::TrackParticle*)> TrackFilter;
    typedef std::function<Tracks(const Tracks&,
                                 const xAOD::Jet&)> TrackSequenceFilter;

    // getter functions
92
    typedef std::function<NamedVar(const SG::AuxElement&)> VarFromBTag;
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    typedef std::function<NamedSeq(const Jet&, const Tracks&)> SeqFromTracks;

    // ___________________________________________________________________
    // Getter functions
    //
    // internally we want a bunch of std::functions that return pairs
    // to populate the lwtnn input map. We define a functor here to
    // deal with the b-tagging cases.
    //
    template <typename T>
    class BVarGetter
    {
    private:
      typedef SG::AuxElement AE;
      AE::ConstAccessor<T> m_getter;
      AE::ConstAccessor<char> m_default_flag;
      std::string m_name;
    public:
      BVarGetter(const std::string& name, const std::string& default_flag):
        m_getter(name),
        m_default_flag(default_flag),
        m_name(name)
        {
        }
117
      NamedVar operator()(const SG::AuxElement& btag) const {
118
119
        T ret_value = m_getter(btag);
        bool is_default = m_default_flag(btag);
Dan Guest's avatar
Dan Guest committed
120
121
122
123
124
125
126
127
128
        if constexpr (std::is_floating_point<T>::value) {
          if (std::isnan(ret_value) && !is_default) {
            throw std::runtime_error(
              "Found NAN value for '" + m_name
              + "'. This is only allowed when using a default"
              " value for this input");
          }
        }
        return {m_name, is_default ? NAN : ret_value};
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
      }
    };

    template <typename T>
    class BVarGetterNoDefault
    {
    private:
      typedef SG::AuxElement AE;
      AE::ConstAccessor<T> m_getter;
      std::string m_name;
    public:
      BVarGetterNoDefault(const std::string& name):
        m_getter(name),
        m_name(name)
        {
        }
145
      NamedVar operator()(const SG::AuxElement& btag) const {
146
        T ret_value = m_getter(btag);
Dan Guest's avatar
Dan Guest committed
147
148
149
150
151
152
153
        if constexpr (std::is_floating_point<T>::value) {
          if (std::isnan(ret_value)) {
            throw std::runtime_error(
              "Found NAN value for '" + m_name + "'.");
          }
        }
        return {m_name, ret_value};
154
155
156
157
158
159
160
161
      }
    };

    // The track getter is responsible for getting the tracks from the
    // jet applying a selection, and then sorting the tracks.
    class TracksFromJet
    {
    public:
162
      TracksFromJet(SortOrder, TrackSelection, const DL2Options&);
163
      Tracks operator()(const xAOD::Jet& jet,
164
                        const SG::AuxElement& btag) const;
165
    private:
166
167
168
169
170
171
172
      using AE = SG::AuxElement;
      using IPC = xAOD::IParticleContainer;
      using TPC = xAOD::TrackParticleContainer;
      using TrackLinks = std::vector<ElementLink<TPC>>;
      using PartLinks = std::vector<ElementLink<IPC>>;
      using TPV = std::vector<const xAOD::TrackParticle*>;
      std::function<TPV(const SG::AuxElement&)> m_associator;
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
      TrackSortVar m_trackSortVar;
      TrackFilter m_trackFilter;
    };

    // The sequence getter takes in tracks and calculates arrays of
    // values which are better suited for inputs to the NNs
    template <typename T>
    class SequenceGetter
    {
    private:
      SG::AuxElement::ConstAccessor<T> m_getter;
      std::string m_name;
    public:
      SequenceGetter(const std::string& name):
        m_getter(name),
        m_name(name)
        {
        }
      NamedSeq operator()(const xAOD::Jet&, const Tracks& trks) const {
        std::vector<double> seq;
        for (const xAOD::TrackParticle* track: trks) {
          seq.push_back(m_getter(*track));
        }
        return {m_name, seq};
      }
    };
  } // end internal namespace
  class DL2
  {
  public:
    DL2(const lwt::GraphConfig&,
        const std::vector<DL2InputConfig>&,
        const std::vector<DL2TrackSequenceConfig>& = {},
206
        const DL2Options& = DL2Options());
207
    void decorate(const xAOD::BTagging& btag) const;
208
209
    void decorate(const xAOD::Jet& jet) const;
    void decorate(const xAOD::Jet& jet, const SG::AuxElement& decorated) const;
210
211
212
213

    // functions to report data depdedencies
    DL2DataDependencyNames getDataDependencyNames() const;

214
215
  private:
    struct TrackSequenceBuilder {
216
217
      TrackSequenceBuilder(SortOrder,
                           TrackSelection,
218
                           const DL2Options&);
219
220
221
222
223
      std::string name;
      internal::TracksFromJet tracksFromJet;
      internal::TrackSequenceFilter flipFilter;
      std::vector<internal::SeqFromTracks> sequencesFromTracks;
    };
224
    typedef SG::AuxElement::Decorator<float> OutputSetter;
225
    typedef std::vector<std::pair<std::string, OutputSetter > > OutNode;
226
    SG::AuxElement::ConstAccessor<ElementLink<xAOD::JetContainer>> m_jetLink;
227
228
229
    std::string m_input_node_name;
    std::unique_ptr<lwt::LightweightGraph> m_graph;
    std::unique_ptr<lwt::NanReplacer> m_variable_cleaner;
230
    std::vector<internal::VarFromBTag> m_varsFromBTag;
231
232
    std::vector<TrackSequenceBuilder> m_trackSequenceBuilders;
    std::map<std::string, OutNode> m_decorators;
233
234
235

    DL2DataDependencyNames m_dataDependencyNames;

236
237
  };

238

239
240
241
242
243
  //
  // Filler functions
  namespace internal {
    // factory functions to produce callable objects that build inputs
    namespace get {
244
      VarFromBTag varFromBTag(const std::string& name,
245
246
247
                              EDMType,
                              const std::string& defaultflag);
      TrackSortVar trackSortVar(SortOrder, const DL2Options&);
248
      std::pair<TrackFilter,std::set<std::string>> trackFilter(
249
        TrackSelection, const DL2Options&);
250
      std::pair<SeqFromTracks,std::set<std::string>> seqFromTracks(
251
        const DL2TrackInputConfig&, const DL2Options&);
252
      std::pair<TrackSequenceFilter,std::set<std::string>> flipFilter(
253
        const DL2Options&);
254
255
256
257
    }
  }
}
#endif