From 354de4fa6f7995cb7195bfd2ef93e52c34881916 Mon Sep 17 00:00:00 2001
From: Christian Grefe <christian.grefe@cern.ch>
Date: Sat, 31 Oct 2020 18:14:09 +0100
Subject: [PATCH] Add TRT NN PID to track variables. Clean up of TRT track
 decorations.

---
 .../TRT_ConditionsData/TRTPIDNN.h             |  18 +-
 .../TRT_ConditionsData/src/TRTPIDNN.cxx       |  23 +-
 .../InDetRecExample/python/TrackingCommon.py  |   7 +-
 .../src/TRT_ElectronPidToolRun2.cxx           | 212 ++++++++++++------
 .../TRT_ElectronPidTools/src/TRT_ToT_dEdx.cxx |   5 +
 .../TrkTrackSummary/TrackSummary.h            |  14 +-
 .../src/TrackParticleCreatorTool.cxx          |   5 +-
 .../TrkTrackSummaryTool/TrackSummaryTool.h    |  11 -
 .../src/TrackSummaryTool.cxx                  |  57 +----
 9 files changed, 208 insertions(+), 144 deletions(-)

diff --git a/InnerDetector/InDetConditions/TRT_ConditionsData/TRT_ConditionsData/TRTPIDNN.h b/InnerDetector/InDetConditions/TRT_ConditionsData/TRT_ConditionsData/TRTPIDNN.h
index 2f44f176456..a916d1bd924 100644
--- a/InnerDetector/InDetConditions/TRT_ConditionsData/TRT_ConditionsData/TRTPIDNN.h
+++ b/InnerDetector/InDetConditions/TRT_ConditionsData/TRT_ConditionsData/TRTPIDNN.h
@@ -30,6 +30,14 @@ namespace InDet {
     TRTPIDNN()=default;
     virtual ~TRTPIDNN()=default;
 
+    std::string getDefaultOutputNode() const {
+      return m_outputNode;
+    }
+
+    std::string getDefaultOutputLabel() const {
+      return m_outputLabel;
+    }
+
     // get the structure of the scalar inputs to the NN
     std::map<std::string, std::map<std::string, double>> getScalarInputs() const {
       return m_scalarInputs;
@@ -40,9 +48,16 @@ namespace InDet {
       return m_vectorInputs;
     }
 
+    // calculate NN response for default output node and label
+    double evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
+             std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
+      return evaluate(scalarInputs, vectorInputs, m_outputNode, m_outputLabel);
+    }
+
     // calculate NN response
     double evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
-             std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const;
+             std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs,
+             const std::string& outputNode, const std::string& outputLabel) const;
 
     // set up the NN
     StatusCode configure(const std::string& json);
@@ -53,6 +68,7 @@ namespace InDet {
     std::map<std::string, std::map<std::string, double>> m_scalarInputs;  // template for the structure of the scalar inputs to the NN
     std::map<std::string, std::map<std::string, std::vector<double>>> m_vectorInputs;  // template for the structure of the vector inputs to the NN
     std::string m_outputNode;  // name of the output node of the NN
+    std::string m_outputLabel;  // name of the output label of the NN
 };
 }
 CLASS_DEF(InDet::TRTPIDNN,341715853,1)
diff --git a/InnerDetector/InDetConditions/TRT_ConditionsData/src/TRTPIDNN.cxx b/InnerDetector/InDetConditions/TRT_ConditionsData/src/TRTPIDNN.cxx
index 79ae8c9fdbb..55221248e0e 100644
--- a/InnerDetector/InDetConditions/TRT_ConditionsData/src/TRTPIDNN.cxx
+++ b/InnerDetector/InDetConditions/TRT_ConditionsData/src/TRTPIDNN.cxx
@@ -17,9 +17,16 @@
 #include "boost/property_tree/exceptions.hpp"
 
 double InDet::TRTPIDNN::evaluate(std::map<std::string, std::map<std::string, double>>& scalarInputs,
-        std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
-  const auto result = m_nn->compute(scalarInputs, vectorInputs);
-  return result.at("e_prob_0");
+        std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs,
+        const std::string& outputNode, const std::string& outputLabel) const {
+  MsgStream log(Athena::getMessageSvc(),"TRTPIDNN");
+  const auto result = m_nn->compute(scalarInputs, vectorInputs, outputNode);
+  const auto itResult = result.find(outputLabel);
+  if (itResult == result.end()) {
+    log << MSG::ERROR << " unable to find output: node=" << outputNode << ", label=" << outputLabel << endmsg;
+    return 0.5;
+  }
+  return itResult->second;
 }
 
 StatusCode InDet::TRTPIDNN::configure(const std::string& json) {
@@ -39,6 +46,14 @@ StatusCode InDet::TRTPIDNN::configure(const std::string& json) {
     return StatusCode::FAILURE;
   }
 
+  // set the default output node name
+  if (m_nnConfig.outputs.empty() or m_nnConfig.outputs.begin()->second.labels.empty()) {
+    log << MSG::ERROR << " unable to define NN output." << endmsg;
+    return StatusCode::FAILURE;
+  }
+  m_outputNode = m_nnConfig.outputs.begin()->first;
+  m_outputLabel = *(m_nnConfig.outputs[m_outputNode].labels.begin());
+
   // store templates of the structure of the inputs to the NN
   m_scalarInputs.clear();
   for (auto input : m_nnConfig.inputs) {
@@ -47,7 +62,6 @@ StatusCode InDet::TRTPIDNN::configure(const std::string& json) {
       m_scalarInputs[input.name][variable.name] = input.defaults[variable.name];
     }
   }
-
   m_vectorInputs.clear();
   for (auto input : m_nnConfig.input_sequences) {
     m_vectorInputs[input.name] = {};
@@ -56,6 +70,5 @@ StatusCode InDet::TRTPIDNN::configure(const std::string& json) {
     }
   }
 
-
   return StatusCode::SUCCESS;
 }
diff --git a/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py b/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py
index 9e51a95f703..68302473537 100644
--- a/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py
+++ b/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py
@@ -1088,7 +1088,6 @@ def getInDetTrackSummaryTool(name='InDetTrackSummaryTool',**kwargs) :
                          doSharedHits           = False,
                          doHolesInDet           = do_holes,
                          TRT_ElectronPidTool    = None,         # we don't want to use those tools during pattern
-                         TRT_ToT_dEdxTool       = None,         # dito
                          PixelToTPIDTool        = None)         # we don't want to use those tools during pattern
     from TrkTrackSummaryTool.TrkTrackSummaryToolConf import Trk__TrackSummaryTool
     return Trk__TrackSummaryTool(name = the_name, **kwargs)
@@ -1111,16 +1110,12 @@ def getInDetTrackSummaryToolSharedHits(name='InDetTrackSummaryToolSharedHits',**
     if 'TRT_ElectronPidTool' not in kwargs :
         kwargs = setDefaults( kwargs, TRT_ElectronPidTool    = getInDetTRT_ElectronPidTool())
 
-    if 'TRT_ToT_dEdxTool' not in kwargs :
-        kwargs = setDefaults( kwargs, TRT_ToT_dEdxTool       = getInDetTRT_dEdxTool())
-
     if 'PixelToTPIDTool' not in kwargs :
         kwargs = setDefaults( kwargs, PixelToTPIDTool        = getInDetPixelToTPIDTool())
 
     from InDetRecExample.InDetJobProperties import InDetFlags
     kwargs = setDefaults(kwargs,
-                         doSharedHits           = InDetFlags.doSharedHits(),
-                         minTRThitsForTRTdEdx   = 1)    # default is 1
+                         doSharedHits           = InDetFlags.doSharedHits())
 
     return getInDetTrackSummaryTool( name, **kwargs)
 
diff --git a/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ElectronPidToolRun2.cxx b/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ElectronPidToolRun2.cxx
index 01c1fd75077..6c7ba6afe3e 100644
--- a/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ElectronPidToolRun2.cxx
+++ b/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ElectronPidToolRun2.cxx
@@ -21,6 +21,7 @@
 
 // Tracking:
 #include "TrkTrack/Track.h"
+#include "TrkTrackSummary/TrackSummary.h"
 #include "TrkTrack/TrackStateOnSurface.h"
 #include "TrkMeasurementBase/MeasurementBase.h"
 #include "TrkRIO_OnTrack/RIO_OnTrack.h"
@@ -35,6 +36,9 @@
 // ToT Tool Interface
 #include "TRT_ElectronPidTools/ITRT_ToT_dEdx.h"
 
+// For the track length in straw calculations
+#include "TRT_ToT_dEdx.h"
+
 // Particle masses
 
 // Math functions:
@@ -48,6 +52,16 @@
 
 //#include "TRT_ElectronPidToolRun2_HTcalculation.cxx"
 
+// Helper method to store NN input variables into maps
+template <typename T>
+void storeNNVariable(std::map<std::string, T>& theMap, const std::string& name, const T& value) {
+  auto it = theMap.find(name);
+  if (it != theMap.end()) {
+    it->second = value;
+  }
+}
+
+
 
 /*****************************************************************************\
 |*%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%*|
@@ -121,13 +135,15 @@ StatusCode InDet::TRT_ElectronPidToolRun2::finalize()
 std::vector<float> InDet::TRT_ElectronPidToolRun2::electronProbability_old(const Trk::Track& track)
 {
   // Simply return values without calculation
-  std::vector<float> PIDvalues(4);
-  PIDvalues[0] = 0.5;
-  PIDvalues[1] = 0.5;
-  PIDvalues[2] = 0.0;
-  PIDvalues[3] = 0.5;
-  //PIDvalues[4] = 0.0;
-  //PIDvalues[5] = 0.5;
+  std::vector<float> PIDvalues(Trk::numberOfeProbabilityTypes);
+  PIDvalues[Trk::eProbabilityComb] = 0.5;
+  PIDvalues[Trk::eProbabilityHT] = 0.5;
+  PIDvalues[Trk::eProbabilityToT] = 0.5;
+  PIDvalues[Trk::eProbabilityBrem] = 0.5;
+  PIDvalues[Trk::eProbabilityNN] = 0.5;
+  PIDvalues[Trk::TRTTrackOccupancy] = 0.0;
+  PIDvalues[Trk::TRTdEdx] = 0.0;
+  PIDvalues[Trk::eProbabilityNumberOfTRTHitsUsedFordEdx] = 0.0;
   const Trk::TrackParameters* perigee = track.perigeeParameters();
   if (!perigee) { return PIDvalues; }
   return PIDvalues;
@@ -158,16 +174,16 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
    ATH_MSG_WARNING ("  No PID NN available from the DB.");
  }
 
-  //Initialize the return vector
-  std::vector<float> PIDvalues(6);
-  float & prob_El_Comb      = PIDvalues[0] = 0.5;
-  float & prob_El_HT        = PIDvalues[1] = 0.5;
-  float & prob_El_ToT       = PIDvalues[2] = 0.5;
-  float & prob_El_Brem      = PIDvalues[3] = 0.5;
-  float & occ_local         = PIDvalues[4] = 0.0;
-  float & prob_El_NN        = PIDvalues[5] = 0.5;
-
-  float dEdx = 0.0;
+  // Initialize the vector with default PID values
+  std::vector<float> PIDvalues(Trk::numberOfeProbabilityTypes);
+  PIDvalues[Trk::eProbabilityComb] = 0.5;
+  PIDvalues[Trk::eProbabilityHT] = 0.5;
+  PIDvalues[Trk::eProbabilityToT] = 0.5;
+  PIDvalues[Trk::eProbabilityBrem] = 0.5;
+  PIDvalues[Trk::eProbabilityNN] = 0.5;
+  PIDvalues[Trk::TRTTrackOccupancy] = 0.0;
+  PIDvalues[Trk::TRTdEdx] = 0.0;
+  PIDvalues[Trk::eProbabilityNumberOfTRTHitsUsedFordEdx] = 0.0;
 
   // Check for perigee:
   const Trk::TrackParameters* perigee = track.perigeeParameters();
@@ -196,7 +212,12 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
   double eta  = -log(tan(theta/2.0));
 
   // Check the tool to get the local occupancy (i.e. for the track in question):
-  occ_local = m_LocalOccTool->LocalOccupancy(ctx,track);
+  PIDvalues[Trk::TRTTrackOccupancy] = m_LocalOccTool->LocalOccupancy(ctx,track);
+
+  if (PIDvalues[Trk::TRTTrackOccupancy] > 1.0  || PIDvalues[Trk::TRTTrackOccupancy]  < 0.0) {
+    ATH_MSG_WARNING("  Occupancy was outside allowed range! Returning default Pid values. Occupancy = " << PIDvalues[Trk::TRTTrackOccupancy] );
+    return PIDvalues;
+  }
 
   ATH_MSG_DEBUG ("");
   ATH_MSG_DEBUG ("");
@@ -212,8 +233,20 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
   // Loop over TRT hits on track, and calculate HT and R-ToT probability:
   // ------------------------------------------------------------------------------------
 
+  std::vector<double> hit_HTMB;
+  std::vector<double> hit_gasType;
+  std::vector<double> hit_tot;
+  std::vector<double> hit_L;
+  std::vector<double> hit_rTrkWire;
+  std::vector<double> hit_HitZ;
+  std::vector<double> hit_HitR;
+  std::vector<double> hit_isPrec;
+
   unsigned int nTRThits     = 0;
   unsigned int nTRThitsHTMB = 0;
+  unsigned int nXehits      = 0;
+  unsigned int nArhits      = 0;
+  unsigned int nPrecHits    = 0;
 
 
   // Check for track states:
@@ -234,15 +267,7 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
     if (!measurement) continue;
 
     // Get drift circle (ensures that hit is from TRT):
-    const InDet::TRT_DriftCircleOnTrack* driftcircle = nullptr;
-    if (measurement->type(Trk::MeasurementBaseType::RIO_OnTrack)) {
-      const Trk::RIO_OnTrack* tmpRio =
-        static_cast<const Trk::RIO_OnTrack*>(measurement);
-      if (tmpRio->rioType(Trk::RIO_OnTrackType::TRT_DriftCircle)) {
-        driftcircle = static_cast<const InDet::TRT_DriftCircleOnTrack*>(tmpRio);
-      }
-    }
-
+    const InDet::TRT_DriftCircleOnTrack* driftcircle = dynamic_cast<const InDet::TRT_DriftCircleOnTrack*>(measurement);
     if (!driftcircle) continue;
 
     // From now (May 2015) onwards, we ONLY USE MIDDLE HT BIT:
@@ -250,6 +275,7 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
 
     nTRThits++;
     if (isHTMB) nTRThitsHTMB++;
+    hit_HTMB.push_back(double(isHTMB));
 
 
     // ------------------------------------------------------------------------------------
@@ -276,14 +302,17 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
     }
 
     // Get Z (Barrel) or R (Endcap) location of the hit, and distance from track to wire (i.e. anode) in straw:
-    double HitZ, HitR, rTrkWire;
-    bool hasTrackParameters= true; // Keep track of this for HT prob calculation
+    double HitZ = 0.;
+    double HitR = 0.;
+    double rTrkWire = 0.;
+    bool hasTrackParameters = true; // Keep track of this for HT prob calculation
     if ((*tsosIter)->trackParameters()) {
       // If we have precise information (from hit), get that:
       const Amg::Vector3D& gp = driftcircle->globalPosition();
       HitR = gp.perp();
       HitZ = gp.z();
       rTrkWire = fabs((*tsosIter)->trackParameters()->parameters()[Trk::driftRadius]);
+      if (rTrkWire > 2.2) rTrkWire = 2.175;   // Happens once in a while - no need for warning!
     } else {
       // Otherwise just use the straw coordinates:
       hasTrackParameters = false; // Jared - pass this to HT calculation
@@ -292,6 +321,12 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
       rTrkWire = 0;
     }
 
+    // fill vectors for NN PID
+    hit_HitZ.push_back(HitZ);
+    hit_HitR.push_back(HitR);
+    hit_rTrkWire.push_back(rTrkWire);
+    hit_L.push_back(TRT_ToT_dEdx::calculateTrackLengthInStraw((*tsosIter), m_trtId));
+    hit_tot.push_back(driftcircle->timeOverThreshold());
 
     // ------------------------------------------------------------------------------------
     // Collection and checks of input variables for HT probability calculation:
@@ -315,13 +350,6 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
       ZRpos[TrtPart] = ZRpos_min[TrtPart] + 0.001;
     }
 
-    if (rTrkWire > 2.2) rTrkWire = 2.175;   // Happens once in a while - no need for warning!
-
-    if (occ_local > 1.0  ||  occ_local < 0.0) {
-      ATH_MSG_WARNING("  Occupancy was outside allowed range!  TrtPart = " << TrtPart << "  Occupancy = " << occ_local);
-      continue;
-    }
-
     // ------------------------------------------------------------------------------------
     // Calculate the HT probability:
     // ------------------------------------------------------------------------------------
@@ -345,7 +373,24 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
                   << nTRThits << "  TrtPart: " << TrtPart
                   << "  GasType: " << GasType << "  SL: " << StrawLayer
                   << "  ZRpos: " << ZRpos[TrtPart] << "  TWdist: " << rTrkWire
-                  << "  Occ_Local: " << occ_local << "  HTMB: " << isHTMB);
+                  << "  Occ_Local: " << PIDvalues[Trk::TRTTrackOccupancy]  << "  HTMB: " << isHTMB);
+
+    // RNN gas type observables
+    hit_gasType.push_back(double(GasType));
+    if (GasType == 0) {
+      nXehits++;
+    } else if (GasType == 1) {
+      nArhits++;
+    }
+
+    // RNN hit preciion observables
+    float errDc = sqrt(driftcircle->localCovariance()(Trk::driftRadius, Trk::driftRadius));
+    bool isPrec = false;
+    if (errDc < 1.0) {
+      isPrec = true;
+      nPrecHits++;
+    }
+    hit_isPrec.push_back(double(isPrec));
 
     // Then call pHT functions with these values:
     // ------------------------------------------
@@ -357,7 +402,7 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
                                      StrawLayer,
                                      ZRpos[TrtPart],
                                      rTrkWire,
-                                     occ_local,
+                                     PIDvalues[Trk::TRTTrackOccupancy] ,
                                      hasTrackParameters);
     double pHTpi = HTcalc->getProbHT(pTrk,
                                      Trk::pion,
@@ -366,7 +411,7 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
                                      StrawLayer,
                                      ZRpos[TrtPart],
                                      rTrkWire,
-                                     occ_local,
+                                     PIDvalues[Trk::TRTTrackOccupancy] ,
                                      hasTrackParameters);
 
     if (pHTel > 0.999 || pHTpi > 0.999 || pHTel < 0.001 || pHTpi < 0.001) {
@@ -374,7 +419,7 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
                     << pHTel << "  pHTpi = " << pHTpi
                     << "     TrtPart: " << TrtPart << "  SL: " << StrawLayer
                     << "  ZRpos: " << ZRpos[TrtPart] << "  TWdist: " << rTrkWire
-                    << "  Occ_Local: " << occ_local);
+                    << "  Occ_Local: " << PIDvalues[Trk::TRTTrackOccupancy] );
       continue;
     }
 
@@ -383,7 +428,7 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
                     << pHTel << "  pHTpi = " << pHTpi
                     << "     TrtPart: " << TrtPart << "  SL: " << StrawLayer
                     << "  ZRpos: " << ZRpos[TrtPart] << "  TWdist: " << rTrkWire
-                    << "  Occ_Local: " << occ_local);
+                    << "  Occ_Local: " << PIDvalues[Trk::TRTTrackOccupancy] );
       continue;
     }
 
@@ -392,49 +437,84 @@ InDet::TRT_ElectronPidToolRun2::electronProbability(const Trk::Track& track) con
     else        {pHTel_prod *= 1.0-pHTel;  pHTpi_prod *= 1.0-pHTpi;}
     ATH_MSG_DEBUG ("check         pHT(el): " << pHTel << "  pHT(pi): " << pHTpi );
 
-    // Jared - Development Output...
-
-    //std::cout << "check         pHT(el): " << pHTel << "  pHT(pi): " << pHTpi << std::endl;
-
-  }//of loop over hits
+  } // end of loop over hits
 
 
   // If number of hits is adequate (default is 5 hits), calculate HT and ToT probability.
   if (not (nTRThits >= m_minTRThits)) return PIDvalues;
 
   // Calculate electron probability (HT)
-  prob_El_HT = pHTel_prod / (pHTel_prod + pHTpi_prod);
+  PIDvalues[Trk::eProbabilityHT] = pHTel_prod / (pHTel_prod + pHTpi_prod);
 
-  ATH_MSG_DEBUG ("check---------------------------------------------------------------------------------------");
-  ATH_MSG_DEBUG("check  nTRThits: " << nTRThits << "  : " << nTRThitsHTMB
-                                    << "  pHTel_prod: " << pHTel_prod
-                                    << "  pHTpi_prod: " << pHTpi_prod
-                                    << "  probEl: " << prob_El_HT);
-  ATH_MSG_DEBUG ("check---------------------------------------------------------------------------------------");
-  ATH_MSG_DEBUG ("");
-  ATH_MSG_DEBUG ("");
+  ATH_MSG_DEBUG ("check  nTRThits: " << nTRThits << "  : " << nTRThitsHTMB
+                                     << "  pHTel_prod: " << pHTel_prod
+                                     << "  pHTpi_prod: " << pHTpi_prod
+                                     << "  probEl: " << PIDvalues[Trk::eProbabilityHT]);
 
-  // Jared - ToT Implementation
-  dEdx = m_TRTdEdxTool->dEdx(ctx,&track, false); // Divide by L, exclude HT hits
-  double usedHits = m_TRTdEdxTool->usedHits(ctx,&track, false);
-  prob_El_ToT = m_TRTdEdxTool->getTest(ctx,dEdx, pTrk, Trk::electron, Trk::pion, usedHits);
+  PIDvalues[Trk::TRTdEdx] = m_TRTdEdxTool->dEdx(ctx,&track); // default dEdx using all hits
+  PIDvalues[Trk::eProbabilityNumberOfTRTHitsUsedFordEdx] = m_TRTdEdxTool->usedHits(ctx,&track);
+  double dEdx_noHTHits = m_TRTdEdxTool->dEdx(ctx,&track, false); // Divide by L, exclude HT hits
+  double dEdx_usedHits_noHTHits = m_TRTdEdxTool->usedHits(ctx,&track, false);
+  PIDvalues[Trk::eProbabilityToT] = m_TRTdEdxTool->getTest(ctx, dEdx_noHTHits, pTrk, Trk::electron, Trk::pion, dEdx_usedHits_noHTHits);
 
   // Limit the probability values the upper and lower limits that are given/trusted for each part:
-  double limProbHT = HTcalc->Limit(prob_El_HT);
-  double limProbToT = HTcalc->Limit(prob_El_ToT);
+  double limProbHT = HTcalc->Limit(PIDvalues[Trk::eProbabilityHT]);
+  double limProbToT = HTcalc->Limit(PIDvalues[Trk::eProbabilityToT]);
 
   // Calculate the combined probability, assuming no correlations (none are expected).
-  prob_El_Comb = (limProbHT * limProbToT ) / ( (limProbHT * limProbToT) + ( (1.0-limProbHT) * (1.0-limProbToT)) );
+  PIDvalues[Trk::eProbabilityComb] = (limProbHT * limProbToT ) / ( (limProbHT * limProbToT) + ( (1.0-limProbHT) * (1.0-limProbToT)) );
 
   // Troels: VERY NASTY NAMING, BUT AGREED UPON FOR NOW (for debugging, 27. NOV. 2014):
-  prob_El_Brem = pHTel_prod; // decorates electron LH to el brem for now... (still used?)
-
-  //std::cout << "Prob_HT = " << prob_El_HT << "   Prob_ToT = " << prob_El_ToT << "   Prob_Comb = " << prob_El_Comb << std::endl;
+  PIDvalues[Trk::eProbabilityBrem] = pHTel_prod; // decorates electron LH to el brem for now... (still used?)
 
   // Calculate RNN PID score
   std::map<std::string, std::map<std::string, double>> scalarInputs_NN = PIDNN->getScalarInputs();
   std::map<std::string, std::map<std::string, std::vector<double>>> vectorInputs_NN = PIDNN->getVectorInputs();
-  prob_El_NN = PIDNN->evaluate(scalarInputs_NN, vectorInputs_NN);
+  
+  // Calculate the hit fraction
+  double fAr = double(nArhits) / nTRThits;
+  double fHTMB = double(nTRThitsHTMB) / nTRThits;
+  double PHF = double(nPrecHits) / nTRThits;
+
+  if (!scalarInputs_NN.empty()) {
+    std::map<std::string, double>& trackVarMap = scalarInputs_NN.begin()->second;
+    storeNNVariable(trackVarMap, "trkOcc", (double) PIDvalues[Trk::TRTTrackOccupancy]);
+    storeNNVariable(trackVarMap, "p", pTrk);
+    storeNNVariable(trackVarMap, "pT", pT);
+    storeNNVariable(trackVarMap, "nXehits", (double) nXehits);
+    storeNNVariable(trackVarMap, "fAr", fAr);
+    storeNNVariable(trackVarMap, "fHTMB", fHTMB);
+    storeNNVariable(trackVarMap, "PHF", PHF);
+    storeNNVariable(trackVarMap, "dEdx", (double) dEdx_noHTHits);
+  }
+
+  if (!vectorInputs_NN.empty()) {
+    std::map<std::string, std::vector<double>>& hitVarMap = vectorInputs_NN.begin()->second;
+    storeNNVariable(hitVarMap, "hit_HTMB", hit_HTMB);
+    storeNNVariable(hitVarMap, "hit_gasType", hit_gasType);
+    storeNNVariable(hitVarMap, "hit_tot", hit_tot);
+    storeNNVariable(hitVarMap, "hit_L", hit_L);
+    storeNNVariable(hitVarMap, "hit_rTrkWire", hit_rTrkWire);
+    storeNNVariable(hitVarMap, "hit_HitZ", hit_HitZ);
+    storeNNVariable(hitVarMap, "hit_HitR", hit_HitR);
+    storeNNVariable(hitVarMap, "hit_isPrec", hit_isPrec);
+  }
+  PIDvalues[Trk::eProbabilityNN] = PIDNN->evaluate(scalarInputs_NN, vectorInputs_NN);
+
+  ATH_MSG_DEBUG ("check NN PID calculation: ");
+  for (auto scalarInputs : scalarInputs_NN) {
+    ATH_MSG_DEBUG ("  scalar inputs: " << scalarInputs.first);
+    for (auto variable : scalarInputs.second) {
+      ATH_MSG_DEBUG ("    " << variable.first << " = " << variable.second);
+    }
+  }
+  for (auto vectorInputs : vectorInputs_NN) {
+    ATH_MSG_DEBUG ("  vector inputs: " << vectorInputs.first);
+    for (auto variable : vectorInputs.second) {
+      ATH_MSG_DEBUG ("    " << variable.first << " = " << variable.second);
+    }
+  }
+  ATH_MSG_DEBUG ("  eProbilityNN: " << PIDvalues[Trk::eProbabilityNN]);
 
   return PIDvalues;
 }
diff --git a/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ToT_dEdx.cxx b/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ToT_dEdx.cxx
index cc7cb3ec278..56e65ca323a 100644
--- a/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ToT_dEdx.cxx
+++ b/InnerDetector/InDetRecTools/TRT_ElectronPidTools/src/TRT_ToT_dEdx.cxx
@@ -1252,6 +1252,11 @@ double TRT_ToT_dEdx::calculateTrackLengthInStraw(const Trk::TrackStateOnSurface*
   const InDetDD::TRT_BaseElement* element = driftcircle->detectorElement();
   double strawphi = element->center(DCId).phi();
 
+  // check if track is an outlier
+  if (Trt_Rtrack >= 2.0) {
+    return 0.;
+  }
+
   double length=0;
   if (HitPart == 1) { //Barrel
     length = 2*std::sqrt(4-Trt_Rtrack*Trt_Rtrack)*1./std::abs(std::sin(Trt_HitTheta));
diff --git a/Tracking/TrkEvent/TrkTrackSummary/TrkTrackSummary/TrackSummary.h b/Tracking/TrkEvent/TrkTrackSummary/TrkTrackSummary/TrackSummary.h
index 98ff9e8f6d7..eadb982947b 100755
--- a/Tracking/TrkEvent/TrkTrackSummary/TrkTrackSummary/TrackSummary.h
+++ b/Tracking/TrkEvent/TrkTrackSummary/TrkTrackSummary/TrackSummary.h
@@ -125,9 +125,11 @@ enum SummaryType {
         eProbabilityBrem_res                = 50, //!< Electron probability from Brem fitting (DNA) [float]. 
         pixeldEdx_res                       = 51, //!< the dE/dx estimate, calculated using the pixel clusters [?]
         eProbabilityNN_res                  = 73, //!< Electron probability from NN [float].
+        TRTTrackOccupancy_res               = 74, //!< TRT track occupancy.
+        TRTdEdx_res                         = 75, //!< dEdx from TRT ToT measurement.
 
  // -- numbers...
-        numberOfTrackSummaryTypes = 74
+        numberOfTrackSummaryTypes = 76
     };
 
 // Troels.Petersen@cern.ch:
@@ -137,11 +139,11 @@ enum SummaryType {
         eProbabilityToT             = 2,       //!< Electron probability from Time-Over-Threshold (ToT) information.
         eProbabilityBrem            = 3,       //!< Electron probability from Brem fitting (DNA).
         eProbabilityNN              = 4,       //!< Electron probability from NN.
-        numberOfeProbabilityTypes   = 5        
-    }; 
-  // the eProbability vector is abused to store : 
-  // [5] TRT local occupancy
-  // [6] TRT dE/dx
+        TRTTrackOccupancy           = 5,       //!< TRT track occupancy.
+        TRTdEdx                     = 6,       //!< dEdx from TRT ToT measurement.
+        eProbabilityNumberOfTRTHitsUsedFordEdx = 7, //!< Number of TRT hits used for dEdx measurement.
+        numberOfeProbabilityTypes   = 8        
+    };
 
 /** enumerates the various detector types currently accessible from the isHit() method.
 \todo work out how to add muons to this*/
