Commit 69ea65de authored by Vakhtang Tsulaia's avatar Vakhtang Tsulaia
Browse files

Merge branch 'btag-the-jet' into 'master'

Btag the jet: run b-tagging NNs without the BTagging object

See merge request atlas/athena!47312
parents 337900b1 0ec2c72d
......@@ -27,6 +27,7 @@ atlas_add_library( FlavorTagDiscriminants
Root/DL2Tool.cxx
Root/customGetter.cxx
Root/FlipTagEnums.cxx
Root/AssociationEnums.cxx
Root/VRJetOverlapDecorator.cxx
Root/VRJetOverlapDecoratorTool.cxx
Root/HbbTag.cxx
......@@ -44,6 +45,7 @@ atlas_add_library( FlavorTagDiscriminants
if (NOT XAOD_STANDALONE)
atlas_add_component( FlavorTagDiscriminantsLib
src/BTagDecoratorAlg.cxx
src/JetTagDecoratorAlg.cxx
src/BTagToJetLinkerAlg.cxx
src/JetToBTagLinkerAlg.cxx
src/BTagTrackLinkCopyAlg.cxx
......
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef ASSOCIATION_ENUMS_HH
#define ASSOCIATION_ENUMS_HH
#include <string>
namespace FlavorTagDiscriminants {
enum class TrackLinkType {
TRACK_PARTICLE,
IPARTICLE
};
TrackLinkType trackLinkTypeFromString(const std::string&);
}
#endif
......@@ -5,41 +5,26 @@
#ifndef BTAG_DECORATOR_ALG_H
#define BTAG_DECORATOR_ALG_H
#include "AthenaBaseComps/AthReentrantAlgorithm.h"
#include "FlavorTagDiscriminants/DecoratorAlg.h"
#include "FlavorTagDiscriminants/IBTagDecorator.h"
#include "xAODBTagging/BTaggingContainer.h"
#include "xAODTracking/TrackParticleContainer.h"
#include "StoreGate/WriteDecorHandleKey.h"
#include "StoreGate/ReadDecorHandleKey.h"
namespace detail {
using BTag_t = FlavorTagDiscriminants::DecoratorAlg<
xAOD::BTaggingContainer,
IBTagDecorator,
xAOD::TrackParticleContainer
>;
}
namespace FlavorTagDiscriminants {
class BTagDecoratorAlg : public AthReentrantAlgorithm
class BTagDecoratorAlg : public detail::BTag_t
{
public:
BTagDecoratorAlg(const std::string& name, ISvcLocator* svcloc);
virtual StatusCode initialize() override;
virtual StatusCode execute(const EventContext& cxt) const override;
virtual StatusCode finalize() override;
private:
SG::ReadHandleKey<xAOD::BTaggingContainer> m_btagContainerKey {
this, "btagContainer", "", "Key for the input btag collection"};
SG::ReadHandleKey<xAOD::TrackParticleContainer> m_trackContainerKey {
this, "trackContainer", "InDetTrackParticles",
"Key for track particle container"};
Gaudi::Property<std::vector<std::string>> m_undeclaredReadDecorKeys {
this, "undeclaredReadDecorKeys", {},
"List of read handles that we don't read, e.g. static variables" };
ToolHandle<IBTagDecorator> m_decorator{
this, "decorator", "", "Decorator tool"};
// Keys to keep track of the inputs / outputs
std::vector<SG::ReadDecorHandleKey<xAOD::BTaggingContainer>> m_btagAux;
std::vector<SG::ReadDecorHandleKey<xAOD::TrackParticleContainer>> m_trkAux;
std::vector<SG::WriteDecorHandleKey<xAOD::BTaggingContainer>> m_btagDecor;
};
}
......
......@@ -8,6 +8,7 @@
// local includes
#include "FlavorTagDiscriminants/customGetter.h"
#include "FlavorTagDiscriminants/FlipTagEnums.h"
#include "FlavorTagDiscriminants/AssociationEnums.h"
#include "FlavorTagDiscriminants/DL2DataDependencyNames.h"
#include "xAODBTagging/ftagfloat_t.h"
......@@ -67,6 +68,7 @@ namespace FlavorTagDiscriminants {
FlipTagConfig flip;
std::string track_link_name;
std::map<std::string,std::string> remap_scalar;
TrackLinkType track_link_type;
};
......@@ -87,7 +89,7 @@ namespace FlavorTagDiscriminants {
const xAOD::Jet&)> TrackSequenceFilter;
// getter functions
typedef std::function<NamedVar(const BTagging&)> VarFromBTag;
typedef std::function<NamedVar(const SG::AuxElement&)> VarFromBTag;
typedef std::function<NamedSeq(const Jet&, const Tracks&)> SeqFromTracks;
// ___________________________________________________________________
......@@ -112,7 +114,7 @@ namespace FlavorTagDiscriminants {
m_name(name)
{
}
NamedVar operator()(const xAOD::BTagging& btag) const {
NamedVar operator()(const SG::AuxElement& btag) const {
T ret_value = m_getter(btag);
bool is_default = m_default_flag(btag);
if constexpr (std::is_floating_point<T>::value) {
......@@ -140,7 +142,7 @@ namespace FlavorTagDiscriminants {
m_name(name)
{
}
NamedVar operator()(const xAOD::BTagging& btag) const {
NamedVar operator()(const SG::AuxElement& btag) const {
T ret_value = m_getter(btag);
if constexpr (std::is_floating_point<T>::value) {
if (std::isnan(ret_value)) {
......@@ -159,11 +161,15 @@ namespace FlavorTagDiscriminants {
public:
TracksFromJet(SortOrder, TrackSelection, const DL2Options&);
Tracks operator()(const xAOD::Jet& jet,
const xAOD::BTagging& btag) const;
const SG::AuxElement& btag) const;
private:
typedef SG::AuxElement AE;
typedef std::vector<ElementLink<xAOD::TrackParticleContainer>> TrackLinks;
AE::ConstAccessor<TrackLinks> m_trackAssociator;
using AE = SG::AuxElement;
using IPC = xAOD::IParticleContainer;
using TPC = xAOD::TrackParticleContainer;
using TrackLinks = std::vector<ElementLink<TPC>>;
using PartLinks = std::vector<ElementLink<IPC>>;
using TPV = std::vector<const xAOD::TrackParticle*>;
std::function<TPV(const SG::AuxElement&)> m_associator;
TrackSortVar m_trackSortVar;
TrackFilter m_trackFilter;
};
......@@ -199,6 +205,8 @@ namespace FlavorTagDiscriminants {
const std::vector<DL2TrackSequenceConfig>& = {},
const DL2Options& = DL2Options());
void decorate(const xAOD::BTagging& btag) const;
void decorate(const xAOD::Jet& jet) const;
void decorate(const xAOD::Jet& jet, const SG::AuxElement& decorated) const;
// functions to report data depdedencies
DL2DataDependencyNames getDataDependencyNames() const;
......
......@@ -6,10 +6,12 @@
#define DL2_HIGH_LEVEL_HH
#include "FlavorTagDiscriminants/FlipTagEnums.h"
#include "FlavorTagDiscriminants/AssociationEnums.h"
#include "FlavorTagDiscriminants/DL2DataDependencyNames.h"
// EDM includes
#include "xAODBTagging/BTaggingFwd.h"
#include "xAODJet/JetFwd.h"
#include <memory>
#include <string>
......@@ -24,10 +26,12 @@ namespace FlavorTagDiscriminants {
public:
DL2HighLevel(const std::string& nn_file_name,
FlipTagConfig = FlipTagConfig::STANDARD,
std::map<std::string, std::string> remap_scalar = {});
std::map<std::string, std::string> remap_scalar = {},
TrackLinkType = TrackLinkType::TRACK_PARTICLE);
DL2HighLevel(DL2HighLevel&&);
~DL2HighLevel();
void decorate(const xAOD::BTagging& btag) const;
void decorate(const xAOD::Jet& jet) const;
DL2DataDependencyNames getDataDependencyNames() const;
private:
std::unique_ptr<DL2> m_dl2;
......
// for text editors: this file is -*- C++ -*-
/*
Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef DL2_TOOL_H
......@@ -8,6 +8,7 @@
#include "AsgTools/AsgTool.h"
#include "FlavorTagDiscriminants/IBTagDecorator.h"
#include "FlavorTagDiscriminants/IJetTagDecorator.h"
namespace FlavorTagDiscriminants {
......@@ -17,11 +18,14 @@ namespace FlavorTagDiscriminants {
std::string nnFile;
std::string flipTagConfig;
std::map<std::string,std::string> variableRemapping;
std::string trackLinkType;
};
class DL2Tool : public asg::AsgTool, virtual public IBTagDecorator
class DL2Tool : public asg::AsgTool,
virtual public IBTagDecorator,
virtual public IJetTagDecorator
{
ASG_TOOL_CLASS(DL2Tool, IBTagDecorator )
ASG_TOOL_CLASS2(DL2Tool, IBTagDecorator, IJetTagDecorator )
public:
DL2Tool(const std::string& name);
~DL2Tool();
......@@ -29,7 +33,8 @@ namespace FlavorTagDiscriminants {
StatusCode initialize() override;
// returns 0 for success
virtual void decorate(const xAOD::BTagging& jet) const override;
virtual void decorate(const xAOD::BTagging& btag) const override;
virtual void decorate(const xAOD::Jet& jet) const override;
virtual std::set<std::string> getDecoratorKeys() const override;
virtual std::set<std::string> getAuxInputKeys() const override;
......
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef DECORATOR_ALG_H
#define DECORATOR_ALG_H
#include "AthenaBaseComps/AthReentrantAlgorithm.h"
#include "StoreGate/WriteDecorHandleKey.h"
#include "StoreGate/ReadDecorHandleKey.h"
namespace FlavorTagDiscriminants {
template <typename CONTAINER, typename DECORATOR, typename CONSTITUENTS>
class DecoratorAlg : public AthReentrantAlgorithm
{
public:
DecoratorAlg(const std::string& name, ISvcLocator* svcloc);
virtual StatusCode initialize() override;
virtual StatusCode execute(const EventContext& cxt) const override;
virtual StatusCode finalize() override;
private:
SG::ReadHandleKey<CONTAINER> m_containerKey {
this, "container", "", "Key for the input collection"};
SG::ReadHandleKey<CONSTITUENTS> m_constituentKey {
this, "constituentContainer", "",
"Key for track inputs container"};
Gaudi::Property<std::vector<std::string>> m_undeclaredReadDecorKeys {
this, "undeclaredReadDecorKeys", {},
"List of read handles that we don't read, e.g. static variables" };
ToolHandle<DECORATOR> m_decorator{
this, "decorator", "", "Decorator tool"};
// Keys to keep track of the inputs / outputs
std::vector<SG::ReadDecorHandleKey<CONTAINER>> m_aux;
std::vector<SG::ReadDecorHandleKey<CONSTITUENTS>> m_constituentAux;
std::vector<SG::WriteDecorHandleKey<CONTAINER>> m_decor;
};
}
#include "DecoratorAlg.icc"
#endif
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#include "StoreGate/ReadDecorHandle.h"
#include <exception>
namespace FlavorTagDiscriminants {
template <typename C, typename D, typename N>
DecoratorAlg<C,D,N>::DecoratorAlg(
const std::string& name, ISvcLocator* svcloc):
AthReentrantAlgorithm(name, svcloc)
{
}
template <typename C, typename D, typename N>
StatusCode DecoratorAlg<C,D,N>::initialize() {
ATH_CHECK(m_containerKey.initialize());
ATH_CHECK(m_constituentKey.initialize());
ATH_CHECK(m_decorator.retrieve());
std::set<std::string> veto(
m_undeclaredReadDecorKeys.begin(),
m_undeclaredReadDecorKeys.end());
// now we build data dependencies from the internal tools. We have
// to reserve the vectors here to prevent a segfault since read /
// write handles aren't movable once declaired as a property.
m_aux.reserve(m_decorator->getAuxInputKeys().size());
for (const std::string& key: m_decorator->getAuxInputKeys()) {
const std::string full = m_containerKey.key() + "." + key;
if (veto.count(full)) {
ATH_MSG_DEBUG("Not declaring accessor: " + full);
continue;
}
ATH_MSG_DEBUG("Adding accessor: " + full);
m_aux.emplace_back(this, key, full, "");
ATH_CHECK(m_aux.back().initialize());
}
m_constituentAux.reserve(m_decorator->getConstituentAuxInputKeys().size());
for (const std::string& key: m_decorator->getConstituentAuxInputKeys()) {
const std::string full = m_constituentKey.key() + "." + key;
if (veto.count(full)) {
ATH_MSG_DEBUG("Not declaring accessor: " + full);
continue;
}
ATH_MSG_DEBUG("Adding constituent accessor: " + full);
m_constituentAux.emplace_back(this, key, full, "");
ATH_CHECK(m_constituentAux.back().initialize());
}
m_decor.reserve(m_decorator->getDecoratorKeys().size());
for (const std::string& key: m_decorator->getDecoratorKeys()) {
const std::string full = m_containerKey.key() + "." + key;
ATH_MSG_DEBUG("Adding decorator: " + full);
m_decor.emplace_back(this, key, full, "");
ATH_CHECK(m_decor.back().initialize());
}
ATH_MSG_DEBUG("Finished setting up");
return StatusCode::SUCCESS;
}
template <typename C, typename D, typename N>
StatusCode DecoratorAlg<C,D,N>::execute(const EventContext& cxt ) const {
SG::ReadHandle<C> container(
m_containerKey, cxt);
if (!container.isValid()) {
ATH_MSG_ERROR("no container " << container.key());
return StatusCode::FAILURE;
}
ATH_MSG_DEBUG(
"Decorating " + std::to_string(container->size()) + " elements");
for (const auto* element: *container) {
m_decorator->decorate(*element);
}
return StatusCode::SUCCESS;
}
template <typename C, typename D, typename N>
StatusCode DecoratorAlg<C,D,N>::finalize() {
return StatusCode::SUCCESS;
}
}
......@@ -6,10 +6,14 @@
#ifndef I_BTAG_DECORATOR_H
#define I_BTAG_DECORATOR_H
#include "IDependencyReporter.h"
#include "AsgTools/IAsgTool.h"
#include "xAODBTagging/BTaggingFwd.h"
class IBTagDecorator : virtual public asg::IAsgTool {
class IBTagDecorator : virtual public asg::IAsgTool,
virtual public IDependencyReporter
{
ASG_TOOL_INTERFACE(IBTagDecorator)
public:
......@@ -19,11 +23,6 @@ public:
/// Method to decorate a jet.
virtual void decorate(const xAOD::BTagging& btag) const = 0;
// Names of the decorations being added
virtual std::set<std::string> getDecoratorKeys() const = 0;
virtual std::set<std::string> getAuxInputKeys() const = 0;
virtual std::set<std::string> getConstituentAuxInputKeys() const = 0;
};
......
// for text editors: this file is -*- C++ -*-
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef I_DEPENDENCY_REPORTER
#define I_DEPENDENCY_REPORTER
#include <set>
#include <string>
class IDependencyReporter {
public:
/// Destructor.
virtual ~IDependencyReporter() { };
// Names of the decorations being added
virtual std::set<std::string> getDecoratorKeys() const = 0;
virtual std::set<std::string> getAuxInputKeys() const = 0;
virtual std::set<std::string> getConstituentAuxInputKeys() const = 0;
};
#endif
// for text editors: this file is -*- C++ -*-
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef I_JETTAG_DECORATOR_H
#define I_JETTAG_DECORATOR_H
#include "IDependencyReporter.h"
#include "AsgTools/IAsgTool.h"
#include "xAODJet/JetFwd.h"
class IJetTagDecorator : virtual public asg::IAsgTool,
virtual public IDependencyReporter
{
ASG_TOOL_INTERFACE(IJetTagDecorator)
public:
/// Destructor.
virtual ~IJetTagDecorator() { };
/// Method to decorate a jet.
virtual void decorate(const xAOD::Jet& jet) const = 0;
};
#endif
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef JET_TAG_DECORATOR_ALG_H
#define JET_TAG_DECORATOR_ALG_H
#include "FlavorTagDiscriminants/DecoratorAlg.h"
#include "FlavorTagDiscriminants/IJetTagDecorator.h"
#include "xAODJet/JetContainer.h"
#include "xAODTracking/TrackParticleContainer.h"
namespace detail {
using JetTag_t = FlavorTagDiscriminants::DecoratorAlg<
xAOD::JetContainer,
IJetTagDecorator,
xAOD::TrackParticleContainer
>;
}
namespace FlavorTagDiscriminants {
class JetTagDecoratorAlg : public detail::JetTag_t
{
public:
JetTagDecoratorAlg(const std::string& name, ISvcLocator* svcloc);
};
}
#endif
......@@ -11,7 +11,7 @@
// EDM includes
#include "xAODJet/JetFwd.h"
#include "xAODTracking/TrackParticleFwd.h"
#include "xAODBTagging/BTaggingFwd.h"
#include "AthContainers/AuxElement.h"
#include <functional>
#include <string>
......@@ -54,7 +54,7 @@ namespace FlavorTagDiscriminants {
// internal functions
namespace internal {
std::function<std::pair<std::string, double>(const xAOD::BTagging&)>
std::function<std::pair<std::string, double>(const SG::AuxElement&)>
customGetterAndName(const std::string&);
std::pair<std::function<std::pair<std::string, std::vector<double>>(
......
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#include "FlavorTagDiscriminants/AssociationEnums.h"
#include <stdexcept>
namespace FlavorTagDiscriminants{
#define RETURN_CONFIG(cfg) \
if (name == std::string(#cfg)) return TrackLinkType::cfg
TrackLinkType trackLinkTypeFromString(const std::string& name) {
RETURN_CONFIG(TRACK_PARTICLE);
RETURN_CONFIG(IPARTICLE);
throw std::logic_error("DL2 association scheme '" + name + "' unknown");
}
#undef RETURN_CONFIG
}
......@@ -20,6 +20,7 @@ namespace FlavorTagDiscriminants {
track_prefix = "btagIp_";
flip = FlipTagConfig::STANDARD;
track_link_name = "BTagTrackToJetAssociator";
track_link_type = TrackLinkType::TRACK_PARTICLE;
}
// DL2
......@@ -109,6 +110,18 @@ namespace FlavorTagDiscriminants {
}
void DL2::decorate(const xAOD::BTagging& btag) const {
auto jetLink = m_jetLink(btag);
if (!jetLink.isValid()) {
throw std::runtime_error("invalid jetLink");
}
const xAOD::Jet& jet = **jetLink;
decorate(jet, btag);
}
void DL2::decorate(const xAOD::Jet& jet) const {
decorate(jet, jet);
}
void DL2::decorate(const xAOD::Jet& jet, const SG::AuxElement& btag) const {
using namespace internal;
std::vector<NamedVar> vvec;
for (const auto& getter: m_varsFromBTag) {
......@@ -126,11 +139,6 @@ namespace FlavorTagDiscriminants {
}
// add track sequences
auto jetLink = m_jetLink(btag);
if (!jetLink.isValid()) {
throw std::runtime_error("invalid jetLink");
}
const xAOD::Jet& jet = **jetLink;
std::map<std::string,std::map<std::string, std::vector<double>>> seqs;
for (const auto& builder: m_trackSequenceBuilders) {
Tracks sorted_tracks = builder.tracksFromJet(jet, btag);
......@@ -174,19 +182,51 @@ namespace FlavorTagDiscriminants {
TracksFromJet::TracksFromJet(SortOrder order,
TrackSelection selection,
const DL2Options& options):
m_trackAssociator(options.track_link_name),
m_trackSortVar(get::trackSortVar(order, options)),
m_trackFilter(get::trackFilter(selection, options).first)
{
// We have several ways to get tracks: either we retrieve an
// IParticleContainer and cast the pointers to TrackParticle, or
// we retrieve a TrackParticleContainer directly. Unfortunately
// the way tracks are stored isn't consistent across the EDM, so
// we allow configuration for both setups.
//
if (options.track_link_type == TrackLinkType::IPARTICLE) {