diff --git a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx index 74df2f1d09018a8940c05f66d70a163f268d028f..f3b81b49c61e63629e08ec76cdde67c1095b7ef1 100644 --- a/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx +++ b/PhysicsAnalysis/JetTagging/JetTagPerformanceCalibration/xAODBTaggingEfficiency/Root/BTaggingTruthTaggingTool.cxx @@ -190,23 +190,18 @@ StatusCode BTaggingTruthTaggingTool::initialize() { m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin(); } else{ - if (m_pathToONNX != ""){ - if (m_useQuntile){ - ATH_MSG_ERROR("BTaggingTruthTaggingTool::TruthTagging with GNN doesn't support m_useQuntile=true yet"); - return StatusCode::FAILURE; - } else { - // 60% = 4, 70% = 3, 77% = 2, 85% = 1, 100% = 0 - m_OP_index_for_GNN = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin() + 1; // GNN predicts 5 bins + if(m_useQuntile){ + m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_OP) - m_availableOP.begin(); + if(m_OperatingPoint_index >= m_availableOP.size()) { + ATH_MSG_ERROR(m_OP << " not in the list of available OPs"); + return StatusCode::FAILURE; } - } else { - if(m_useQuntile){ - m_OperatingPoint_index = find(m_availableOP.begin(), m_availableOP.end(), m_OP) - m_availableOP.begin(); - if(m_OperatingPoint_index >= m_availableOP.size()) { - ATH_MSG_ERROR(m_OP << " not in the list of available OPs"); - return StatusCode::FAILURE; - } - } // m_useQuantile - } // !ONNX + } // m_useQuantile + + if (m_pathToONNX != ""){ + // 60% = 4, 70% = 3, 77% = 2, 85% = 1, 100% = 0 + m_OP_index_for_GNN = find(m_availableOP.begin(), m_availableOP.end(), m_cutBenchmark) - m_availableOP.begin() + 1; // GNN predicts 5 bins + } } m_eff_syst.clear(); @@ -431,14 +426,22 @@ StatusCode BTaggingTruthTaggingTool::setJets(TRFinfo &trfinf,std::vector<int>& f return StatusCode::FAILURE; } trfinf.jets.clear(); - for(unsigned int i =0; i<vars.size(); i++){ + + if (m_taggingStrategy == "Leading2SignalJets"){ + trfinf.njets = std::min(2, static_cast<int>(vars.size())); + } else if (m_taggingStrategy == "AllJets"){ + trfinf.njets = vars.size(); + } else { + ATH_MSG_ERROR("BTaggingTruthTaggingTool::setJets tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool"); + return StatusCode::FAILURE; + } + + for(unsigned int i =0; i<trfinf.njets; i++){ jetVariable jetVar_appo; jetVar_appo.flav=flav.at(i); jetVar_appo.vars=vars.at(i); trfinf.jets.push_back(jetVar_appo); - } - trfinf.njets=trfinf.jets.size(); trfinf.node_feat = node_feat; return StatusCode::SUCCESS; } @@ -682,35 +685,52 @@ StatusCode BTaggingTruthTaggingTool::getAllEffMCGNN(TRFinfo &trfinf){ } } // m_continuous else { - if (m_useQuntile){ - ATH_MSG_ERROR("BTaggingTruthTaggingTool::getMCEfficiencyONNX doesn't support m_useQuntile=true yet"); - return StatusCode::FAILURE; - } // m_useQuantile - else { - std::vector<std::vector<float>> tmp_effMC_allOP; // shape:{num_jets, num_wp} - CorrectionCode code = m_effTool->getMCEfficiencyONNX(trfinf.node_feat, tmp_effMC_allOP); - if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){ - ATH_MSG_ERROR("BTaggingEfficiencyTool::getMCEfficiencyONNX returned CorrectionCode::Error"); - return StatusCode::FAILURE; + std::vector<std::vector<float>> tmp_effMC_allOP; // shape:{num_jets, num_wp} + CorrectionCode code = m_effTool->getMCEfficiencyONNX(trfinf.node_feat, tmp_effMC_allOP); + if(!(code==CorrectionCode::Ok || code==CorrectionCode::OutOfValidityRange)){ + ATH_MSG_ERROR("BTaggingEfficiencyTool::getMCEfficiencyONNX returned CorrectionCode::Error"); + return StatusCode::FAILURE; + } + + if (m_taggingStrategy == "Leading2SignalJets"){ + for (int jet_index=0; jet_index<2; jet_index++){ + float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0); + trfinf.effMC.push_back(tmp_effMC); + } + } else if (m_taggingStrategy == "AllJets") { + for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){ + float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0); + trfinf.effMC.push_back(tmp_effMC); } + } else { + ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool"); + return StatusCode::FAILURE; + } - if (m_taggingStrategy == "Leading2SignalJets"){ - for (int jet_index=0; jet_index<2; jet_index++){ - float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0); - trfinf.effMC.push_back(tmp_effMC); - } - } else if (m_taggingStrategy == "AllJets") { - for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){ - float tmp_effMC = std::accumulate(tmp_effMC_allOP[jet_index].begin()+m_OP_index_for_GNN, tmp_effMC_allOP[jet_index].end(), 0.0); - trfinf.effMC.push_back(tmp_effMC); + if (m_useQuntile){ + int OP_index=0; + for(const auto & op_appo: m_availableOP){ + + // need to transpose + std::vector<float> tmp_effMC_oneOP; // shape:{num_jet} + if (m_taggingStrategy == "Leading2SignalJets"){ + for (int jet_index=0; jet_index<2; jet_index++){ + tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]); + } + } else if (m_taggingStrategy == "AllJets") { + for (int jet_index=0; jet_index<static_cast<int>(tmp_effMC_allOP.size()); jet_index++){ + tmp_effMC_oneOP.push_back(tmp_effMC_allOP[jet_index][OP_index]); + } + } else { + ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool"); + return StatusCode::FAILURE; } - } else { - ATH_MSG_ERROR("BTaggingTruthTaggingTool::tagging strategy " << m_taggingStrategy << " is not implemented in the TruthTagging Tool"); - return StatusCode::FAILURE; + trfinf.effMC_allOP[op_appo] = tmp_effMC_oneOP; + OP_index++; } - } // !m_useQuantile + } // m_useQuantile } // !m_continuous - + return StatusCode::SUCCESS; }