diff --git a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h index da691ad7e237ccd2fc8a5edfbd1f8d905b4a4c93..869ab6f4d13fbf71c9cb8840628d5e61eeaa0b8a 100644 --- a/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h +++ b/Control/AthOnnx/AthOnnxComps/src/OnnxRuntimeInferenceTool.h @@ -48,7 +48,7 @@ namespace AthOnnx { ServiceHandle<IOnnxRuntimeSvc> m_onnxRuntimeSvc{"AthOnnx::OnnxRuntimeSvc", "AthOnnx::OnnxRuntimeSvc"}; ToolHandle<IOnnxRuntimeSessionTool> m_onnxSessionTool{ this, "ORTSessionTool", - "AthOnnx::OnnxRuntimeInferenceToolCPU" + "AthOnnx::OnnxRuntimeSessionToolCPU" }; std::vector<std::string> m_inputNodeNames; std::vector<std::string> m_outputNodeNames; diff --git a/InnerDetector/InDetConfig/python/ITkTrackRecoConfig.py b/InnerDetector/InDetConfig/python/ITkTrackRecoConfig.py index 6347676557484e71b70328d0fca754e6849cbba2..6c91dddf0ba35c4d7216516ee5f995f8d1f8a927 100644 --- a/InnerDetector/InDetConfig/python/ITkTrackRecoConfig.py +++ b/InnerDetector/InDetConfig/python/ITkTrackRecoConfig.py @@ -36,6 +36,12 @@ def CombinedTrackingPassFlagSets(flags): flags_set += [flags.cloneAndReplace( "Tracking.ActiveConfig", "Tracking.ITkActsPass")] + + # GNN pass + if TrackingComponent.GNNChain in flags.Tracking.recoChain: + flags_set += [flags.cloneAndReplace( + "Tracking.ActiveConfig", + "Tracking.ITkGNNPass")] # Acts Conversion Pass if flags.Detector.EnableCalo and flags.Acts.doITkConversion: diff --git a/InnerDetector/InDetConfig/python/ITkTrackingSiPatternConfig.py b/InnerDetector/InDetConfig/python/ITkTrackingSiPatternConfig.py index 8d562d69a0b643ccdf52320583f69ac25282b0b3..95b3c48747e093681b4d9218e3b9d0c0e5c1915e 100644 --- a/InnerDetector/InDetConfig/python/ITkTrackingSiPatternConfig.py +++ b/InnerDetector/InDetConfig/python/ITkTrackingSiPatternConfig.py @@ -51,7 +51,8 @@ def ITkTrackingSiPatternCfg(flags, # # ------------------------------------------------------------ - runTruth = flags.Tracking.ActiveConfig.doAthenaTrack or flags.Tracking.ActiveConfig.doActsToAthenaTrack + runTruth = flags.Tracking.ActiveConfig.doAthenaTrack or flags.Tracking.ActiveConfig.doActsToAthenaTrack or flags.Tracking.ActiveConfig.doGNNTrack + # Athena Track if flags.Tracking.ActiveConfig.doAthenaTrack: @@ -68,6 +69,13 @@ def ITkTrackingSiPatternCfg(flags, flags, TracksLocation=SiSPSeededTrackCollectionKey)) + # GNN Track + if flags.Tracking.ActiveConfig.doGNNTrack: + from InDetGNNTracking.InDetGNNTrackingConfig import GNNTrackMakerCfg + acc.merge(GNNTrackMakerCfg( + flags, + TracksLocation=SiSPSeededTrackCollectionKey)) + # ACTS seed if flags.Tracking.ActiveConfig.doActsSeed: diff --git a/InnerDetector/InDetGNNTracking/README.md b/InnerDetector/InDetGNNTracking/README.md new file mode 100644 index 0000000000000000000000000000000000000000..229d4de43e4289bd6451b5568e1ad112b7fbe96d --- /dev/null +++ b/InnerDetector/InDetGNNTracking/README.md @@ -0,0 +1,25 @@ +# Graph Neural Network for ITk tracking + +## To Fit track candidates from ACORN + + +```bash +function gnn_tracking() { + rm InDetIdDict.xml PoolFileCatalog.xml + # export ATHENA_CORE_NUMBER=6 + #--skipEvents 44 + + Reco_tf.py \ + --CA 'all:True' --autoConfiguration 'everything' \ + --conditionsTag 'all:OFLCOND-MC15c-SDR-14-05' \ + --geometryVersion 'all:ATLAS-P2-RUN4-03-00-00' \ + --multithreaded 'True' \ + --steering 'doRAWtoALL' \ + --digiSteeringConf 'StandardInTimeOnlyTruth' \ + --postInclude 'all:PyJobTransforms.UseFrontier' \ + --preInclude 'all:Campaigns.PhaseIIPileUp200' 'InDetConfig.ConfigurationHelpers.OnlyTrackingPreInclude' 'InDetGNNTracking.InDetGNNTrackingConfig.gnnReaderValidation' \ + --inputRDOFile ${RDO_FILENAME} \ + --outputAODFile 'test.aod.gnnreader.debug.root' \ + --maxEvents 1 2>&1 | tee log.gnnreader_debug.txt +} +```` diff --git a/InnerDetector/InDetGNNTracking/python/InDetGNNTrackingConfig.py b/InnerDetector/InDetGNNTracking/python/InDetGNNTrackingConfig.py index f441e3305b92749eae972c168cd016f8cd3afab0..28ed827463f52f0c7bcbb039baf8732118334937 100644 --- a/InnerDetector/InDetGNNTracking/python/InDetGNNTrackingConfig.py +++ b/InnerDetector/InDetGNNTracking/python/InDetGNNTrackingConfig.py @@ -2,9 +2,12 @@ # Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration # +from pathlib import Path + from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator from AthenaConfiguration.ComponentFactory import CompFactory + def DumpObjectsCfg( flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs): ''' @@ -19,11 +22,85 @@ def DumpObjectsCfg( ) kwargs.setdefault("NtupleFileName", "/DumpObjects/") - #kwargs.setdefault("NtupleDirectoryName", "3-0-0") kwargs.setdefault("NtupleTreeName", "GNN4ITk") kwargs.setdefault("rootFile", True) acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs)) + return acc + +def GNNTrackFinderToolCfg(flags, name='GNNTrackFinderTool', **kwargs): + """Sets up a GNNTrackFinderTool tool and returns it.""" + acc = ComponentAccumulator() + + ### parameters for GNNTrackFinderTool + kwargs.setdefault("embeddingDim", 8) + kwargs.setdefault("rVal", 1.7) + kwargs.setdefault("knnVal", 500) + kwargs.setdefault("filterCut", 0.21) + kwargs.setdefault("inputMLModelDir", "TrainedMLModels4ITk") + kwargs.setdefault("UseCUDA", False) + + from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg + kwargs.setdefault("Embedding", acc.popToolsAndMerge( + OnnxRuntimeInferenceToolCfg(flags, Path("TrainedMLModels4ITk") / "embedding.onnx") + )) + kwargs.setdefault("Filtering", acc.popToolsAndMerge( + OnnxRuntimeInferenceToolCfg(flags, Path("TrainedMLModels4ITk") / "filtering.onnx") + )) + kwargs.setdefault("GNN", acc.popToolsAndMerge( + OnnxRuntimeInferenceToolCfg(flags, Path("TrainedMLModels4ITk") / "gnn.onnx") + )) + + acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs)) + return acc + + +def SeedFitterToolCfg(flags, name="SeedFitterTool", **kwargs): + """Sets up a SeedFitter tool and returns it.""" + acc = ComponentAccumulator() + + ### parameters for SeedFitter + acc.setPrivateTools(CompFactory.InDet.SeedFitterTool(name, **kwargs)) + return acc + +def GNNTrackReaderToolCfg(flags, name='GNNTrackReaderTool', **kwargs): + """Set up a GNNTrackReader tool and return it.""" + acc = ComponentAccumulator() + + ### parameters for GNNTrackReader + kwargs.setdefault("inputTracksDir", "gnntracks") + kwargs.setdefault("csvPrefix", "track") + + acc.setPrivateTools(CompFactory.InDet.GNNTrackReaderTool(name, **kwargs)) + return acc + +def GNNTrackMakerCfg(flags, name="GNNTrackMaker", **kwargs): + """Sets up a GNNTrackMaker algorithm and returns it.""" + + acc = ComponentAccumulator() + + ## tools + SeedFitterTool = acc.popToolsAndMerge(SeedFitterToolCfg(flags)) + kwargs.setdefault("SeedFitterTool", SeedFitterTool) + + from TrkConfig.CommonTrackFitterConfig import ITkTrackFitterCfg + InDetTrackFitter = acc.popToolsAndMerge(ITkTrackFitterCfg(flags)) + kwargs.setdefault("TrackFitter", InDetTrackFitter) + if flags.Tracking.GNN.useTrackFinder: + InDetGNNTrackFinderTool = acc.popToolsAndMerge(GNNTrackFinderToolCfg(flags)) + kwargs.setdefault("GNNTrackFinderTool", InDetGNNTrackFinderTool) + kwargs.setdefault("GNNTrackReaderTool", None) + kwargs.setdefault("UseTrackFinder", True) + kwargs.setdefault("UseTrackReader", False) + elif flags.Tracking.GNN.useTrackReader: + InDetGNNTrackReader = acc.popToolsAndMerge(GNNTrackReaderToolCfg(flags)) + kwargs.setdefault("GNNTrackReaderTool", InDetGNNTrackReader) + kwargs.setdefault("GNNTrackFinderTool", None) + kwargs.setdefault("UseTrackFinder", False) + kwargs.setdefault("UseTrackReader", True) + else: + raise RuntimeError("GNNTrackFinder or GNNTrackReader must be enabled!") + acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs)) return acc diff --git a/InnerDetector/InDetGNNTracking/python/InDetGNNTrackingFlags.py b/InnerDetector/InDetGNNTracking/python/InDetGNNTrackingFlags.py new file mode 100644 index 0000000000000000000000000000000000000000..816cc9cb1b5a312e09b5065fe068dc7157086efa --- /dev/null +++ b/InnerDetector/InDetGNNTracking/python/InDetGNNTrackingFlags.py @@ -0,0 +1,39 @@ +# +# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration +# +from TrkConfig.TrackingPassFlags import createITkTrackingPassFlags +from TrkConfig.TrkConfigFlags import TrackingComponent + + +def createGNNTrackingPassFlags(): + icf = createITkTrackingPassFlags() + icf.extension = "GNN" + icf.doAthenaCluster = True + icf.doAthenaSpacePoint = True + icf.doAthenaSeed = False + icf.doAthenaTrack = False + icf.doAthenaAmbiguityResolution = True + + icf.doGNNTrack = True + + icf.doActsCluster = False + icf.doActsSpacePoint = False + icf.doActsSeed = False + icf.doActsTrack = False + return icf + + +def gnnReaderValidation(flags): + """flags for Reco_tf with CA used in CI tests: use GNNChain during reconstruction""" + flags.Reco.EnableHGTDExtension = False + flags.Tracking.recoChain = [TrackingComponent.GNNChain] + flags.Tracking.GNN.useTrackReader = True + flags.Tracking.GNN.useTrackFinder = False + + +def gnnFinderValidation(flags): + """flags for Reco_tf with CA used in CI tests: use GNNChain during reconstruction""" + flags.Reco.EnableHGTDExtension = False + flags.Tracking.recoChain = [TrackingComponent.GNNChain] + flags.Tracking.GNN.useTrackReader = False + flags.Tracking.GNN.useTrackFinder = True diff --git a/InnerDetector/InDetGNNTracking/python/__init__.py b/InnerDetector/InDetGNNTracking/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/InnerDetector/InDetGNNTracking/src/GNNTrackReaderTool.cxx b/InnerDetector/InDetGNNTracking/src/GNNTrackReaderTool.cxx new file mode 100644 index 0000000000000000000000000000000000000000..47a40549c092ea2cea34789ccc65318ee21da24c --- /dev/null +++ b/InnerDetector/InDetGNNTracking/src/GNNTrackReaderTool.cxx @@ -0,0 +1,78 @@ +/* + Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +*/ + +#include "GNNTrackReaderTool.h" + +// Framework include(s). +#include "PathResolver/PathResolver.h" +#include <cmath> +#include <fstream> + +InDet::GNNTrackReaderTool::GNNTrackReaderTool( + const std::string& type, const std::string& name, const IInterface* parent): + base_class(type, name, parent) +{ + declareInterface<IGNNTrackReaderTool>(this); +} + +MsgStream& InDet::GNNTrackReaderTool::dump( MsgStream& out ) const +{ + out<<std::endl; + return dumpevent(out); +} + +std::ostream& InDet::GNNTrackReaderTool::dump( std::ostream& out ) const +{ + return out; +} + +MsgStream& InDet::GNNTrackReaderTool::dumpevent( MsgStream& out ) const +{ + out<<"|---------------------------------------------------------------------|" + <<std::endl; + out<<"| Number output tracks | "<<std::setw(12) + <<" |"<<std::endl; + out<<"|---------------------------------------------------------------------|" + <<std::endl; + return out; +} + +void InDet::GNNTrackReaderTool::getTracks(uint32_t runNumber, uint32_t eventNumber, + std::vector<std::vector<uint32_t> >& trackCandidates) const +{ + std::string fileName = m_inputTracksDir + "/" + m_csvPrefix + "_" \ + + std::to_string(runNumber) + "_" + std::to_string(eventNumber) + ".csv"; + + trackCandidates.clear(); + std::ifstream csvFile(fileName); + + if (!csvFile.is_open()) { + ATH_MSG_ERROR("Cannot open file " << fileName); + return; + } else { + ATH_MSG_INFO("File " << fileName << " is opened."); + } + + std::string line; + while(std::getline(csvFile, line)){ + std::stringstream lineStream(line); + std::string cell; + std::vector<uint32_t> trackCandidate; + while(std::getline(lineStream, cell, ',')) + { + uint32_t cellId = 0; + try { + cellId = std::stoi(cell); + } catch (const std::invalid_argument& ia) { + std::cout << "Invalid argument: " << ia.what() << " for cell " << cell << std::endl; + continue; + } + + if (std::find(trackCandidate.begin(), trackCandidate.end(), cellId) == trackCandidate.end()) { + trackCandidate.push_back(cellId); + } + } + trackCandidates.push_back(std::move(trackCandidate)); + } +} \ No newline at end of file diff --git a/InnerDetector/InDetGNNTracking/src/GNNTrackReaderTool.h b/InnerDetector/InDetGNNTracking/src/GNNTrackReaderTool.h new file mode 100644 index 0000000000000000000000000000000000000000..8c9953dfb407471a3801320e1bbedae1822dab2f --- /dev/null +++ b/InnerDetector/InDetGNNTracking/src/GNNTrackReaderTool.h @@ -0,0 +1,74 @@ +/* + Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef GNNTrackReaderTool_H +#define GNNTrackReaderTool_H + +// System include(s). +#include <list> +#include <iostream> +#include <memory> + +#include "AthenaBaseComps/AthAlgTool.h" +#include "IGNNTrackReaderTool.h" + +class MsgStream; + +namespace InDet{ + /** + * @class InDet::GNNTrackReaderTool + * @brief InDet::GNNTrackReaderTool is a tool that reads track candidates + * of each event from a CSV file named as "track_runNumber_eventNumber.csv" + * or with a customized pre-fix. + * The directory, inputTracksDir, contains CSV files for all events. + * 1) If the inputTracksDir is not specified, the tool will read the CSV files + * from the current directory. + * 2) If the corresponding CSV file does not exist for a given event (runNumber, eventNumber), + * the tool will print an error message and return an empty list. + * @author xiangyang.ju@cern.ch + */ + + class GNNTrackReaderTool: public extends<AthAlgTool, IGNNTrackReaderTool> + { + public: + GNNTrackReaderTool(const std::string& type, const std::string& name, const IInterface* parent); + + /////////////////////////////////////////////////////////////////// + // Main methods for local track finding asked by the ISiMLTrackFinder + /////////////////////////////////////////////////////////////////// + + /** + * @brief Get track candidates from a CSV file named by runNumber and eventNumber. + * @param runNumber run number of the event. + * @param eventNumber event number of the event. + * @param tracks a list of track candidates in terms of spacepoint indices as read from the CSV file. + */ + virtual void getTracks(uint32_t runNumber, uint32_t eventNumber, + std::vector<std::vector<uint32_t> >& tracks) const override final; + + /////////////////////////////////////////////////////////////////// + // Print internal tool parameters and status + /////////////////////////////////////////////////////////////////// + virtual MsgStream& dump(MsgStream& out) const override; + virtual std::ostream& dump(std::ostream& out) const override; + + protected: + + GNNTrackReaderTool() = delete; + GNNTrackReaderTool(const GNNTrackReaderTool&) =delete; + GNNTrackReaderTool &operator=(const GNNTrackReaderTool&) = delete; + + StringProperty m_inputTracksDir{this, "inputTracksDir", "."}; + StringProperty m_csvPrefix{this, "csvPrefix", "track"}; + + MsgStream& dumpevent (MsgStream& out) const; + + }; + + MsgStream& operator << (MsgStream& ,const GNNTrackReaderTool&); + std::ostream& operator << (std::ostream&,const GNNTrackReaderTool&); + +} + +#endif diff --git a/InnerDetector/InDetGNNTracking/src/IGNNTrackReaderTool.h b/InnerDetector/InDetGNNTracking/src/IGNNTrackReaderTool.h new file mode 100644 index 0000000000000000000000000000000000000000..15f32b3e97ac19bfccd4ee5fcf29ec0d6b4b5cdd --- /dev/null +++ b/InnerDetector/InDetGNNTracking/src/IGNNTrackReaderTool.h @@ -0,0 +1,77 @@ +/* + Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration +*/ + +#ifndef IGNNTrackReaderTool_H +#define IGNNTrackReaderTool_H + +#include <list> +#include <vector> + +#include "GaudiKernel/AlgTool.h" + +class MsgStream; + +namespace InDet { + + /** + * @interface IGNNTrackReaderTool + * @brief Read GNN Track candidates from a CSV file + * @author xiangyang.ju@cern.ch + */ + class IGNNTrackReaderTool : virtual public IAlgTool + { + public: + /////////////////////////////////////////////////////////////////// + /// @name InterfaceID + /////////////////////////////////////////////////////////////////// + //@{ + DeclareInterfaceID(IGNNTrackReaderTool, 1, 0); + //@} + + /////////////////////////////////////////////////////////////////// + /// Main methods for reading track candidates + /////////////////////////////////////////////////////////////////// + /** + * @brief Get track candidates from a CSV file named by runNumber and eventNumber. + * @param runNumber run number of the event. + * @param eventNumber event number of the event. + * @param tracks a list of track candidates in terms of spacepoint indices as read from the CSV file. + */ + virtual void getTracks(uint32_t runNumber, uint32_t eventNumber, + std::vector<std::vector<uint32_t> >& tracks) const =0; + + + /////////////////////////////////////////////////////////////////// + // Print internal tool parameters and status + /////////////////////////////////////////////////////////////////// + + virtual MsgStream& dump(MsgStream& out) const=0; + virtual std::ostream& dump(std::ostream& out) const=0; + }; + + + /////////////////////////////////////////////////////////////////// + // Overload of << operator for MsgStream and std::ostream + /////////////////////////////////////////////////////////////////// + + MsgStream& operator << (MsgStream& ,const IGNNTrackReaderTool&); + std::ostream& operator << (std::ostream&,const IGNNTrackReaderTool&); + + /////////////////////////////////////////////////////////////////// + // Overload of << operator MsgStream + /////////////////////////////////////////////////////////////////// + + inline MsgStream& operator << (MsgStream& sl,const IGNNTrackReaderTool& se) { + return se.dump(sl); + } + /////////////////////////////////////////////////////////////////// + // Overload of << operator std::ostream + /////////////////////////////////////////////////////////////////// + + inline std::ostream& operator << (std::ostream& sl,const IGNNTrackReaderTool& se) { + return se.dump(sl); + } +} + +#endif \ No newline at end of file diff --git a/InnerDetector/InDetGNNTracking/src/SeedFitterTool.cxx b/InnerDetector/InDetGNNTracking/src/SeedFitterTool.cxx index 6f5493612aa37aadd89b5ef065189e7ac4a8b99e..fb9d7eab12c8f38193e43520a898f39e45b37e78 100644 --- a/InnerDetector/InDetGNNTracking/src/SeedFitterTool.cxx +++ b/InnerDetector/InDetGNNTracking/src/SeedFitterTool.cxx @@ -18,9 +18,9 @@ InDet::SeedFitterTool::SeedFitterTool( std::unique_ptr<const Trk::TrackParameters> InDet::SeedFitterTool::fit( const std::vector<const Trk::SpacePoint*>& spacePoints) const { - //// @todo maybe use a even simplier version to estimate track parameters. + //// @todo improve the estimate track parameters. //// Taken from the following link: - //// https://gitlab.cern.ch/xju/athena/-/blob/master/InnerDetector/InDetRecTools/SiTrackMakerTool_xk/src/SiTrackMaker_xk.cxx#L851-993 + //// https://gitlab.cern.ch/atlas/athena/-/blob/main/InnerDetector/InDetRecTools/SiTrackMakerTool_xk/src/SiTrackMaker_xk.cxx#L851-993 //// Only the first 3 spacepoints are used. //// the fitting was not stable. Now require at least 5 SPs. @@ -29,10 +29,10 @@ std::unique_ptr<const Trk::TrackParameters> InDet::SeedFitterTool::fit( return nullptr; } - double track_paras[9]; - + /// get the first cluster on the first hit const Trk::PrepRawData* cl = spacePoints[0]->clusterList().first; if(!cl) return nullptr; + /// and use the surface from this cluster as our reference plane const Trk::PlaneSurface* pla = static_cast<const Trk::PlaneSurface*>(&cl->detectorElement()->surface()); if(!pla) return nullptr; @@ -64,32 +64,26 @@ std::unique_ptr<const Trk::TrackParameters> InDet::SeedFitterTool::fit( // A,B are slope and intercept of the straight line in the u,v plane // connecting the three points. double A = v2/(u2-u1); - double B = 2.*(v2-A*u2); - double C = B/std::sqrt(1.+A*A); // curvature estimate. (2R)²=(1+A²)/b² => 1/2R = b/sqrt(1+A²) = B / sqrt(1+A²). - double T; // estimate of the track dz/dr (1/tanTheta) - std::abs(C) > 1.e-6 ? T = (z2*C)/std::asin(C*std::sqrt(rn)) : T = z2/std::sqrt(rn); + double T = z2*sqrt(r2); const Amg::Transform3D& Tp = pla->transform(); + /// local x of the surface in the global frame double Ax[3] = {Tp(0,0),Tp(1,0),Tp(2,0)}; + /// local y of the surface in the global frame double Ay[3] = {Tp(0,1),Tp(1,1),Tp(2,1)}; + /// centre of the surface in the global frame double D [3] = {Tp(0,3),Tp(1,3),Tp(2,3)}; - + /// location of the first SP w.r.t centre of the surface double d[3] = {x0-D[0],y0-D[1],z0-D[2]}; + double track_paras[5]; + /// local x, y - coordinates of the first SP in the local frame track_paras[0] = d[0]*Ax[0]+d[1]*Ax[1]+d[2]*Ax[2]; track_paras[1] = d[0]*Ay[0]+d[1]*Ay[1]+d[2]*Ay[2]; - - // use constant magnetic field to estimate theta and phi - double magnetic_field = 0.002; // kT - track_paras[2] = std::atan2(y2,x2); + track_paras[2] = std::atan2(b+a*A, a-b*A); track_paras[3] = std::atan2(1.,T) ; - track_paras[5] = -C / (0.3 * magnetic_field); // inverse momentum in GeV^-1 - - track_paras[4] = track_paras[5]/std::sqrt(1.+T*T); // qoverp from qoverpt and theta - track_paras[6] = x0; - track_paras[7] = y0; - track_paras[8] = z0; + track_paras[4] = 0.001/std::sqrt(1.+T*T); // qoverp from qoverpt and theta ATH_MSG_DEBUG( "linearConformalMapping: \n" << \ @@ -122,14 +116,6 @@ std::unique_ptr<const Trk::TrackParameters> InDet::SeedFitterTool::fit( return trkParameters; } -StatusCode InDet::SeedFitterTool::initialize() { - return StatusCode::SUCCESS; -} - -StatusCode InDet::SeedFitterTool::finalize() { - StatusCode sc = AlgTool::finalize(); - return sc; -} MsgStream& InDet::SeedFitterTool::dump( MsgStream& out ) const { diff --git a/InnerDetector/InDetGNNTracking/src/SeedFitterTool.h b/InnerDetector/InDetGNNTracking/src/SeedFitterTool.h index ac66e6dd7901c64ef4cc9f5ba5e0e00f3933dbff..93eec7cf094f48780125c815513adb1c2cd96b3b 100644 --- a/InnerDetector/InDetGNNTracking/src/SeedFitterTool.h +++ b/InnerDetector/InDetGNNTracking/src/SeedFitterTool.h @@ -26,9 +26,6 @@ namespace InDet { // Public methods: /////////////////////////////////////////////////////////////////// SeedFitterTool(const std::string&,const std::string&,const IInterface*); - virtual ~SeedFitterTool () = default; - virtual StatusCode initialize() override; - virtual StatusCode finalize () override; /////////////////////////////////////////////////////////////////// // Methods to convert spacepoints to Trk::Track diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx index e07013addd91ad7cc6c8f05f482a3a944717dcd3..8a21b0d0290632c5abd7664ccdaf8957b58f0ef6 100644 --- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx +++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.cxx @@ -24,11 +24,6 @@ StatusCode InDet::SiGNNTrackFinderTool::initialize() { return StatusCode::SUCCESS; } -StatusCode InDet::SiGNNTrackFinderTool::finalize() { - StatusCode sc = AlgTool::finalize(); - return sc; -} - MsgStream& InDet::SiGNNTrackFinderTool::dump( MsgStream& out ) const { out<<std::endl; diff --git a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h index 3fca4b225e9784e64db5ad9da983f30a283ed70a..64dbf4af558667151cc5d842612ab43fa77e5aac 100644 --- a/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h +++ b/InnerDetector/InDetGNNTracking/src/SiGNNTrackFinderTool.h @@ -31,9 +31,7 @@ namespace InDet{ { public: SiGNNTrackFinderTool(const std::string& type, const std::string& name, const IInterface* parent); - virtual ~SiGNNTrackFinderTool() = default; virtual StatusCode initialize() override; - virtual StatusCode finalize() override; /////////////////////////////////////////////////////////////////// // Main methods for local track finding asked by the ISiMLTrackFinder diff --git a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx index d7caaa467f6fe633ef9ae313a97dfff712c492ea..7e024c9c7ebdb8f9900c3c9a3b32b5c47914a765 100644 --- a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx +++ b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.cxx @@ -23,10 +23,23 @@ StatusCode InDet::SiSPGNNTrackMaker::initialize() ATH_CHECK(m_outputTracksKey.initialize()); - ATH_CHECK(m_gnnTrackFinder.retrieve()); ATH_CHECK(m_trackFitter.retrieve()); ATH_CHECK(m_seedFitter.retrieve()); + if (m_useTrackFinder == m_useTrackReader) { + ATH_MSG_ERROR("Use either track finder or track reader, not both."); + return StatusCode::FAILURE; + } + + if (m_useTrackFinder) { + ATH_MSG_INFO("Use GNN Track Finder"); + ATH_CHECK(m_gnnTrackFinder.retrieve()); + } + if (m_useTrackReader) { + ATH_MSG_INFO("Use GNN Track Reader"); + ATH_CHECK(m_gnnTrackReader.retrieve()); + } + return StatusCode::SUCCESS; } @@ -68,7 +81,14 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const getData(m_SpacePointsSCTKey); std::vector<std::vector<uint32_t> > TT; - ATH_CHECK(m_gnnTrackFinder->getTracks(spacePoints, TT)); + if (m_gnnTrackFinder.isSet()) { + ATH_CHECK(m_gnnTrackFinder->getTracks(spacePoints, TT)); + } else if (m_gnnTrackReader.isSet()) { + m_gnnTrackReader->getTracks(runNumber, eventNumber, TT); + } else { + ATH_MSG_ERROR("Both GNNTrackFinder and GNNTrackReader are not set"); + return StatusCode::FAILURE; + } ATH_MSG_DEBUG("Obtained " << TT.size() << " Tracks"); @@ -80,6 +100,7 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const std::vector<const Trk::PrepRawData*> clusters; std::vector<const Trk::SpacePoint*> trackCandiate; + trackCandiate.reserve(trackIndices.size()); trackCounter++; ATH_MSG_DEBUG("Track " << trackCounter << " has " << trackIndices.size() << " spacepoints"); @@ -89,7 +110,7 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const for (auto& id : trackIndices) { //// for each spacepoint, attach all prepRawData to a list. if (id > spacePoints.size()) { - ATH_MSG_ERROR("SpacePoint index out of range"); + ATH_MSG_WARNING("SpacePoint index "<< id << " out of range: " << spacePoints.size()); continue; } @@ -108,15 +129,23 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const // conformal mapping for track parameters auto trkParameters = m_seedFitter->fit(trackCandiate); if (trkParameters == nullptr) { - ATH_MSG_ERROR("Conformal mapping failed"); + ATH_MSG_WARNING("Conformal mapping failed"); continue; } - bool runOutlierRemoval = true; - Trk::ParticleHypothesis matEffects = Trk::pion; - auto track = m_trackFitter->fit(ctx, clusters, *trkParameters, runOutlierRemoval, matEffects); - if (track) { - outputTracks->push_back(track.release()); + Trk::ParticleHypothesis matEffects = Trk::pion; + // first fit the track with local parameters and without outlier removal. + std::unique_ptr<Trk::Track> track = m_trackFitter->fit(ctx, clusters, *trkParameters, false, matEffects); + if (track != nullptr && track->perigeeParameters() != nullptr) { + // fit the track again with perigee parameters and without outlier removal. + track = std::move(m_trackFitter->fit(ctx, clusters, *track->perigeeParameters(), false, matEffects)); + if (track != nullptr) { + // finally fit with outlier removal + track = std::move(m_trackFitter->fit(ctx, clusters, *track->perigeeParameters(), true, matEffects)); + if (track != nullptr && track->trackSummary() != nullptr) { + outputTracks->push_back(track.release()); + } + } } } @@ -125,16 +154,6 @@ StatusCode InDet::SiSPGNNTrackMaker::execute(const EventContext& ctx) const } -/////////////////////////////////////////////////////////////////// -// Finalize -/////////////////////////////////////////////////////////////////// - -StatusCode InDet::SiSPGNNTrackMaker::finalize() -{ - msg(MSG::INFO)<<(*this)<<endmsg; - return StatusCode::SUCCESS; -} - /////////////////////////////////////////////////////////////////// // Overload of << operator MsgStream /////////////////////////////////////////////////////////////////// diff --git a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.h b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.h index 5d052356a3ab72a88a05bede7fc038184614499e..9671d4609bcfec9c7701400619146d1cf377df43 100644 --- a/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.h +++ b/InnerDetector/InDetGNNTracking/src/SiSPGNNTrackMaker.h @@ -18,6 +18,7 @@ #include "InDetRecToolInterfaces/IGNNTrackFinder.h" #include "InDetRecToolInterfaces/ISeedFitter.h" #include "TrkFitterInterfaces/ITrackFitter.h" +#include "IGNNTrackReaderTool.h" namespace Trk { class ITrackFitter; @@ -37,10 +38,8 @@ namespace InDet { class SiSPGNNTrackMaker : public AthReentrantAlgorithm { public: SiSPGNNTrackMaker(const std::string& name, ISvcLocator* pSvcLocator); - virtual ~SiSPGNNTrackMaker() = default; virtual StatusCode initialize() override; virtual StatusCode execute(const EventContext& ctx) const override; - virtual StatusCode finalize() override; /// Make this algorithm clonable. virtual bool isClonable() const override { return true; }; @@ -56,9 +55,9 @@ namespace InDet { //@{ // input containers SG::ReadHandleKey<SpacePointContainer> m_SpacePointsPixelKey{ - this, "SpacePointsPixelName", "PixelSpacePoints"}; + this, "SpacePointsPixelName", "ITkPixelSpacePoints"}; SG::ReadHandleKey<SpacePointContainer> m_SpacePointsSCTKey{ - this, "SpacePointsSCTName", "SCT_SpacePoints"}; + this, "SpacePointsSCTName", "ITkStripSpacePoints"}; //@} // output container @@ -71,18 +70,25 @@ namespace InDet { //@{ /// GNN-based track finding tool that produces track candidates ToolHandle<IGNNTrackFinder> m_gnnTrackFinder{ - this, "GNNTrackFinder", - "InDet::SiGNNTrackFinder/InDetSiGNNTrackFinder", "Track Finder" + this, "GNNTrackFinderTool", + "InDet::SiGNNTrackFinderTool", "Track Finder" }; ToolHandle<ISeedFitter> m_seedFitter{ - this, "SeedFitter", - "InDet::SiSeedFitter/InDetSiSeedFitter", "Seed Fitter" + this, "SeedFitterTool", + "InDet::SiSeedFitterTool", "Seed Fitter" }; /// Track Fitter ToolHandle<Trk::ITrackFitter> m_trackFitter { this, "TrackFitter", "Trk::GlobalChi2Fitter/InDetTrackFitter", "Track Fitter" }; + ToolHandle<IGNNTrackReaderTool> m_gnnTrackReader{ + this, "GNNTrackReaderTool", + "InDet::GNNTrackReaderTool", "Track Reader" + }; + + BooleanProperty m_useTrackFinder{this, "UseTrackFinder", false}; + BooleanProperty m_useTrackReader{this, "UseTrackReader", true}; //@} MsgStream& dumptools(MsgStream& out) const; diff --git a/InnerDetector/InDetGNNTracking/src/components/InDetGNNTracking_entries.cxx b/InnerDetector/InDetGNNTracking/src/components/InDetGNNTracking_entries.cxx index 1283a12bd8ea1a48a9b90c60a314ca0bec422623..5bf2c52a8ad60aada6d1e765836140ca5a3159cd 100644 --- a/InnerDetector/InDetGNNTracking/src/components/InDetGNNTracking_entries.cxx +++ b/InnerDetector/InDetGNNTracking/src/components/InDetGNNTracking_entries.cxx @@ -1,11 +1,13 @@ #include "../SeedFitterTool.h" #include "../SiGNNTrackFinderTool.h" #include "../SiSPGNNTrackMaker.h" +#include "../GNNTrackReaderTool.h" #include "../DumpObjects.h" using namespace InDet; DECLARE_COMPONENT( SeedFitterTool ) DECLARE_COMPONENT( SiGNNTrackFinderTool ) +DECLARE_COMPONENT( GNNTrackReaderTool ) DECLARE_COMPONENT( SiSPGNNTrackMaker ) DECLARE_COMPONENT( DumpObjects ) diff --git a/Tracking/TrkConfig/python/TrackingPassFlags.py b/Tracking/TrkConfig/python/TrackingPassFlags.py index ce588ef382f9c4cde63087e9eb97619acf6ff05a..fcb940b4f287674468801f193472b8141949a357 100644 --- a/Tracking/TrkConfig/python/TrackingPassFlags.py +++ b/Tracking/TrkConfig/python/TrackingPassFlags.py @@ -465,6 +465,9 @@ def createITkTrackingPassFlags(): icf.addFlag("doActsToAthenaTrack", False) icf.addFlag("doActsToAthenaResolvedTrack", False) + # --- flags for GNN tracking + icf.addFlag("doGNNTrack", False) + return icf diff --git a/Tracking/TrkConfig/python/TrkConfigFlags.py b/Tracking/TrkConfig/python/TrkConfigFlags.py index d43a3d8d700c4387105bcd85b9e52be652699288..c244a76756844b08efada88ea6eb60309e094559 100644 --- a/Tracking/TrkConfig/python/TrkConfigFlags.py +++ b/Tracking/TrkConfig/python/TrkConfigFlags.py @@ -53,6 +53,8 @@ class TrackingComponent(FlagEnum): ValidateActsAmbiguityResolution = "ValidateActsAmbiguityResolution" # Benchmarking BenchmarkSpot = "BenchmarkSpot" + # GNN + GNNChain = "GNNChain" def createTrackingConfigFlags(): @@ -309,6 +311,10 @@ def createTrackingConfigFlags(): # SiSPSeededTrackFinder icf.addFlag("Tracking.useITkFTF", False) + # GNN for ITk flags + icf.addFlag("Tracking.GNN.useTrackFinder", False) + icf.addFlag("Tracking.GNN.useTrackReader", False) + # enable reco steps icf.addFlag("Tracking.recoChain", [TrackingComponent.AthenaChain]) @@ -461,6 +467,11 @@ def createTrackingConfigFlags(): icf.addFlagsCategory ("Tracking.ITkActsBenchmarkSpotPass", createActsBenchmarkSpotTrackingPassFlags, prefix=True) + # GNN + from InDetGNNTracking.InDetGNNTrackingFlags import createGNNTrackingPassFlags + icf.addFlagsCategory ("Tracking.ITkGNNPass", + createGNNTrackingPassFlags, prefix=True) + #################################################################### # Vertexing flags