Commit 69ea65de authored by Vakhtang Tsulaia's avatar Vakhtang Tsulaia
Browse files

Merge branch 'btag-the-jet' into 'master'

Btag the jet: run b-tagging NNs without the BTagging object

See merge request atlas/athena!47312
parents 337900b1 0ec2c72d
...@@ -27,6 +27,7 @@ atlas_add_library( FlavorTagDiscriminants ...@@ -27,6 +27,7 @@ atlas_add_library( FlavorTagDiscriminants
Root/DL2Tool.cxx Root/DL2Tool.cxx
Root/customGetter.cxx Root/customGetter.cxx
Root/FlipTagEnums.cxx Root/FlipTagEnums.cxx
Root/AssociationEnums.cxx
Root/VRJetOverlapDecorator.cxx Root/VRJetOverlapDecorator.cxx
Root/VRJetOverlapDecoratorTool.cxx Root/VRJetOverlapDecoratorTool.cxx
Root/HbbTag.cxx Root/HbbTag.cxx
...@@ -44,6 +45,7 @@ atlas_add_library( FlavorTagDiscriminants ...@@ -44,6 +45,7 @@ atlas_add_library( FlavorTagDiscriminants
if (NOT XAOD_STANDALONE) if (NOT XAOD_STANDALONE)
atlas_add_component( FlavorTagDiscriminantsLib atlas_add_component( FlavorTagDiscriminantsLib
src/BTagDecoratorAlg.cxx src/BTagDecoratorAlg.cxx
src/JetTagDecoratorAlg.cxx
src/BTagToJetLinkerAlg.cxx src/BTagToJetLinkerAlg.cxx
src/JetToBTagLinkerAlg.cxx src/JetToBTagLinkerAlg.cxx
src/BTagTrackLinkCopyAlg.cxx src/BTagTrackLinkCopyAlg.cxx
......
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef ASSOCIATION_ENUMS_HH
#define ASSOCIATION_ENUMS_HH
#include <string>
namespace FlavorTagDiscriminants {
enum class TrackLinkType {
TRACK_PARTICLE,
IPARTICLE
};
TrackLinkType trackLinkTypeFromString(const std::string&);
}
#endif
...@@ -5,41 +5,26 @@ ...@@ -5,41 +5,26 @@
#ifndef BTAG_DECORATOR_ALG_H #ifndef BTAG_DECORATOR_ALG_H
#define BTAG_DECORATOR_ALG_H #define BTAG_DECORATOR_ALG_H
#include "AthenaBaseComps/AthReentrantAlgorithm.h"
#include "FlavorTagDiscriminants/DecoratorAlg.h"
#include "FlavorTagDiscriminants/IBTagDecorator.h" #include "FlavorTagDiscriminants/IBTagDecorator.h"
#include "xAODBTagging/BTaggingContainer.h" #include "xAODBTagging/BTaggingContainer.h"
#include "xAODTracking/TrackParticleContainer.h" #include "xAODTracking/TrackParticleContainer.h"
#include "StoreGate/WriteDecorHandleKey.h"
#include "StoreGate/ReadDecorHandleKey.h" namespace detail {
using BTag_t = FlavorTagDiscriminants::DecoratorAlg<
xAOD::BTaggingContainer,
IBTagDecorator,
xAOD::TrackParticleContainer
>;
}
namespace FlavorTagDiscriminants { namespace FlavorTagDiscriminants {
class BTagDecoratorAlg : public AthReentrantAlgorithm class BTagDecoratorAlg : public detail::BTag_t
{ {
public: public:
BTagDecoratorAlg(const std::string& name, ISvcLocator* svcloc); BTagDecoratorAlg(const std::string& name, ISvcLocator* svcloc);
virtual StatusCode initialize() override;
virtual StatusCode execute(const EventContext& cxt) const override;
virtual StatusCode finalize() override;
private:
SG::ReadHandleKey<xAOD::BTaggingContainer> m_btagContainerKey {
this, "btagContainer", "", "Key for the input btag collection"};
SG::ReadHandleKey<xAOD::TrackParticleContainer> m_trackContainerKey {
this, "trackContainer", "InDetTrackParticles",
"Key for track particle container"};
Gaudi::Property<std::vector<std::string>> m_undeclaredReadDecorKeys {
this, "undeclaredReadDecorKeys", {},
"List of read handles that we don't read, e.g. static variables" };
ToolHandle<IBTagDecorator> m_decorator{
this, "decorator", "", "Decorator tool"};
// Keys to keep track of the inputs / outputs
std::vector<SG::ReadDecorHandleKey<xAOD::BTaggingContainer>> m_btagAux;
std::vector<SG::ReadDecorHandleKey<xAOD::TrackParticleContainer>> m_trkAux;
std::vector<SG::WriteDecorHandleKey<xAOD::BTaggingContainer>> m_btagDecor;
}; };
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
// local includes // local includes
#include "FlavorTagDiscriminants/customGetter.h" #include "FlavorTagDiscriminants/customGetter.h"
#include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/FlipTagEnums.h"
#include "FlavorTagDiscriminants/AssociationEnums.h"
#include "FlavorTagDiscriminants/DL2DataDependencyNames.h" #include "FlavorTagDiscriminants/DL2DataDependencyNames.h"
#include "xAODBTagging/ftagfloat_t.h" #include "xAODBTagging/ftagfloat_t.h"
...@@ -67,6 +68,7 @@ namespace FlavorTagDiscriminants { ...@@ -67,6 +68,7 @@ namespace FlavorTagDiscriminants {
FlipTagConfig flip; FlipTagConfig flip;
std::string track_link_name; std::string track_link_name;
std::map<std::string,std::string> remap_scalar; std::map<std::string,std::string> remap_scalar;
TrackLinkType track_link_type;
}; };
...@@ -87,7 +89,7 @@ namespace FlavorTagDiscriminants { ...@@ -87,7 +89,7 @@ namespace FlavorTagDiscriminants {
const xAOD::Jet&)> TrackSequenceFilter; const xAOD::Jet&)> TrackSequenceFilter;
// getter functions // getter functions
typedef std::function<NamedVar(const BTagging&)> VarFromBTag; typedef std::function<NamedVar(const SG::AuxElement&)> VarFromBTag;
typedef std::function<NamedSeq(const Jet&, const Tracks&)> SeqFromTracks; typedef std::function<NamedSeq(const Jet&, const Tracks&)> SeqFromTracks;
// ___________________________________________________________________ // ___________________________________________________________________
...@@ -112,7 +114,7 @@ namespace FlavorTagDiscriminants { ...@@ -112,7 +114,7 @@ namespace FlavorTagDiscriminants {
m_name(name) m_name(name)
{ {
} }
NamedVar operator()(const xAOD::BTagging& btag) const { NamedVar operator()(const SG::AuxElement& btag) const {
T ret_value = m_getter(btag); T ret_value = m_getter(btag);
bool is_default = m_default_flag(btag); bool is_default = m_default_flag(btag);
if constexpr (std::is_floating_point<T>::value) { if constexpr (std::is_floating_point<T>::value) {
...@@ -140,7 +142,7 @@ namespace FlavorTagDiscriminants { ...@@ -140,7 +142,7 @@ namespace FlavorTagDiscriminants {
m_name(name) m_name(name)
{ {
} }
NamedVar operator()(const xAOD::BTagging& btag) const { NamedVar operator()(const SG::AuxElement& btag) const {
T ret_value = m_getter(btag); T ret_value = m_getter(btag);
if constexpr (std::is_floating_point<T>::value) { if constexpr (std::is_floating_point<T>::value) {
if (std::isnan(ret_value)) { if (std::isnan(ret_value)) {
...@@ -159,11 +161,15 @@ namespace FlavorTagDiscriminants { ...@@ -159,11 +161,15 @@ namespace FlavorTagDiscriminants {
public: public:
TracksFromJet(SortOrder, TrackSelection, const DL2Options&); TracksFromJet(SortOrder, TrackSelection, const DL2Options&);
Tracks operator()(const xAOD::Jet& jet, Tracks operator()(const xAOD::Jet& jet,
const xAOD::BTagging& btag) const; const SG::AuxElement& btag) const;
private: private:
typedef SG::AuxElement AE; using AE = SG::AuxElement;
typedef std::vector<ElementLink<xAOD::TrackParticleContainer>> TrackLinks; using IPC = xAOD::IParticleContainer;
AE::ConstAccessor<TrackLinks> m_trackAssociator; using TPC = xAOD::TrackParticleContainer;
using TrackLinks = std::vector<ElementLink<TPC>>;
using PartLinks = std::vector<ElementLink<IPC>>;
using TPV = std::vector<const xAOD::TrackParticle*>;
std::function<TPV(const SG::AuxElement&)> m_associator;
TrackSortVar m_trackSortVar; TrackSortVar m_trackSortVar;
TrackFilter m_trackFilter; TrackFilter m_trackFilter;
}; };
...@@ -199,6 +205,8 @@ namespace FlavorTagDiscriminants { ...@@ -199,6 +205,8 @@ namespace FlavorTagDiscriminants {
const std::vector<DL2TrackSequenceConfig>& = {}, const std::vector<DL2TrackSequenceConfig>& = {},
const DL2Options& = DL2Options()); const DL2Options& = DL2Options());
void decorate(const xAOD::BTagging& btag) const; void decorate(const xAOD::BTagging& btag) const;
void decorate(const xAOD::Jet& jet) const;
void decorate(const xAOD::Jet& jet, const SG::AuxElement& decorated) const;
// functions to report data depdedencies // functions to report data depdedencies
DL2DataDependencyNames getDataDependencyNames() const; DL2DataDependencyNames getDataDependencyNames() const;
......
...@@ -6,10 +6,12 @@ ...@@ -6,10 +6,12 @@
#define DL2_HIGH_LEVEL_HH #define DL2_HIGH_LEVEL_HH
#include "FlavorTagDiscriminants/FlipTagEnums.h" #include "FlavorTagDiscriminants/FlipTagEnums.h"
#include "FlavorTagDiscriminants/AssociationEnums.h"
#include "FlavorTagDiscriminants/DL2DataDependencyNames.h" #include "FlavorTagDiscriminants/DL2DataDependencyNames.h"
// EDM includes // EDM includes
#include "xAODBTagging/BTaggingFwd.h" #include "xAODBTagging/BTaggingFwd.h"
#include "xAODJet/JetFwd.h"
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -24,10 +26,12 @@ namespace FlavorTagDiscriminants { ...@@ -24,10 +26,12 @@ namespace FlavorTagDiscriminants {
public: public:
DL2HighLevel(const std::string& nn_file_name, DL2HighLevel(const std::string& nn_file_name,
FlipTagConfig = FlipTagConfig::STANDARD, FlipTagConfig = FlipTagConfig::STANDARD,
std::map<std::string, std::string> remap_scalar = {}); std::map<std::string, std::string> remap_scalar = {},
TrackLinkType = TrackLinkType::TRACK_PARTICLE);
DL2HighLevel(DL2HighLevel&&); DL2HighLevel(DL2HighLevel&&);
~DL2HighLevel(); ~DL2HighLevel();
void decorate(const xAOD::BTagging& btag) const; void decorate(const xAOD::BTagging& btag) const;
void decorate(const xAOD::Jet& jet) const;
DL2DataDependencyNames getDataDependencyNames() const; DL2DataDependencyNames getDataDependencyNames() const;
private: private:
std::unique_ptr<DL2> m_dl2; std::unique_ptr<DL2> m_dl2;
......
// for text editors: this file is -*- C++ -*- // for text editors: this file is -*- C++ -*-
/* /*
Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/ */
#ifndef DL2_TOOL_H #ifndef DL2_TOOL_H
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "AsgTools/AsgTool.h" #include "AsgTools/AsgTool.h"
#include "FlavorTagDiscriminants/IBTagDecorator.h" #include "FlavorTagDiscriminants/IBTagDecorator.h"
#include "FlavorTagDiscriminants/IJetTagDecorator.h"
namespace FlavorTagDiscriminants { namespace FlavorTagDiscriminants {
...@@ -17,11 +18,14 @@ namespace FlavorTagDiscriminants { ...@@ -17,11 +18,14 @@ namespace FlavorTagDiscriminants {
std::string nnFile; std::string nnFile;
std::string flipTagConfig; std::string flipTagConfig;
std::map<std::string,std::string> variableRemapping; std::map<std::string,std::string> variableRemapping;
std::string trackLinkType;
}; };
class DL2Tool : public asg::AsgTool, virtual public IBTagDecorator class DL2Tool : public asg::AsgTool,
virtual public IBTagDecorator,
virtual public IJetTagDecorator
{ {
ASG_TOOL_CLASS(DL2Tool, IBTagDecorator ) ASG_TOOL_CLASS2(DL2Tool, IBTagDecorator, IJetTagDecorator )
public: public:
DL2Tool(const std::string& name); DL2Tool(const std::string& name);
~DL2Tool(); ~DL2Tool();
...@@ -29,7 +33,8 @@ namespace FlavorTagDiscriminants { ...@@ -29,7 +33,8 @@ namespace FlavorTagDiscriminants {
StatusCode initialize() override; StatusCode initialize() override;
// returns 0 for success // returns 0 for success
virtual void decorate(const xAOD::BTagging& jet) const override; virtual void decorate(const xAOD::BTagging& btag) const override;
virtual void decorate(const xAOD::Jet& jet) const override;
virtual std::set<std::string> getDecoratorKeys() const override; virtual std::set<std::string> getDecoratorKeys() const override;
virtual std::set<std::string> getAuxInputKeys() const override; virtual std::set<std::string> getAuxInputKeys() const override;
......
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef DECORATOR_ALG_H
#define DECORATOR_ALG_H
#include "AthenaBaseComps/AthReentrantAlgorithm.h"
#include "StoreGate/WriteDecorHandleKey.h"
#include "StoreGate/ReadDecorHandleKey.h"
namespace FlavorTagDiscriminants {
template <typename CONTAINER, typename DECORATOR, typename CONSTITUENTS>
class DecoratorAlg : public AthReentrantAlgorithm
{
public:
DecoratorAlg(const std::string& name, ISvcLocator* svcloc);
virtual StatusCode initialize() override;
virtual StatusCode execute(const EventContext& cxt) const override;
virtual StatusCode finalize() override;
private:
SG::ReadHandleKey<CONTAINER> m_containerKey {
this, "container", "", "Key for the input collection"};
SG::ReadHandleKey<CONSTITUENTS> m_constituentKey {
this, "constituentContainer", "",
"Key for track inputs container"};
Gaudi::Property<std::vector<std::string>> m_undeclaredReadDecorKeys {
this, "undeclaredReadDecorKeys", {},
"List of read handles that we don't read, e.g. static variables" };
ToolHandle<DECORATOR> m_decorator{
this, "decorator", "", "Decorator tool"};
// Keys to keep track of the inputs / outputs
std::vector<SG::ReadDecorHandleKey<CONTAINER>> m_aux;
std::vector<SG::ReadDecorHandleKey<CONSTITUENTS>> m_constituentAux;
std::vector<SG::WriteDecorHandleKey<CONTAINER>> m_decor;
};
}
#include "DecoratorAlg.icc"
#endif
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#include "StoreGate/ReadDecorHandle.h"
#include <exception>
namespace FlavorTagDiscriminants {
template <typename C, typename D, typename N>
DecoratorAlg<C,D,N>::DecoratorAlg(
const std::string& name, ISvcLocator* svcloc):
AthReentrantAlgorithm(name, svcloc)
{
}
template <typename C, typename D, typename N>
StatusCode DecoratorAlg<C,D,N>::initialize() {
ATH_CHECK(m_containerKey.initialize());
ATH_CHECK(m_constituentKey.initialize());
ATH_CHECK(m_decorator.retrieve());
std::set<std::string> veto(
m_undeclaredReadDecorKeys.begin(),
m_undeclaredReadDecorKeys.end());
// now we build data dependencies from the internal tools. We have
// to reserve the vectors here to prevent a segfault since read /
// write handles aren't movable once declaired as a property.
m_aux.reserve(m_decorator->getAuxInputKeys().size());
for (const std::string& key: m_decorator->getAuxInputKeys()) {
const std::string full = m_containerKey.key() + "." + key;
if (veto.count(full)) {
ATH_MSG_DEBUG("Not declaring accessor: " + full);
continue;
}
ATH_MSG_DEBUG("Adding accessor: " + full);
m_aux.emplace_back(this, key, full, "");
ATH_CHECK(m_aux.back().initialize());
}
m_constituentAux.reserve(m_decorator->getConstituentAuxInputKeys().size());
for (const std::string& key: m_decorator->getConstituentAuxInputKeys()) {
const std::string full = m_constituentKey.key() + "." + key;
if (veto.count(full)) {
ATH_MSG_DEBUG("Not declaring accessor: " + full);
continue;
}
ATH_MSG_DEBUG("Adding constituent accessor: " + full);
m_constituentAux.emplace_back(this, key, full, "");
ATH_CHECK(m_constituentAux.back().initialize());
}
m_decor.reserve(m_decorator->getDecoratorKeys().size());
for (const std::string& key: m_decorator->getDecoratorKeys()) {
const std::string full = m_containerKey.key() + "." + key;
ATH_MSG_DEBUG("Adding decorator: " + full);
m_decor.emplace_back(this, key, full, "");
ATH_CHECK(m_decor.back().initialize());
}
ATH_MSG_DEBUG("Finished setting up");
return StatusCode::SUCCESS;
}
template <typename C, typename D, typename N>
StatusCode DecoratorAlg<C,D,N>::execute(const EventContext& cxt ) const {
SG::ReadHandle<C> container(
m_containerKey, cxt);
if (!container.isValid()) {
ATH_MSG_ERROR("no container " << container.key());
return StatusCode::FAILURE;
}
ATH_MSG_DEBUG(
"Decorating " + std::to_string(container->size()) + " elements");
for (const auto* element: *container) {
m_decorator->decorate(*element);
}
return StatusCode::SUCCESS;
}
template <typename C, typename D, typename N>
StatusCode DecoratorAlg<C,D,N>::finalize() {
return StatusCode::SUCCESS;
}
}
...@@ -6,10 +6,14 @@ ...@@ -6,10 +6,14 @@
#ifndef I_BTAG_DECORATOR_H #ifndef I_BTAG_DECORATOR_H
#define I_BTAG_DECORATOR_H #define I_BTAG_DECORATOR_H
#include "IDependencyReporter.h"
#include "AsgTools/IAsgTool.h" #include "AsgTools/IAsgTool.h"
#include "xAODBTagging/BTaggingFwd.h" #include "xAODBTagging/BTaggingFwd.h"
class IBTagDecorator : virtual public asg::IAsgTool { class IBTagDecorator : virtual public asg::IAsgTool,
virtual public IDependencyReporter
{
ASG_TOOL_INTERFACE(IBTagDecorator) ASG_TOOL_INTERFACE(IBTagDecorator)
public: public:
...@@ -19,11 +23,6 @@ public: ...@@ -19,11 +23,6 @@ public:
/// Method to decorate a jet. /// Method to decorate a jet.
virtual void decorate(const xAOD::BTagging& btag) const = 0; virtual void decorate(const xAOD::BTagging& btag) const = 0;
// Names of the decorations being added
virtual std::set<std::string> getDecoratorKeys() const = 0;
virtual std::set<std::string> getAuxInputKeys() const = 0;
virtual std::set<std::string> getConstituentAuxInputKeys() const = 0;
}; };
......
// for text editors: this file is -*- C++ -*-
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef I_DEPENDENCY_REPORTER
#define I_DEPENDENCY_REPORTER
#include <set>
#include <string>
class IDependencyReporter {
public:
/// Destructor.
virtual ~IDependencyReporter() { };
// Names of the decorations being added
virtual std::set<std::string> getDecoratorKeys() const = 0;
virtual std::set<std::string> getAuxInputKeys() const = 0;
virtual std::set<std::string> getConstituentAuxInputKeys() const = 0;
};
#endif
// for text editors: this file is -*- C++ -*-
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef I_JETTAG_DECORATOR_H
#define I_JETTAG_DECORATOR_H
#include "IDependencyReporter.h"
#include "AsgTools/IAsgTool.h"
#include "xAODJet/JetFwd.h"
class IJetTagDecorator : virtual public asg::IAsgTool,
virtual public IDependencyReporter
{
ASG_TOOL_INTERFACE(IJetTagDecorator)
public:
/// Destructor.
virtual ~IJetTagDecorator() { };
/// Method to decorate a jet.
virtual void decorate(const xAOD::Jet& jet) const = 0;
};
#endif
/*
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef JET_TAG_DECORATOR_ALG_H
#define JET_TAG_DECORATOR_ALG_H
#include "FlavorTagDiscriminants/DecoratorAlg.h"
#include "FlavorTagDiscriminants/IJetTagDecorator.h"
#include "xAODJet/JetContainer.h"
#include "xAODTracking/TrackParticleContainer.h"
namespace detail {
using JetTag_t = FlavorTagDiscriminants::DecoratorAlg<
xAOD::JetContainer,
IJetTagDecorator,
xAOD::TrackParticleContainer
>;
}
namespace FlavorTagDiscriminants {