Commit 01112782 authored by Lukas Ehrke's avatar Lukas Ehrke Committed by Edward Moyse
Browse files

Merge branch '21.2-electronMulticlass' into '21.2'

Add option to use a multiclass DNN for electron ID.

See merge request !42221

Changes when going to master are mostly from the different
lwtnn interface which uses Eigen Vectors/Matrices as input
and output.
parent f08611f6
......@@ -18,7 +18,7 @@ atlas_add_library( ElectronPhotonSelectorToolsLib
xAODHIEvent PATCoreAcceptLib AsgDataHandlesLib
PRIVATE_LINK_LIBRARIES ${ROOT_LIBRARIES} AsgMessagingLib FourMomUtils
xAODCaloEvent xAODEventInfo PathResolver EgammaAnalysisHelpersLib
${LWTNN_LIBRARIES} ${EIGEN_LIBRARIES} )
${LWTNN_LIBRARIES} ${EIGEN_LIBRARIES} EventPrimitives )
if( NOT XAOD_STANDALONE )
atlas_add_component( ElectronPhotonSelectorTools
......
......@@ -7,11 +7,16 @@
#ifndef __ASGELECTRONSELECTORTOOL__
#define __ASGELECTRONSELECTORTOOL__
// This include is needed at the top before any includes regarding Eigen
// since it includes Eigen in a specific way which causes compilation errors
// if not included before Eigen
#include "EventPrimitives/EventPrimitives.h"
// Atlas includes
#include "AsgTools/AsgTool.h"
#include "EgammaAnalysisInterfaces/IAsgElectronLikelihoodTool.h"
#include "xAODEgamma/ElectronFwd.h"
#include <Eigen/Dense>
class EventContext;
......@@ -95,10 +100,13 @@ private:
/** Applies a logit transformation to the score returned by the underlying MVA tool*/
double transformMLOutput( double score ) const;
/** Combines the six output nodes of a multiclass model into one discriminant. */
double combineOutputs(const Eigen::Matrix<float, -1, 1>& mvaScores, double eta) const;
/** Gets the Discriminant Eta bin [0,s_fnDiscEtaBins-1] given the eta*/
unsigned int getDiscEtaBin( double eta ) const;
/** Gets the Descriminant Et bin the et (MeV) [0,s_fnDiscEtBinsOneExtra-1] or [0,s_fnDiscEtBins-1]*/
/** Gets the Descriminant Et bin the et (MeV) [0,s_fnDiscEtBins-1]*/
unsigned int getDiscEtBin( double et ) const;
// NOTE that this will only perform the cut interpolation up to ~45 GeV, so
......@@ -130,6 +138,13 @@ private:
std::vector<std::string> m_variables;
/// Multiclass model or not
bool m_multiClass;
/// Use the CF output node in the numerator or the denominator
bool m_cfSignal;
/// Fractions to combine the output nodes of a multiclass model into one discriminant.
std::vector<double> m_fractions;
/// do cut on ambiguity bit
std::vector<int> m_cutAmbiguity;
/// cut min on b-layer hits
......
......@@ -49,6 +49,12 @@ AsgElectronSelectorTool::AsgElectronSelectorTool( const std::string& myname ) :
declareProperty("inputModelFileName", m_modelFileName="", "The input file name that holds the model" );
// QuantileTransformer file name ( required for preprocessing ). Managed in the ElectronDNNCalculator.
declareProperty("quantileFileName", m_quantileFileName="", "The input file name that holds the QuantileTransformer");
// Model used is a multiclass or a binary model
declareProperty("multiClass", m_multiClass, "Whether the given model is multiclass or not");
// If multiclass, how to treat the chargeflip output node when combining into one discriminant
declareProperty("cfSignal", m_cfSignal, "Whether to include the CF fraction in the numerator or denominator");
// If multiclass, fractions with which the different output nodes get multiplied before combining them
declareProperty("Fractions", m_fractions, "Fractions to combine the single outputs into one discriminant");
// Variable list
declareProperty("Variables", m_variables, "Variables used in the MVA tool");
// The mva cut values
......@@ -142,8 +148,14 @@ StatusCode AsgElectronSelectorTool::initialize()
m_variables.push_back( substr );
}
// Model is multiclass or not, default is binary model
m_multiClass = env.GetValue("multiClass", false);
// Create an instance of the class calculating the DNN score
m_mvaTool = std::make_unique<ElectronDNNCalculator>(this, filename.c_str(), qfilename.c_str(), m_variables);
m_mvaTool = std::make_unique<ElectronDNNCalculator>(this, filename.c_str(), qfilename.c_str(), m_variables, m_multiClass);
// Include cf node in numerator or denominator when combining different outputs
m_cfSignal = env.GetValue("cfSignal", true);
// Fractions to multiply different outputs with before combining
m_fractions = AsgConfigHelper::HelperDouble("Fractions", env);
// cut on MVA discriminant
m_cutSelector = AsgConfigHelper::HelperDouble("CutSelector", env);
......@@ -170,6 +182,16 @@ StatusCode AsgElectronSelectorTool::initialize()
return StatusCode::FAILURE;
}
if (m_multiClass){
// Fractions are only needed if multiclass model is used
// There are five fractions for the combination, the signal fraction is either one (cfSignal == false) or 1 - cf fraction (cfSignal == true)
if (m_fractions.size() != numberOfExpectedEtaBins * 5){
ATH_MSG_ERROR("Configuration issue : multiclass but not the right amount of fractions." << m_fractions.size());
return StatusCode::FAILURE;
}
}
if (!m_cutSCT.empty()){
if (m_cutSCT.size() != numberOfExpectedEtaBins){
ATH_MSG_ERROR("Configuration issue : cutSCT expected size " << numberOfExpectedEtaBins <<
......@@ -631,10 +653,19 @@ double AsgElectronSelectorTool::calculate( const EventContext& ctx, const xAOD::
vars.nPixHitsPlusDeadSensors = nPixHitsPlusDeadSensors;
vars.nSCTHitsPlusDeadSensors = nSCTHitsPlusDeadSensors;
double mvaScore = m_mvaTool->calculate(vars);
Eigen::Matrix<float, -1, 1> mvaScores = m_mvaTool->calculate(vars);
return transformMLOutput(mvaScore);
double discriminant = 0;
// If a binary model is used, vector will have one entry, if multiclass is used vector will have six entries
if (!m_multiClass){
discriminant = transformMLOutput(mvaScores(0, 0));
}
else{
// combine the six output nodes into one discriminant to cut on, any necessary transformation is applied within combineOutputs()
discriminant = combineOutputs(mvaScores, eta);
}
return discriminant;
}
//=============================================================================
......@@ -749,6 +780,32 @@ double AsgElectronSelectorTool::transformMLOutput( double score ) const
return score;
}
double AsgElectronSelectorTool::combineOutputs( const Eigen::Matrix<float, -1, 1>& mvaScores, double eta ) const{
unsigned int etaBin = getDiscEtaBin(eta);
double disc = 0;
if (m_cfSignal){
// Put cf node into numerator
disc = (mvaScores(0, 0) * (1 - m_fractions.at(5 * etaBin + 0)) +
(mvaScores(1, 0) * m_fractions.at(5 * etaBin + 0))) /
((mvaScores(2, 0) * m_fractions.at(5 * etaBin + 1)) +
(mvaScores(3, 0) * m_fractions.at(5 * etaBin + 2)) +
(mvaScores(4, 0) * m_fractions.at(5 * etaBin + 3)) +
(mvaScores(5, 0) * m_fractions.at(5 * etaBin + 4)));
}
else{
// Put cf node in denominator
disc = mvaScores(0, 0) /
((mvaScores(1, 0) * m_fractions.at(5 * etaBin + 0)) +
(mvaScores(2, 0) * m_fractions.at(5 * etaBin + 1)) +
(mvaScores(3, 0) * m_fractions.at(5 * etaBin + 2)) +
(mvaScores(4, 0) * m_fractions.at(5 * etaBin + 3)) +
(mvaScores(5, 0) * m_fractions.at(5 * etaBin + 4)));
}
// Log transform to have values in reasonable range
return std::log(disc);
}
// Gets the Discriminant Eta bin [0,s_fnDiscEtaBins-1] given the eta
unsigned int AsgElectronSelectorTool::getDiscEtaBin( double eta ) const
......
......@@ -26,8 +26,10 @@
ElectronDNNCalculator::ElectronDNNCalculator(AsgElectronSelectorTool* owner,
const std::string& modelFileName,
const std::string& quantileFileName,
const std::vector<std::string>& variables) :
asg::AsgMessagingForward(owner)
const std::vector<std::string>& variables,
const bool multiClass) :
asg::AsgMessagingForward(owner),
m_multiClass(multiClass)
{
ATH_MSG_INFO("Initializing ElectronDNNCalculator...");
......@@ -53,8 +55,19 @@ ElectronDNNCalculator::ElectronDNNCalculator(AsgElectronSelectorTool* owner,
// create the model
inputFile.open(modelFileName);
auto parsedGraph = lwt::parse_json_graph(inputFile);
m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
// Test whether the number of outputs of the given network corresponds to the expected number
size_t nOutputs = parsedGraph.outputs.begin()->second.labels.size();
if (nOutputs != 6 && nOutputs != 1){
throw std::runtime_error("Given model does not have 1 or 6 outputs. Something seems to be wrong with the model file.");
}
else if (nOutputs == 1 && m_multiClass){
throw std::runtime_error("Given model has 1 output but config file specifies mutliclass. Something is wrong");
}
else if (nOutputs == 6 && !m_multiClass){
throw std::runtime_error("Given model has 6 output but config file does not specify mutliclass. Something is wrong");
}
m_graph = std::make_unique<lwt::generic::FastGraph<float>>(parsedGraph, order);
if (quantileFileName.empty()){
throw std::runtime_error("No file found at '" + quantileFileName + "'");
......@@ -64,14 +77,14 @@ ElectronDNNCalculator::ElectronDNNCalculator(AsgElectronSelectorTool* owner,
ATH_MSG_INFO("Loading QuantileTransformer " << quantileFileName);
TFile* qtfile = TFile::Open(quantileFileName.data());
if (readQuantileTransformer((TTree*)qtfile->Get("tree"), variables) == 0){
throw std::runtime_error("Could not load all variables for the QuantileTransformer");
throw std::runtime_error("Could not load all variables for the QuantileTransformer");
}
}
// takes the input variables, transforms them according to the given QuantileTransformer and predicts the DNN value
double ElectronDNNCalculator::calculate( const MVAEnum::MVACalcVars& varsStruct ) const
// takes the input variables, transforms them according to the given QuantileTransformer and predicts the DNN value(s)
Eigen::Matrix<float, -1, 1> ElectronDNNCalculator::calculate( const MVAEnum::MVACalcVars& varsStruct ) const
{
// Create the input for the model
Eigen::VectorXf inputVector(20);
......@@ -99,11 +112,10 @@ double ElectronDNNCalculator::calculate( const MVAEnum::MVACalcVars& varsStruct
inputVector(19) = transformInput( m_quantiles.wtots1, varsStruct.wtots1);
std::vector<Eigen::VectorXf> inp;
inp.push_back(inputVector);
inp.emplace_back(std::move(inputVector));
auto output = m_graph->compute(inp);
double score = output(0);
return score;
return output;
}
......@@ -151,7 +163,7 @@ int ElectronDNNCalculator::readQuantileTransformer( TTree* tree, const std::vect
std::map<std::string, double> readVars;
for ( const auto& var : variables ){
sc = tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
sc = tree->SetBranchAddress(TString(var), &readVars[var]) == -5 ? 0 : 1;
}
for (int i = 0; i < tree->GetEntries(); i++){
tree->GetEntry(i);
......
......@@ -10,7 +10,7 @@
// This include is needed at the top before any includes regarding lwtnn
// since it includes Eigen in a specific way which causes compilation errors
// if not included before lwtnn
#include "GeoPrimitives/GeoPrimitives.h"
#include "EventPrimitives/EventPrimitives.h"
#include "AsgMessaging/AsgMessagingForward.h"
#include "ElectronPhotonSelectorTools/AsgElectronSelectorTool.h"
......@@ -77,13 +77,14 @@ public:
ElectronDNNCalculator( AsgElectronSelectorTool* owner,
const std::string& modelFileName,
const std::string& quantileFileName,
const std::vector<std::string>& variablesName);
const std::vector<std::string>& variablesName,
const bool multiClass);
/** Standard destructor*/
~ElectronDNNCalculator() {};
/** Get the prediction of the DNN model*/
double calculate( const MVAEnum::MVACalcVars& varsStruct ) const;
Eigen::Matrix<float, -1, 1> calculate( const MVAEnum::MVACalcVars& varsStruct ) const;
private:
/** transform the input variables according to a given QuantileTransformer.*/
......@@ -98,6 +99,8 @@ private:
MVAEnum::QTVars m_quantiles;
/// Reference values for the QuantileTransformer. Basically just equidistant bins between 0 and 1.
std::vector<double> m_references;
/// Whether the used model is a multiclass model or not.
bool m_multiClass;
};
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment