diff --git a/Reconstruction/tauRecTools/Root/BDTHelper.cxx b/Reconstruction/tauRecTools/Root/BDTHelper.cxx new file mode 100644 index 0000000000000000000000000000000000000000..c1050bd3c1b312476b5e3cd07da87c3a3bc154d9 --- /dev/null +++ b/Reconstruction/tauRecTools/Root/BDTHelper.cxx @@ -0,0 +1,109 @@ +/* + Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration +*/ + +#include "tauRecTools/BDTHelper.h" + +#include "TFile.h" +#include "TTree.h" +#include "TObjArray.h" + +namespace tauRecTools { + +BDTHelper::BDTHelper() : + asg::AsgMessaging("BDTHelper"), + m_BDT(nullptr) { +} + + + +BDTHelper::~BDTHelper() { +} + + + +StatusCode BDTHelper::initialize(const TString& weightFileName) { + + std::unique_ptr<TFile> file(TFile::Open(weightFileName)); + if (!file) { + ATH_MSG_ERROR("Cannot find input BDT file: " << weightFileName); + return StatusCode::FAILURE; + } + ATH_MSG_INFO( "Open file: " << weightFileName); + + TTree* tree = dynamic_cast<TTree*> (file->Get("BDT")); + if (!tree) { + ATH_MSG_ERROR("Cannot find input BDT tree"); + return StatusCode::FAILURE; + } + m_BDT = std::make_unique<MVAUtils::BDT>(tree); + + TNamed* varList = dynamic_cast<TNamed*> (file->Get("varList")); + if (!varList) { + ATH_MSG_ERROR("No variable list in file: " << weightFileName); + return StatusCode::FAILURE; + } + TString names = varList->GetTitle(); + delete varList; + + // abtain the list of input variables + m_inputVariableNames = parseString(names); + + file->Close(); + + return StatusCode::SUCCESS; +} + +std::vector<TString> BDTHelper::parseString(const TString& str, const TString& delim/*=","*/) const { + std::vector<TString> parsedString; + + TObjArray* objList = str.Tokenize(delim); + size_t arraySize = objList->GetEntries(); + + // split the string with ",", and put them into a vector + for(size_t i = 0; i < arraySize; ++i) { + TString var = dynamic_cast<TObjString*> (objList->At(i))->String(); + var.ReplaceAll(" ", ""); + if(var.Contains(":=")) { + var=var(var.Index(":=")+2, var.Length()-var.Index(":=")-2); + } + if(0==var.Length()) continue; + parsedString.push_back(var); + } + + delete objList; + + return parsedString; +} + +std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float> &availableVariables) const { + std::vector<float> values; + + // sort the input variables by the order in varList (from BDT) + for (const TString& name : m_inputVariableNames) { + std::map<TString, float>::const_iterator itr = availableVariables.find(name); + if(itr==availableVariables.end()) { + ATH_MSG_ERROR(name << " not available"); + } + else { + values.push_back(itr->second); + } + } + + return values; +} + + +StatusCode BDTHelper::getGradBoostMVA(const std::map<TString, float> &availableVariables, float& score) const { + std::vector<float> values = getInputVariables(availableVariables); + + if (values.size() < m_inputVariableNames.size()) { + ATH_MSG_ERROR("There are missing variables."); + return StatusCode::FAILURE; + } + + score = m_BDT->GetGradBoostMVA(values); + return StatusCode::SUCCESS; +} + +} // end of namespace tauRecTools diff --git a/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx b/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx index 42dc919f69c36bea6947f211b3cf8aae9bbe283a..c48e7e2f5ee1376c4ce9c83b314144ccc868008a 100644 --- a/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx +++ b/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx @@ -16,8 +16,6 @@ #include "tauRecTools/HelperFunctions.h" #include "xAODPFlow/PFO.h" -#include "MVAUtils/BDT.h" - using std::vector; using std::string; @@ -27,20 +25,7 @@ using std::string; TauPi0ScoreCalculator::TauPi0ScoreCalculator( const string& name ) : TauRecToolBase(name), - m_mvaBDT(nullptr), - m_Abs_FIRST_ETA(0), - m_SECOND_R(0), - m_Abs_DELTA_THETA(0), - m_CENTER_LAMBDA_helped(0), - m_LONGITUDINAL(0), - m_ENG_FRAC_EM(0), - m_ENG_FRAC_CORE(0), - m_log_SECOND_ENG_DENS(0), - m_EcoreOverEEM1(0), - m_NPosCells_EM1(0), - m_NPosCells_EM2(0), - m_firstEtaWRTCluster_EM1(0), - m_secondEtaWRTCluster_EM2(0) + m_mvaBDT(nullptr) { declareProperty("BDTWeightFile", m_weightfile); } @@ -56,30 +41,10 @@ TauPi0ScoreCalculator::~TauPi0ScoreCalculator() StatusCode TauPi0ScoreCalculator::initialize() { - //--------------------------------------------------------------------- - // Create TMVA reader - //--------------------------------------------------------------------- - m_availableVars.insert( std::make_pair("Pi0Cluster_Abs_FIRST_ETA" ,&m_Abs_FIRST_ETA) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_SECOND_R" ,&m_SECOND_R) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_Abs_DELTA_THETA" ,&m_Abs_DELTA_THETA) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_CENTER_LAMBDA_helped" ,&m_CENTER_LAMBDA_helped) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_LONGITUDINAL" ,&m_LONGITUDINAL) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_ENG_FRAC_EM" ,&m_ENG_FRAC_EM) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_ENG_FRAC_CORE" ,&m_ENG_FRAC_CORE) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_log_SECOND_ENG_DENS" ,&m_log_SECOND_ENG_DENS) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_EcoreOverEEM1" ,&m_EcoreOverEEM1) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_NPosECells_EM1" ,&m_NPosCells_EM1) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_NPosECells_EM2" ,&m_NPosCells_EM2) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_AbsFirstEtaWRTClusterPosition_EM1" ,&m_firstEtaWRTCluster_EM1) ); - m_availableVars.insert( std::make_pair("Pi0Cluster_secondEtaWRTClusterPosition_EM2" ,&m_secondEtaWRTCluster_EM2) ); - std::string weightFile = find_file(m_weightfile); - m_mvaBDT = tauRecTools::configureMVABDT(m_availableVars, weightFile); - if(m_mvaBDT==nullptr) { - ATH_MSG_FATAL("Couldn't configure MVA"); - return StatusCode::FAILURE; - } + m_mvaBDT = std::make_unique<tauRecTools::BDTHelper>(); + ATH_CHECK(m_mvaBDT->initialize(weightFile)); return StatusCode::SUCCESS; } @@ -106,7 +71,7 @@ StatusCode TauPi0ScoreCalculator::executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOCo for( auto neutralPFO : neutralPFOContainer ) { float BDTScore = calculateScore(neutralPFO); - neutralPFO->setBDTPi0Score((float) BDTScore); + neutralPFO->setBDTPi0Score(BDTScore); } ATH_MSG_DEBUG("End of TauPi0ScoreCalculator::execute"); @@ -117,60 +82,105 @@ StatusCode TauPi0ScoreCalculator::executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOCo float TauPi0ScoreCalculator::calculateScore(const xAOD::PFO* neutralPFO) { - m_Abs_FIRST_ETA=0.; - m_SECOND_R=0.; - m_Abs_DELTA_THETA=0.; - m_CENTER_LAMBDA_helped=0.; - m_LONGITUDINAL=0.; - m_ENG_FRAC_EM=0.; - m_ENG_FRAC_CORE=0.; - m_log_SECOND_ENG_DENS=0.; - m_EcoreOverEEM1=0.; - // Need to convert int variables to floats after retrieving them - int NPosCells_EM1=0; - int NPosCells_EM2=0; - m_firstEtaWRTCluster_EM1=0.; - m_secondEtaWRTCluster_EM2=0.; - - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_FIRST_ETA,m_Abs_FIRST_ETA) == false) + std::map<TString, float> availableVariables; // map of the variable name to its value + + float Abs_FIRST_ETA = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_FIRST_ETA, Abs_FIRST_ETA) == false) { ATH_MSG_WARNING("Can't find FIRST_ETA. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_R,m_SECOND_R) == false) + } + Abs_FIRST_ETA = fabs(Abs_FIRST_ETA); + availableVariables.insert(std::make_pair("Pi0Cluster_Abs_FIRST_ETA", Abs_FIRST_ETA)); + + float SECOND_R = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_R, SECOND_R) == false) { ATH_MSG_WARNING("Can't find SECOND_R. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_DELTA_THETA,m_Abs_DELTA_THETA) == false) + } + availableVariables.insert(std::make_pair("Pi0Cluster_SECOND_R", SECOND_R)); + + float Abs_DELTA_THETA = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_DELTA_THETA, Abs_DELTA_THETA) == false) { ATH_MSG_WARNING("Can't find DELTA_THETA. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_CENTER_LAMBDA,m_CENTER_LAMBDA_helped) == false) + } + Abs_DELTA_THETA = fabs(Abs_DELTA_THETA); + availableVariables.insert(std::make_pair("Pi0Cluster_Abs_DELTA_THETA", Abs_DELTA_THETA)); + + float CENTER_LAMBDA_helped = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_CENTER_LAMBDA, CENTER_LAMBDA_helped) == false) { ATH_MSG_WARNING("Can't find CENTER_LAMBDA. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_LONGITUDINAL,m_LONGITUDINAL) == false) + } + CENTER_LAMBDA_helped = fmin(CENTER_LAMBDA_helped, 1000.); + availableVariables.insert(std::make_pair("Pi0Cluster_CENTER_LAMBDA_helped", CENTER_LAMBDA_helped)); + + float LONGITUDINAL = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_LONGITUDINAL, LONGITUDINAL) == false) { ATH_MSG_WARNING("Can't find LONGITUDINAL. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_EM,m_ENG_FRAC_EM) == false) + } + availableVariables.insert(std::make_pair("Pi0Cluster_LONGITUDINAL", LONGITUDINAL)); + + float ENG_FRAC_EM = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_EM, ENG_FRAC_EM) == false) { ATH_MSG_WARNING("Can't find ENG_FRAC_EM. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_CORE,m_ENG_FRAC_CORE) == false) + } + availableVariables.insert(std::make_pair("Pi0Cluster_ENG_FRAC_EM", ENG_FRAC_EM)); + + float ENG_FRAC_CORE = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_CORE, ENG_FRAC_CORE) == false) { ATH_MSG_WARNING("Can't find ENG_FRAC_CORE. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_ENG_DENS,m_log_SECOND_ENG_DENS) == false) + } + availableVariables.insert(std::make_pair("Pi0Cluster_ENG_FRAC_CORE", ENG_FRAC_CORE)); + + float log_SECOND_ENG_DENS = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_ENG_DENS, log_SECOND_ENG_DENS) == false) { ATH_MSG_WARNING("Can't find SECOND_ENG_DENS. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_EM1CoreFrac,m_EcoreOverEEM1) == false) + } + if(log_SECOND_ENG_DENS==0.) { + log_SECOND_ENG_DENS=-50.; + } + else { + log_SECOND_ENG_DENS = log(log_SECOND_ENG_DENS); + } + availableVariables.insert(std::make_pair("Pi0Cluster_log_SECOND_ENG_DENS", log_SECOND_ENG_DENS)); + + float EcoreOverEEM1 = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_EM1CoreFrac, EcoreOverEEM1) == false) { ATH_MSG_WARNING("Can't find EM1CoreFrac. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM1,NPosCells_EM1) == false) + } + availableVariables.insert(std::make_pair("Pi0Cluster_EcoreOverEEM1", EcoreOverEEM1)); + + int NPosECells_EM1 = 0; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM1, NPosECells_EM1) == false) { ATH_MSG_WARNING("Can't find NPosECells_EM1. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM2,NPosCells_EM2) == false) + } + availableVariables.insert(std::make_pair("Pi0Cluster_NPosECells_EM1", static_cast<float>(NPosECells_EM1))); + + int NPosECells_EM2 = 0; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM2, NPosECells_EM2) == false) { ATH_MSG_WARNING("Can't find NPosECells_EM2. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1,m_firstEtaWRTCluster_EM1) == false) + } + availableVariables.insert(std::make_pair("Pi0Cluster_NPosECells_EM2", static_cast<float>(NPosECells_EM2))); + + float AbsFirstEtaWRTClusterPosition_EM1 = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1, AbsFirstEtaWRTClusterPosition_EM1) == false) { ATH_MSG_WARNING("Can't find firstEtaWRTClusterPosition_EM1. Set it to 0."); - if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2,m_secondEtaWRTCluster_EM2) == false) + } + AbsFirstEtaWRTClusterPosition_EM1 = fabs(AbsFirstEtaWRTClusterPosition_EM1); + availableVariables.insert(std::make_pair("Pi0Cluster_AbsFirstEtaWRTClusterPosition_EM1", AbsFirstEtaWRTClusterPosition_EM1)); + + float secondEtaWRTClusterPosition_EM2 = 0.; + if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2, secondEtaWRTClusterPosition_EM2) == false) { ATH_MSG_WARNING("Can't find secondEtaWRTClusterPosition_EM2. Set it to 0."); - // Apply variable transformations - m_Abs_FIRST_ETA = fabs(m_Abs_FIRST_ETA); - m_Abs_DELTA_THETA = fabs(m_Abs_DELTA_THETA); - m_CENTER_LAMBDA_helped = fmin(m_CENTER_LAMBDA_helped, 1000.); - if(m_log_SECOND_ENG_DENS==0.) m_log_SECOND_ENG_DENS=-50.; - else m_log_SECOND_ENG_DENS = log(m_log_SECOND_ENG_DENS); - // Convert ints to floats so they can be read by the TMVA reader - m_NPosCells_EM1 = (float) NPosCells_EM1; - m_NPosCells_EM2 = (float) NPosCells_EM2; - m_firstEtaWRTCluster_EM1 = fabs(m_firstEtaWRTCluster_EM1); + } + availableVariables.insert(std::make_pair("Pi0Cluster_secondEtaWRTClusterPosition_EM2", secondEtaWRTClusterPosition_EM2)); // Calculate BDT score - float BDTScore = m_mvaBDT->GetGradBoostMVA(m_mvaBDT->GetPointers()); - - return BDTScore; + float score = 0; + StatusCode sc = m_mvaBDT->getGradBoostMVA(availableVariables, score); + + // return failure when the availableVariables lack some variables + if (sc.isFailure()) { + ATH_MSG_WARNING("Failed to calculate the BDT score"); + score = -999; + } + + return score; } diff --git a/Reconstruction/tauRecTools/tauRecTools/BDTHelper.h b/Reconstruction/tauRecTools/tauRecTools/BDTHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..690af0ed3c69cfdb943268b62b8ae9104b67f591 --- /dev/null +++ b/Reconstruction/tauRecTools/tauRecTools/BDTHelper.h @@ -0,0 +1,37 @@ +/* + Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef TAURECTOOLS_BDTHELPER_H +#define TAURECTOOLS_BDTHELPER_H + +#include "AsgTools/AsgMessaging.h" +#include "AsgTools/StatusCode.h" + +#include "MVAUtils/BDT.h" + +class TString; + +namespace tauRecTools { + class BDTHelper : public asg::AsgMessaging { + public: + BDTHelper(); + ~BDTHelper(); + + StatusCode initialize(const TString& weightFileName); + + StatusCode getGradBoostMVA(const std::map<TString, float>& availableVariables, float& score) const; + + MVAUtils::BDT* getBDT() const { return m_BDT.get(); } + + private: + std::vector<TString> parseString(const TString& str, const TString& delim = ",") const; + + std::vector<float> getInputVariables(const std::map<TString, float>& availableVariables) const ; + + std::unique_ptr<MVAUtils::BDT> m_BDT; + std::vector<TString> m_inputVariableNames; + }; +} + +#endif // not TAURECTOOLS_BDTHELPER_H diff --git a/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h b/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h index e859a0454289723aa6ec7a2ba1399c64fbdb8f57..548dd45d70daf0d71b939954e25144d29af56b84 100644 --- a/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h +++ b/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h @@ -9,8 +9,7 @@ #include <map> #include "tauRecTools/TauRecToolBase.h" #include "xAODPFlow/PFO.h" - -#include "MVAUtils/BDT.h" +#include "tauRecTools/BDTHelper.h" /** * @brief Selectes pi0Candidates (Pi0 Finder). @@ -32,28 +31,11 @@ public: virtual StatusCode executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) override; private: - std::unique_ptr<MVAUtils::BDT> m_mvaBDT; - - std::string m_weightfile; - - float m_Abs_FIRST_ETA; - float m_SECOND_R; - float m_Abs_DELTA_THETA; - float m_CENTER_LAMBDA_helped; - float m_LONGITUDINAL; - float m_ENG_FRAC_EM; - float m_ENG_FRAC_CORE; - float m_log_SECOND_ENG_DENS; - float m_EcoreOverEEM1; - float m_NPosCells_EM1; - float m_NPosCells_EM2; - float m_firstEtaWRTCluster_EM1; - float m_secondEtaWRTCluster_EM2; - - std::map<TString, float*> m_availableVars;//!< keeps track of available of availble floats - /** @brief function used to calculate BDT score */ float calculateScore(const xAOD::PFO* neutralPFO); + + std::string m_weightfile; + std::unique_ptr<tauRecTools::BDTHelper> m_mvaBDT; }; #endif /* TAUPI0SCORECALCULATOR_H */