diff --git a/Tracking/TrkTools/TrkParticleCreator/src/TrackParticleCreatorTool.cxx b/Tracking/TrkTools/TrkParticleCreator/src/TrackParticleCreatorTool.cxx
index 7699e4803aa..8a43b488b1f 100644
--- a/Tracking/TrkTools/TrkParticleCreator/src/TrackParticleCreatorTool.cxx
+++ b/Tracking/TrkTools/TrkParticleCreator/src/TrackParticleCreatorTool.cxx
@@ -87,9 +87,8 @@ createEProbabilityMap(std::map<std::string, std::pair<Trk::eProbabilityType, boo
   eprob_map.insert(std::make_pair("eProbabilityToT", std::make_pair(Trk::eProbabilityToT, true)));
   eprob_map.insert(std::make_pair("eProbabilityBrem", std::make_pair(Trk::eProbabilityBrem, true)));
   eprob_map.insert(std::make_pair("eProbabilityNN", std::make_pair(Trk::eProbabilityNN, true)));
-  eprob_map.insert(std::make_pair("TRTTrackOccupancy", std::make_pair(Trk::numberOfeProbabilityTypes, true)));
-  eprob_map.insert(std::make_pair(
-    "TRTdEdx", std::make_pair(static_cast<Trk::eProbabilityType>(Trk::numberOfeProbabilityTypes + 1), true)));
+  eprob_map.insert(std::make_pair("TRTdEdx", std::make_pair(Trk::TRTdEdx, true)));
+  eprob_map.insert(std::make_pair("TRTTrackOccupancy", std::make_pair(Trk::TRTTrackOccupancy, true)));
 }
 
 void
diff --git a/Tracking/TrkTools/TrkTrackSummaryTool/TrkTrackSummaryTool/TrackSummaryTool.h b/Tracking/TrkTools/TrkTrackSummaryTool/TrkTrackSummaryTool/TrackSummaryTool.h
index 2c3ca7d0363..b83d211e1c9 100755
--- a/Tracking/TrkTools/TrkTrackSummaryTool/TrkTrackSummaryTool/TrackSummaryTool.h
+++ b/Tracking/TrkTools/TrkTrackSummaryTool/TrkTrackSummaryTool/TrackSummaryTool.h
@@ -12,7 +12,6 @@
 #include "TrkTrack/Track.h"
 #include "TrkTrackSummary/TrackSummary.h"
 
-#include "TRT_ElectronPidTools/ITRT_ToT_dEdx.h"
 #include "TrkToolInterfaces/IExtendedTrackSummaryHelperTool.h"
 #include "TrkToolInterfaces/IPixelToTPIDTool.h"
 #include "TrkToolInterfaces/ITRT_ElectronPidTool.h"
@@ -22,7 +21,6 @@
 
 class AtlasDetectorID;
 class Identifier;
-class ITRT_ToT_dEdx;
 
 namespace Trk {
 class ITRT_ElectronPidTool;
@@ -222,8 +220,6 @@ private:
                                                        "TRT_ElectronPidTool",
                                                        "",
                                                        "" };
-  /** tool to calculate the TRT_ToT_dEdx.*/
-  ToolHandle<ITRT_ToT_dEdx> m_trt_dEdxTool{ this, "TRT_ToT_dEdxTool", "", "" };
   /**tool to calculate dE/dx using pixel clusters*/
   ToolHandle<IPixelToTPIDTool> m_dedxtool{ this, "PixelToTPIDTool", "", "" };
   /**tool to decipher muon RoTs*/
@@ -253,13 +249,6 @@ private:
   /** switch to deactivate Pixel info init */
   Gaudi::Property<bool> m_pixelExists{ this, "PixelExists", true, "" };
 
-  /** Only compute TRT dE/dx if there are at least this number of TRT hits or
-   * outliers.*/
-  Gaudi::Property<int> m_minTRThitsForTRTdEdx{ this,
-                                               "minTRThitsForTRTdEdx",
-                                               1,
-                                               "" };
-
   Gaudi::Property<bool> m_alwaysRecomputeHoles {
     this, "AlwaysRecomputeHoles", false, ""
   };
diff --git a/Tracking/TrkTools/TrkTrackSummaryTool/src/TrackSummaryTool.cxx b/Tracking/TrkTools/TrkTrackSummaryTool/src/TrackSummaryTool.cxx
index 52ccbbc7353..ead091f996a 100755
--- a/Tracking/TrkTools/TrkTrackSummaryTool/src/TrackSummaryTool.cxx
+++ b/Tracking/TrkTools/TrkTrackSummaryTool/src/TrackSummaryTool.cxx
@@ -71,11 +71,6 @@ StatusCode
        if ( !m_eProbabilityTool.empty()) msg(MSG::INFO) << "Retrieved tool " << m_eProbabilityTool << endmsg;
 
 
-    if (!m_trt_dEdxTool.empty()) {
-      ATH_CHECK( m_trt_dEdxTool.retrieve() );
-    }
-
-
     if ( !m_dedxtool.empty() && m_dedxtool.retrieve().isFailure() )
     {
         ATH_MSG_ERROR ("Failed to retrieve pixel dEdx tool " << m_dedxtool);
@@ -196,8 +191,7 @@ information.resize(std::min(information.size(),
                             static_cast<size_t>(numberOfTrackSummaryTypes)));
 
 // Troels.Petersen@cern.ch:
-unsigned int numberOfeProbabilityTypes = Trk::numberOfeProbabilityTypes + 1;
-std::vector<float> eProbability(numberOfeProbabilityTypes, 0.5);
+std::vector<float> eProbability(Trk::numberOfeProbabilityTypes, 0.5);
 
   float dedx = -1;
   int nhitsuseddedx = -1;
@@ -318,22 +312,7 @@ std::vector<float> eProbability(numberOfeProbabilityTypes, 0.5);
     searchHolesStepWise(track,information, doHolesInDet, doHolesMuon);
   }
 
-  if (!m_trt_dEdxTool.empty()) {
-    if (information[Trk::numberOfTRTHits]+information[Trk::numberOfTRTOutliers]>=m_minTRThitsForTRTdEdx) {
-      int nhits = static_cast<int>( m_trt_dEdxTool->usedHits(&track) );
-      double fvalue = (nhits>0 ? m_trt_dEdxTool->dEdx(&track) : 0.0);
-      eProbability.push_back(fvalue);
-      information[ numberOfTRTHitsUsedFordEdx] = static_cast<uint8_t>(std::max(nhits,0));
-    }
-    else {
-      information[ numberOfTRTHitsUsedFordEdx]=0;
-      eProbability.push_back(0.0);
-    }
-  }
-  else {
-    eProbability.push_back(0.0);
-  }
-
+  information[Trk::numberOfTRTHitsUsedFordEdx] = eProbability[Trk::eProbabilityNumberOfTRTHitsUsedFordEdx];
   ts.m_eProbability = eProbability;
   ts.m_idHitPattern = hitPattern.to_ulong();
   ts.m_dedx = dedx;
@@ -360,28 +339,14 @@ void Trk::TrackSummaryTool::updateSharedHitCount(const Track& track, const Trk::
 
 void Trk::TrackSummaryTool::updateAdditionalInfo(const Track& track, TrackSummary &summary, bool initialise_to_zero) const
 {
-  unsigned int numberOfeProbabilityTypes = Trk::numberOfeProbabilityTypes+1;
-  std::vector<float> eProbability(numberOfeProbabilityTypes,0.5);
-  if ( !m_eProbabilityTool.empty() ) eProbability = m_eProbabilityTool->electronProbability(track);
-
-  if (!m_trt_dEdxTool.empty()) {
-    if (summary.get(Trk::numberOfTRTHits)+summary.get(Trk::numberOfTRTOutliers)>=m_minTRThitsForTRTdEdx) {
-      int nhits = static_cast<int>( m_trt_dEdxTool->usedHits(&track) );
-      double fvalue = (nhits>0 ? m_trt_dEdxTool->dEdx(&track) : 0.0);
-      eProbability.push_back(fvalue);
-      if (!summary.update(Trk::numberOfTRTHitsUsedFordEdx, static_cast<uint8_t>(std::max(nhits,0)) )) {
-        ATH_MSG_WARNING( "Attempt to update numberOfTRTHitsUsedFordEdx but this summary information is "
-                         "already set. numberOfTRTHitsUsedFordEdx is:" << summary.get(numberOfTRTHitsUsedFordEdx)
-                         << " =?= should:" << nhits );
-      }
-    }
-    else {
-      eProbability.push_back(0.0);
-      if (!summary.update(Trk::numberOfTRTHitsUsedFordEdx, 0) ) {
-        ATH_MSG_WARNING( "Attempt to update numberOfTRTHitsUsedFordEdx but this summary information is "
-                         "already set. numberOfTRTHitsUsedFordEdx is:" << summary.get(numberOfTRTHitsUsedFordEdx)
-                         << " =?= should:" << 0 );
-      }
+  std::vector<float> eProbability(Trk::numberOfeProbabilityTypes, 0.5);
+  if (!m_eProbabilityTool.empty()) {
+    eProbability = m_eProbabilityTool->electronProbability(track);
+    int nHits = eProbability[Trk::eProbabilityNumberOfTRTHitsUsedFordEdx];
+    if (!summary.update(Trk::numberOfTRTHitsUsedFordEdx, static_cast<uint8_t>(std::max(nHits,0)) )) {
+      ATH_MSG_WARNING("Attempt to update numberOfTRTHitsUsedFordEdx but this summary information is "
+                      "already set. numberOfTRTHitsUsedFordEdx is:" << summary.get(numberOfTRTHitsUsedFordEdx)
+                      << " =?= should:" << nHits );
     }
   }
 
@@ -393,7 +358,7 @@ void Trk::TrackSummaryTool::updateAdditionalInfo(const Track& track, TrackSummar
     dedx = m_dedxtool->dEdx(track, nhitsuseddedx, noverflowhitsdedx);
   }
 
-  m_idTool->updateAdditionalInfo(summary, eProbability,dedx, nhitsuseddedx,noverflowhitsdedx);
+  m_idTool->updateAdditionalInfo(summary, eProbability, dedx, nhitsuseddedx, noverflowhitsdedx);
 
   m_idTool->updateExpectedHitInfo(track, summary);
 
-- 
GitLab