Commit 3d2fac2c authored by Bowen Zhang's avatar Bowen Zhang Committed by Walter Lampl
Browse files

tauRec: update input variables for tau decay mode classifier

parent 218bd80e
......@@ -118,7 +118,7 @@ class tauRecDecayModeNNClassifierConfig(JobProperty):
"""
statusOn=True
allowedTypes=['string']
StoredValue='NNDecayModeWeights-20200625.json'
StoredValue='NNDecayMode_R22_v1.json'
class tauRecCalibrateLCConfig(JobProperty):
"""Config file for TauCalibrateLC
......
/*
Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
// local include(s)
......@@ -30,7 +30,7 @@ TauDecayModeNNClassifier::TauDecayModeNNClassifier(const std::string &name)
declareProperty("OutputName", m_outputName = "NNDecayMode");
declareProperty("ProbPrefix", m_probPrefix = "NNDecayModeProb_");
declareProperty("WeightFile", m_weightFile = "");
declareProperty("MaxChargedPFOs", m_maxChargedPFOs = 3);
declareProperty("MaxTauTracks", m_maxTauTracks = 3);
declareProperty("MaxNeutralPFOs", m_maxNeutralPFOs = 8);
declareProperty("MaxShotPFOs", m_maxShotPFOs = 6);
declareProperty("MaxConvTracks", m_maxConvTracks = 4);
......@@ -92,7 +92,7 @@ StatusCode TauDecayModeNNClassifier::execute(xAOD::TauJet &xTau) const
//
InputMap inputMapDummy;
InputSequenceMap inputSeqMap;
std::set<std::string> branches = {"ChargedPFO", "NeutralPFO", "ShotPFO", "ConvTrack"};
std::set<std::string> branches = {"TauTrack", "NeutralPFO", "ShotPFO", "ConvTrack"};
DMHelper::initMapKeys(inputSeqMap, branches);
ATH_CHECK(getInputs(xTau, inputSeqMap));
......@@ -172,7 +172,7 @@ StatusCode TauDecayModeNNClassifier::finalize()
StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSequenceMap &inputSeqMap) const
{
std::vector<PFOPtr> vChargedPFOs;
std::vector<TrkPtr> vTauTracks;
std::vector<PFOPtr> vNeutralPFOs;
std::vector<PFOPtr> vShotPFOs;
std::vector<TrkPtr> vConvTracks;
......@@ -180,12 +180,8 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
// set objects
// -----------
// charged PFOs
for (std::size_t i = 0; i < xTau.nChargedPFOs(); ++i)
{
const auto pfo = xTau.chargedPFO(i);
vChargedPFOs.push_back(pfo);
}
// classified tau tracks
vTauTracks = xTau.tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged);
// neutral PFOs
for (std::size_t i = 0; i < xTau.nNeutralPFOs(); ++i)
......@@ -217,10 +213,10 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
vShotPFOs.push_back(pfo);
}
// conversion tracks
// classified conversion tracks
vConvTracks = xTau.tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedConversion);
DMHelper::sortAndKeep<PFOPtr>(vChargedPFOs, m_maxChargedPFOs);
DMHelper::sortAndKeep<TrkPtr>(vTauTracks, m_maxTauTracks);
DMHelper::sortAndKeep<PFOPtr>(vNeutralPFOs, m_maxNeutralPFOs);
DMHelper::sortAndKeep<PFOPtr>(vShotPFOs, m_maxShotPFOs);
DMHelper::sortAndKeep<TrkPtr>(vConvTracks, m_maxConvTracks);
......@@ -265,6 +261,14 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
in_seq_map["jetpt_log"].push_back(DMHelper::Log10Robust(tau_p4.Pt()));
};
// a function to set the track impact parameter variables
auto setTrackIPVars = [](VectorMap &in_seq_map, const TrkPtr &trk) {
in_seq_map["d0TJVA"].push_back(trk->d0TJVA());
in_seq_map["d0SigTJVA"].push_back(trk->d0SigTJVA());
in_seq_map["z0sinthetaTJVA"].push_back(trk->z0sinthetaTJVA());
in_seq_map["z0sinthetaSigTJVA"].push_back(trk->z0sinthetaSigTJVA());
};
// a function to set the neutral pfo variables
auto setNeutralPFOVars = [](VectorMap &in_seq_map, const PFOPtr &pfo) {
// get the attributes of a given PFO object
......@@ -280,20 +284,23 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
in_seq_map["SECOND_ENG_DENS_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_SECOND_ENG_DENS), 1e-6f));
in_seq_map["NPosECells_EM1"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM1));
in_seq_map["NPosECells_EM2"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM2));
in_seq_map["energy_EM1"].push_back(getAttr(PFOAttributes::cellBased_energy_EM1));
in_seq_map["energy_EM2"].push_back(getAttr(PFOAttributes::cellBased_energy_EM2));
in_seq_map["EM1CoreFrac"].push_back(getAttr(PFOAttributes::cellBased_EM1CoreFrac));
in_seq_map["firstEtaWRTClusterPosition_EM1"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1));
in_seq_map["firstEtaWRTClusterPosition_EM2"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM2));
in_seq_map["secondEtaWRTClusterPosition_EM1_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM1), 1e-6f));
in_seq_map["secondEtaWRTClusterPosition_EM2_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2), 1e-6f));
in_seq_map["ptSubRatio_logabs"].push_back(DMHelper::Log10Robust(TMath::Abs(DMVar::ptSubRatio(pfo)), 1e-6f));
in_seq_map["energyfrac_EM2"].push_back(DMVar::energyFracEM2(pfo, getAttr(PFOAttributes::cellBased_energy_EM2)));
};
// set Charged PFOs variables
VectorMap &chrg_map = inputSeqMap.at("ChargedPFO");
// set tau tracks variables
VectorMap &chrg_map = inputSeqMap.at("TauTrack");
DMHelper::initMapKeys(chrg_map, DMVar::sCommonP4Vars);
for (const auto &pfo : vChargedPFOs)
DMHelper::initMapKeys(chrg_map, DMVar::sTrackIPVars);
for (const auto &trk : vTauTracks)
{
setCommonP4Vars(chrg_map, pfo->p4());
setCommonP4Vars(chrg_map, trk->p4());
setTrackIPVars(chrg_map, trk);
}
// set Neutral PFOs variables
......@@ -325,9 +332,11 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
// set Conversion tracks variables
VectorMap &conv_map = inputSeqMap.at("ConvTrack");
DMHelper::initMapKeys(conv_map, DMVar::sCommonP4Vars);
DMHelper::initMapKeys(conv_map, DMVar::sTrackIPVars);
for (const auto &trk : vConvTracks)
{
setCommonP4Vars(conv_map, trk->p4());
setTrackIPVars(conv_map, trk);
}
return StatusCode::SUCCESS;
......@@ -339,10 +348,14 @@ namespace tauRecTools
const std::set<std::string> TauDecayModeNNVariable::sCommonP4Vars = {
"dphiECal", "detaECal", "dphi", "deta", "pt_log", "jetpt_log"};
const std::set<std::string> TauDecayModeNNVariable::sTrackIPVars = {
"d0TJVA", "d0SigTJVA", "z0sinthetaTJVA", "z0sinthetaSigTJVA"};
const std::set<std::string> TauDecayModeNNVariable::sNeutralPFOVars = {
"FIRST_ETA", "SECOND_R_log", "DELTA_THETA", "CENTER_LAMBDA_log", "LONGITUDINAL", "ENG_FRAC_CORE",
"SECOND_ENG_DENS_log", "NPosECells_EM1", "NPosECells_EM2", "EM1CoreFrac", "firstEtaWRTClusterPosition_EM1",
"secondEtaWRTClusterPosition_EM1_log", "secondEtaWRTClusterPosition_EM2_log", "ptSubRatio_logabs", "energyfrac_EM2"};
"SECOND_ENG_DENS_log", "NPosECells_EM1", "NPosECells_EM2", "energy_EM1", "energy_EM2", "EM1CoreFrac",
"firstEtaWRTClusterPosition_EM1", "firstEtaWRTClusterPosition_EM2",
"secondEtaWRTClusterPosition_EM1_log", "secondEtaWRTClusterPosition_EM2_log"};
const std::array<std::string, TauDecayModeNNVariable::nClasses> TauDecayModeNNVariable::sModeNames = {
"1p0n", "1p1n", "1pXn", "3p0n", "3pXn"};
......@@ -375,7 +388,7 @@ namespace tauRecTools
T val{static_cast<T>(0)};
if (!pfo->attribute(attr, val))
{
throw std::runtime_error("Can not retrieve PFO attribute!");
throw std::runtime_error("Can not retrieve PFO attribute! enum = " + std::to_string(static_cast<unsigned>(attr)));
}
return val;
}
......
/*
Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
*/
#ifndef TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
......@@ -46,7 +46,7 @@ private:
std::string m_outputName; //!
std::string m_probPrefix; //!
std::string m_weightFile; //!
std::size_t m_maxChargedPFOs; //!
std::size_t m_maxTauTracks; //!
std::size_t m_maxNeutralPFOs; //!
std::size_t m_maxShotPFOs; //!
std::size_t m_maxConvTracks; //!
......@@ -78,6 +78,7 @@ namespace tauRecTools
TauDecayModeNNVariable() = delete;
static const std::size_t nClasses = 5;
static const std::set<std::string> sCommonP4Vars;
static const std::set<std::string> sTrackIPVars;
static const std::set<std::string> sNeutralPFOVars;
static const std::array<std::string, nClasses> sModeNames;
static float deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment