Commit 325b5414 authored by Walter Lampl's avatar Walter Lampl Committed by Atlas Nightlybuild
Browse files

Merge branch 'dl2-data-deps-master' into 'master'

Add methods to get data dependencies out of DL2

See merge request atlas/athena!32627

(cherry picked from commit ac7e7bc7)

dc1fca8c Add methods to get data dependencies out of DL2
parent 3289ab68
......@@ -8,6 +8,7 @@
#include <vector>
#include <set>
#include "AthContainers/AuxElement.h"
#include "AthLinks/ElementLink.h"
......@@ -27,7 +28,7 @@ struct BTagSignedIP {
class BTagTrackAugmenter {
public:
BTagTrackAugmenter();
BTagTrackAugmenter(const std::string& prefix = "btagIp_" );
void augment(const xAOD::TrackParticle &track, const xAOD::Jet &jet);
// NOTE: this should be called in the derivations if possible,
......@@ -48,6 +49,7 @@ public:
double z0SinThetaUncertainty(const xAOD::TrackParticle &track) const;
BTagSignedIP get_signed_ip(const xAOD::TrackParticle &track, const xAOD::Jet &jet) const;
std::set<std::string> getTrackIpDataDependencyNames() const;
private:
typedef SG::AuxElement AE;
......@@ -68,6 +70,8 @@ private:
AE::Decorator<float> m_ip3d_signed_z0_significance;
AE::Decorator<int> m_ip2d_grade;
AE::Decorator<int> m_ip3d_grade;
std::string m_prefix;
};
#endif
......@@ -8,6 +8,7 @@
// local includes
#include "FlavorTagDiscriminants/customGetter.h"
#include "FlavorTagDiscriminants/FlipTagEnums.h"
#include "FlavorTagDiscriminants/DL2DataDependencyNames.h"
// EDM includes
#include "xAODJet/Jet.h"
......@@ -72,7 +73,7 @@ namespace FlavorTagDiscriminants {
const xAOD::Jet&)> TrackSequenceFilter;
// getter functions
typedef std::function<NamedVar(const Jet&)> VarFromJet;
typedef std::function<NamedVar(const Jet&)> VarFromBTag;
typedef std::function<NamedSeq(const Jet&, const Tracks&)> SeqFromTracks;
// ___________________________________________________________________
......@@ -172,6 +173,10 @@ namespace FlavorTagDiscriminants {
std::map<std::string, std::string> out_remap = {},
OutputType = OutputType::DOUBLE);
void decorate(const xAOD::Jet& jet) const;
// functions to report data depdedencies
DL2DataDependencyNames getDataDependencyNames() const;
private:
struct TrackSequenceBuilder {
TrackSequenceBuilder(SortOrder, TrackSelection, FlipTagConfig);
......@@ -185,23 +190,30 @@ namespace FlavorTagDiscriminants {
std::string m_input_node_name;
std::unique_ptr<lwt::LightweightGraph> m_graph;
std::unique_ptr<lwt::NanReplacer> m_variable_cleaner;
std::vector<internal::VarFromJet> m_varsFromJet;
std::vector<internal::VarFromBTag> m_varsFromBTag;
std::vector<TrackSequenceBuilder> m_trackSequenceBuilders;
std::map<std::string, OutNode> m_decorators;
DL2DataDependencyNames m_dataDependencyNames;
};
//
// Filler functions
namespace internal {
// factory functions to produce callable objects that build inputs
namespace get {
VarFromJet varFromJet(const std::string& name,
VarFromBTag varFromBTag(const std::string& name,
EDMType,
const std::string& defaultflag);
TrackSortVar trackSortVar(SortOrder);
TrackFilter trackFilter(TrackSelection);
SeqFromTracks seqFromTracks(const DL2TrackInputConfig&);
TrackSequenceFilter flipFilter(FlipTagConfig);
std::pair<TrackFilter,std::set<std::string>> trackFilter(
TrackSelection);
std::pair<SeqFromTracks,std::set<std::string>> seqFromTracks(
const DL2TrackInputConfig&);
std::pair<TrackSequenceFilter,std::set<std::string>> flipFilter(
FlipTagConfig);
}
}
}
......
#ifndef DL2_DATA_DEPENDENCY_NAMES_H
#define DL2_DATA_DEPENDENCY_NAMES_H
#include <set>
#include <string>
namespace FlavorTagDiscriminants {
struct DL2DataDependencyNames {
std::set<std::string> trackInputs;
std::set<std::string> bTagInputs;
std::set<std::string> bTagOutputs;
};
}
#endif
......@@ -6,6 +6,7 @@
#define DL2_HIGH_LEVEL_HH
#include "FlavorTagDiscriminants/FlipTagEnums.h"
#include "FlavorTagDiscriminants/DL2DataDependencyNames.h"
// EDM includes
#include "xAODJet/Jet.h"
......@@ -27,6 +28,7 @@ namespace FlavorTagDiscriminants {
DL2HighLevel(DL2HighLevel&&);
~DL2HighLevel();
void decorate(const xAOD::Jet& jet) const;
DL2DataDependencyNames getDataDependencyNames() const;
private:
std::unique_ptr<DL2> m_dl2;
};
......
......@@ -7,13 +7,22 @@
#include <cmath>
#include <cstddef>
BTagTrackAugmenter::BTagTrackAugmenter():
m_ip_d0("btagIp_d0"),
m_ip_z0("btagIp_z0SinTheta"),
m_ip_d0_sigma("btagIp_d0Uncertainty"),
m_ip_z0_sigma("btagIp_z0SinThetaUncertainty"),
m_track_displacement("btagIp_trackDisplacement"),
m_track_momentum("btagIp_trackMomentum"),
namespace str {
const std::string d0 = "d0";
const std::string z0SinTheta = "z0SinTheta";
const std::string d0Uncertainty = "d0Uncertainty";
const std::string z0SinThetaUncertainty = "z0SinThetaUncertainty";
const std::string trackDisplacement = "trackDisplacement";
const std::string trackMomentum = "trackMomentum";
}
BTagTrackAugmenter::BTagTrackAugmenter(const std::string& prefix):
m_ip_d0(prefix + str::d0),
m_ip_z0(prefix + str::z0SinTheta),
m_ip_d0_sigma(prefix + str::d0Uncertainty),
m_ip_z0_sigma(prefix + str::z0SinThetaUncertainty),
m_track_displacement(prefix + str::trackDisplacement),
m_track_momentum(prefix + str::trackMomentum),
m_ip2d_trackParticleLinks("IP2D_TrackParticleLinks"),
m_ip3d_trackParticleLinks("IP3D_TrackParticleLinks"),
m_ip2d_gradeOfTracks("IP2D_gradeOfTracks"),
......@@ -24,7 +33,8 @@ BTagTrackAugmenter::BTagTrackAugmenter():
m_ip3d_signed_d0_significance("IP3D_signed_d0_significance"),
m_ip3d_signed_z0_significance("IP3D_signed_z0_significance"),
m_ip2d_grade("IP2D_grade"),
m_ip3d_grade("IP3D_grade")
m_ip3d_grade("IP3D_grade"),
m_prefix(prefix)
{
}
......@@ -108,3 +118,15 @@ void BTagTrackAugmenter::augment_with_grades(const xAOD::TrackParticle &track, c
}
m_ip2d_grade(track) = ip2d_grade;
}
std::set<std::string> BTagTrackAugmenter::getTrackIpDataDependencyNames() const
{
return {
m_prefix + str::d0,
m_prefix + str::z0SinTheta,
m_prefix + str::d0Uncertainty,
m_prefix + str::z0SinThetaUncertainty,
m_prefix + str::trackDisplacement,
m_prefix + str::trackMomentum};
}
......@@ -35,9 +35,15 @@ namespace FlavorTagDiscriminants {
lwt::rep::all));
}
for (const auto& input: inputs) {
auto filler = get::varFromJet(input.name, input.type,
input.default_flag);
m_varsFromJet.push_back(filler);
auto filler = get::varFromBTag(input.name, input.type,
input.default_flag);
if (input.type != EDMType::CUSTOM_GETTER) {
m_dataDependencyNames.bTagInputs.insert(input.name);
}
if (input.default_flag.size() > 0) {
m_dataDependencyNames.bTagInputs.insert(input.default_flag);
}
m_varsFromBTag.push_back(filler);
}
// set up sequence inputs
......@@ -45,12 +51,18 @@ namespace FlavorTagDiscriminants {
TrackSequenceBuilder track_getter(track_cfg.order,
track_cfg.selection,
flipConfig);
// add the tracking data dependencies
auto track_data_deps = get::trackFilter(track_cfg.selection).second;
track_data_deps.merge(get::flipFilter(flipConfig).second);
track_getter.name = track_cfg.name;
for (const DL2TrackInputConfig& input_cfg: track_cfg.inputs) {
track_getter.sequencesFromTracks.push_back(
get::seqFromTracks(input_cfg));
auto [seqGetter, deps] = get::seqFromTracks(input_cfg);
track_getter.sequencesFromTracks.push_back(seqGetter);
track_data_deps.merge(deps);
}
m_trackSequenceBuilders.push_back(track_getter);
m_dataDependencyNames.trackInputs = track_data_deps;
}
// set up outputs
......@@ -67,6 +79,7 @@ namespace FlavorTagDiscriminants {
name = replacement_itr->second;
out_remap.erase(replacement_itr);
}
m_dataDependencyNames.bTagOutputs.insert(name);
// for the spring 2019 retraining campaign we're stuck with
// doubles. Hopefully at some point we can move to using
......@@ -96,10 +109,11 @@ namespace FlavorTagDiscriminants {
throw std::logic_error("found unused output remapping(s): " + outputs);
}
}
void DL2::decorate(const xAOD::Jet& jet) const {
using namespace internal;
std::vector<NamedVar> vvec;
for (const auto& getter: m_varsFromJet) {
for (const auto& getter: m_varsFromBTag) {
vvec.push_back(getter(jet));
}
std::map<std::string, std::map<std::string, double> > nodes;
......@@ -134,11 +148,15 @@ namespace FlavorTagDiscriminants {
}
}
DL2DataDependencyNames DL2::getDataDependencyNames() const {
return m_dataDependencyNames;
}
DL2::TrackSequenceBuilder::TrackSequenceBuilder(SortOrder order,
TrackSelection selection,
FlipTagConfig flipcfg):
tracksFromJet(order, selection),
flipFilter(internal::get::flipFilter(flipcfg))
flipFilter(internal::get::flipFilter(flipcfg).first)
{
}
......@@ -152,7 +170,7 @@ namespace FlavorTagDiscriminants {
TracksFromJet::TracksFromJet(SortOrder order, TrackSelection selection):
m_trackAssociator("BTagTrackToJetAssociator"),
m_trackSortVar(get::trackSortVar(order)),
m_trackFilter(get::trackFilter(selection))
m_trackFilter(get::trackFilter(selection).first)
{
}
Tracks TracksFromJet::operator()(const xAOD::Jet& jet) const {
......@@ -184,7 +202,7 @@ namespace FlavorTagDiscriminants {
namespace get {
// factory for functions that get variables out of the b-tagging
// object
VarFromJet varFromJet(const std::string& name, EDMType type,
VarFromBTag varFromBTag(const std::string& name, EDMType type,
const std::string& default_flag) {
if(default_flag.size() == 0 || name==default_flag)
{
......@@ -238,20 +256,31 @@ namespace FlavorTagDiscriminants {
// factory for functions that return true for tracks we want to
// use, false for those we don't want
TrackFilter trackFilter(TrackSelection selection) {
std::pair<TrackFilter,std::set<std::string>> trackFilter(
TrackSelection selection) {
typedef xAOD::TrackParticle Tp;
typedef SG::AuxElement AE;
BTagTrackAugmenter aug;
AE::ConstAccessor<unsigned char> pix_hits("numberOfPixelHits");
AE::ConstAccessor<unsigned char> pix_holes("numberOfPixelHoles");
AE::ConstAccessor<unsigned char> pix_shared("numberOfPixelSharedHits");
AE::ConstAccessor<unsigned char> pix_dead("numberOfPixelDeadSensors");
AE::ConstAccessor<unsigned char> sct_hits("numberOfSCTHits");
AE::ConstAccessor<unsigned char> sct_holes("numberOfSCTHoles");
AE::ConstAccessor<unsigned char> sct_shared("numberOfSCTSharedHits");
AE::ConstAccessor<unsigned char> sct_dead("numberOfSCTDeadSensors");
auto data_deps = aug.getTrackIpDataDependencyNames();
// make sure we record accessors as data dependencies
auto addAccessor = [&data_deps](const std::string& n) {
AE::ConstAccessor<unsigned char> a(n);
data_deps.insert(n);
return a;
};
auto pix_hits = addAccessor("numberOfPixelHits");
auto pix_holes = addAccessor("numberOfPixelHoles");
auto pix_shared = addAccessor("numberOfPixelSharedHits");
auto pix_dead = addAccessor("numberOfPixelDeadSensors");
auto sct_hits = addAccessor("numberOfSCTHits");
auto sct_holes = addAccessor("numberOfSCTHoles");
auto sct_shared = addAccessor("numberOfSCTSharedHits");
auto sct_dead = addAccessor("numberOfSCTDeadSensors");
switch (selection) {
case TrackSelection::ALL: return [](const Tp*) {return true;};
case TrackSelection::ALL: return {[](const Tp*) {return true;}, {} };
// the following numbers come from Nicole, Dec 2018:
// pt > 1 GeV
// abs(d0) < 1 mm
......@@ -260,21 +289,23 @@ namespace FlavorTagDiscriminants {
// <= 2 si holes
// <= 1 pix holes
case TrackSelection::IP3D_2018:
return [=](const Tp* tp) {
// from the track selector tool
if (std::abs(tp->eta()) > 2.5) return false;
double n_module_shared = (
pix_shared(*tp) + sct_shared(*tp) / 2);
if (n_module_shared > 1) return false;
if (tp->pt() <= 1e3) return false;
if (std::abs(aug.d0(*tp)) >= 1.0) return false;
if (std::abs(aug.z0SinTheta(*tp)) >= 1.5) return false;
if (pix_hits(*tp) + pix_dead(*tp) + sct_hits(*tp)
+ sct_dead(*tp) < 7) return false;
if ((pix_holes(*tp) + sct_holes(*tp)) > 2) return false;
if (pix_holes(*tp) > 1) return false;
return true;
};
return {
[=](const Tp* tp) {
// from the track selector tool
if (std::abs(tp->eta()) > 2.5) return false;
double n_module_shared = (
pix_shared(*tp) + sct_shared(*tp) / 2);
if (n_module_shared > 1) return false;
if (tp->pt() <= 1e3) return false;
if (std::abs(aug.d0(*tp)) >= 1.0) return false;
if (std::abs(aug.z0SinTheta(*tp)) >= 1.5) return false;
if (pix_hits(*tp) + pix_dead(*tp) + sct_hits(*tp)
+ sct_dead(*tp) < 7) return false;
if ((pix_holes(*tp) + sct_holes(*tp)) > 2) return false;
if (pix_holes(*tp) > 1) return false;
return true;
}, data_deps
};
default:
throw std::logic_error("unknown track selection function");
}
......@@ -282,11 +313,20 @@ namespace FlavorTagDiscriminants {
// factory for functions that build std::vector objects from
// track sequences
SeqFromTracks seqFromTracks(const DL2TrackInputConfig& cfg) {
std::pair<SeqFromTracks,std::set<std::string>> seqFromTracks(
const DL2TrackInputConfig& cfg)
{
switch (cfg.type) {
case EDMType::FLOAT: return SequenceGetter<float>(cfg.name);
case EDMType::UCHAR: return SequenceGetter<unsigned char>(cfg.name);
case EDMType::CUSTOM_GETTER: return customNamedSeqGetter(cfg.name);
case EDMType::FLOAT: return {
SequenceGetter<float>(cfg.name), {cfg.name}
};
case EDMType::UCHAR: return {
SequenceGetter<unsigned char>(cfg.name), {cfg.name}
};
case EDMType::CUSTOM_GETTER: return {
customNamedSeqGetter(cfg.name),
BTagTrackAugmenter().getTrackIpDataDependencyNames()
};
default: {
throw std::logic_error("Unknown EDM type");
}
......@@ -313,20 +353,27 @@ namespace FlavorTagDiscriminants {
}
// factory function
TrackSequenceFilter flipFilter(FlipTagConfig cfg) {
std::pair<TrackSequenceFilter,std::set<std::string>> flipFilter(
FlipTagConfig cfg)
{
namespace ph = std::placeholders; // for _1, _2, _3
BTagTrackAugmenter aug;
switch(cfg) {
case FlipTagConfig::NEGATIVE_IP_ONLY:
// flips order and removes tracks with negative IP
return std::bind(&negativeIpOnly, aug, ph::_1, ph::_2);
return {
std::bind(&negativeIpOnly, aug, ph::_1, ph::_2),
aug.getTrackIpDataDependencyNames()
};
case FlipTagConfig::FLIP_SIGN:
// Just flips the order
return [](const Tracks& tr, const xAOD::Jet& ) {
return Tracks(tr.crbegin(), tr.crend());
};
return {
[](const Tracks& tr, const xAOD::Jet& ) {
return Tracks(tr.crbegin(), tr.crend());},
{}
};
case FlipTagConfig::STANDARD:
return [](const Tracks& tr, const xAOD::Jet& ) { return tr; };
return {[](const Tracks& tr, const xAOD::Jet& ) { return tr; }, {}};
default: {
throw std::logic_error("Unknown flip config");
}
......
......@@ -165,4 +165,9 @@ namespace FlavorTagDiscriminants {
m_dl2->decorate(jet);
}
DL2DataDependencyNames DL2HighLevel::getDataDependencyNames() const
{
return m_dl2->getDataDependencyNames();
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment