From 3ebdd3961097d634bdf7ef83a0774b6ca2a4dfd5 Mon Sep 17 00:00:00 2001
From: Nilotpal Kakati <nilotpal.kakati@cern.ch>
Date: Thu, 24 Jun 2021 14:00:15 +0300
Subject: [PATCH 1/5] adding strategy and PCBT to fixed cut

---
 .../IBTaggingTruthTaggingTool.h               |   4 +-
 .../Root/BTaggingTruthTaggingTool.cxx         | 134 +++++++++++-------
 .../BTaggingTruthTaggingTool.h                |  13 +-
 3 files changed, 94 insertions(+), 57 deletions(-)

diff --git a/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h b/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
index e210eb1d9b58..1080dcefb021 100644
--- a/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
+++ b/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
@@ -47,9 +47,9 @@ class IBTaggingTruthTaggingTool : virtual public CP::ISystematicsTool {
    ...
   }
   */
-  virtual StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results,int rand_seed=-1)=0 ;
+  virtual StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, TString strategy="", int rand_seed=-1)=0 ;
     
-  virtual StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results,int rand_seed=-1)=0;
+  virtual StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, TString strategy="", int rand_seed=-1)=0;
 
 };
 #endif // CPIBTAGGINGTRUTHTAGGINGTOOL_H
diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
index 48b6ef943354..ff45613629a2 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
@@ -189,13 +189,23 @@ StatusCode BTaggingTruthTaggingTool::initialize() {
     m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin();
   }
   else{
-    if(m_useQuntile){
-      m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_OP) - m_availableOP.begin();
-      if(m_OperatingPoint_index >= m_availableOP.size()) {
-        ATH_MSG_ERROR(m_OP << " not in the list of available OPs");
-        return StatusCode::FAILURE;
+    if (m_pathToONNX != ""){
+      if (m_useQuntile){
+        ATH_MSG_ERROR("BTaggingTruthTaggingTool::TruthTagging with GNN doesn't support m_useQuntile=true yet");
+        return StatusCode::FAILURE;      
+      } else {
+        // 60% = 4, 70% = 3, 77% = 2, 85% = 1, 100% = 0
+        m_OP_index_for_GNN = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin() + 1; // GNN predicts 5 bins        
       }
-    }
+    } else {
+      if(m_useQuntile){
+        m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_OP) - m_availableOP.begin();
+        if(m_OperatingPoint_index >= m_availableOP.size()) {
+          ATH_MSG_ERROR(m_OP << " not in the list of available OPs");
+          return StatusCode::FAILURE;
+        }
+      } // m_useQuantile
+    } // !ONNX
   }
   
   m_eff_syst.clear();
@@ -465,11 +475,11 @@ float BTaggingTruthTaggingTool::getPermutationRW(TRFinfo &trfinf,bool isIncl, un
 
 
 
-StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys){
+StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys, TString strategy){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
-  if(sys==0) ANA_CHECK(getAllEffMC(trfinf));
+  if(sys==0) ANA_CHECK(getAllEffMC(trfinf, strategy));
   ANA_CHECK(check_syst_range(sys));
   if(trfinf.trfwsys_ex.size()==0)  trfinf.trfwsys_ex.resize(m_eff_syst.size());
   if(trfinf.trfwsys_in.size()==0) trfinf.trfwsys_in.resize(m_eff_syst.size());
@@ -505,7 +515,7 @@ StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::ve
 }
 
 
-StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results, int rand_seed){
+StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results, int rand_seed, TString strategy){
   ANA_CHECK_SET_TYPE (StatusCode);
   results.clear();
 
@@ -594,7 +604,7 @@ StatusCode BTaggingTruthTaggingTool::CalculateResults(const xAOD::JetContainer&
 }
 
 // setting inputs that the onnx tool will use
-StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw, Analysis::TruthTagResults& results,int rand_seed){
+StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw, Analysis::TruthTagResults& results, TString strategy, int rand_seed){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
@@ -602,11 +612,11 @@ StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std:
     
   ANA_CHECK(setJets(trfinf, node_feat, tagw));
 
-  return CalculateResults(trfinf,results,rand_seed);
+  return CalculateResults(trfinf, results, rand_seed, strategy);
 }
 
 // setting inputs that the onnx tool will use
-StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results,int rand_seed){
+StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, TString strategy, int rand_seed){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
@@ -614,20 +624,20 @@ StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContain
 
   ANA_CHECK(setJets(trfinf, jets, node_feat));
 
-  return CalculateResults(trfinf,results,rand_seed);
+  return CalculateResults(trfinf, results, rand_seed, strategy);
 }
 
-StatusCode BTaggingTruthTaggingTool::getAllEffMC(TRFinfo &trfinf){
-  if ( m_pathToONNX == "" ){
+StatusCode BTaggingTruthTaggingTool::getAllEffMC(TRFinfo &trfinf, TString strategy){
+  if (trfinf.node_feat.size() == 0){
     return getAllEffMCCDI(trfinf);
   } else {
 
-    return getAllEffMCGNN(trfinf);
+    return getAllEffMCGNN(trfinf, strategy);
   }
 }
 
 // uses onnx tool (no support for m_useQuantile now)
-StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf){
+StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf, TString strategy){
 
   trfinf.effMC.clear();
   if(m_useQuntile == true || m_continuous == true){
@@ -648,19 +658,45 @@ StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf){
       
       // need to transpose
       std::vector<float> tmp_effMC_oneOP; // shape:{num_jet}
-      for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
-        tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]);
+      if (strategy == "Leading2SignalJets"){
+        for (int jet_index=0; jet_index<2; jet_index++){
+          tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]);
+        }
+      } else {
+        for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
+          tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]);
+        }
       }
       trfinf.effMC_allOP[op_appo] = tmp_effMC_oneOP;
       OP_index++;
     }
-  } else {
-    CorrectionCode code = m_effTool->getMCEfficiencyONNX(trfinf.node_feat, trfinf.effMC);
-    if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){
-      ATH_MSG_ERROR("BTaggingEfficiencyTool::getMCEfficiencyONNX returned CorrectionCode::Error");
-      return StatusCode::FAILURE;
-    }
-  }
+  } // m_continuous
+  else {
+    if (m_useQuntile){
+      ATH_MSG_ERROR("BTaggingTruthTaggingTool::getMCEfficiencyONNX doesn't support m_useQuntile=true yet");
+      return StatusCode::FAILURE;      
+    } // m_useQuantile
+    else {
+      std::vector<std::vector<float>> tmp_effMC_allOP; // shape:{num_jets, num_wp}
+      CorrectionCode code = m_effTool->getMCEfficiencyONNX(trfinf.node_feat, tmp_effMC_allOP);
+      if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){
+        ATH_MSG_ERROR("BTaggingEfficiencyTool::getMCEfficiencyONNX returned CorrectionCode::Error");
+        return StatusCode::FAILURE;
+      }
+
+      if (strategy == "Leading2SignalJets"){
+        for (int jet_index=0; jet_index<2; jet_index++){
+          float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
+          trfinf.effMC.push_back(tmp_effMC);
+        }
+      } else {
+        for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
+          float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
+          trfinf.effMC.push_back(tmp_effMC);
+        }
+      }
+    } // !m_useQuantile
+  } // !m_continuous
         
   return StatusCode::SUCCESS;
 }
@@ -809,36 +845,36 @@ StatusCode BTaggingTruthTaggingTool::getAllEffSF(TRFinfo &trfinf,int sys){
     for(int iop = static_cast<int>(m_availableOP.size())-1; iop >= 0; iop--) {
       std::string op_appo = m_availableOP.at(iop);
       if(!m_useQuntile &&  iop < static_cast<int>(m_OperatingPoint_index)) continue;
-      for(size_t i=0; i<trfinf.jets.size(); i++){
-      SF=1.;
-      //set a dumb value of the truth tag weight to get the different efficiency maps for each bin. to be improved..
-      if(iop+1 < static_cast<int>(m_availableOP.size())){
-        trfinf.jets.at(i).vars.jetTagWeight = (m_binEdges.at(iop)+m_binEdges.at(iop+1))/2.; //to-do: make it fancy? random distribution for the tagger score
-      }
-      else{
-        trfinf.jets.at(i).vars.jetTagWeight = (m_binEdges.at(iop)+1.)/2.;
-      }
-    
-      CorrectionCode code = m_effTool->getScaleFactor(trfinf.jets.at(i).flav, trfinf.jets.at(i).vars, SF) ;
-      if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){
-        ATH_MSG_ERROR("BTaggingEfficiencyTool::getScaleFactor returned CorrectionCode::Error");
-        return StatusCode::FAILURE;
-      }
+      for(size_t i=0; i<trfinf.effMC_allOP[op_appo].size(); i++){
+        SF=1.;
+        //set a dumb value of the truth tag weight to get the different efficiency maps for each bin. to be improved..
+        if(iop+1 < static_cast<int>(m_availableOP.size())){
+          trfinf.jets.at(i).vars.jetTagWeight = (m_binEdges.at(iop)+m_binEdges.at(iop+1))/2.; //to-do: make it fancy? random distribution for the tagger score
+        }
+        else{
+          trfinf.jets.at(i).vars.jetTagWeight = (m_binEdges.at(iop)+1.)/2.;
+        }
+      
+        CorrectionCode code = m_effTool->getScaleFactor(trfinf.jets.at(i).flav, trfinf.jets.at(i).vars, SF) ;
+        if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){
+          ATH_MSG_ERROR("BTaggingEfficiencyTool::getScaleFactor returned CorrectionCode::Error");
+          return StatusCode::FAILURE;
+        }
 
-      trfinf.eff_allOP[op_appo].at(i)=trfinf.effMC_allOP[op_appo].at(i)*SF;
+        trfinf.eff_allOP[op_appo].at(i)=trfinf.effMC_allOP[op_appo].at(i)*SF;
 
-      //now sum all the corrected MC Eff together
-      if(iop+1 < static_cast<int>(m_availableOP.size())){
-          trfinf.eff_allOP[op_appo].at(i)+=trfinf.eff_allOP[m_availableOP.at(iop+1)].at(i); //they are already corrected for SF
-      }
-      if( op_appo == m_cutBenchmark)
-        trfinf.eff.at(i) = trfinf.eff_allOP[m_cutBenchmark].at(i);
+        //now sum all the corrected MC Eff together
+        if(iop+1 < static_cast<int>(m_availableOP.size())){
+            trfinf.eff_allOP[op_appo].at(i)+=trfinf.eff_allOP[m_availableOP.at(iop+1)].at(i); //they are already corrected for SF
+        }
+        if( op_appo == m_cutBenchmark)
+          trfinf.eff.at(i) = trfinf.eff_allOP[m_cutBenchmark].at(i);
       } //jets
     } //OP
   } //continuous
   
   else{
-    for(unsigned int i=0; i<trfinf.jets.size(); i++){
+    for(unsigned int i=0; i<trfinf.effMC.size(); i++){
       SF=1.;
       CorrectionCode code = m_effTool->getScaleFactor(trfinf.jets.at(i).flav, trfinf.jets.at(i).vars, SF) ;
       if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){
diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
index e4354fba0c37..adeb54089aa1 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
@@ -92,15 +92,15 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
   BTaggingTruthTaggingTool( const std::string& name );
 
   private:
-  StatusCode CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results,int rand_seed = -1);
+  StatusCode CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results,int rand_seed = -1, TString strategy="");
             
   public:
   StatusCode CalculateResults( std::vector<float>& pt, std::vector<float>& eta, std::vector<int>& flav, std::vector<float>& tagw, Analysis::TruthTagResults& results,int rand_seed = -1);
   StatusCode CalculateResults( const xAOD::JetContainer& jets, Analysis::TruthTagResults& results,int rand_seed = -1);
         
   // will use onnxtool
-  StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, int rand_seed=-1);
-  StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results,int rand_seed = -1);
+  StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, TString strategy="", int rand_seed=-1);
+  StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, TString strategy="", int rand_seed = -1);
 
   StatusCode setEffMapIndex(const std::string& flavour, unsigned int index);
   void setUseSystematics(bool useSystematics);
@@ -130,7 +130,7 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
             
   // get truth tagging weights
   // for one single systematic (including "Nominal")
-  StatusCode GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys=0);
+  StatusCode GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys=0, TString strategy="");
 
   // tag permutation: trf_chosen_perm_ex.at(ntag).at(i) tells if the i-th jet is tagged in a selection requiring == ntag tags
   StatusCode getTagPermutation(TRFinfo &trfinf, std::vector<std::vector<bool> > &trf_chosen_perm_ex, std::vector<std::vector<bool> > &trf_chosen_perm_in);
@@ -163,9 +163,9 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
 
   StatusCode getTRFweight(TRFinfo &trfinf,unsigned int nbtag, bool isInclusive, int sys);
 
-  StatusCode getAllEffMC(TRFinfo &trfinf);
+  StatusCode getAllEffMC(TRFinfo &trfinf, TString strategy="");
   StatusCode getAllEffMCCDI(TRFinfo &trfinf);
-  StatusCode getAllEffMCGNN(TRFinfo &trfinf);
+  StatusCode getAllEffMCGNN(TRFinfo &trfinf, TString strategy="");
             
   StatusCode getAllEffSF(TRFinfo &trfinf,int =0);
   std::vector<CP::SystematicSet> m_eff_syst;
@@ -254,6 +254,7 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
   int m_nbtag;
 
   unsigned int m_OperatingPoint_index;
+  unsigned int m_OP_index_for_GNN;
 
   std::map<std::string, asg::AnaToolHandle<IBTaggingEfficiencyTool> > m_effTool_allOP;
 
-- 
GitLab


From bd498106e2c037bcc6fc9b0d4311a768615c08d3 Mon Sep 17 00:00:00 2001
From: Nilotpal Kakati <nilotpal.kakati@cern.ch>
Date: Thu, 24 Jun 2021 16:58:11 +0300
Subject: [PATCH 2/5] fixed warning - unused parameter strategy

---
 .../xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx    | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
index ff45613629a2..c70fcd8cf4bc 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
@@ -541,7 +541,7 @@ StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis:
     trf_weight_ex.clear();
     trf_weight_in.clear();
 
-    ANA_CHECK(GetTruthTagWeights(trfinf, trf_weight_ex, trf_weight_in, i));
+    ANA_CHECK(GetTruthTagWeights(trfinf, trf_weight_ex, trf_weight_in, i, strategy));
 
   }
 
-- 
GitLab


From 3a763c562a3330282b17e574a70c7aebc7103c7a Mon Sep 17 00:00:00 2001
From: Nilotpal Kakati <nilotpal.kakati@cern.ch>
Date: Sat, 26 Jun 2021 23:07:27 +0300
Subject: [PATCH 3/5] passing TString strategy as a const reference

---
 .../IBTaggingTruthTaggingTool.h                      |  4 ++--
 .../Root/BTaggingTruthTaggingTool.cxx                | 12 ++++++------
 .../BTaggingTruthTaggingTool.h                       | 12 ++++++------
 3 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h b/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
index 1080dcefb021..71e15c0b9342 100644
--- a/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
+++ b/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
@@ -47,9 +47,9 @@ class IBTaggingTruthTaggingTool : virtual public CP::ISystematicsTool {
    ...
   }
   */
-  virtual StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, TString strategy="", int rand_seed=-1)=0 ;
+  virtual StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed=-1)=0 ;
     
-  virtual StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, TString strategy="", int rand_seed=-1)=0;
+  virtual StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed=-1)=0;
 
 };
 #endif // CPIBTAGGINGTRUTHTAGGINGTOOL_H
diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
index c70fcd8cf4bc..377ea2eb528d 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
@@ -475,7 +475,7 @@ float BTaggingTruthTaggingTool::getPermutationRW(TRFinfo &trfinf,bool isIncl, un
 
 
 
-StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys, TString strategy){
+StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys, const TString &strategy){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
@@ -515,7 +515,7 @@ StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::ve
 }
 
 
-StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results, int rand_seed, TString strategy){
+StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results, int rand_seed, const TString &strategy){
   ANA_CHECK_SET_TYPE (StatusCode);
   results.clear();
 
@@ -604,7 +604,7 @@ StatusCode BTaggingTruthTaggingTool::CalculateResults(const xAOD::JetContainer&
 }
 
 // setting inputs that the onnx tool will use
-StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw, Analysis::TruthTagResults& results, TString strategy, int rand_seed){
+StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw, Analysis::TruthTagResults& results, const TString &strategy, int rand_seed){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
@@ -616,7 +616,7 @@ StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std:
 }
 
 // setting inputs that the onnx tool will use
-StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, TString strategy, int rand_seed){
+StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, const TString &strategy, int rand_seed){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
@@ -627,7 +627,7 @@ StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContain
   return CalculateResults(trfinf, results, rand_seed, strategy);
 }
 
-StatusCode BTaggingTruthTaggingTool::getAllEffMC(TRFinfo &trfinf, TString strategy){
+StatusCode BTaggingTruthTaggingTool::getAllEffMC(TRFinfo &trfinf, const TString &strategy){
   if (trfinf.node_feat.size() == 0){
     return getAllEffMCCDI(trfinf);
   } else {
@@ -637,7 +637,7 @@ StatusCode BTaggingTruthTaggingTool::getAllEffMC(TRFinfo &trfinf, TString strate
 }
 
 // uses onnx tool (no support for m_useQuantile now)
-StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf, TString strategy){
+StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf, const TString &strategy){
 
   trfinf.effMC.clear();
   if(m_useQuntile == true || m_continuous == true){
diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
index adeb54089aa1..30a7658fa95c 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
@@ -92,15 +92,15 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
   BTaggingTruthTaggingTool( const std::string& name );
 
   private:
-  StatusCode CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results,int rand_seed = -1, TString strategy="");
+  StatusCode CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results,int rand_seed = -1, const TString &strategy="");
             
   public:
   StatusCode CalculateResults( std::vector<float>& pt, std::vector<float>& eta, std::vector<int>& flav, std::vector<float>& tagw, Analysis::TruthTagResults& results,int rand_seed = -1);
   StatusCode CalculateResults( const xAOD::JetContainer& jets, Analysis::TruthTagResults& results,int rand_seed = -1);
         
   // will use onnxtool
-  StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, TString strategy="", int rand_seed=-1);
-  StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, TString strategy="", int rand_seed = -1);
+  StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed=-1);
+  StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed = -1);
 
   StatusCode setEffMapIndex(const std::string& flavour, unsigned int index);
   void setUseSystematics(bool useSystematics);
@@ -130,7 +130,7 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
             
   // get truth tagging weights
   // for one single systematic (including "Nominal")
-  StatusCode GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys=0, TString strategy="");
+  StatusCode GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys=0, const TString &strategy="");
 
   // tag permutation: trf_chosen_perm_ex.at(ntag).at(i) tells if the i-th jet is tagged in a selection requiring == ntag tags
   StatusCode getTagPermutation(TRFinfo &trfinf, std::vector<std::vector<bool> > &trf_chosen_perm_ex, std::vector<std::vector<bool> > &trf_chosen_perm_in);
@@ -163,9 +163,9 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
 
   StatusCode getTRFweight(TRFinfo &trfinf,unsigned int nbtag, bool isInclusive, int sys);
 
-  StatusCode getAllEffMC(TRFinfo &trfinf, TString strategy="");
+  StatusCode getAllEffMC(TRFinfo &trfinf, const TString &strategy="");
   StatusCode getAllEffMCCDI(TRFinfo &trfinf);
-  StatusCode getAllEffMCGNN(TRFinfo &trfinf, TString strategy="");
+  StatusCode getAllEffMCGNN(TRFinfo &trfinf, const TString &strategy="");
             
   StatusCode getAllEffSF(TRFinfo &trfinf,int =0);
   std::vector<CP::SystematicSet> m_eff_syst;
-- 
GitLab


From db105cb164cfe87c71fe6d89fa166a111cb1613e Mon Sep 17 00:00:00 2001
From: Nilotpal Kakati <nilotpal.kakati@cern.ch>
Date: Tue, 29 Jun 2021 16:48:30 +0300
Subject: [PATCH 4/5]  adding taggingStrategy as a property

---
 .../IBTaggingTruthTaggingTool.h               |  4 +-
 .../Root/BTaggingTruthTaggingTool.cxx         | 45 ++++++++++++-------
 .../BTaggingTruthTaggingTool.h                | 16 +++----
 3 files changed, 39 insertions(+), 26 deletions(-)

diff --git a/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h b/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
index 71e15c0b9342..d1acfeba4b77 100644
--- a/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
+++ b/PhysicsAnalysis/Interfaces/FTagAnalysisInterfaces/FTagAnalysisInterfaces/IBTaggingTruthTaggingTool.h
@@ -47,9 +47,9 @@ class IBTaggingTruthTaggingTool : virtual public CP::ISystematicsTool {
    ...
   }
   */
-  virtual StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed=-1)=0 ;
+  virtual StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, int rand_seed=-1)=0 ;
     
-  virtual StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed=-1)=0;
+  virtual StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, int rand_seed=-1)=0;
 
 };
 #endif // CPIBTAGGINGTRUTHTAGGINGTOOL_H
diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
index 377ea2eb528d..74df2f1d0901 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
@@ -80,7 +80,8 @@ BTaggingTruthTaggingTool::BTaggingTruthTaggingTool( const std::string & name)
   declareProperty("doDirectTagging",                  m_doDirectTag = false ,    "If set to true it also computes and stores the direct tagging choice and the related SFs for each jet");
       
   // if it is empty, the onnx tool won't be initialised
-  declareProperty( "pathToONNX",                      m_pathToONNX = "",          "path to the onnx file that will be used for inference");
+  declareProperty( "pathToONNX",                     m_pathToONNX = "",          "path to the onnx file that will be used for inference");
+  declareProperty( "TaggingStrategy",                m_taggingStrategy = "AllJets",     "tagging strategy in the Analysis (eg. 'leading2SignalJets' in boosted VHbb). Required to do TT with GNN");
 }
 
 StatusCode BTaggingTruthTaggingTool::setEffMapIndex(const std::string& flavour, unsigned int index){
@@ -304,6 +305,12 @@ StatusCode BTaggingTruthTaggingTool::initialize() {
     } //loop
   } //quantile
 
+  // "AllJets" is the default strategy
+  if ((m_taggingStrategy != "AllJets") && (m_taggingStrategy != "Leading2SignalJets")){
+    ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool");
+    return StatusCode::FAILURE;    
+  }
+
   return StatusCode::SUCCESS;
 }
 
@@ -475,11 +482,11 @@ float BTaggingTruthTaggingTool::getPermutationRW(TRFinfo &trfinf,bool isIncl, un
 
 
 
-StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys, const TString &strategy){
+StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
-  if(sys==0) ANA_CHECK(getAllEffMC(trfinf, strategy));
+  if(sys==0) ANA_CHECK(getAllEffMC(trfinf));
   ANA_CHECK(check_syst_range(sys));
   if(trfinf.trfwsys_ex.size()==0)  trfinf.trfwsys_ex.resize(m_eff_syst.size());
   if(trfinf.trfwsys_in.size()==0) trfinf.trfwsys_in.resize(m_eff_syst.size());
@@ -515,7 +522,7 @@ StatusCode BTaggingTruthTaggingTool::GetTruthTagWeights(TRFinfo &trfinf, std::ve
 }
 
 
-StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results, int rand_seed, const TString &strategy){
+StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results, int rand_seed){
   ANA_CHECK_SET_TYPE (StatusCode);
   results.clear();
 
@@ -541,7 +548,7 @@ StatusCode BTaggingTruthTaggingTool::CalculateResults(TRFinfo &trfinf, Analysis:
     trf_weight_ex.clear();
     trf_weight_in.clear();
 
-    ANA_CHECK(GetTruthTagWeights(trfinf, trf_weight_ex, trf_weight_in, i, strategy));
+    ANA_CHECK(GetTruthTagWeights(trfinf, trf_weight_ex, trf_weight_in, i));
 
   }
 
@@ -604,7 +611,7 @@ StatusCode BTaggingTruthTaggingTool::CalculateResults(const xAOD::JetContainer&
 }
 
 // setting inputs that the onnx tool will use
-StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw, Analysis::TruthTagResults& results, const TString &strategy, int rand_seed){
+StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw, Analysis::TruthTagResults& results, int rand_seed){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
@@ -612,11 +619,11 @@ StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const std::vector<std:
     
   ANA_CHECK(setJets(trfinf, node_feat, tagw));
 
-  return CalculateResults(trfinf, results, rand_seed, strategy);
+  return CalculateResults(trfinf, results, rand_seed);
 }
 
 // setting inputs that the onnx tool will use
-StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, const TString &strategy, int rand_seed){
+StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, int rand_seed){
 
   ANA_CHECK_SET_TYPE (StatusCode);
 
@@ -624,20 +631,20 @@ StatusCode BTaggingTruthTaggingTool::CalculateResultsONNX(const xAOD::JetContain
 
   ANA_CHECK(setJets(trfinf, jets, node_feat));
 
-  return CalculateResults(trfinf, results, rand_seed, strategy);
+  return CalculateResults(trfinf, results, rand_seed);
 }
 
-StatusCode BTaggingTruthTaggingTool::getAllEffMC(TRFinfo &trfinf, const TString &strategy){
+StatusCode BTaggingTruthTaggingTool::getAllEffMC(TRFinfo &trfinf){
   if (trfinf.node_feat.size() == 0){
     return getAllEffMCCDI(trfinf);
   } else {
 
-    return getAllEffMCGNN(trfinf, strategy);
+    return getAllEffMCGNN(trfinf);
   }
 }
 
 // uses onnx tool (no support for m_useQuantile now)
-StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf, const TString &strategy){
+StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf){
 
   trfinf.effMC.clear();
   if(m_useQuntile == true || m_continuous == true){
@@ -658,14 +665,17 @@ StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf, const TStri
       
       // need to transpose
       std::vector<float> tmp_effMC_oneOP; // shape:{num_jet}
-      if (strategy == "Leading2SignalJets"){
+      if (m_taggingStrategy == "Leading2SignalJets"){
         for (int jet_index=0; jet_index<2; jet_index++){
           tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]);
         }
-      } else {
+      } else if (m_taggingStrategy == "AllJets") {
         for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
           tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]);
         }
+      } else {
+        ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool");
+        return StatusCode::FAILURE;    
       }
       trfinf.effMC_allOP[op_appo] = tmp_effMC_oneOP;
       OP_index++;
@@ -684,16 +694,19 @@ StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf, const TStri
         return StatusCode::FAILURE;
       }
 
-      if (strategy == "Leading2SignalJets"){
+      if (m_taggingStrategy == "Leading2SignalJets"){
         for (int jet_index=0; jet_index<2; jet_index++){
           float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
           trfinf.effMC.push_back(tmp_effMC);
         }
-      } else {
+      } else if (m_taggingStrategy == "AllJets") {
         for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
           float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
           trfinf.effMC.push_back(tmp_effMC);
         }
+      } else {
+        ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool");
+        return StatusCode::FAILURE;    
       }
     } // !m_useQuantile
   } // !m_continuous
diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
index 30a7658fa95c..06fbb74ccbd8 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/BTaggingTruthTaggingTool.h
@@ -92,15 +92,15 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
   BTaggingTruthTaggingTool( const std::string& name );
 
   private:
-  StatusCode CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results,int rand_seed = -1, const TString &strategy="");
+  StatusCode CalculateResults(TRFinfo &trfinf, Analysis::TruthTagResults& results,int rand_seed = -1);
             
   public:
   StatusCode CalculateResults( std::vector<float>& pt, std::vector<float>& eta, std::vector<int>& flav, std::vector<float>& tagw, Analysis::TruthTagResults& results,int rand_seed = -1);
   StatusCode CalculateResults( const xAOD::JetContainer& jets, Analysis::TruthTagResults& results,int rand_seed = -1);
         
   // will use onnxtool
-  StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed=-1);
-  StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, const TString &strategy="", int rand_seed = -1);
+  StatusCode CalculateResultsONNX( const std::vector<std::vector<float>>& node_feat, std::vector<float>& tagw,  Analysis::TruthTagResults& results, int rand_seed=-1);
+  StatusCode CalculateResultsONNX( const xAOD::JetContainer& jets, const std::vector<std::vector<float>>& node_feat, Analysis::TruthTagResults& results, int rand_seed = -1);
 
   StatusCode setEffMapIndex(const std::string& flavour, unsigned int index);
   void setUseSystematics(bool useSystematics);
@@ -130,7 +130,7 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
             
   // get truth tagging weights
   // for one single systematic (including "Nominal")
-  StatusCode GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys=0, const TString &strategy="");
+  StatusCode GetTruthTagWeights(TRFinfo &trfinf, std::vector<float> &trf_weight_ex, std::vector<float> &trf_weight_in, int sys=0);
 
   // tag permutation: trf_chosen_perm_ex.at(ntag).at(i) tells if the i-th jet is tagged in a selection requiring == ntag tags
   StatusCode getTagPermutation(TRFinfo &trfinf, std::vector<std::vector<bool> > &trf_chosen_perm_ex, std::vector<std::vector<bool> > &trf_chosen_perm_in);
@@ -163,9 +163,9 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
 
   StatusCode getTRFweight(TRFinfo &trfinf,unsigned int nbtag, bool isInclusive, int sys);
 
-  StatusCode getAllEffMC(TRFinfo &trfinf, const TString &strategy="");
+  StatusCode getAllEffMC(TRFinfo &trfinf);
   StatusCode getAllEffMCCDI(TRFinfo &trfinf);
-  StatusCode getAllEffMCGNN(TRFinfo &trfinf, const TString &strategy="");
+  StatusCode getAllEffMCGNN(TRFinfo &trfinf);
             
   StatusCode getAllEffSF(TRFinfo &trfinf,int =0);
   std::vector<CP::SystematicSet> m_eff_syst;
@@ -232,8 +232,8 @@ class BTaggingTruthTaggingTool: public asg::AsgTool,
   bool m_doDirectTag;
   /// if this string is empty, the onnx tool won't be used
   std::string m_pathToONNX;
-            
-            
+  /// tagging strategy is required to do TT with GNN, when we don't want to truth tag all the jets (eg. 'leading2SignalJets')          
+  std::string m_taggingStrategy;            
 
   //*********************************//
   // Prop. of BTaggingSelectionTool  //
-- 
GitLab


From 850160c320ee9b2cea82f844fc5b7176b5d21680 Mon Sep 17 00:00:00 2001
From: Nilotpal Kakati <nilotpal.kakati@cern.ch>
Date: Tue, 27 Jul 2021 10:57:24 +0300
Subject: [PATCH 5/5] useQuantile support for GNN

---
 .../Root/BTaggingTruthTaggingTool.cxx         | 106 +++++++++++-------
 1 file changed, 63 insertions(+), 43 deletions(-)

diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
index 74df2f1d0901..f3b81b49c61e 100644
--- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
+++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx
@@ -190,23 +190,18 @@ StatusCode BTaggingTruthTaggingTool::initialize() {
     m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin();
   }
   else{
-    if (m_pathToONNX != ""){
-      if (m_useQuntile){
-        ATH_MSG_ERROR("BTaggingTruthTaggingTool::TruthTagging with GNN doesn't support m_useQuntile=true yet");
-        return StatusCode::FAILURE;      
-      } else {
-        // 60% = 4, 70% = 3, 77% = 2, 85% = 1, 100% = 0
-        m_OP_index_for_GNN = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin() + 1; // GNN predicts 5 bins        
+    if(m_useQuntile){
+      m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_OP) - m_availableOP.begin();
+      if(m_OperatingPoint_index >= m_availableOP.size()) {
+        ATH_MSG_ERROR(m_OP << " not in the list of available OPs");
+        return StatusCode::FAILURE;
       }
-    } else {
-      if(m_useQuntile){
-        m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_OP) - m_availableOP.begin();
-        if(m_OperatingPoint_index >= m_availableOP.size()) {
-          ATH_MSG_ERROR(m_OP << " not in the list of available OPs");
-          return StatusCode::FAILURE;
-        }
-      } // m_useQuantile
-    } // !ONNX
+    } // m_useQuantile
+
+    if (m_pathToONNX != ""){
+      // 60% = 4, 70% = 3, 77% = 2, 85% = 1, 100% = 0
+      m_OP_index_for_GNN = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin() + 1; // GNN predicts 5 bins        
+    }
   }
   
   m_eff_syst.clear();
@@ -431,14 +426,22 @@ StatusCode BTaggingTruthTaggingTool::setJets(TRFinfo &trfinf,std::vector<int>& f
     return StatusCode::FAILURE;
   }
   trfinf.jets.clear();
-  for(unsigned int i =0; i<vars.size(); i++){
+
+  if (m_taggingStrategy == "Leading2SignalJets"){
+    trfinf.njets = std::min(2, static_cast<int>(vars.size()));
+  } else if (m_taggingStrategy == "AllJets"){
+    trfinf.njets = vars.size();
+  } else {
+    ATH_MSG_ERROR("BTaggingTruthTaggingTool::setJets tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool");
+    return StatusCode::FAILURE;  
+  }
+
+  for(unsigned int i =0; i<trfinf.njets; i++){
     jetVariable jetVar_appo;
     jetVar_appo.flav=flav.at(i);
     jetVar_appo.vars=vars.at(i);
     trfinf.jets.push_back(jetVar_appo);
-
   }
-  trfinf.njets=trfinf.jets.size();
   trfinf.node_feat = node_feat;
   return StatusCode::SUCCESS;
 }
@@ -682,35 +685,52 @@ StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf){
     }
   } // m_continuous
   else {
-    if (m_useQuntile){
-      ATH_MSG_ERROR("BTaggingTruthTaggingTool::getMCEfficiencyONNX doesn't support m_useQuntile=true yet");
-      return StatusCode::FAILURE;      
-    } // m_useQuantile
-    else {
-      std::vector<std::vector<float>> tmp_effMC_allOP; // shape:{num_jets, num_wp}
-      CorrectionCode code = m_effTool->getMCEfficiencyONNX(trfinf.node_feat, tmp_effMC_allOP);
-      if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){
-        ATH_MSG_ERROR("BTaggingEfficiencyTool::getMCEfficiencyONNX returned CorrectionCode::Error");
-        return StatusCode::FAILURE;
+    std::vector<std::vector<float>> tmp_effMC_allOP; // shape:{num_jets, num_wp}
+    CorrectionCode code = m_effTool->getMCEfficiencyONNX(trfinf.node_feat, tmp_effMC_allOP);
+    if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){
+      ATH_MSG_ERROR("BTaggingEfficiencyTool::getMCEfficiencyONNX returned CorrectionCode::Error");
+      return StatusCode::FAILURE;
+    }
+    
+    if (m_taggingStrategy == "Leading2SignalJets"){
+      for (int jet_index=0; jet_index<2; jet_index++){
+        float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
+        trfinf.effMC.push_back(tmp_effMC);
+      }
+    } else if (m_taggingStrategy == "AllJets") {
+      for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
+        float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
+        trfinf.effMC.push_back(tmp_effMC);
       }
+    } else {
+      ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool");
+      return StatusCode::FAILURE;    
+    }
 
-      if (m_taggingStrategy == "Leading2SignalJets"){
-        for (int jet_index=0; jet_index<2; jet_index++){
-          float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
-          trfinf.effMC.push_back(tmp_effMC);
-        }
-      } else if (m_taggingStrategy == "AllJets") {
-        for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
-          float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0);
-          trfinf.effMC.push_back(tmp_effMC);
+    if (m_useQuntile){
+      int OP_index=0;
+      for(const auto & op_appo: m_availableOP){
+        
+        // need to transpose
+        std::vector<float> tmp_effMC_oneOP; // shape:{num_jet}
+        if (m_taggingStrategy == "Leading2SignalJets"){
+          for (int jet_index=0; jet_index<2; jet_index++){
+            tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]);
+          }
+        } else if (m_taggingStrategy == "AllJets") {
+          for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){
+            tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]);
+          }
+        } else {
+          ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool");
+          return StatusCode::FAILURE;    
         }
-      } else {
-        ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool");
-        return StatusCode::FAILURE;    
+        trfinf.effMC_allOP[op_appo] = tmp_effMC_oneOP;
+        OP_index++;
       }
-    } // !m_useQuantile
+    } // m_useQuantile
   } // !m_continuous
-        
+         
   return StatusCode::SUCCESS;
 }
 
-- 
GitLab