From cd7224e6fbbac5354157577ecfc8fc682b733e7d Mon Sep 17 00:00:00 2001 From: Xiaocong Ai <xiaocong.ai@cern.ch> Date: Fri, 21 Jun 2024 20:14:25 +0200 Subject: [PATCH] update FaserActsKalmanFilterAlg --- .../src/FaserActsKalmanFilterAlg.cxx | 157 +++++++++--------- .../src/FaserActsKalmanFilterAlg.h | 66 +++++--- .../src/TrackFittingFunction.cxx | 126 +++++++++++--- 3 files changed, 233 insertions(+), 116 deletions(-) diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx index 5dd831f95..d6e8bb2eb 100755 --- a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx +++ b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx @@ -3,6 +3,7 @@ */ #include "FaserActsKalmanFilterAlg.h" +#include "FaserActsGeometry/FASERMagneticFieldWrapper.h" // ATHENA #include "GaudiKernel/EventContext.h" @@ -62,14 +63,13 @@ #include <fstream> #include <cmath> -using TrajectoryContainer = std::vector<FaserActsRecMultiTrajectory>; using namespace Acts::UnitLiterals; using Acts::VectorHelpers::eta; using Acts::VectorHelpers::perp; using Acts::VectorHelpers::phi; using Acts::VectorHelpers::theta; -using ThisMeasurement = Acts::Measurement<IndexSourceLink, Acts::BoundIndices, 2>; +using ThisMeasurement = Acts::Measurement<Acts::BoundIndices, 2>; using IdentifierMap = std::map<Identifier, Acts::GeometryIdentifier>; FaserActsKalmanFilterAlg::FaserActsKalmanFilterAlg(const std::string& name, ISvcLocator* pSvcLocator) : @@ -80,12 +80,12 @@ StatusCode FaserActsKalmanFilterAlg::initialize() { ATH_CHECK(m_fieldCondObjInputKey.initialize()); ATH_CHECK(m_trackingGeometryTool.retrieve()); ATH_CHECK(m_trackFinderTool.retrieve()); - ATH_CHECK(m_trajectoryWriterTool.retrieve()); - ATH_CHECK(m_trajectoryStatesWriterTool.retrieve()); +//todo ATH_CHECK(m_trajectoryWriterTool.retrieve()); +//todo ATH_CHECK(m_trajectoryStatesWriterTool.retrieve()); // ATH_CHECK(m_protoTrackWriterTool.retrieve()); ATH_CHECK(m_trackCollection.initialize()); ATH_CHECK(detStore()->retrieve(m_idHelper,"FaserSCT_ID")); - m_fit = makeTrackFitterFunction(m_trackingGeometryTool->trackingGeometry()); + if (m_actsLogging == "VERBOSE") { m_logger = Acts::getDefaultLogger("KalmanFitter", Acts::Logging::VERBOSE); } else if (m_actsLogging == "DEBUG") { @@ -93,6 +93,12 @@ StatusCode FaserActsKalmanFilterAlg::initialize() { } else { m_logger = Acts::getDefaultLogger("KalmanFitter", Acts::Logging::INFO); } + + auto magneticField = std::make_shared<FASERMagneticFieldWrapper>(); + double reverseFilteringMomThreshold = 0.1; //@todo: needs to be validated + //@todo: the multiple scattering and energy loss are originallly turned off + m_fit = makeTrackFitterFunction(m_trackingGeometryTool->trackingGeometry(), magneticField, false, false, reverseFilteringMomThreshold, Acts::FreeToBoundCorrection(false), *m_logger); + return StatusCode::SUCCESS; } @@ -133,62 +139,58 @@ StatusCode FaserActsKalmanFilterAlg::execute() { int n_trackSeeds = initialTrackParameters->size(); - TrajectoryContainer trajectories; - trajectories.reserve(1); - - for (int i = 0; i < n_trackSeeds; ++i) { + auto actsTrackContainer = std::make_shared<Acts::VectorTrackContainer>(); + auto actsTrackStateContainer = std::make_shared<Acts::VectorMultiTrajectory>(); + TrackContainer tracks(actsTrackContainer, actsTrackStateContainer); - Acts::KalmanFitterOptions<MeasurementCalibrator, Acts::VoidOutlierFinder, Acts::VoidReverseFilteringLogic> kfOptions( - geoctx, - magctx, - calctx, - MeasurementCalibrator(measurements->at(i)), - Acts::VoidOutlierFinder(), - Acts::VoidReverseFilteringLogic(), - Acts::LoggerWrapper{*m_logger}, - Acts::PropagatorPlainOptions(), - &(*initialSurface) - ); - kfOptions.multipleScattering = false; - kfOptions.energyLoss = false; + //@todo: the initialSurface should be targetSurface + FaserActsKalmanFilterAlg::GeneralFitterOptions options{ + geoctx, magctx, calctx, &(*initialSurface), + Acts::PropagatorPlainOptions()}; + for (int i = 0; i < n_trackSeeds; ++i) { ATH_MSG_DEBUG("Invoke fitter"); - auto result = (*m_fit)(sourceLinks->at(i), initialTrackParameters->at(i), kfOptions); + auto sls = sourceLinks->at(i); + std::vector<Acts::SourceLink> actsSls; + for(const auto& sl: sls){ + actsSls.push_back(Acts::SourceLink{sl}); + } + auto result = (*m_fit)(actsSls, initialTrackParameters->at(i), options, MeasurementCalibratorAdapter(MeasurementCalibrator(), measurements->at(i)), tracks); int itrack = 0; if (result.ok()) { // Get the fit output object - const auto& fitOutput = result.value(); + //const auto& fitOutput = result.value(); std::unique_ptr<Trk::Track> track = makeTrack(geoctx, result, clusters->at(i)); if (track) { outputTracks->push_back(std::move(track)); } - // The track entry indices container. One element here. - std::vector<size_t> trackTips; - trackTips.reserve(1); - trackTips.emplace_back(fitOutput.lastMeasurementIndex); - // The fitted parameters container. One element (at most) here. - IndexedParams indexedParams; - - if (fitOutput.fittedParameters) { - const auto& params = fitOutput.fittedParameters.value(); - ATH_MSG_VERBOSE("Fitted paramemeters for track " << itrack); - ATH_MSG_VERBOSE(" parameters: " << params); - ATH_MSG_VERBOSE(" position: " << params.position(geoctx).transpose()); - ATH_MSG_VERBOSE(" momentum: " << params.momentum().transpose()); - // Push the fitted parameters to the container - indexedParams.emplace(fitOutput.lastMeasurementIndex, std::move(params)); - } else { - ATH_MSG_DEBUG("No fitted paramemeters for track " << itrack); - } - // Create a SimMultiTrajectory - trajectories.emplace_back(std::move(fitOutput.fittedStates), - std::move(trackTips), std::move(indexedParams)); + // // The track entry indices container. One element here. + // std::vector<size_t> trackTips; + // trackTips.reserve(1); + // trackTips.emplace_back(fitOutput.lastMeasurementIndex); + // // The fitted parameters container. One element (at most) here. + // IndexedParams indexedParams; + + // if (fitOutput.fittedParameters) { + // const auto& params = fitOutput.fittedParameters.value(); + // ATH_MSG_VERBOSE("Fitted paramemeters for track " << itrack); + // ATH_MSG_VERBOSE(" parameters: " << params); + // ATH_MSG_VERBOSE(" position: " << params.position(geoctx).transpose()); + // ATH_MSG_VERBOSE(" momentum: " << params.momentum().transpose()); + // // Push the fitted parameters to the container + // indexedParams.emplace(fitOutput.lastMeasurementIndex, std::move(params)); + // } else { + // ATH_MSG_DEBUG("No fitted paramemeters for track " << itrack); + // } + // // Create a SimMultiTrajectory + // trajectories.emplace_back(std::move(fitOutput.fittedStates), + // std::move(trackTips), std::move(indexedParams)); } else { ATH_MSG_WARNING("Fit failed for track " << itrack << " with error" << result.error()); // Fit failed, but still create a empty truth fit track - trajectories.push_back(FaserActsRecMultiTrajectory()); + // trajectories.push_back(FaserActsRecMultiTrajectory()); } } @@ -221,86 +223,88 @@ Acts::MagneticFieldContext FaserActsKalmanFilterAlg::getMagneticFieldContext(con std::unique_ptr<Trk::Track> FaserActsKalmanFilterAlg::makeTrack(Acts::GeometryContext& geoCtx, TrackFitterResult& fitResult, std::vector<const Tracker::FaserSCT_Cluster*> clusters) const { - using ConstTrackStateProxy = - Acts::detail_lt::TrackStateProxy<IndexSourceLink, 6, true>; std::unique_ptr<Trk::Track> newtrack = nullptr; //Get the fit output object - const auto& fitOutput = fitResult.value(); - if (fitOutput.fittedParameters) { - DataVector<const Trk::TrackStateOnSurface>* finalTrajectory = new DataVector<const Trk::TrackStateOnSurface>{}; + const auto& track = fitResult.value(); + if (track.hasReferenceSurface()) { + std::unique_ptr<DataVector<const Trk::TrackStateOnSurface>> finalTrajectory = std::make_unique<DataVector<const Trk::TrackStateOnSurface>>(); std::vector<std::unique_ptr<const Acts::BoundTrackParameters>> actsSmoothedParam; // Loop over all the output state to create track state - fitOutput.fittedStates.visitBackwards(fitOutput.lastMeasurementIndex, [&](const ConstTrackStateProxy& state) { + for (const auto& state : track.trackStatesReversed()) { auto flag = state.typeFlags(); if (state.referenceSurface().associatedDetectorElement() != nullptr) { // We need to determine the type of state std::bitset<Trk::TrackStateOnSurface::NumberOfTrackStateOnSurfaceTypes> typePattern; - const Trk::TrackParameters *parm; + std::unique_ptr<Trk::TrackParameters> parm; // State is a hole (no associated measurement), use predicted para meters - if (flag[Acts::TrackStateFlag::HoleFlag] == true) { + if (flag.test(Acts::TrackStateFlag::HoleFlag) == true) { + //@todo: ParticleHypothesis? const Acts::BoundTrackParameters actsParam(state.referenceSurface().getSharedPtr(), state.predicted(), - state.predictedCovariance()); - parm = ConvertActsTrackParameterToATLAS(actsParam, geoCtx); + state.predictedCovariance(), Acts::ParticleHypothesis::pion()); + parm = std::move(ConvertActsTrackParameterToATLAS(actsParam, geoCtx)); // auto boundaryCheck = m_boundaryCheckTool->boundaryCheck(*p arm); typePattern.set(Trk::TrackStateOnSurface::Hole); } // The state was tagged as an outlier, use filtered parameters - else if (flag[Acts::TrackStateFlag::OutlierFlag] == true) { + else if (flag.test(Acts::TrackStateFlag::OutlierFlag) == true) { + //@todo: ParticleHypothesis? const Acts::BoundTrackParameters actsParam(state.referenceSurface().getSharedPtr(), - state.filtered(), state.filteredCovariance()); - parm = ConvertActsTrackParameterToATLAS(actsParam, geoCtx); + state.filtered(), state.filteredCovariance(), Acts::ParticleHypothesis::pion()); + parm = std::move(ConvertActsTrackParameterToATLAS(actsParam, geoCtx)); typePattern.set(Trk::TrackStateOnSurface::Outlier); } // The state is a measurement state, use smoothed parameters else { + //@todo: ParticleHypothesis? const Acts::BoundTrackParameters actsParam(state.referenceSurface().getSharedPtr(), - state.smoothed(), state.smoothedCovariance()); + state.smoothed(), state.smoothedCovariance(), Acts::ParticleHypothesis::pion()); actsSmoothedParam.push_back(std::make_unique<const Acts::BoundTrackParameters>(Acts::BoundTrackParameters(actsParam))); // const auto& psurface=actsParam.referenceSurface(); Acts::Vector2 local(actsParam.parameters()[Acts::eBoundLoc0], actsParam.parameters()[Acts::eBoundLoc1]); // const Acts::Vector3 dir = Acts::makeDirectionUnitFromPhiTheta(actsParam.parameters()[Acts::eBoundPhi], actsParam.parameters()[Acts::eBoundTheta]); // auto pos=actsParam.position(tgContext); - parm = ConvertActsTrackParameterToATLAS(actsParam, geoCtx); + parm = std::move(ConvertActsTrackParameterToATLAS(actsParam, geoCtx)); typePattern.set(Trk::TrackStateOnSurface::Measurement); } - Tracker::FaserSCT_ClusterOnTrack* measState = nullptr; - if (state.hasUncalibrated()) { - const Tracker::FaserSCT_Cluster* fitCluster = clusters.at(state.uncalibrated().index()); + std::unique_ptr<Tracker::FaserSCT_ClusterOnTrack> measState = nullptr; + if (state.hasUncalibratedSourceLink()) { + auto sl = state.getUncalibratedSourceLink().template get<IndexSourceLink>(); + const Tracker::FaserSCT_Cluster* fitCluster = clusters.at(sl.index()); if (fitCluster->detectorElement() != nullptr) { - measState = new Tracker::FaserSCT_ClusterOnTrack{ + measState = std::make_unique<Tracker::FaserSCT_ClusterOnTrack>( fitCluster, Trk::LocalParameters{ Trk::DefinedParameter{fitCluster->localPosition()[0], Trk::loc1}, Trk::DefinedParameter{fitCluster->localPosition()[1], Trk::loc2} }, - fitCluster->localCovariance(), + Amg::MatrixX(fitCluster->localCovariance()), m_idHelper->wafer_hash(fitCluster->detectorElement()->identify()) - }; + ); } } double nDoF = state.calibratedSize(); const Trk::FitQualityOnSurface *quality = new Trk::FitQualityOnSurface(state.chi2(), nDoF); - const Trk::TrackStateOnSurface *perState = new Trk::TrackStateOnSurface(measState, parm, quality, nullptr, typePattern); + const Trk::TrackStateOnSurface *perState = new Trk::TrackStateOnSurface(*quality, std::move(measState), std::move(parm), nullptr, typePattern); // If a state was succesfully created add it to the trajectory if (perState) { - finalTrajectory->insert(finalTrajectory->begin(), perState); + (*finalTrajectory).insert((*finalTrajectory).begin(), perState); } - } - return; - }); + } // state has referenceSurface + } //end loop for all states // Create the track using the states const Trk::TrackInfo newInfo(Trk::TrackInfo::TrackFitter::KalmanFitter, Trk::ParticleHypothesis::muon); // Trk::FitQuality* q = nullptr; // newInfo.setTrackFitter(Trk::TrackInfo::TrackFitter::KalmanFitter ); //Mark the fitter as KalmanFitter - newtrack = std::make_unique<Trk::Track>(newInfo, std::move(*finalTrajectory), nullptr); - } + newtrack = std::make_unique<Trk::Track>(newInfo, std::move(finalTrajectory), nullptr); + } // hasFittedParameters + return newtrack; } -const Trk::TrackParameters* +std::unique_ptr<Trk::TrackParameters> FaserActsKalmanFilterAlg ::ConvertActsTrackParameterToATLAS(const Acts::BoundTrackParameters &actsParameter, const Acts::GeometryContext& gctx) const { using namespace Acts::UnitLiterals; std::optional<AmgSymMatrix(5)> cov = std::nullopt; @@ -321,7 +325,6 @@ FaserActsKalmanFilterAlg ::ConvertActsTrackParameterToATLAS(const Acts::BoundTra double tqOverP=actsParameter.get<Acts::eBoundQOverP>()*1_MeV; double p = std::abs(1. / tqOverP); Amg::Vector3D tmom(p * std::cos(tphi) * std::sin(ttheta), p * std::sin(tphi) * std::sin(ttheta), p * std::cos(ttheta)); - const Trk::CurvilinearParameters * curv = new Trk::CurvilinearParameters(pos,tmom,tqOverP>0, cov); - return curv; + return std::make_unique<Trk::CurvilinearParameters>(pos, tmom, tqOverP>0, cov); } diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h index a2b4c7295..cf2f254e4 100755 --- a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h +++ b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h @@ -24,32 +24,36 @@ #include "GeneratorObjects/McEventCollection.h" #include "TrackerSimData/TrackerSimDataCollection.h" #include "TrkTrack/TrackCollection.h" -#include "TrajectoryWriterTool.h" +//todo#include "TrajectoryWriterTool.h" // ACTS #include "Acts/MagneticField/ConstantBField.hpp" #include "Acts/MagneticField/InterpolatedBFieldMap.hpp" -#include "Acts/MagneticField/SharedBField.hpp" #include "Acts/Propagator/EigenStepper.hpp" #include "Acts/Propagator/Propagator.hpp" #include "Acts/Propagator/detail/SteppingLogger.hpp" #include "Acts/TrackFitting/KalmanFitter.hpp" #include "Acts/Geometry/TrackingGeometry.hpp" #include "Acts/EventData/TrackParameters.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackProxy.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" #include "Acts/Geometry/GeometryIdentifier.hpp" #include "Acts/Utilities/Helpers.hpp" #include "Acts/Definitions/Common.hpp" +#include "Acts/Utilities/Result.hpp" + // PACKAGE #include "FaserActsGeometry/FASERMagneticFieldWrapper.h" #include "FaserActsGeometryInterfaces/IFaserActsTrackingGeometryTool.h" #include "FaserActsGeometryInterfaces/IFaserActsExtrapolationTool.h" -#include "FaserActsRecMultiTrajectory.h" #include "FaserActsKalmanFilter/IndexSourceLink.h" #include "FaserActsKalmanFilter/Measurement.h" #include "FaserActsKalmanFilter/ITrackFinderTool.h" //#include "ProtoTrackWriterTool.h" -#include "RootTrajectoryStatesWriterTool.h" +//todo #include "RootTrajectoryStatesWriterTool.h" // STL #include <memory> @@ -70,8 +74,9 @@ namespace TrackerDD { class SCT_DetectorManager; } -using TrajectoryContainer = std::vector<FaserActsRecMultiTrajectory>; -using BField_t = FASERMagneticFieldWrapper; +//@toberemoved +//using TrajectoryContainer = std::vector<FaserActsRecMultiTrajectory>; +//using BField_t = FASERMagneticFieldWrapper; //class FaserActsKalmanFilterAlg : public AthReentrantAlgorithm { class FaserActsKalmanFilterAlg : public AthAlgorithm { @@ -83,26 +88,47 @@ public: // StatusCode execute(const EventContext& ctx) const override; StatusCode execute() override; StatusCode finalize() override; + +//@toberemoved +// using IndexedParams = std::unordered_map<size_t, Acts::BoundTrackParameters>; +// using TrackFitterOptions = +// Acts::KalmanFitterOptions<MeasurementCalibrator, Acts::VoidOutlierFinder, +// Acts::VoidReverseFilteringLogic>; + using TrackContainer = + Acts::TrackContainer<Acts::VectorTrackContainer, + Acts::VectorMultiTrajectory, std::shared_ptr>; + + struct GeneralFitterOptions { + std::reference_wrapper<const Acts::GeometryContext> geoContext; + std::reference_wrapper<const Acts::MagneticFieldContext> magFieldContext; + std::reference_wrapper<const Acts::CalibrationContext> calibrationContext; + const Acts::Surface* referenceSurface = nullptr; + Acts::PropagatorPlainOptions propOptions; + }; - using IndexedParams = std::unordered_map<size_t, Acts::BoundTrackParameters>; - using TrackFitterOptions = - Acts::KalmanFitterOptions<MeasurementCalibrator, Acts::VoidOutlierFinder, - Acts::VoidReverseFilteringLogic>; - using TrackFitterResult = - Acts::Result<Acts::KalmanFitterResult<IndexSourceLink>>; + using TrackFitterResult = Acts::Result<TrackContainer::TrackProxy>; - using TrackParameters = Acts::CurvilinearTrackParameters; + using TrackParameters = Acts::BoundTrackParameters; class TrackFitterFunction { public: virtual ~TrackFitterFunction() = default; - virtual TrackFitterResult operator()(const std::vector<IndexSourceLink>&, - const TrackParameters&, - const TrackFitterOptions&) const = 0; + virtual TrackFitterResult operator()( + const std::vector<Acts::SourceLink> &sourceLinks, + const TrackParameters &initialParameters, + const GeneralFitterOptions& options, + const MeasurementCalibratorAdapter& calibrator, + TrackContainer& tracks + ) const = 0; }; static std::shared_ptr<TrackFitterFunction> makeTrackFitterFunction( - std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry); + std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry, + std::shared_ptr<const Acts::MagneticFieldProvider> magneticField, + bool multipleScattering, bool energyLoss, + double reverseFilteringMomThreshold, + Acts::FreeToBoundCorrection freeToBoundCorrection, + const Acts::Logger& logger); virtual Acts::MagneticFieldContext getMagneticFieldContext(const EventContext& ctx) const; @@ -113,15 +139,15 @@ private: Gaudi::Property<std::string> m_actsLogging {this, "ActsLogging", "VERBOSE"}; std::unique_ptr<Trk::Track> makeTrack(Acts::GeometryContext& tgContext, TrackFitterResult& fitResult, std::vector<const Tracker::FaserSCT_Cluster*> clusters) const; - const Trk::TrackParameters* ConvertActsTrackParameterToATLAS(const Acts::BoundTrackParameters &actsParameter, const Acts::GeometryContext& gctx) const; + std::unique_ptr<Trk::TrackParameters> ConvertActsTrackParameterToATLAS(const Acts::BoundTrackParameters &actsParameter, const Acts::GeometryContext& gctx) const; // Read handle for conditions object to get the field cache SG::ReadCondHandleKey<FaserFieldCacheCondObj> m_fieldCondObjInputKey {this, "FaserFieldCacheCondObj", "fieldCondObj", "Name of the Magnetic Field conditions object key"}; ToolHandle<ITrackFinderTool> m_trackFinderTool {this, "TrackFinderTool", "TruthTrackFinderTool"}; ToolHandle<IFaserActsTrackingGeometryTool> m_trackingGeometryTool {this, "TrackingGeometryTool", "FaserActsTrackingGeometryTool"}; - ToolHandle<TrajectoryWriterTool> m_trajectoryWriterTool {this, "OutputTool", "TrajectoryWriterTool"}; - ToolHandle<RootTrajectoryStatesWriterTool> m_trajectoryStatesWriterTool {this, "RootTrajectoryStatesWriterTool", "RootTrajectoryStatesWriterTool"}; +//todo ToolHandle<TrajectoryWriterTool> m_trajectoryWriterTool {this, "OutputTool", "TrajectoryWriterTool"}; +//todo ToolHandle<RootTrajectoryStatesWriterTool> m_trajectoryStatesWriterTool {this, "RootTrajectoryStatesWriterTool", "RootTrajectoryStatesWriterTool"}; // ToolHandle<ProtoTrackWriterTool> m_protoTrackWriterTool {this, "ProtoTrackWriterTool", "ProtoTrackWriterTool"}; SG::ReadHandleKey<McEventCollection> m_mcEventKey { this, "McEventCollection", "BeamTruthEvent" }; diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx index 460300bac..c8d2bbe95 100644 --- a/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx +++ b/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx @@ -1,34 +1,102 @@ #include "FaserActsKalmanFilterAlg.h" -#include "FaserActsGeometry/FASERMagneticFieldWrapper.h" - +//#include "FaserActsGeometry/FASERMagneticFieldWrapper.h" +//#include "FaserActsKalmanFilter/IndexSourceLink.h" + +#include "Acts/Definitions/Direction.hpp" +#include "Acts/Definitions/TrackParametrization.hpp" +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackStatePropMask.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Propagator/DirectNavigator.hpp" #include "Acts/Propagator/EigenStepper.hpp" #include "Acts/Propagator/Navigator.hpp" #include "Acts/Propagator/Propagator.hpp" #include "Acts/TrackFitting/GainMatrixSmoother.hpp" #include "Acts/TrackFitting/GainMatrixUpdater.hpp" +#include "Acts/TrackFitting/KalmanFitter.hpp" +#include "Acts/Utilities/Delegate.hpp" +#include "Acts/Utilities/Logger.hpp" namespace { -using Updater = Acts::GainMatrixUpdater; -using Smoother = Acts::GainMatrixSmoother; using Stepper = Acts::EigenStepper<>; using Propagator = Acts::Propagator<Stepper, Acts::Navigator>; -using Fitter = Acts::KalmanFitter<Propagator, Updater, Smoother>; +using Fitter = Acts::KalmanFitter<Propagator, Acts::VectorMultiTrajectory>; + + +struct SimpleReverseFilteringLogic { + double momentumThreshold = 0; + + bool doBackwardFiltering( + Acts::VectorMultiTrajectory::ConstTrackStateProxy trackState) const { + auto momentum = fabs(1 / trackState.filtered()[Acts::eBoundQOverP]); + return (momentum <= momentumThreshold); + } +}; struct TrackFitterFunctionImpl : public FaserActsKalmanFilterAlg::TrackFitterFunction { Fitter trackFitter; - TrackFitterFunctionImpl(Fitter &&f) : trackFitter(std::move(f)) {} + Acts::GainMatrixUpdater kfUpdater; + Acts::GainMatrixSmoother kfSmoother; + SimpleReverseFilteringLogic reverseFilteringLogic; + + bool multipleScattering = false; + bool energyLoss = false; + Acts::FreeToBoundCorrection freeToBoundCorrection; + + IndexSourceLink::SurfaceAccessor slSurfaceAccessor; + + TrackFitterFunctionImpl(Fitter &&f, const Acts::TrackingGeometry& trkGeo) : trackFitter(std::move(f)),slSurfaceAccessor{trkGeo} {} + + template <typename calibrator_t> + auto makeKfOptions(const FaserActsKalmanFilterAlg::GeneralFitterOptions& options, + const calibrator_t& calibrator) const { + Acts::KalmanFitterExtensions<Acts::VectorMultiTrajectory> extensions; + extensions.updater.connect< + &Acts::GainMatrixUpdater::operator()<Acts::VectorMultiTrajectory>>( + &kfUpdater); + extensions.smoother.connect< + &Acts::GainMatrixSmoother::operator()<Acts::VectorMultiTrajectory>>( + &kfSmoother); + extensions.reverseFilteringLogic + .connect<&SimpleReverseFilteringLogic::doBackwardFiltering>( + &reverseFilteringLogic); + + Acts::KalmanFitterOptions<Acts::VectorMultiTrajectory> kfOptions( + options.geoContext, options.magFieldContext, options.calibrationContext, + extensions, options.propOptions, &(*options.referenceSurface)); + + kfOptions.referenceSurfaceStrategy = + Acts::KalmanFitterTargetSurfaceStrategy::first; + kfOptions.multipleScattering = multipleScattering; + kfOptions.energyLoss = energyLoss; + kfOptions.freeToBoundCorrection = freeToBoundCorrection; + kfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>( + &calibrator); + kfOptions.extensions.surfaceAccessor + .connect<&IndexSourceLink::SurfaceAccessor::operator()>( + &slSurfaceAccessor); + + return kfOptions; + } FaserActsKalmanFilterAlg::TrackFitterResult operator()( - const std::vector<IndexSourceLink> &sourceLinks, + const std::vector<Acts::SourceLink> &sourceLinks, const FaserActsKalmanFilterAlg::TrackParameters &initialParameters, - const FaserActsKalmanFilterAlg::TrackFitterOptions &options) - const override { - return trackFitter.fit(sourceLinks, initialParameters, options); - }; + const FaserActsKalmanFilterAlg::GeneralFitterOptions& options, + const MeasurementCalibratorAdapter& calibrator, + FaserActsKalmanFilterAlg::TrackContainer& tracks) const override { + const auto kfOptions = makeKfOptions(options, calibrator); + return trackFitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters, + kfOptions, tracks); + } }; } // namespace @@ -36,19 +104,39 @@ struct TrackFitterFunctionImpl std::shared_ptr<FaserActsKalmanFilterAlg::TrackFitterFunction> FaserActsKalmanFilterAlg::makeTrackFitterFunction( - std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry) { - auto magneticField = std::make_shared<FASERMagneticFieldWrapper>(); - auto stepper = Stepper(std::move(magneticField)); - Acts::Navigator::Config cfg{trackingGeometry}; + std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry, + std::shared_ptr<const Acts::MagneticFieldProvider> magneticField, + bool multipleScattering, bool energyLoss, + double reverseFilteringMomThreshold, + Acts::FreeToBoundCorrection freeToBoundCorrection, + const Acts::Logger& logger) { + // Stepper should be copied into the fitters + const Stepper stepper(std::move(magneticField)); + + // Standard fitter + const auto& geo = *trackingGeometry; + Acts::Navigator::Config cfg{std::move(trackingGeometry)}; cfg.resolvePassive = false; cfg.resolveMaterial = true; cfg.resolveSensitive = true; - Acts::Navigator navigator(cfg); - Propagator propagator(std::move(stepper), std::move(navigator)); - Fitter trackFitter(std::move(propagator)); - return std::make_shared<TrackFitterFunctionImpl>(std::move(trackFitter)); + Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator")); + Propagator propagator(stepper, std::move(navigator), + logger.cloneWithSuffix("Propagator")); + Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter")); + + // build the fitter function. owns the fitter object. + auto fitterFunction = std::make_shared<TrackFitterFunctionImpl>( + std::move(trackFitter), geo); + fitterFunction->multipleScattering = multipleScattering; + fitterFunction->energyLoss = energyLoss; + fitterFunction->reverseFilteringLogic.momentumThreshold = + reverseFilteringMomThreshold; + fitterFunction->freeToBoundCorrection = freeToBoundCorrection; + + return fitterFunction; } +// The following are obsolete /* namespace ActsExtrapolationDetail { -- GitLab