From 2c6174446d6e919f0621ba9eba18cb15ce907597 Mon Sep 17 00:00:00 2001
From: Tobias Boeckh <tobias.boeckh@cern.ch>
Date: Fri, 18 Nov 2022 01:41:18 +0100
Subject: [PATCH] added TrackTruthMathingTool which returns the matched
 xAOD::TruthParticle and hit count for a track and write out truth track
 variables in NtupleDumper

---
 PhysicsAnalysis/NtupleDumper/CMakeLists.txt   |   2 +-
 .../NtupleDumper/src/NtupleDumperAlg.cxx      | 214 +++++++++++++++++-
 .../NtupleDumper/src/NtupleDumperAlg.h        |  31 +++
 .../Acts/FaserActsKalmanFilter/CMakeLists.txt |   3 +
 .../ITrackTruthMatchingTool.h                 |  21 ++
 .../src/TrackTruthMatchingTool.cxx            | 103 +++++++++
 .../src/TrackTruthMatchingTool.h              |  40 ++++
 .../FaserActsKalmanFilter_entries.cxx         |   2 +
 8 files changed, 406 insertions(+), 10 deletions(-)
 create mode 100644 Tracking/Acts/FaserActsKalmanFilter/FaserActsKalmanFilter/ITrackTruthMatchingTool.h
 create mode 100644 Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.cxx
 create mode 100644 Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.h

