diff --git a/PhysicsAnalysis/NtupleDumper/CMakeLists.txt b/PhysicsAnalysis/NtupleDumper/CMakeLists.txt index 944ed9689053915dd526251fabf1c5b1af5dfa4e..b81b98183b68fc67cf0df4592c7f1cad6917e6f4 100644 --- a/PhysicsAnalysis/NtupleDumper/CMakeLists.txt +++ b/PhysicsAnalysis/NtupleDumper/CMakeLists.txt @@ -5,7 +5,7 @@ atlas_add_component( src/NtupleDumperAlg.h src/NtupleDumperAlg.cxx src/component/NtupleDumper_entries.cxx - LINK_LIBRARIES AthenaBaseComps StoreGateLib xAODFaserWaveform xAODFaserTrigger ScintIdentifier FaserCaloIdentifier GeneratorObjects FaserActsGeometryLib TrackerSimEvent TrackerSimData TrackerIdentifier TrackerReadoutGeometry TrkTrack GeoPrimitives TrackerRIO_OnTrack TrackerSpacePoint + LINK_LIBRARIES AthenaBaseComps StoreGateLib xAODFaserWaveform xAODFaserTrigger ScintIdentifier FaserCaloIdentifier GeneratorObjects FaserActsGeometryLib TrackerSimEvent TrackerSimData TrackerIdentifier TrackerReadoutGeometry TrkTrack GeoPrimitives TrackerRIO_OnTrack TrackerSpacePoint FaserActsKalmanFilterLib ) atlas_install_python_modules(python/*.py) diff --git a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx index d1a07f428c73dfb7d6c7f1bf69ab20b73f12a196..dba1c0e68822a2f54893294a781065d3bf052383 100644 --- a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx +++ b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.cxx @@ -16,6 +16,9 @@ #include "xAODTruth/TruthParticle.h" #include <cmath> #include <TH1F.h> +#include <numeric> + +constexpr float NaN = std::numeric_limits<double>::quiet_NaN(); NtupleDumperAlg::NtupleDumperAlg(const std::string &name, @@ -70,7 +73,73 @@ void NtupleDumperAlg::FillWaveBranches(const xAOD::WaveformHitContainer &wave) c } } -StatusCode NtupleDumperAlg::initialize() +NtupleDumperAlg::Vector3 NtupleDumperAlg::getGlobalPosition(const FaserSiHit &hit) const { + Identifier waferId = + m_sctHelper->wafer_id(hit.getStation(), hit.getPlane(), hit.getRow(), + hit.getModule(), hit.getSensor()); + Vector3 localStartPos = hit.localStartPosition(); + Vector3 localEndPos = hit.localEndPosition(); + Vector3 localPos = 0.5 * (localEndPos + localStartPos); + const TrackerDD::SiDetectorElement *element = m_detMgr->getDetectorElement(waferId); + Vector3 globalPosition = Amg::EigenTransformToCLHEP(element->transformHit()) * localPos; + return globalPosition; +} + +StatusCode NtupleDumperAlg::getTruthPositions(int barcode) const { + SG::ReadHandle<FaserSiHitCollection> siHitCollection(m_siHitCollectionKey); + ATH_CHECK(siHitCollection.isValid()); + + // create map with truth positions in each station + std::array<std::vector<Vector3>, 4> hitMap {}; + for (const FaserSiHit &hit : *siHitCollection) { + if (hit.trackNumber() == barcode) { + Vector3 position = getGlobalPosition(hit); + hitMap[hit.getStation()].push_back(position); + } + } + + // calculate average position in each station + for (int station=0; station < 4; ++station) { + std::vector<Vector3> &hits {hitMap[station]}; + if (hits.empty()) { + m_t_st_x[station].push_back(NaN); + m_t_st_y[station].push_back(NaN); + m_t_st_z[station].push_back(NaN); + } else { + auto const count = static_cast<double>(hits.size()); + std::array<double, 3> sums {}; + for (const Vector3 &hit : hits) { + sums[0] += hit.x(); + sums[1] += hit.y(); + sums[2] += hit.z(); + } + m_t_st_x[station].push_back(sums[0] / count); + m_t_st_y[station].push_back(sums[1] / count); + m_t_st_z[station].push_back(sums[2] / count); + } + } + return StatusCode::SUCCESS; +} + +bool NtupleDumperAlg::isFiducial() const { + bool isFiducial {true}; + for (int station = 0; station < 4; ++station) { + double st_x {m_t_st_x[station].back()}; + double st_y {m_t_st_y[station].back()}; + if (!std::isnan(st_x) && !std::isnan(st_y)) { + // distance from center < 100 mm + if (st_x * st_x + st_y * st_y > 100 * 100) + isFiducial = false; + } else { + // there have to be simulated hits in stations 1 - 3 + if (station > 0) + isFiducial = false; + } + } + return isFiducial; +} + +StatusCode NtupleDumperAlg::initialize() { ATH_CHECK(m_truthEventContainer.initialize()); ATH_CHECK(m_truthParticleContainer.initialize()); @@ -85,6 +154,7 @@ StatusCode NtupleDumperAlg::initialize() ATH_CHECK(m_simDataCollection.initialize()); ATH_CHECK(m_FaserTriggerData.initialize()); ATH_CHECK(m_ClockWaveformContainer.initialize()); + ATH_CHECK(m_siHitCollectionKey.initialize()); ATH_CHECK(detStore()->retrieve(m_sctHelper, "FaserSCT_ID")); ATH_CHECK(detStore()->retrieve(m_vetoNuHelper, "VetoNuID")); @@ -96,6 +166,7 @@ StatusCode NtupleDumperAlg::initialize() ATH_CHECK(detStore()->retrieve(m_detMgr, "SCT")); ATH_CHECK(m_extrapolationTool.retrieve()); ATH_CHECK(m_trackingGeometryTool.retrieve()); + ATH_CHECK(m_trackTruthMatchingTool.retrieve()); ATH_CHECK(m_spacePointContainerKey.initialize()); @@ -220,6 +291,37 @@ StatusCode NtupleDumperAlg::initialize() m_tree->Branch("Track_ThetaX_atCalo", &m_thetaxCalo); m_tree->Branch("Track_ThetaY_atCalo", &m_thetayCalo); + m_tree->Branch("t_pdg", &m_t_pdg); + m_tree->Branch("t_barcode", &m_t_barcode); + m_tree->Branch("t_truthHitRatio", &m_t_truthHitRatio); + m_tree->Branch("t_prodVtx_x", &m_t_prodVtx_x); + m_tree->Branch("t_prodVtx_y", &m_t_prodVtx_y); + m_tree->Branch("t_prodVtx_z", &m_t_prodVtx_z); + m_tree->Branch("t_decayVtx_x", &m_t_decayVtx_x); + m_tree->Branch("t_decayVtx_y", &m_t_decayVtx_y); + m_tree->Branch("t_decayVtx_z", &m_t_decayVtx_z); + m_tree->Branch("t_px", &m_t_px); + m_tree->Branch("t_py", &m_t_py); + m_tree->Branch("t_pz", &m_t_pz); + m_tree->Branch("t_theta", &m_t_theta); + m_tree->Branch("t_phi", &m_t_phi); + m_tree->Branch("t_p", &m_t_p); + m_tree->Branch("t_pT", &m_t_pT); + m_tree->Branch("t_eta", &m_t_eta); + m_tree->Branch("t_st0_x", &m_t_st_x[0]); + m_tree->Branch("t_st0_y", &m_t_st_y[0]); + m_tree->Branch("t_st0_z", &m_t_st_z[0]); + m_tree->Branch("t_st1_x", &m_t_st_x[1]); + m_tree->Branch("t_st1_y", &m_t_st_y[1]); + m_tree->Branch("t_st1_z", &m_t_st_z[1]); + m_tree->Branch("t_st2_x", &m_t_st_x[2]); + m_tree->Branch("t_st2_y", &m_t_st_y[2]); + m_tree->Branch("t_st2_z", &m_t_st_z[2]); + m_tree->Branch("t_st3_x", &m_t_st_x[3]); + m_tree->Branch("t_st3_y", &m_t_st_y[3]); + m_tree->Branch("t_st3_z", &m_t_st_z[3]); + m_tree->Branch("isFiducial", &m_isFiducial); + m_tree->Branch("pTruthLepton", &m_truthLeptonMomentum, "pTruthLepton/D"); m_tree->Branch("truthBarcode", &m_truthBarcode, "truthBarcode/I"); m_tree->Branch("truthPdg", &m_truthPdg, "truthPdg/I"); @@ -273,7 +375,7 @@ StatusCode NtupleDumperAlg::initialize() } -StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const +StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const { clearTree(); @@ -339,7 +441,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const return StatusCode::SUCCESS; // finished with this event - } else if ( ((m_tap&8)==0) && (((m_tap&4)==0)||((m_tap&2)==0)) && (((m_tap&4)==0)||((m_tap&1)==0)) && (((m_tap&2)==0)||((m_tap&1)==0)) ) { // don't process events that don't trigger coincidence triggers: 1=calo, 2=veotnu|neto1|preshower, 4=TimingLayer, 8=(VetoNu|Veto2)&Preshower + } else if ( ((m_tap&8)==0) && (((m_tap&4)==0)||((m_tap&2)==0)) && (((m_tap&4)==0)||((m_tap&1)==0)) && (((m_tap&2)==0)||((m_tap&1)==0)) ) { // don't process events that don't trigger coincidence triggers: 1=calo, 2=veotnu|neto1|preshower, 4=TimingLayer, 8=(VetoNu|Veto2)&Preshower return StatusCode::SUCCESS; } m_tbp=triggerData->tbp(); @@ -420,7 +522,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const FillWaveBranches(*triggerContainer); FillWaveBranches(*preshowerContainer); FillWaveBranches(*ecalContainer); - + m_calo_total=m_wave_charge[0]+m_wave_charge[1]+m_wave_charge[2]+m_wave_charge[3]; m_calo_rawtotal=m_wave_raw_charge[0]+m_wave_raw_charge[1]+m_wave_raw_charge[2]+m_wave_raw_charge[3]; @@ -430,13 +532,13 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const m_Calo1_Edep = (m_wave_charge[1] / 24.333) * m_MIP_sim_Edep_calo; m_Calo2_Edep = (m_wave_charge[2] / 24.409) * m_MIP_sim_Edep_calo; m_Calo3_Edep = (m_wave_charge[3] / 25.555) * m_MIP_sim_Edep_calo; - } else if (m_CaloConfig == "Low_gain") { // assume low gain calo + } else if (m_CaloConfig == "Low_gain") { // assume low gain calo m_Calo0_Edep = (m_wave_charge[0] / 0.7909) * m_MIP_sim_Edep_calo; m_Calo1_Edep = (m_wave_charge[1] / 0.8197) * m_MIP_sim_Edep_calo; m_Calo2_Edep = (m_wave_charge[2] / 0.8256) * m_MIP_sim_Edep_calo; m_Calo3_Edep = (m_wave_charge[3] / 0.8821) * m_MIP_sim_Edep_calo; } else { - ATH_MSG_WARNING("Run config is neither High_gain nor Low_gain, it is " << m_CaloConfig << ", calo calibration will be zero"); + ATH_MSG_WARNING("Run config is neither High_gain nor Low_gain, it is " << m_CaloConfig << ", calo calibration will be zero"); } m_Calo_Total_Edep = m_Calo0_Edep + m_Calo1_Edep + m_Calo2_Edep + m_Calo3_Edep; @@ -580,6 +682,50 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const m_pzdown.push_back(candidateDownParameters->momentum().z()); m_pdown.push_back(sqrt( pow(candidateDownParameters->momentum().x(),2) + pow(candidateDownParameters->momentum().y(),2) + pow(candidateDownParameters->momentum().z(),2) )); + if (!realData) { + auto [truthParticle, hitCount] = m_trackTruthMatchingTool->getTruthParticle(track); + if (truthParticle != nullptr) { + m_t_pdg.push_back(truthParticle->pdgId()); + m_t_barcode.push_back(truthParticle->barcode()); + // the track fit eats up 5 degrees of freedom, thus the number of hits on track + // is m_DoF + 5 + m_t_truthHitRatio.push_back(hitCount / (m_DoF.back() + 5)); + ATH_CHECK(getTruthPositions(truthParticle->barcode())); + m_isFiducial.push_back(isFiducial()); + if (truthParticle->hasProdVtx()) { + m_t_prodVtx_x.push_back(truthParticle->prodVtx()->x()); + m_t_prodVtx_y.push_back(truthParticle->prodVtx()->y()); + m_t_prodVtx_z.push_back(truthParticle->prodVtx()->z()); + } else { + m_t_prodVtx_x.push_back(NaN); + m_t_prodVtx_y.push_back(NaN); + m_t_prodVtx_z.push_back(NaN); + } + if (truthParticle->hasDecayVtx()) { + m_t_decayVtx_x.push_back(truthParticle->decayVtx()->x()); + m_t_decayVtx_y.push_back(truthParticle->decayVtx()->y()); + m_t_decayVtx_z.push_back(truthParticle->decayVtx()->z()); + } else { + m_t_decayVtx_x.push_back(NaN); + m_t_decayVtx_y.push_back(NaN); + m_t_decayVtx_z.push_back(NaN); + } + m_t_px.push_back(truthParticle->px()); + m_t_py.push_back(truthParticle->py()); + m_t_pz.push_back(truthParticle->pz()); + m_t_theta.push_back(truthParticle->p4().Theta()); + m_t_phi.push_back(truthParticle->p4().Phi()); + m_t_p.push_back(truthParticle->p4().P()); + m_t_pT.push_back(truthParticle->p4().Pt()); + m_t_eta.push_back(truthParticle->p4().Eta()); + } else { + setNaN(); + } + } else { + ATH_MSG_WARNING("Can not find truthParticle."); + setNaN(); + } + // fill extrapolation vectors with filler values that get changed iif the track extrapolation succeeds m_xVetoNu.push_back(-10000); m_yVetoNu.push_back(-10000); @@ -676,7 +822,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const ATH_MSG_INFO("Trig null targetParameters"); } - } + } // extrapolate track from tracking station 3 if (stationMap.count(3) > 0) { // extrapolation crashes if the track does not end in the Station 3, as it is too far away to extrapolate @@ -751,7 +897,7 @@ StatusCode NtupleDumperAlg::execute(const EventContext &ctx) const } -StatusCode NtupleDumperAlg::finalize() +StatusCode NtupleDumperAlg::finalize() { return StatusCode::SUCCESS; } @@ -843,7 +989,7 @@ NtupleDumperAlg::clearTree() const m_charge.clear(); m_nLayers.clear(); m_longTracks = 0; - + m_nHit0.clear(); m_nHit1.clear(); m_nHit2.clear(); @@ -884,7 +1030,57 @@ NtupleDumperAlg::clearTree() const m_thetaxCalo.clear(); m_thetayCalo.clear(); + m_t_pdg.clear(); + m_t_barcode.clear(); + m_t_truthHitRatio.clear(); + m_t_prodVtx_x.clear(); + m_t_prodVtx_y.clear(); + m_t_prodVtx_z.clear(); + m_t_decayVtx_x.clear(); + m_t_decayVtx_y.clear(); + m_t_decayVtx_z.clear(); + m_t_px.clear(); + m_t_py.clear(); + m_t_pz.clear(); + m_t_theta.clear(); + m_t_phi.clear(); + m_t_p.clear(); + m_t_pT.clear(); + m_t_eta.clear(); + m_isFiducial.clear(); + for (int station = 0; station < 4; ++station) { + m_t_st_x[station].clear(); + m_t_st_y[station].clear(); + m_t_st_z[station].clear(); + } + m_truthLeptonMomentum = 0; m_truthBarcode = 0; m_truthPdg = 0; } + +void NtupleDumperAlg::setNaN() const { + m_t_pdg.push_back(0); + m_t_barcode.push_back(-1); + m_t_truthHitRatio.push_back(NaN); + m_t_prodVtx_x.push_back(NaN); + m_t_prodVtx_y.push_back(NaN); + m_t_prodVtx_z.push_back(NaN); + m_t_decayVtx_x.push_back(NaN); + m_t_decayVtx_y.push_back(NaN); + m_t_decayVtx_z.push_back(NaN); + m_t_px.push_back(NaN); + m_t_py.push_back(NaN); + m_t_pz.push_back(NaN); + m_t_theta.push_back(NaN); + m_t_phi.push_back(NaN); + m_t_p.push_back(NaN); + m_t_pT.push_back(NaN); + m_t_eta.push_back(NaN); + for (int station = 0; station < 4; ++station) { + m_t_st_x[station].push_back(NaN); + m_t_st_y[station].push_back(NaN); + m_t_st_z[station].push_back(NaN); + } + m_isFiducial.push_back(false); +} \ No newline at end of file diff --git a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h index 1b26caadeb7290c34a199ef0e941f9203b600bba..e35d10f2ba5d074430b439d606b65f30f03e2170 100644 --- a/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h +++ b/PhysicsAnalysis/NtupleDumper/src/NtupleDumperAlg.h @@ -15,6 +15,8 @@ #include "TrackerSimData/TrackerSimDataCollection.h" #include "FaserActsGeometryInterfaces/IFaserActsExtrapolationTool.h" #include "FaserActsGeometryInterfaces/IFaserActsTrackingGeometryTool.h" +#include "FaserActsKalmanFilter/ITrackTruthMatchingTool.h" +#include "TrackerSimEvent/FaserSiHitCollection.h" #include <vector> @@ -44,11 +46,16 @@ private: bool waveformHitOK(const xAOD::WaveformHit* hit) const; void clearTree() const; + void setNaN() const; void addBranch(const std::string &name,float* var); void addBranch(const std::string &name,unsigned int* var); void addWaveBranches(const std::string &name, int nchannels, int first); void FillWaveBranches(const xAOD::WaveformHitContainer &wave) const; + using Vector3 = HepGeom::Point3D<double>; + Vector3 getGlobalPosition(const FaserSiHit &hit) const; + StatusCode getTruthPositions(int barcode) const; + bool isFiducial() const; ServiceHandle <ITHistSvc> m_histSvc; SG::ReadHandleKey<xAOD::TruthEventContainer> m_truthEventContainer { this, "EventContainer", "TruthEvents", "Truth event container name." }; @@ -64,11 +71,13 @@ private: SG::ReadHandleKey<xAOD::WaveformHitContainer> m_ecalContainer { this, "EcalContainer", "CaloWaveformHits", "Ecal hit container name" }; SG::ReadHandleKey<Tracker::FaserSCT_ClusterContainer> m_clusterContainer { this, "ClusterContainer", "SCT_ClusterContainer", "Tracker cluster container name" }; SG::ReadHandleKey<FaserSCT_SpacePointContainer> m_spacePointContainerKey { this, "SpacePoints", "SCT_SpacePointContainer", "space point container"}; + SG::ReadHandleKey<FaserSiHitCollection> m_siHitCollectionKey{this, "FaserSiHitCollection", "SCT_Hits"}; SG::ReadHandleKey<xAOD::FaserTriggerData> m_FaserTriggerData { this, "FaserTriggerDataKey", "FaserTriggerData", "ReadHandleKey for xAOD::FaserTriggerData"}; SG::ReadHandleKey<xAOD::WaveformClock> m_ClockWaveformContainer { this, "WaveformClockKey", "WaveformClock", "ReadHandleKey for ClockWaveforms Container"}; ToolHandle<IFaserActsExtrapolationTool> m_extrapolationTool { this, "ExtrapolationTool", "FaserActsExtrapolationTool" }; ToolHandle<IFaserActsTrackingGeometryTool> m_trackingGeometryTool {this, "TrackingGeometryTool", "FaserActsTrackingGeometryTool"}; + ToolHandle<ITrackTruthMatchingTool> m_trackTruthMatchingTool {this, "TrackTruthMatchingTool", "TrackTruthMatchingTool"}; const TrackerDD::SCT_DetectorManager* m_detMgr {nullptr}; @@ -203,6 +212,28 @@ private: mutable std::vector<double> m_thetaxCalo; mutable std::vector<double> m_thetayCalo; + mutable std::vector<int> m_t_pdg; // pdg code of the truth matched particle + mutable std::vector<int> m_t_barcode; // barcode of the truth matched particle + mutable std::vector<double> m_t_truthHitRatio; // ratio of hits on track matched to the truth particle over all hits on track + mutable std::vector<double> m_t_prodVtx_x; // x component of the production vertex in mm + mutable std::vector<double> m_t_prodVtx_y; // y component of the production vertex in mm + mutable std::vector<double> m_t_prodVtx_z; // z component of the production vertex in mm + mutable std::vector<double> m_t_decayVtx_x; // x component of the decay vertex in mm + mutable std::vector<double> m_t_decayVtx_y; // y component of the decay vertex in mm + mutable std::vector<double> m_t_decayVtx_z; // z component of the decay vertex in mm + mutable std::vector<double> m_t_px; // truth momentum px in MeV + mutable std::vector<double> m_t_py; // truth momentum py in MeV + mutable std::vector<double> m_t_pz; // truth momentum pz in MeV + mutable std::vector<double> m_t_theta; // angle of truth particle with respsect to the beam axis in rad, theta = arctan(sqrt(px * px + py * py) / pz) + mutable std::vector<double> m_t_phi; // polar angle of truth particle in rad, phi = arctan(py / px) + mutable std::vector<double> m_t_p; // truth momentum p in MeV + mutable std::vector<double> m_t_pT; // transverse truth momentum pT in MeV + mutable std::vector<double> m_t_eta; // eta of truth particle + mutable std::array<std::vector<double>, 4> m_t_st_x; // vector of the x components of the simulated hits of the truth particle for each station + mutable std::array<std::vector<double>, 4> m_t_st_y; // vector of the y components of the simulated hits of the truth particle for each station + mutable std::array<std::vector<double>, 4> m_t_st_z; // vector of the z components of the simulated hits of the truth particle for each station + mutable std::vector<bool> m_isFiducial; // track is fiducial if there are simulated hits for stations 1 - 3 and the distance from the center is smaller than 100 mm + mutable double m_truthLeptonMomentum; mutable int m_truthBarcode; mutable int m_truthPdg; diff --git a/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt b/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt index 9377c4716810301c16fe7ca606ba31e6f5ba318c..aeebb0ae8a68d4390315be3bf85505b219c32226 100755 --- a/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt +++ b/Tracking/Acts/FaserActsKalmanFilter/CMakeLists.txt @@ -41,6 +41,7 @@ atlas_add_component(FaserActsKalmanFilter FaserActsKalmanFilter/IndexSourceLink.h FaserActsKalmanFilter/ITrackFinderTool.h FaserActsKalmanFilter/ITrackSeedTool.h + FaserActsKalmanFilter/ITrackTruthMatchingTool.h KalmanFitterTool.h LinearFit.h # ClusterTrackSeedTool.h @@ -101,6 +102,8 @@ atlas_add_component(FaserActsKalmanFilter src/TrackClassification.cxx src/TrackSeedWriterTool.cxx src/TrackSelection.cxx + src/TrackTruthMatchingTool.h + src/TrackTruthMatchingTool.cxx # src/TruthTrackFinderTool.cxx # src/TruthSeededTrackFinderTool.cxx src/ThreeStationTrackSeedTool.cxx diff --git a/Tracking/Acts/FaserActsKalmanFilter/FaserActsKalmanFilter/ITrackTruthMatchingTool.h b/Tracking/Acts/FaserActsKalmanFilter/FaserActsKalmanFilter/ITrackTruthMatchingTool.h new file mode 100644 index 0000000000000000000000000000000000000000..045e597a9ebf916bad29665c6e07a4217c027ce8 --- /dev/null +++ b/Tracking/Acts/FaserActsKalmanFilter/FaserActsKalmanFilter/ITrackTruthMatchingTool.h @@ -0,0 +1,21 @@ +/* + Copyright (C) 2002-2022 CERN for the benefit of the ATLAS and FASER + collaborations +*/ + +#ifndef FASERACTSKALMANFILTER_ITRACKTRUTHMATCHINGTOOL_H +#define FASERACTSKALMANFILTER_ITRACKTRUTHMATCHINGTOOL_H + +#include "GaudiKernel/IAlgTool.h" +#include "TrkTrack/Track.h" +#include "xAODTruth/TruthParticle.h" + +class ITrackTruthMatchingTool : virtual public IAlgTool { +public: + DeclareInterfaceID(ITrackTruthMatchingTool, 1, 0); + + virtual std::pair<const xAOD::TruthParticle*, int> + getTruthParticle(const Trk::Track *track) const = 0; +}; + +#endif /* FASERACTSKALMANFILTER_ITRACKTRUTHMATCHINGTOOL_H */ diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.cxx new file mode 100644 index 0000000000000000000000000000000000000000..502faa3d4bb4dde421f63d2105f8d5525c67a0be --- /dev/null +++ b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.cxx @@ -0,0 +1,103 @@ +#include "TrackTruthMatchingTool.h" +#include "TrackerPrepRawData/FaserSCT_Cluster.h" +#include "TrackerRIO_OnTrack/FaserSCT_ClusterOnTrack.h" + +TrackTruthMatchingTool::TrackTruthMatchingTool(const std::string &type, + const std::string &name, + const IInterface *parent) + : base_class(type, name, parent) {} + +StatusCode TrackTruthMatchingTool::initialize() { + ATH_CHECK(m_simDataCollectionKey.initialize()); + ATH_CHECK(m_truthParticleContainerKey.initialize()); + return StatusCode::SUCCESS; +} + +std::pair<const xAOD::TruthParticle *, int> +TrackTruthMatchingTool::getTruthParticle(const Trk::Track *track) const { + const xAOD::TruthParticle *truthParticle = nullptr; + const EventContext &ctx = Gaudi::Hive::currentContext(); + SG::ReadHandle<xAOD::TruthParticleContainer> truthParticleContainer{ + m_truthParticleContainerKey, ctx}; + if (!truthParticleContainer.isValid()) { + ATH_MSG_WARNING("xAOD::TruthParticleContainer is not valid."); + return {truthParticle, -1}; + } + SG::ReadHandle<TrackerSimDataCollection> simDataCollection{ + m_simDataCollectionKey, ctx}; + if (!simDataCollection.isValid()) { + ATH_MSG_WARNING("TrackerSimDataCollection is not valid."); + return {truthParticle, -1}; + } + std::vector<ParticleHitCount> particleHitCounts{}; + identifyContributingParticles(*track, *simDataCollection, particleHitCounts); + if (particleHitCounts.empty()) { + ATH_MSG_WARNING("Cannot find any truth particle matched to the track."); + return {truthParticle, -1}; + } + int barcode = particleHitCounts.front().barcode; + int hitCount = particleHitCounts.front().hitCount; + auto it = std::find_if(truthParticleContainer->begin(), + truthParticleContainer->end(), + [barcode](const xAOD::TruthParticle_v1 *particle) { + return particle->barcode() == barcode; + }); + if (it == truthParticleContainer->end()) { + ATH_MSG_WARNING("Cannot find particle with barcode " + << barcode << " in truth particle container."); + return {truthParticle, -1}; + } + truthParticle = *it; + return {truthParticle, hitCount}; +} + +StatusCode TrackTruthMatchingTool::finalize() { return StatusCode::SUCCESS; } + +void TrackTruthMatchingTool::increaseHitCount( + std::vector<ParticleHitCount> &particleHitCounts, int barcode) { + auto it = std::find_if( + particleHitCounts.begin(), particleHitCounts.end(), + [=](const ParticleHitCount &phc) { return (phc.barcode == barcode); }); + // either increase count if we saw the particle before or add it + if (it != particleHitCounts.end()) { + it->hitCount += 1u; + } else { + particleHitCounts.push_back({barcode, 1u}); + } +} + +void TrackTruthMatchingTool::sortHitCount( + std::vector<ParticleHitCount> &particleHitCounts) { + std::sort(particleHitCounts.begin(), particleHitCounts.end(), + [](const ParticleHitCount &lhs, const ParticleHitCount &rhs) { + return (lhs.hitCount > rhs.hitCount); + }); +} + +void TrackTruthMatchingTool::identifyContributingParticles( + const Trk::Track &track, const TrackerSimDataCollection &simDataCollection, + std::vector<ParticleHitCount> &particleHitCounts) { + for (const Trk::MeasurementBase *meas : *track.measurementsOnTrack()) { + const auto *clusterOnTrack = + dynamic_cast<const Tracker::FaserSCT_ClusterOnTrack *>(meas); + if (!clusterOnTrack) + continue; + std::vector<int> barcodes{}; + const Tracker::FaserSCT_Cluster *cluster = clusterOnTrack->prepRawData(); + for (Identifier id : cluster->rdoList()) { + if (simDataCollection.count(id) == 0) + continue; + const auto &deposits = simDataCollection.at(id).getdeposits(); + for (const TrackerSimData::Deposit &deposit : deposits) { + int barcode = deposit.first->barcode(); + // count each barcode only once for a wafer + if (std::find(barcodes.begin(), barcodes.end(), barcode) == + barcodes.end()) { + barcodes.push_back(barcode); + increaseHitCount(particleHitCounts, barcode); + } + } + } + } + sortHitCount(particleHitCounts); +} diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.h b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.h new file mode 100644 index 0000000000000000000000000000000000000000..94ab440588b4353a4d45a8d7a15a61aac434f6a3 --- /dev/null +++ b/Tracking/Acts/FaserActsKalmanFilter/src/TrackTruthMatchingTool.h @@ -0,0 +1,40 @@ +#ifndef FASERACTSKALMANFILTER_TRACKTRUTHMATCHINGTOOL_H +#define FASERACTSKALMANFILTER_TRACKTRUTHMATCHINGTOOL_H + +#include "AthenaBaseComps/AthAlgTool.h" +#include "FaserActsKalmanFilter/ITrackTruthMatchingTool.h" +#include "TrackerSimData/TrackerSimDataCollection.h" +#include "TrkTrack/Track.h" +#include "xAODTruth/TruthParticleContainer.h" + +class TrackTruthMatchingTool + : public extends<AthAlgTool, ITrackTruthMatchingTool> { +public: + TrackTruthMatchingTool(const std::string &type, const std::string &name, + const IInterface *parent); + virtual ~TrackTruthMatchingTool() = default; + virtual StatusCode initialize() override; + virtual StatusCode finalize() override; + + std::pair<const xAOD::TruthParticle*, int> + getTruthParticle(const Trk::Track *track) const; + +private: + struct ParticleHitCount { + int barcode; + size_t hitCount; + }; + static void increaseHitCount(std::vector<ParticleHitCount> &particleHitCounts, + int particleId); + static void sortHitCount(std::vector<ParticleHitCount> &particleHitCounts); + static void identifyContributingParticles( + const Trk::Track &track, const TrackerSimDataCollection &simDataCollection, + std::vector<ParticleHitCount> &particleHitCounts); + + SG::ReadHandleKey<TrackerSimDataCollection> m_simDataCollectionKey{ + this, "TrackerSimDataCollection", "SCT_SDO_Map"}; + SG::ReadHandleKey<xAOD::TruthParticleContainer> m_truthParticleContainerKey{ + this, "ParticleContainer", "TruthParticles"}; +}; + +#endif /* FASERACTSKALMANFILTER_TRACKTRUTHMATCHINGTOOL_H */ diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx index 967f80641f7846d303be4632c252ca30eaab6436..665a802301e8a0d5f156ad533eac838be053d898 100755 --- a/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx +++ b/Tracking/Acts/FaserActsKalmanFilter/src/components/FaserActsKalmanFilter_entries.cxx @@ -29,6 +29,7 @@ #include "../CircleFitTrackSeedTool.h" #include "../GhostBusters.h" #include "../CreateTrkTrackTool.h" +#include "../TrackTruthMatchingTool.h" DECLARE_COMPONENT(FaserActsKalmanFilterAlg) DECLARE_COMPONENT(CombinatorialKalmanFilterAlg) @@ -57,3 +58,4 @@ DECLARE_COMPONENT(SeedingAlg) DECLARE_COMPONENT(CircleFitTrackSeedTool) DECLARE_COMPONENT(GhostBusters) DECLARE_COMPONENT(CreateTrkTrackTool) +DECLARE_COMPONENT(TrackTruthMatchingTool)