Skip to content
Snippets Groups Projects
Commit 76789fe9 authored by Marcin Piotr Wandas's avatar Marcin Piotr Wandas Committed by Frank Winklmeier
Browse files

[ATLASRECTS-5396][ATLASRECTS-5380] Move initReader method from execute to...

[ATLASRECTS-5396][ATLASRECTS-5380] Move initReader method from execute to initialize method in MvaTESEvaluator
parent 546864d7
No related branches found
No related tags found
No related merge requests found
......@@ -93,7 +93,22 @@ std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float> &
return values;
}
std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float*> &availableVariables) const {
std::vector<float> values;
// sort the input variables by the order in varList (from BDT)
for (const TString& name : m_inputVariableNames) {
std::map<TString, float*>::const_iterator itr = availableVariables.find(name);
if(itr==availableVariables.end()) {
ATH_MSG_ERROR(name << " not available");
}
else {
values.push_back(*itr->second);
}
}
return values;
}
std::vector<float> BDTHelper::getInputVariables(const xAOD::TauJet& tau) const {
std::vector<float> values;
......@@ -130,7 +145,19 @@ float BDTHelper::getGradBoostMVA(const std::map<TString, float> &availableVariab
return score;
}
float BDTHelper::getResponse(const std::map<TString, float*> &availableVariables) const {
std::vector<float> values = getInputVariables(availableVariables);
float score = -999;
if (values.size() < m_inputVariableNames.size()) {
ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
}
else {
score = m_BDT->GetResponse(values);
}
return score;
}
float BDTHelper::getGradBoostMVA(const xAOD::TauJet& tau) const {
std::vector<float> values = getInputVariables(tau);
......
......@@ -6,6 +6,8 @@
#include "tauRecTools/MvaTESEvaluator.h"
#include "tauRecTools/HelperFunctions.h"
#include <TTree.h>
#include <vector>
//_____________________________________________________________________________
......@@ -22,13 +24,17 @@ MvaTESEvaluator::~MvaTESEvaluator()
//_____________________________________________________________________________
StatusCode MvaTESEvaluator::initialize(){
const std::string weightFile = find_file(m_sWeightFileName);
m_bdtHelper = std::make_unique<tauRecTools::BDTHelper>();
ATH_CHECK(m_bdtHelper->initialize(weightFile));
return StatusCode::SUCCESS;
}
//_____________________________________________________________________________
StatusCode MvaTESEvaluator::initReader(std::unique_ptr<MVAUtils::BDT>& reader,
std::map<TString, float*>& availableVars,
MvaInputVariables& vars) const {
StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
std::map<TString, float*> availableVars;
MvaInputVariables vars;
// Declare input variables to the reader
if(!m_in_trigger) {
......@@ -65,26 +71,6 @@ StatusCode MvaTESEvaluator::initReader(std::unique_ptr<MVAUtils::BDT>& reader,
availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.etaDetectorAxis", &vars.etaDetectorAxis) );
}
std::string weightFile = find_file(m_sWeightFileName);
reader = tauRecTools::configureMVABDT( availableVars, weightFile.c_str() );
if(reader==nullptr) {
ATH_MSG_FATAL("Couldn't configure MVA");
return StatusCode::FAILURE;
}
return StatusCode::SUCCESS;
}
//_____________________________________________________________________________
StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
std::unique_ptr<MVAUtils::BDT> reader{nullptr};
std::map<TString, float*> availableVars;
MvaInputVariables vars;
if (initReader(reader, availableVars, vars) == StatusCode::FAILURE)
return StatusCode::FAILURE;
// Retrieve event info
const SG::AuxElement::ConstAccessor<float> acc_mu("mu");
const SG::AuxElement::ConstAccessor<int> acc_nVtxPU("nVtxPU");
......@@ -128,7 +114,7 @@ StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
vars.nTracks = (float)xTau.nTracks();
xTau.detail(xAOD::TauJetParameters::PFOEngRelDiff, vars.PFOEngRelDiff);
float ptMVA = float( vars.ptCombined * reader->GetResponse() );
float ptMVA = float( vars.ptCombined * m_bdtHelper->getResponse(availableVars) );
if(ptMVA<1) ptMVA=1;
xTau.setP4(xAOD::TauJetParameters::FinalCalib, ptMVA, vars.etaConstituent, xTau.phiPanTauCellBased(), 0);
......@@ -145,7 +131,7 @@ StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
vars.upsilon_cluster = acc_UpsilonCluster(xTau);
vars.lead_cluster_frac = acc_LeadClusterFrac(xTau);
float ptMVA = float( vars.ptDetectorAxis * reader->GetResponse() );
float ptMVA = float( vars.ptDetectorAxis * m_bdtHelper->getResponse(availableVars) );
if(ptMVA<1) ptMVA=1;
// this may have to be changed if we apply a calo-only MVA calibration first, followed by a calo+track MVA calibration
......
......@@ -22,6 +22,8 @@ namespace tauRecTools {
float getGradBoostMVA(const std::map<TString, float>& availableVariables) const;
float getResponse(const std::map<TString, float*>& availableVariables) const;
float getGradBoostMVA(const xAOD::TauJet& tau) const;
MVAUtils::BDT* getBDT() const { return m_BDT.get(); }
......@@ -31,6 +33,8 @@ namespace tauRecTools {
std::vector<float> getInputVariables(const std::map<TString, float>& availableVariables) const ;
std::vector<float> getInputVariables(const std::map<TString, float*>& availableVariables) const;
std::vector<float> getInputVariables(const xAOD::TauJet& tau) const ;
std::unique_ptr<MVAUtils::BDT> m_BDT;
......
......@@ -8,7 +8,7 @@
// tauRecTools include(s)
#include "tauRecTools/TauRecToolBase.h"
#include "MVAUtils/BDT.h"
#include "tauRecTools/BDTHelper.h"
#include <map>
......@@ -35,24 +35,24 @@ class MvaTESEvaluator
{
float mu{0.0}; //!
float nVtxPU{0.0}; //!
float center_lambda{0.0}; //!
float first_eng_dens{0.0}; //!
float second_lambda{0.0}; //!
float presampler_frac{0.0}; //!
float eprobability{0.0}; //!
float ptCombined{0.0}; //!
float ptLC_D_ptCombined{0.0}; //!
float ptConstituent_D_ptCombined{0.0};//!
float etaConstituent{0.0}; //!
float PanTauBDT_1p0n_vs_1p1n{0.0}; //!
float PanTauBDT_1p1n_vs_1pXn{0.0}; //!
float PanTauBDT_3p0n_vs_3pXn{0.0}; //!
float nTracks{0.0}; //!
float PFOEngRelDiff{0.0}; //!
// Spectators
float truthPtVis{0.0}; //!
float pt{0.0}; //!
......@@ -66,10 +66,8 @@ class MvaTESEvaluator
float upsilon_cluster{0.0}; //!
float lead_cluster_frac{0.0}; //!
};
StatusCode initReader(std::unique_ptr<MVAUtils::BDT>& reader,
std::map<TString, float*>& availableVars,
MvaInputVariables& vars) const;
std::unique_ptr<tauRecTools::BDTHelper> m_bdtHelper;
// Configurable properties
Gaudi::Property<std::string> m_sWeightFileName{this, "WeightFileName", "MvaTES_20170207_v2_BDTG.weights.root"};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment