From 76789fe99b229010584a0f3b0b957ff1d6deea54 Mon Sep 17 00:00:00 2001
From: Marcin Piotr Wandas <marcin.piotr.wandas@cern.ch>
Date: Mon, 20 Apr 2020 11:38:52 +0000
Subject: [PATCH] [ATLASRECTS-5396][ATLASRECTS-5380] Move initReader method
 from execute to initialize method in MvaTESEvaluator

---
 Reconstruction/tauRecTools/Root/BDTHelper.cxx | 27 ++++++++++++++
 .../tauRecTools/Root/MvaTESEvaluator.cxx      | 36 ++++++-------------
 .../tauRecTools/tauRecTools/BDTHelper.h       |  4 +++
 .../tauRecTools/tauRecTools/MvaTESEvaluator.h | 16 ++++-----
 4 files changed, 49 insertions(+), 34 deletions(-)

diff --git a/Reconstruction/tauRecTools/Root/BDTHelper.cxx b/Reconstruction/tauRecTools/Root/BDTHelper.cxx
index 177e5a8030904..1623f035d2786 100644
--- a/Reconstruction/tauRecTools/Root/BDTHelper.cxx
+++ b/Reconstruction/tauRecTools/Root/BDTHelper.cxx
@@ -93,7 +93,22 @@ std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float> &
   return values;
 }
 
+std::vector<float> BDTHelper::getInputVariables(const std::map<TString, float*> &availableVariables) const {  
+  std::vector<float> values;
 
+  // sort the input variables by the order in varList (from BDT)
+  for (const TString& name : m_inputVariableNames) {
+    std::map<TString, float*>::const_iterator itr = availableVariables.find(name);
+    if(itr==availableVariables.end()) {
+      ATH_MSG_ERROR(name << " not available");
+    }
+    else {
+      values.push_back(*itr->second);
+    }
+  }
+
+  return values;
+}
 
 std::vector<float> BDTHelper::getInputVariables(const xAOD::TauJet& tau) const {
   std::vector<float> values;
@@ -130,7 +145,19 @@ float BDTHelper::getGradBoostMVA(const std::map<TString, float> &availableVariab
   return score;
 }
 
+float BDTHelper::getResponse(const std::map<TString, float*> &availableVariables) const {
+  std::vector<float> values = getInputVariables(availableVariables);
 
+  float score = -999;
+  if (values.size() < m_inputVariableNames.size()) {
+    ATH_MSG_ERROR("There are missing variables when calculating the BDT score, will return -999");
+  }
+  else {  
+    score = m_BDT->GetResponse(values);
+  }
+
+  return score;
+}
 
 float BDTHelper::getGradBoostMVA(const xAOD::TauJet& tau) const {
   std::vector<float> values = getInputVariables(tau);
diff --git a/Reconstruction/tauRecTools/Root/MvaTESEvaluator.cxx b/Reconstruction/tauRecTools/Root/MvaTESEvaluator.cxx
index aa5d9b0e4eff6..24285b5f26717 100644
--- a/Reconstruction/tauRecTools/Root/MvaTESEvaluator.cxx
+++ b/Reconstruction/tauRecTools/Root/MvaTESEvaluator.cxx
@@ -6,6 +6,8 @@
 #include "tauRecTools/MvaTESEvaluator.h"
 #include "tauRecTools/HelperFunctions.h"
 
+#include <TTree.h>
+
 #include <vector>
 
 //_____________________________________________________________________________
@@ -22,13 +24,17 @@ MvaTESEvaluator::~MvaTESEvaluator()
 //_____________________________________________________________________________
 StatusCode MvaTESEvaluator::initialize(){
   
+  const std::string weightFile = find_file(m_sWeightFileName);
+  m_bdtHelper = std::make_unique<tauRecTools::BDTHelper>();
+  ATH_CHECK(m_bdtHelper->initialize(weightFile));
   return StatusCode::SUCCESS;
 }
 
 //_____________________________________________________________________________
-StatusCode MvaTESEvaluator::initReader(std::unique_ptr<MVAUtils::BDT>& reader,
-                                       std::map<TString, float*>& availableVars,
-                                       MvaInputVariables& vars) const {
+StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
+
+  std::map<TString, float*> availableVars;
+  MvaInputVariables vars;
 
   // Declare input variables to the reader
   if(!m_in_trigger) {
@@ -65,26 +71,6 @@ StatusCode MvaTESEvaluator::initReader(std::unique_ptr<MVAUtils::BDT>& reader,
     availableVars.insert( std::make_pair("TrigTauJetsAuxDyn.etaDetectorAxis", &vars.etaDetectorAxis) );
   }
 
-  std::string weightFile = find_file(m_sWeightFileName);
-
-  reader = tauRecTools::configureMVABDT( availableVars, weightFile.c_str() );
-  if(reader==nullptr) {
-    ATH_MSG_FATAL("Couldn't configure MVA");
-    return StatusCode::FAILURE;
-  }
-
-  return StatusCode::SUCCESS;
-}
-
-//_____________________________________________________________________________
-StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
-
-  std::unique_ptr<MVAUtils::BDT> reader{nullptr};
-  std::map<TString, float*> availableVars;
-  MvaInputVariables vars;
-  if (initReader(reader, availableVars, vars) == StatusCode::FAILURE)
-   return StatusCode::FAILURE;
-
   // Retrieve event info
   const SG::AuxElement::ConstAccessor<float> acc_mu("mu");
   const SG::AuxElement::ConstAccessor<int> acc_nVtxPU("nVtxPU");
@@ -128,7 +114,7 @@ StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
     vars.nTracks = (float)xTau.nTracks();
     xTau.detail(xAOD::TauJetParameters::PFOEngRelDiff, vars.PFOEngRelDiff);
     
-    float ptMVA = float( vars.ptCombined * reader->GetResponse() );
+    float ptMVA = float( vars.ptCombined * m_bdtHelper->getResponse(availableVars) );
     if(ptMVA<1) ptMVA=1;
     xTau.setP4(xAOD::TauJetParameters::FinalCalib, ptMVA, vars.etaConstituent, xTau.phiPanTauCellBased(), 0);
 
@@ -145,7 +131,7 @@ StatusCode MvaTESEvaluator::execute(xAOD::TauJet& xTau) const {
     vars.upsilon_cluster = acc_UpsilonCluster(xTau);
     vars.lead_cluster_frac = acc_LeadClusterFrac(xTau);
 
-    float ptMVA = float( vars.ptDetectorAxis * reader->GetResponse() );
+    float ptMVA = float( vars.ptDetectorAxis * m_bdtHelper->getResponse(availableVars) );
     if(ptMVA<1) ptMVA=1;
 
     // this may have to be changed if we apply a calo-only MVA calibration first, followed by a calo+track MVA calibration
diff --git a/Reconstruction/tauRecTools/tauRecTools/BDTHelper.h b/Reconstruction/tauRecTools/tauRecTools/BDTHelper.h
index e0f2d6e398ba0..fdf0ca111238d 100644
--- a/Reconstruction/tauRecTools/tauRecTools/BDTHelper.h
+++ b/Reconstruction/tauRecTools/tauRecTools/BDTHelper.h
@@ -22,6 +22,8 @@ namespace tauRecTools {
 
       float getGradBoostMVA(const std::map<TString, float>& availableVariables) const; 
 
+      float getResponse(const std::map<TString, float*>& availableVariables) const; 
+
       float getGradBoostMVA(const xAOD::TauJet& tau) const; 
 
       MVAUtils::BDT* getBDT() const { return m_BDT.get(); }
@@ -31,6 +33,8 @@ namespace tauRecTools {
       
       std::vector<float> getInputVariables(const std::map<TString, float>& availableVariables) const ;
       
+      std::vector<float> getInputVariables(const std::map<TString, float*>& availableVariables) const;
+
       std::vector<float> getInputVariables(const xAOD::TauJet& tau) const ;
 
       std::unique_ptr<MVAUtils::BDT> m_BDT;
diff --git a/Reconstruction/tauRecTools/tauRecTools/MvaTESEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/MvaTESEvaluator.h
index da95d38c57561..c25972d598958 100644
--- a/Reconstruction/tauRecTools/tauRecTools/MvaTESEvaluator.h
+++ b/Reconstruction/tauRecTools/tauRecTools/MvaTESEvaluator.h
@@ -8,7 +8,7 @@
 // tauRecTools include(s)
 #include "tauRecTools/TauRecToolBase.h"
 
-#include "MVAUtils/BDT.h"
+#include "tauRecTools/BDTHelper.h"
 
 #include <map>
 
@@ -35,24 +35,24 @@ class MvaTESEvaluator
   {
     float mu{0.0}; //!
     float nVtxPU{0.0}; //!
-  
+
     float center_lambda{0.0}; //!
     float first_eng_dens{0.0}; //!
     float second_lambda{0.0}; //!
     float presampler_frac{0.0}; //!
     float eprobability{0.0}; //!
-  
+
     float ptCombined{0.0}; //!
     float ptLC_D_ptCombined{0.0}; //!
     float ptConstituent_D_ptCombined{0.0};//!
     float etaConstituent{0.0}; //!
-  
+
     float PanTauBDT_1p0n_vs_1p1n{0.0}; //!
     float PanTauBDT_1p1n_vs_1pXn{0.0}; //!
     float PanTauBDT_3p0n_vs_3pXn{0.0}; //!
     float nTracks{0.0}; //!
     float PFOEngRelDiff{0.0}; //!
-  
+
     // Spectators
     float truthPtVis{0.0}; //!
     float pt{0.0}; //!
@@ -66,10 +66,8 @@ class MvaTESEvaluator
     float upsilon_cluster{0.0}; //!
     float lead_cluster_frac{0.0}; //!
   };
- 
-  StatusCode initReader(std::unique_ptr<MVAUtils::BDT>& reader,
-                        std::map<TString, float*>& availableVars,
-                        MvaInputVariables& vars) const;
+
+  std::unique_ptr<tauRecTools::BDTHelper> m_bdtHelper;
 
   // Configurable properties
   Gaudi::Property<std::string> m_sWeightFileName{this, "WeightFileName", "MvaTES_20170207_v2_BDTG.weights.root"};
-- 
GitLab