diff --git a/PhysicsAnalysis/NtupleDumper/CMakeLists.txt b/PhysicsAnalysis/NtupleDumper/CMakeLists.txt
index 944ed9689..b81b98183 100644
--- a/PhysicsAnalysis/NtupleDumper/CMakeLists.txt
+++ b/PhysicsAnalysis/NtupleDumper/CMakeLists.txt
@@ -5,7 +5,7 @@ atlas_add_component(
         src/NtupleDumperAlg.h
         src/NtupleDumperAlg.cxx
         src/component/NtupleDumper_entries.cxx
-        LINK_LIBRARIES AthenaBaseComps StoreGateLib xAODFaserWaveform xAODFaserTrigger ScintIdentifier FaserCaloIdentifier GeneratorObjects FaserActsGeometryLib TrackerSimEvent TrackerSimData TrackerIdentifier TrackerReadoutGeometry TrkTrack GeoPrimitives TrackerRIO_OnTrack TrackerSpacePoint
+        LINK_LIBRARIES AthenaBaseComps StoreGateLib xAODFaserWaveform xAODFaserTrigger ScintIdentifier FaserCaloIdentifier GeneratorObjects FaserActsGeometryLib TrackerSimEvent TrackerSimData TrackerIdentifier TrackerReadoutGeometry TrkTrack GeoPrimitives TrackerRIO_OnTrack TrackerSpacePoint FaserActsKalmanFilterLib
 )
 
 atlas_install_python_modules(python/*.py)
diff --git a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx
index d1a07f428..dba1c0e68 100644
--- a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx
+++ b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx
@@ -16,6 +16,9 @@
 #include "xAODTruth/TruthParticle.h"
 #include <cmath>
 #include <TH1F.h>
+#include <numeric>
+
+constexpr float NaN = std::numeric_limits<double>::quiet_NaN();
 
 
 NtupleDumperAlg::NtupleDumperAlg(const std::string &name, 
@@ -70,7 +73,73 @@ void NtupleDumperAlg::FillWaveBranches(const xAOD::WaveformHitContainer &wave) c
   }
 }
 
-StatusCode NtupleDumperAlg::initialize() 
+NtupleDumperAlg::Vector3 NtupleDumperAlg::getGlobalPosition(const FaserSiHit &hit) const {
+  Identifier waferId =
+      m_sctHelper->wafer_id(hit.getStation(), hit.getPlane(), hit.getRow(),
+                            hit.getModule(), hit.getSensor());
+  Vector3 localStartPos = hit.localStartPosition();
+  Vector3 localEndPos = hit.localEndPosition();
+  Vector3 localPos = 0.5 * (localEndPos + localStartPos);
+  const TrackerDD::SiDetectorElement *element = m_detMgr->getDetectorElement(waferId);
+  Vector3 globalPosition = Amg::EigenTransformToCLHEP(element->transformHit()) * localPos;
+  return globalPosition;
+}
+
+StatusCode NtupleDumperAlg::getTruthPositions(int barcode) const {
+  SG::ReadHandle<FaserSiHitCollection> siHitCollection(m_siHitCollectionKey);
+  ATH_CHECK(siHitCollection.isValid());
+
+  // create map with truth positions in each station
+  std::array<std::vector<Vector3>, 4> hitMap {};
+  for (const FaserSiHit &hit : *siHitCollection) {
+    if (hit.trackNumber() == barcode) {
+      Vector3 position = getGlobalPosition(hit);
+      hitMap[hit.getStation()].push_back(position);
+    }
+  }
+
+  // calculate average position in each station
+  for (int station=0; station < 4; ++station) {
+    std::vector<Vector3> &hits {hitMap[station]};
+    if (hits.empty()) {
+      m_t_st_x[station].push_back(NaN);
+      m_t_st_y[station].push_back(NaN);
+      m_t_st_z[station].push_back(NaN);
+    } else {
+      auto const count = static_cast<double>(hits.size());
+      std::array<double, 3> sums {};
+      for (const Vector3 &hit : hits) {
+        sums[0] += hit.x();
+        sums[1] += hit.y();
+        sums[2] += hit.z();
+      }
+      m_t_st_x[station].push_back(sums[0] / count);
+      m_t_st_y[station].push_back(sums[1] / count);
+      m_t_st_z[station].push_back(sums[2] / count);
+    }
+  }
+  return StatusCode::SUCCESS;
+}
+
+bool NtupleDumperAlg::isFiducial() const {
+  bool isFiducial {true};
+  for (int station = 0; station < 4; ++station) {
+    double st_x {m_t_st_x[station].back()};
+    double st_y {m_t_st_y[station].back()};
+    if (!std::isnan(st_x) && !std::isnan(st_y)) {
+      // distance from center < 100 mm
+      if (st_x * st_x + st_y * st_y > 100 * 100)
+      isFiducial = false;
+    } else {
+      // there have to be simulated hits in stations 1 - 3
+      if (station > 0)
+        isFiducial = false;
+    }
+  }
+  return isFiducial;
+}
+
+StatusCode NtupleDumperAlg::initialize()
 {
   ATH_CHECK(m_truthEventContainer.initialize());
   ATH_CHECK(m_truthParticleContainer.initialize());
@@ -85,6 +154,7 @@ StatusCode NtupleDumperAlg::initialize()
   ATH_CHECK(m_simDataCollection.initialize());
   ATH_CHECK(m_FaserTriggerData.initialize());
   ATH_CHECK(m_ClockWaveformContainer.initialize());
+  ATH_CHECK(m_siHitCollectionKey.initialize());
 
   ATH_CHECK(detStore()->retrieve(m_sctHelper,       "FaserSCT_ID"));
   ATH_CHECK(detStore()->retrieve(m_vetoNuHelper,    "VetoNuID"));
@@ -96,6 +166,7 @@ StatusCode NtupleDumperAlg::initialize()
   ATH_CHECK(detStore()->retrieve(m_detMgr, "SCT"));
   ATH_CHECK(m_extrapolationTool.retrieve());
   ATH_CHECK(m_trackingGeometryTool.retrieve());
+  ATH_CHECK(m_trackTruthMatchingTool.retrieve());
 
   ATH_CHECK(m_spacePointContainerKey.initialize());
 
@@ -220,6 +291,37 @@ StatusCode NtupleDumperAlg::initialize()
   m_tree->Branch("Track_ThetaX_atCalo", &m_thetaxCalo);
   m_tree->Branch("Track_ThetaY_atCalo", &m_thetayCalo);
 
+  m_tree->Branch("t_pdg", &m_t_pdg);
+  m_tree->Branch("t_barcode", &m_t_barcode);
+  m_tree->Branch("t_truthHitRatio", &m_t_truthHitRatio);
+  m_tree->Branch("t_prodVtx_x", &m_t_prodVtx_x);
+  m_tree->Branch("t_prodVtx_y", &m_t_prodVtx_y);
+  m_tree->Branch("t_prodVtx_z", &m_t_prodVtx_z);
+  m_tree->Branch("t_decayVtx_x", &m_t_decayVtx_x);
+  m_tree->Branch("t_decayVtx_y", &m_t_decayVtx_y);
+  m_tree->Branch("t_decayVtx_z", &m_t_decayVtx_z);
+  m_tree->Branch("t_px", &m_t_px);
+  m_tree->Branch("t_py", &m_t_py);
+  m_tree->Branch("t_pz", &m_t_pz);
+  m_tree->Branch("t_theta", &m_t_theta);
+  m_tree->Branch("t_phi", &m_t_phi);
+  m_tree->Branch("t_p", &m_t_p);
+  m_tree->Branch("t_pT", &m_t_pT);
+  m_tree->Branch("t_eta", &m_t_eta);
+  m_tree->Branch("t_st0_x", &m_t_st_x[0]);
+  m_tree->Branch("t_st0_y", &m_t_st_y[0]);
+  m_tree->Branch("t_st0_z", &m_t_st_z[0]);
+  m_tree->Branch("t_st1_x", &m_t_st_x[1]);
+  m_tree->Branch("t_st1_y", &m_t_st_y[1]);
+  m_tree->Branch("t_st1_z", &m_t_st_z[1]);
+  m_tree->Branch("t_st2_x", &m_t_st_x[2]);
+  m_tree->Branch("t_st2_y", &m_t_st_y[2]);
+  m_tree->Branch("t_st2_z", &m_t_st_z[2]);
+  m_tree->Branch("t_st3_x", &m_t_st_x[3]);
+  m_tree->Branch("t_st3_y", &m_t_st_y[3]);
+  m_tree->Branch("t_st3_z", &m_t_st_z[3]);
+  m_tree->Branch("isFiducial", &m_isFiducial);
+
   m_tree->Branch("pTruthLepton", &m_truthLeptonMomentum, "pTruthLepton/D");
   m_tree->Branch("truthBarcode", &m_truthBarcode, "truthBarcode/I");
   m_tree->Branch("truthPdg", &m_truthPdg, "truthPdg/I");
@@ -273,7 +375,7 @@ StatusCode NtupleDumperAlg::initialize()
 }
 
 
-StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const 
+StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const
 {
   clearTree();
 
@@ -339,7 +441,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const
 
       return StatusCode::SUCCESS; // finished with this event
 
-    } else if ( ((m_tap&8)==0) && (((m_tap&4)==0)||((m_tap&2)==0)) && (((m_tap&4)==0)||((m_tap&1)==0)) && (((m_tap&2)==0)||((m_tap&1)==0)) ) { // don't process events that don't trigger coincidence triggers: 1=calo, 2=veotnu|neto1|preshower, 4=TimingLayer, 8=(VetoNu|Veto2)&Preshower 
+    } else if ( ((m_tap&8)==0) && (((m_tap&4)==0)||((m_tap&2)==0)) && (((m_tap&4)==0)||((m_tap&1)==0)) && (((m_tap&2)==0)||((m_tap&1)==0)) ) { // don't process events that don't trigger coincidence triggers: 1=calo, 2=veotnu|neto1|preshower, 4=TimingLayer, 8=(VetoNu|Veto2)&Preshower
       return StatusCode::SUCCESS;
     }
     m_tbp=triggerData->tbp();
@@ -420,7 +522,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const
   FillWaveBranches(*triggerContainer);
   FillWaveBranches(*preshowerContainer);
   FillWaveBranches(*ecalContainer);
-  
+
   m_calo_total=m_wave_charge[0]+m_wave_charge[1]+m_wave_charge[2]+m_wave_charge[3];
   m_calo_rawtotal=m_wave_raw_charge[0]+m_wave_raw_charge[1]+m_wave_raw_charge[2]+m_wave_raw_charge[3];
 
@@ -430,13 +532,13 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const
     m_Calo1_Edep = (m_wave_charge[1] / 24.333) * m_MIP_sim_Edep_calo;
     m_Calo2_Edep = (m_wave_charge[2] / 24.409) * m_MIP_sim_Edep_calo;
     m_Calo3_Edep = (m_wave_charge[3] / 25.555) * m_MIP_sim_Edep_calo;
-  } else if (m_CaloConfig == "Low_gain") { // assume low gain calo 
+  } else if (m_CaloConfig == "Low_gain") { // assume low gain calo
     m_Calo0_Edep = (m_wave_charge[0] / 0.7909) * m_MIP_sim_Edep_calo;
     m_Calo1_Edep = (m_wave_charge[1] / 0.8197) * m_MIP_sim_Edep_calo;
     m_Calo2_Edep = (m_wave_charge[2] / 0.8256) * m_MIP_sim_Edep_calo;
     m_Calo3_Edep = (m_wave_charge[3] / 0.8821) * m_MIP_sim_Edep_calo;
   } else {
-   ATH_MSG_WARNING("Run config is neither High_gain nor Low_gain, it is " << m_CaloConfig << ", calo calibration will be zero"); 
+   ATH_MSG_WARNING("Run config is neither High_gain nor Low_gain, it is " << m_CaloConfig << ", calo calibration will be zero");
   }
   m_Calo_Total_Edep = m_Calo0_Edep + m_Calo1_Edep + m_Calo2_Edep + m_Calo3_Edep;
 
@@ -580,6 +682,50 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const
     m_pzdown.push_back(candidateDownParameters->momentum().z());
     m_pdown.push_back(sqrt( pow(candidateDownParameters->momentum().x(),2) + pow(candidateDownParameters->momentum().y(),2) + pow(candidateDownParameters->momentum().z(),2) ));
 
+    if (!realData) {
+      auto [truthParticle, hitCount] = m_trackTruthMatchingTool->getTruthParticle(track);
+      if (truthParticle != nullptr) {
+        m_t_pdg.push_back(truthParticle->pdgId());
+        m_t_barcode.push_back(truthParticle->barcode());
+        // the track fit eats up 5 degrees of freedom, thus the number of hits on track
+        // is m_DoF + 5
+        m_t_truthHitRatio.push_back(hitCount / (m_DoF.back() + 5));
+        ATH_CHECK(getTruthPositions(truthParticle->barcode()));
+        m_isFiducial.push_back(isFiducial());
+        if (truthParticle->hasProdVtx()) {
+          m_t_prodVtx_x.push_back(truthParticle->prodVtx()->x());
+          m_t_prodVtx_y.push_back(truthParticle->prodVtx()->y());
+          m_t_prodVtx_z.push_back(truthParticle->prodVtx()->z());
+        } else {
+          m_t_prodVtx_x.push_back(NaN);
+          m_t_prodVtx_y.push_back(NaN);
+          m_t_prodVtx_z.push_back(NaN);
+        }
+        if (truthParticle->hasDecayVtx()) {
+          m_t_decayVtx_x.push_back(truthParticle->decayVtx()->x());
+          m_t_decayVtx_y.push_back(truthParticle->decayVtx()->y());
+          m_t_decayVtx_z.push_back(truthParticle->decayVtx()->z());
+        } else {
+          m_t_decayVtx_x.push_back(NaN);
+          m_t_decayVtx_y.push_back(NaN);
+          m_t_decayVtx_z.push_back(NaN);
+        }
+        m_t_px.push_back(truthParticle->px());
+        m_t_py.push_back(truthParticle->py());
+        m_t_pz.push_back(truthParticle->pz());
+        m_t_theta.push_back(truthParticle->p4().Theta());
+        m_t_phi.push_back(truthParticle->p4().Phi());
+        m_t_p.push_back(truthParticle->p4().P());
+        m_t_pT.push_back(truthParticle->p4().Pt());
+        m_t_eta.push_back(truthParticle->p4().Eta());
+      } else {
+        setNaN();
+      }
+    } else {
+      ATH_MSG_WARNING("Can not find truthParticle.");
+      setNaN();
+    }
+
     // fill extrapolation vectors with filler values that get changed iif the track extrapolation succeeds
     m_xVetoNu.push_back(-10000);
     m_yVetoNu.push_back(-10000);
@@ -676,7 +822,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const
         ATH_MSG_INFO("Trig null targetParameters");
       }
 
-    } 
+    }
 
     // extrapolate track from tracking station 3
     if (stationMap.count(3) > 0) { // extrapolation crashes if the track does not end in the Station 3, as it is too far away to extrapolate
@@ -751,7 +897,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const
 }
 
 
-StatusCode NtupleDumperAlg::finalize() 
+StatusCode NtupleDumperAlg::finalize()
 {
   return StatusCode::SUCCESS;
 }
@@ -843,7 +989,7 @@ NtupleDumperAlg::clearTree() const
   m_charge.clear();
   m_nLayers.clear();
   m_longTracks = 0;
- 
+
   m_nHit0.clear();
   m_nHit1.clear();
   m_nHit2.clear();
@@ -884,7 +1030,57 @@ NtupleDumperAlg::clearTree() const
   m_thetaxCalo.clear();
   m_thetayCalo.clear();
 
+  m_t_pdg.clear();
+  m_t_barcode.clear();
+  m_t_truthHitRatio.clear();
+  m_t_prodVtx_x.clear();
+  m_t_prodVtx_y.clear();
+  m_t_prodVtx_z.clear();
+  m_t_decayVtx_x.clear();
+  m_t_decayVtx_y.clear();
+  m_t_decayVtx_z.clear();
+  m_t_px.clear();
+  m_t_py.clear();
+  m_t_pz.clear();
+  m_t_theta.clear();
+  m_t_phi.clear();
+  m_t_p.clear();
+  m_t_pT.clear();
+  m_t_eta.clear();
+  m_isFiducial.clear();
+  for (int station = 0; station < 4; ++station) {
+    m_t_st_x[station].clear();
+    m_t_st_y[station].clear();
+    m_t_st_z[station].clear();
+  }
+
   m_truthLeptonMomentum = 0;
   m_truthBarcode = 0;
   m_truthPdg = 0;
 }
+
+void NtupleDumperAlg::setNaN() const {
+  m_t_pdg.push_back(0);
+  m_t_barcode.push_back(-1);
+  m_t_truthHitRatio.push_back(NaN);
+  m_t_prodVtx_x.push_back(NaN);
+  m_t_prodVtx_y.push_back(NaN);
+  m_t_prodVtx_z.push_back(NaN);
+  m_t_decayVtx_x.push_back(NaN);
+  m_t_decayVtx_y.push_back(NaN);
+  m_t_decayVtx_z.push_back(NaN);
+  m_t_px.push_back(NaN);
+  m_t_py.push_back(NaN);
+  m_t_pz.push_back(NaN);
+  m_t_theta.push_back(NaN);
+  m_t_phi.push_back(NaN);
+  m_t_p.push_back(NaN);
+  m_t_pT.push_back(NaN);
+  m_t_eta.push_back(NaN);
+  for (int station = 0; station < 4; ++station) {
+    m_t_st_x[station].push_back(NaN);
+    m_t_st_y[station].push_back(NaN);
+    m_t_st_z[station].push_back(NaN);
+  }
+  m_isFiducial.push_back(false);
+}
\ No newline at end of file
diff --git a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h
index 1b26caade..e35d10f2b 100644
--- a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h
+++ b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h
@@ -15,6 +15,8 @@
 #include "TrackerSimData/TrackerSimDataCollection.h"
 #include "FaserActsGeometryInterfaces/IFaserActsExtrapolationTool.h"
 #include "FaserActsGeometryInterfaces/IFaserActsTrackingGeometryTool.h"
+#include "FaserActsKalmanFilter/ITrackTruthMatchingTool.h"
+#include "TrackerSimEvent/FaserSiHitCollection.h"
 
 #include <vector>
 
@@ -44,11 +46,16 @@ private:
 
   bool waveformHitOK(const xAOD::WaveformHit* hit) const;
   void clearTree() const;
+  void setNaN() const;
   void addBranch(const std::string &name,float* var);
   void addBranch(const std::string &name,unsigned int* var);
   void addWaveBranches(const std::string &name, int nchannels, int first);
   void FillWaveBranches(const xAOD::WaveformHitContainer &wave) const;
 
+  using Vector3 = HepGeom::Point3D<double>;
+  Vector3 getGlobalPosition(const FaserSiHit &hit) const;
+  StatusCode getTruthPositions(int barcode) const;
+  bool isFiducial() const;
   ServiceHandle <ITHistSvc> m_histSvc;
 
   SG::ReadHandleKey<xAOD::TruthEventContainer> m_truthEventContainer { this, "EventContainer", "TruthEvents", "Truth event container name." };
@@ -64,11 +71,13 @@ private:
   SG::ReadHandleKey<xAOD::WaveformHitContainer> m_ecalContainer { this, "EcalContainer", "CaloWaveformHits", "Ecal hit container name" };
   SG::ReadHandleKey<Tracker::FaserSCT_ClusterContainer> m_clusterContainer { this, "ClusterContainer", "SCT_ClusterContainer", "Tracker cluster container name" };
   SG::ReadHandleKey<FaserSCT_SpacePointContainer> m_spacePointContainerKey { this, "SpacePoints", "SCT_SpacePointContainer", "space point container"};
+  SG::ReadHandleKey<FaserSiHitCollection> m_siHitCollectionKey{this, "FaserSiHitCollection", "SCT_Hits"};
 
   SG::ReadHandleKey<xAOD::FaserTriggerData> m_FaserTriggerData     { this, "FaserTriggerDataKey", "FaserTriggerData", "ReadHandleKey for xAOD::FaserTriggerData"};
   SG::ReadHandleKey<xAOD::WaveformClock> m_ClockWaveformContainer     { this, "WaveformClockKey", "WaveformClock", "ReadHandleKey for ClockWaveforms Container"};
   ToolHandle<IFaserActsExtrapolationTool> m_extrapolationTool { this, "ExtrapolationTool", "FaserActsExtrapolationTool" };  
   ToolHandle<IFaserActsTrackingGeometryTool> m_trackingGeometryTool {this, "TrackingGeometryTool", "FaserActsTrackingGeometryTool"};
+  ToolHandle<ITrackTruthMatchingTool> m_trackTruthMatchingTool {this, "TrackTruthMatchingTool", "TrackTruthMatchingTool"};
 
   const TrackerDD::SCT_DetectorManager* m_detMgr {nullptr};
 
@@ -203,6 +212,28 @@ private:
   mutable std::vector<double> m_thetaxCalo;
   mutable std::vector<double> m_thetayCalo;
 
+  mutable std::vector<int> m_t_pdg; // pdg code of the truth matched particle
+  mutable std::vector<int> m_t_barcode; // barcode of the truth matched particle
+  mutable std::vector<double> m_t_truthHitRatio; // ratio of hits on track matched to the truth particle over all hits on track
+  mutable std::vector<double> m_t_prodVtx_x; // x component of the production vertex in mm
+  mutable std::vector<double> m_t_prodVtx_y; // y component of the production vertex in mm
+  mutable std::vector<double> m_t_prodVtx_z; // z component of the production vertex in mm
+  mutable std::vector<double> m_t_decayVtx_x; // x component of the decay vertex in mm
+  mutable std::vector<double> m_t_decayVtx_y; // y component of the decay vertex in mm
+  mutable std::vector<double> m_t_decayVtx_z; // z component of the decay vertex in mm
+  mutable std::vector<double> m_t_px; // truth momentum px in MeV
+  mutable std::vector<double> m_t_py;  // truth momentum py in MeV
+  mutable std::vector<double> m_t_pz;  // truth momentum pz in MeV
+  mutable std::vector<double> m_t_theta; // angle of truth particle with respsect to the beam axis in rad, theta = arctan(sqrt(px * px + py * py) / pz)
+  mutable std::vector<double> m_t_phi; // polar angle of truth particle in rad, phi = arctan(py / px)
+  mutable std::vector<double> m_t_p; // truth momentum p in MeV
+  mutable std::vector<double> m_t_pT; // transverse truth momentum pT in MeV
+  mutable std::vector<double> m_t_eta; // eta of truth particle
+  mutable std::array<std::vector<double>, 4> m_t_st_x; // vector of the x components of the simulated hits of the truth particle for each station
+  mutable std::array<std::vector<double>, 4> m_t_st_y; // vector of the y components of the simulated hits of the truth particle for each station
+  mutable std::array<std::vector<double>, 4> m_t_st_z; // vector of the z components of the simulated hits of the truth particle for each station
+  mutable std::vector<bool> m_isFiducial; // track is fiducial if there are simulated hits for stations 1 - 3 and the distance from the center is smaller than 100 mm
+
   mutable double m_truthLeptonMomentum;
   mutable int    m_truthBarcode;
   mutable int    m_truthPdg;
diff --git a/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt b/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt
index 9377c4716..aeebb0ae8 100755
--- a/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt
+++ b/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt
@@ -41,6 +41,7 @@ atlas_add_component(FaserActsKalmanFilter
     FaserActsKalmanFilter/IndexSourceLink.h
     FaserActsKalmanFilter/ITrackFinderTool.h
     FaserActsKalmanFilter/ITrackSeedTool.h
+    FaserActsKalmanFilter/ITrackTruthMatchingTool.h
     KalmanFitterTool.h
     LinearFit.h
 #    ClusterTrackSeedTool.h
@@ -101,6 +102,8 @@ atlas_add_component(FaserActsKalmanFilter
     src/TrackClassification.cxx
     src/TrackSeedWriterTool.cxx
     src/TrackSelection.cxx
+    src/TrackTruthMatchingTool.h
+    src/TrackTruthMatchingTool.cxx
 #    src/TruthTrackFinderTool.cxx
 #    src/TruthSeededTrackFinderTool.cxx
     src/ThreeStationTrackSeedTool.cxx
diff --git a/Tracking/Acts/FaserActsKalmanFilter/FaserActsKalmanFilter/ITrackTruthMatchingTool.h b/Tracking/Acts/FaserActsKalmanFilter/FaserActsKalmanFilter/ITrackTruthMatchingTool.h
new file mode 100644
index 000000000..045e597a9
--- /dev/null
+++ b/Tracking/Acts/FaserActsKalmanFilter/FaserActsKalmanFilter/ITrackTruthMatchingTool.h
@@ -0,0 +1,21 @@
+/*
+  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS and FASER
+  collaborations
+*/
+
+#ifndef FASERACTSKALMANFILTER_ITRACKTRUTHMATCHINGTOOL_H
+#define FASERACTSKALMANFILTER_ITRACKTRUTHMATCHINGTOOL_H
+
+#include "GaudiKernel/IAlgTool.h"
+#include "TrkTrack/Track.h"
+#include "xAODTruth/TruthParticle.h"
+
+class ITrackTruthMatchingTool : virtual public IAlgTool {
+public:
+  DeclareInterfaceID(ITrackTruthMatchingTool, 1, 0);
+
+  virtual std::pair<const xAOD::TruthParticle*, int>
+  getTruthParticle(const Trk::Track *track) const = 0;
+};
+
+#endif /* FASERACTSKALMANFILTER_ITRACKTRUTHMATCHINGTOOL_H */
diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.cxx
new file mode 100644
index 000000000..502faa3d4
--- /dev/null
+++ b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.cxx
@@ -0,0 +1,103 @@
+#include "TrackTruthMatchingTool.h"
+#include "TrackerPrepRawData/FaserSCT_Cluster.h"
+#include "TrackerRIO_OnTrack/FaserSCT_ClusterOnTrack.h"
+
+TrackTruthMatchingTool::TrackTruthMatchingTool(const std::string &type,
+                                               const std::string &name,
+                                               const IInterface *parent)
+    : base_class(type, name, parent) {}
+
+StatusCode TrackTruthMatchingTool::initialize() {
+  ATH_CHECK(m_simDataCollectionKey.initialize());
+  ATH_CHECK(m_truthParticleContainerKey.initialize());
+  return StatusCode::SUCCESS;
+}
+
+std::pair<const xAOD::TruthParticle *, int>
+TrackTruthMatchingTool::getTruthParticle(const Trk::Track *track) const {
+  const xAOD::TruthParticle *truthParticle = nullptr;
+  const EventContext &ctx = Gaudi::Hive::currentContext();
+  SG::ReadHandle<xAOD::TruthParticleContainer> truthParticleContainer{
+      m_truthParticleContainerKey, ctx};
+  if (!truthParticleContainer.isValid()) {
+    ATH_MSG_WARNING("xAOD::TruthParticleContainer is not valid.");
+    return {truthParticle, -1};
+  }
+  SG::ReadHandle<TrackerSimDataCollection> simDataCollection{
+      m_simDataCollectionKey, ctx};
+  if (!simDataCollection.isValid()) {
+    ATH_MSG_WARNING("TrackerSimDataCollection is not valid.");
+    return {truthParticle, -1};
+  }
+  std::vector<ParticleHitCount> particleHitCounts{};
+  identifyContributingParticles(*track, *simDataCollection, particleHitCounts);
+  if (particleHitCounts.empty()) {
+    ATH_MSG_WARNING("Cannot find any truth particle matched to the track.");
+    return {truthParticle, -1};
+  }
+  int barcode = particleHitCounts.front().barcode;
+  int hitCount = particleHitCounts.front().hitCount;
+  auto it = std::find_if(truthParticleContainer->begin(),
+                         truthParticleContainer->end(),
+                         [barcode](const xAOD::TruthParticle_v1 *particle) {
+                           return particle->barcode() == barcode;
+                         });
+  if (it == truthParticleContainer->end()) {
+    ATH_MSG_WARNING("Cannot find particle with barcode "
+                    << barcode << " in truth particle container.");
+    return {truthParticle, -1};
+  }
+  truthParticle = *it;
+  return {truthParticle, hitCount};
+}
+
+StatusCode TrackTruthMatchingTool::finalize() { return StatusCode::SUCCESS; }
+
+void TrackTruthMatchingTool::increaseHitCount(
+    std::vector<ParticleHitCount> &particleHitCounts, int barcode) {
+  auto it = std::find_if(
+      particleHitCounts.begin(), particleHitCounts.end(),
+      [=](const ParticleHitCount &phc) { return (phc.barcode == barcode); });
+  // either increase count if we saw the particle before or add it
+  if (it != particleHitCounts.end()) {
+    it->hitCount += 1u;
+  } else {
+    particleHitCounts.push_back({barcode, 1u});
+  }
+}
+
+void TrackTruthMatchingTool::sortHitCount(
+    std::vector<ParticleHitCount> &particleHitCounts) {
+  std::sort(particleHitCounts.begin(), particleHitCounts.end(),
+            [](const ParticleHitCount &lhs, const ParticleHitCount &rhs) {
+              return (lhs.hitCount > rhs.hitCount);
+            });
+}
+
+void TrackTruthMatchingTool::identifyContributingParticles(
+    const Trk::Track &track, const TrackerSimDataCollection &simDataCollection,
+    std::vector<ParticleHitCount> &particleHitCounts) {
+  for (const Trk::MeasurementBase *meas : *track.measurementsOnTrack()) {
+    const auto *clusterOnTrack =
+        dynamic_cast<const Tracker::FaserSCT_ClusterOnTrack *>(meas);
+    if (!clusterOnTrack)
+      continue;
+    std::vector<int> barcodes{};
+    const Tracker::FaserSCT_Cluster *cluster = clusterOnTrack->prepRawData();
+    for (Identifier id : cluster->rdoList()) {
+      if (simDataCollection.count(id) == 0)
+        continue;
+      const auto &deposits = simDataCollection.at(id).getdeposits();
+      for (const TrackerSimData::Deposit &deposit : deposits) {
+        int barcode = deposit.first->barcode();
+        // count each barcode only once for a wafer
+        if (std::find(barcodes.begin(), barcodes.end(), barcode) ==
+            barcodes.end()) {
+          barcodes.push_back(barcode);
+          increaseHitCount(particleHitCounts, barcode);
+        }
+      }
+    }
+  }
+  sortHitCount(particleHitCounts);
+}
diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.h b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.h
new file mode 100644
index 000000000..94ab44058
--- /dev/null
+++ b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.h
@@ -0,0 +1,40 @@
+#ifndef FASERACTSKALMANFILTER_TRACKTRUTHMATCHINGTOOL_H
+#define FASERACTSKALMANFILTER_TRACKTRUTHMATCHINGTOOL_H
+
+#include "AthenaBaseComps/AthAlgTool.h"
+#include "FaserActsKalmanFilter/ITrackTruthMatchingTool.h"
+#include "TrackerSimData/TrackerSimDataCollection.h"
+#include "TrkTrack/Track.h"
+#include "xAODTruth/TruthParticleContainer.h"
+
+class TrackTruthMatchingTool
+    : public extends<AthAlgTool, ITrackTruthMatchingTool> {
+public:
+  TrackTruthMatchingTool(const std::string &type, const std::string &name,
+                         const IInterface *parent);
+  virtual ~TrackTruthMatchingTool() = default;
+  virtual StatusCode initialize() override;
+  virtual StatusCode finalize() override;
+
+  std::pair<const xAOD::TruthParticle*, int>
+  getTruthParticle(const Trk::Track *track) const;
+
+private:
+  struct ParticleHitCount {
+    int barcode;
+    size_t hitCount;
+  };
+  static void increaseHitCount(std::vector<ParticleHitCount> &particleHitCounts,
+                               int particleId);
+  static void sortHitCount(std::vector<ParticleHitCount> &particleHitCounts);
+  static void identifyContributingParticles(
+      const Trk::Track &track, const TrackerSimDataCollection &simDataCollection,
+      std::vector<ParticleHitCount> &particleHitCounts);
+
+  SG::ReadHandleKey<TrackerSimDataCollection> m_simDataCollectionKey{
+      this, "TrackerSimDataCollection", "SCT_SDO_Map"};
+  SG::ReadHandleKey<xAOD::TruthParticleContainer> m_truthParticleContainerKey{
+      this, "ParticleContainer", "TruthParticles"};
+};
+
+#endif /* FASERACTSKALMANFILTER_TRACKTRUTHMATCHINGTOOL_H */
diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx
index 967f80641..665a80230 100755
--- a/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx
+++ b/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx
@@ -29,6 +29,7 @@
 #include "../CircleFitTrackSeedTool.h"
 #include "../GhostBusters.h"
 #include "../CreateTrkTrackTool.h"
+#include "../TrackTruthMatchingTool.h"
 
 DECLARE_COMPONENT(FaserActsKalmanFilterAlg)
 DECLARE_COMPONENT(CombinatorialKalmanFilterAlg)
@@ -57,3 +58,4 @@ DECLARE_COMPONENT(SeedingAlg)
 DECLARE_COMPONENT(CircleFitTrackSeedTool)
 DECLARE_COMPONENT(GhostBusters)
 DECLARE_COMPONENT(CreateTrkTrackTool)
+DECLARE_COMPONENT(TrackTruthMatchingTool)
-- 
GitLab