diff --git a/Reconstruction/tauRec/python/tauRecFlags.py b/Reconstruction/tauRec/python/tauRecFlags.py
index 68195039e145c2e0948960ace54f4d16fe96b575..50fa8fd26ef5fb1c2327f84767880a4880346ef1 100644
--- a/Reconstruction/tauRec/python/tauRecFlags.py
+++ b/Reconstruction/tauRec/python/tauRecFlags.py
@@ -118,7 +118,7 @@ class tauRecDecayModeNNClassifierConfig(JobProperty):
     """
     statusOn=True
     allowedTypes=['string']
-    StoredValue='NNDecayModeWeights-20200625.json'
+    StoredValue='NNDecayMode_R22_v1.json'
 
 class tauRecCalibrateLCConfig(JobProperty):
     """Config file for TauCalibrateLC
diff --git a/Reconstruction/tauRecTools/Root/TauDecayModeNNClassifier.cxx b/Reconstruction/tauRecTools/Root/TauDecayModeNNClassifier.cxx
index 0abe3c721103ba3fd19b3a1904184698661d29fd..47db6f668bf98892d8012c70ad0e28ea13ae657a 100644
--- a/Reconstruction/tauRecTools/Root/TauDecayModeNNClassifier.cxx
+++ b/Reconstruction/tauRecTools/Root/TauDecayModeNNClassifier.cxx
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
 */
 
 // local include(s)
@@ -30,7 +30,7 @@ TauDecayModeNNClassifier::TauDecayModeNNClassifier(const std::string &name)
   declareProperty("OutputName", m_outputName = "NNDecayMode");
   declareProperty("ProbPrefix", m_probPrefix = "NNDecayModeProb_");
   declareProperty("WeightFile", m_weightFile = "");
-  declareProperty("MaxChargedPFOs", m_maxChargedPFOs = 3);
+  declareProperty("MaxTauTracks", m_maxTauTracks = 3);
   declareProperty("MaxNeutralPFOs", m_maxNeutralPFOs = 8);
   declareProperty("MaxShotPFOs", m_maxShotPFOs = 6);
   declareProperty("MaxConvTracks", m_maxConvTracks = 4);
@@ -92,7 +92,7 @@ StatusCode TauDecayModeNNClassifier::execute(xAOD::TauJet &xTau) const
   //
   InputMap inputMapDummy;
   InputSequenceMap inputSeqMap;
-  std::set<std::string> branches = {"ChargedPFO", "NeutralPFO", "ShotPFO", "ConvTrack"};
+  std::set<std::string> branches = {"TauTrack", "NeutralPFO", "ShotPFO", "ConvTrack"};
   DMHelper::initMapKeys(inputSeqMap, branches);
 
   ATH_CHECK(getInputs(xTau, inputSeqMap));
@@ -172,7 +172,7 @@ StatusCode TauDecayModeNNClassifier::finalize()
 
 StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSequenceMap &inputSeqMap) const
 {
-  std::vector<PFOPtr> vChargedPFOs;
+  std::vector<TrkPtr> vTauTracks;
   std::vector<PFOPtr> vNeutralPFOs;
   std::vector<PFOPtr> vShotPFOs;
   std::vector<TrkPtr> vConvTracks;
@@ -180,12 +180,8 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
   // set objects
   // -----------
 
-  // charged PFOs
-  for (std::size_t i = 0; i < xTau.nChargedPFOs(); ++i)
-  {
-    const auto pfo = xTau.chargedPFO(i);
-    vChargedPFOs.push_back(pfo);
-  }
+  // classified tau tracks
+  vTauTracks = xTau.tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged);
 
   // neutral PFOs
   for (std::size_t i = 0; i < xTau.nNeutralPFOs(); ++i)
@@ -217,10 +213,10 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
     vShotPFOs.push_back(pfo);
   }
 
-  // conversion tracks
+  // classified conversion tracks
   vConvTracks = xTau.tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedConversion);
 
-  DMHelper::sortAndKeep<PFOPtr>(vChargedPFOs, m_maxChargedPFOs);
+  DMHelper::sortAndKeep<TrkPtr>(vTauTracks, m_maxTauTracks);
   DMHelper::sortAndKeep<PFOPtr>(vNeutralPFOs, m_maxNeutralPFOs);
   DMHelper::sortAndKeep<PFOPtr>(vShotPFOs, m_maxShotPFOs);
   DMHelper::sortAndKeep<TrkPtr>(vConvTracks, m_maxConvTracks);
@@ -265,6 +261,14 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
     in_seq_map["jetpt_log"].push_back(DMHelper::Log10Robust(tau_p4.Pt()));
   };
 
+  // a function to set the track impact parameter variables
+  auto setTrackIPVars = [](VectorMap &in_seq_map, const TrkPtr &trk) {
+    in_seq_map["d0TJVA"].push_back(trk->d0TJVA());
+    in_seq_map["d0SigTJVA"].push_back(trk->d0SigTJVA());
+    in_seq_map["z0sinthetaTJVA"].push_back(trk->z0sinthetaTJVA());
+    in_seq_map["z0sinthetaSigTJVA"].push_back(trk->z0sinthetaSigTJVA());
+  };
+
   // a function to set the neutral pfo variables
   auto setNeutralPFOVars = [](VectorMap &in_seq_map, const PFOPtr &pfo) {
     // get the attributes of a given PFO object
@@ -280,20 +284,23 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
     in_seq_map["SECOND_ENG_DENS_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_SECOND_ENG_DENS), 1e-6f));
     in_seq_map["NPosECells_EM1"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM1));
     in_seq_map["NPosECells_EM2"].push_back(getAttrInt(PFOAttributes::cellBased_NPosECells_EM2));
+    in_seq_map["energy_EM1"].push_back(getAttr(PFOAttributes::cellBased_energy_EM1));
+    in_seq_map["energy_EM2"].push_back(getAttr(PFOAttributes::cellBased_energy_EM2));
     in_seq_map["EM1CoreFrac"].push_back(getAttr(PFOAttributes::cellBased_EM1CoreFrac));
     in_seq_map["firstEtaWRTClusterPosition_EM1"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1));
+    in_seq_map["firstEtaWRTClusterPosition_EM2"].push_back(getAttr(PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM2));
     in_seq_map["secondEtaWRTClusterPosition_EM1_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM1), 1e-6f));
     in_seq_map["secondEtaWRTClusterPosition_EM2_log"].push_back(DMHelper::Log10Robust(getAttr(PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2), 1e-6f));
-    in_seq_map["ptSubRatio_logabs"].push_back(DMHelper::Log10Robust(TMath::Abs(DMVar::ptSubRatio(pfo)), 1e-6f));
-    in_seq_map["energyfrac_EM2"].push_back(DMVar::energyFracEM2(pfo, getAttr(PFOAttributes::cellBased_energy_EM2)));
   };
 
-  // set Charged PFOs variables
-  VectorMap &chrg_map = inputSeqMap.at("ChargedPFO");
+  // set tau tracks variables
+  VectorMap &chrg_map = inputSeqMap.at("TauTrack");
   DMHelper::initMapKeys(chrg_map, DMVar::sCommonP4Vars);
-  for (const auto &pfo : vChargedPFOs)
+  DMHelper::initMapKeys(chrg_map, DMVar::sTrackIPVars);
+  for (const auto &trk : vTauTracks)
   {
-    setCommonP4Vars(chrg_map, pfo->p4());
+    setCommonP4Vars(chrg_map, trk->p4());
+    setTrackIPVars(chrg_map, trk);
   }
 
   // set Neutral PFOs variables
@@ -325,9 +332,11 @@ StatusCode TauDecayModeNNClassifier::getInputs(const xAOD::TauJet &xTau, InputSe
   // set Conversion tracks variables
   VectorMap &conv_map = inputSeqMap.at("ConvTrack");
   DMHelper::initMapKeys(conv_map, DMVar::sCommonP4Vars);
+  DMHelper::initMapKeys(conv_map, DMVar::sTrackIPVars);
   for (const auto &trk : vConvTracks)
   {
     setCommonP4Vars(conv_map, trk->p4());
+    setTrackIPVars(conv_map, trk);
   }
 
   return StatusCode::SUCCESS;
@@ -339,10 +348,14 @@ namespace tauRecTools
   const std::set<std::string> TauDecayModeNNVariable::sCommonP4Vars = {
       "dphiECal", "detaECal", "dphi", "deta", "pt_log", "jetpt_log"};
 
+  const std::set<std::string> TauDecayModeNNVariable::sTrackIPVars = {
+      "d0TJVA", "d0SigTJVA", "z0sinthetaTJVA", "z0sinthetaSigTJVA"};
+
   const std::set<std::string> TauDecayModeNNVariable::sNeutralPFOVars = {
       "FIRST_ETA", "SECOND_R_log", "DELTA_THETA", "CENTER_LAMBDA_log", "LONGITUDINAL", "ENG_FRAC_CORE",
-      "SECOND_ENG_DENS_log", "NPosECells_EM1", "NPosECells_EM2", "EM1CoreFrac", "firstEtaWRTClusterPosition_EM1",
-      "secondEtaWRTClusterPosition_EM1_log", "secondEtaWRTClusterPosition_EM2_log", "ptSubRatio_logabs", "energyfrac_EM2"};
+      "SECOND_ENG_DENS_log", "NPosECells_EM1", "NPosECells_EM2", "energy_EM1", "energy_EM2", "EM1CoreFrac", 
+      "firstEtaWRTClusterPosition_EM1", "firstEtaWRTClusterPosition_EM2", 
+      "secondEtaWRTClusterPosition_EM1_log", "secondEtaWRTClusterPosition_EM2_log"};
 
   const std::array<std::string, TauDecayModeNNVariable::nClasses> TauDecayModeNNVariable::sModeNames = {
       "1p0n", "1p1n", "1pXn", "3p0n", "3pXn"};
@@ -375,7 +388,7 @@ namespace tauRecTools
     T val{static_cast<T>(0)};
     if (!pfo->attribute(attr, val))
     {
-      throw std::runtime_error("Can not retrieve PFO attribute!");
+      throw std::runtime_error("Can not retrieve PFO attribute! enum = " + std::to_string(static_cast<unsigned>(attr)));
     }
     return val;
   }
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauDecayModeNNClassifier.h b/Reconstruction/tauRecTools/tauRecTools/TauDecayModeNNClassifier.h
index a83cd9b0be870f48cfc3c1b892f8e26e051b9e8b..ef6e5ede2d7d2acc1821b2206d5b3878098cdd10 100644
--- a/Reconstruction/tauRecTools/tauRecTools/TauDecayModeNNClassifier.h
+++ b/Reconstruction/tauRecTools/tauRecTools/TauDecayModeNNClassifier.h
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration
 */
 
 #ifndef TAURECTOOLS_TAUDECAYMODENNCLASSIFIER_H
@@ -46,7 +46,7 @@ private:
   std::string m_outputName;      //!
   std::string m_probPrefix;      //!
   std::string m_weightFile;      //!
-  std::size_t m_maxChargedPFOs;  //!
+  std::size_t m_maxTauTracks;  //!
   std::size_t m_maxNeutralPFOs;  //!
   std::size_t m_maxShotPFOs;     //!
   std::size_t m_maxConvTracks;   //!
@@ -78,6 +78,7 @@ namespace tauRecTools
     TauDecayModeNNVariable() = delete;
     static const std::size_t nClasses = 5;
     static const std::set<std::string> sCommonP4Vars;
+    static const std::set<std::string> sTrackIPVars;
     static const std::set<std::string> sNeutralPFOVars;
     static const std::array<std::string, nClasses> sModeNames;
     static float deltaPhi(const TLorentzVector &p4, const TLorentzVector &p4_tau);