Skip to content
Snippets Groups Projects
Commit f03d1fdb authored by Dan Guest's avatar Dan Guest
Browse files

AFT-422: Add rnn track sort function

The RNN was missing the functionality to sort tracks by the signed d0
significance. This commit adds this.
parent 673ea2b4
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,8 @@ namespace lwt {
namespace FlavorTagDiscriminants {
enum class EDMType {UCHAR, INT, FLOAT, DOUBLE, CUSTOM_GETTER};
enum class SortOrder {ABS_D0_SIGNIFICANCE_DESCENDING, PT_DESCENDING};
enum class SortOrder {
ABS_D0_SIGNIFICANCE_DESCENDING, D0_SIGNIFICANCE_DESCENDING, PT_DESCENDING};
enum class TrackSelection {ALL, IP3D_2018};
// Structures to define DL2 input.
......@@ -89,8 +90,8 @@ namespace FlavorTagDiscriminants {
typedef std::pair<std::string, std::vector<double> > NamedSeq;
typedef xAOD::Jet Jet;
typedef std::vector<const xAOD::TrackParticle*> Tracks;
typedef std::function<bool(const xAOD::TrackParticle*,
const xAOD::TrackParticle*)> TrackSort;
typedef std::function<double(const xAOD::TrackParticle*,
const xAOD::Jet&)> TrackSortVar;
typedef std::function<bool(const xAOD::TrackParticle*)> TrackSelect;
// getter functions
......@@ -137,7 +138,7 @@ namespace FlavorTagDiscriminants {
typedef SG::AuxElement AE;
typedef std::vector<ElementLink<xAOD::TrackParticleContainer>> TrackLinks;
AE::ConstAccessor<TrackLinks> m_track_associator;
TrackSort m_sort_function;
TrackSortVar m_sort_var_getter;
TrackSelect m_select_function;
};
......@@ -192,7 +193,7 @@ namespace FlavorTagDiscriminants {
// Filler functions
namespace internal {
Getter get_filler(std::string name, EDMType, std::string default_flag);
TrackSort get_track_sort(SortOrder);
TrackSortVar get_track_sort(SortOrder);
TrackSelect get_track_select(TrackSelection);
SeqGetter get_seq_getter(const DL2TrackInputConfig&);
}
......
......@@ -3,6 +3,7 @@
*/
#include "FlavorTagDiscriminants/DL2.h"
#include "FlavorTagDiscriminants/BTagTrackAugmenter.h"
#include "lwtnn/LightweightGraph.hh"
#include "lwtnn/NanReplacer.hh"
......@@ -97,8 +98,7 @@ namespace FlavorTagDiscriminants {
// save out things
for (const auto& dec: m_decorators) {
// the second argument to compute(...) is for sequences, we
// don't currently have any.
// the second argument to compute(...) is for sequences
auto out_vals = m_graph->compute(nodes, seqs, dec.first);
for (const auto& node: dec.second) {
node.second(*jet.btagging()) = out_vals.at(node.first);
......@@ -161,23 +161,27 @@ namespace FlavorTagDiscriminants {
// Track Getter Class
TrackGetter::TrackGetter(SortOrder order, TrackSelection selection):
m_track_associator("BTagTrackToJetAssociator"),
m_sort_function(get_track_sort(order)),
m_sort_var_getter(get_track_sort(order)),
m_select_function(get_track_select(selection))
{
}
Tracks TrackGetter::operator()(const xAOD::Jet& jet) const {
const xAOD::BTagging *btagging = jet.btagging();
if (!btagging) throw std::runtime_error("can't find btagging object");
std::vector<const xAOD::TrackParticle*> tracks;
std::vector<std::pair<double, const xAOD::TrackParticle*>> tracks;
for (const auto &link : m_track_associator(*btagging)) {
if(!link.isValid()) {
throw std::logic_error("invalid track link");
}
const xAOD::TrackParticle *tp = *link;
if (m_select_function(tp)) tracks.push_back(tp);
if (m_select_function(tp)) {
tracks.push_back({m_sort_var_getter(tp, jet), tp});
};
}
std::sort(tracks.begin(), tracks.end(), m_sort_function);
return tracks;
std::sort(tracks.begin(), tracks.end(), std::greater<>());
std::vector<const xAOD::TrackParticle*> only_tracks;
for (const auto& trk: tracks) only_tracks.push_back(trk.second);
return only_tracks;
}
......@@ -196,22 +200,24 @@ namespace FlavorTagDiscriminants {
}
}
}
TrackSort get_track_sort(SortOrder order) {
TrackSortVar get_track_sort(SortOrder order) {
typedef xAOD::TrackParticle Tp;
typedef xAOD::Jet Jet;
typedef SG::AuxElement AE;
AE::ConstAccessor<float> d0("btag_ip_d0");
AE::ConstAccessor<float> d0_sigma("btag_ip_d0_sigma");
BTagTrackAugmenter aug;
switch(order) {
case SortOrder::ABS_D0_SIGNIFICANCE_DESCENDING:
return [d0, d0_sigma](const Tp* tp1, const Tp* tp2) {
double sd01 = std::abs(d0(*tp1) / d0_sigma(*tp1));
double sd02 = std::abs(d0(*tp2) / d0_sigma(*tp2));
return sd01 > sd02;
return [d0, d0_sigma](const Tp* tp, const Jet&) {
return std::abs(d0(*tp) / d0_sigma(*tp));
};
case SortOrder::PT_DESCENDING:
return [](const Tp* tp1, const Tp* tp2) {
return tp1->pt() > tp2->pt();
case SortOrder::D0_SIGNIFICANCE_DESCENDING:
return [aug](const Tp* tp, const Jet& j) {
return aug.get_signed_ip(*tp, j).ip3d_signed_d0_significance;
};
case SortOrder::PT_DESCENDING:
return [](const Tp* tp, const Jet&) {return tp->pt();};
default: {
throw std::logic_error("Unknown sort function");
}
......
......@@ -82,7 +82,8 @@ namespace FlavorTagDiscriminants {
{"(log_)?(ptfrac|dr)"_r, EDMType::CUSTOM_GETTER}
};
SortRegexes trk_sort_regexes {
{".*sd0sort"_r, SortOrder::ABS_D0_SIGNIFICANCE_DESCENDING},
{".*absSd0sort"_r, SortOrder::ABS_D0_SIGNIFICANCE_DESCENDING},
{".*sd0sort"_r, SortOrder::D0_SIGNIFICANCE_DESCENDING},
{".*ptsort"_r, SortOrder::PT_DESCENDING},
};
TrkSelRegexes trk_select_regexes {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment