From cd7224e6fbbac5354157577ecfc8fc682b733e7d Mon Sep 17 00:00:00 2001
From: Xiaocong Ai <xiaocong.ai@cern.ch>
Date: Fri, 21 Jun 2024 20:14:25 +0200
Subject: [PATCH] update FaserActsKalmanFilterAlg

---
 .../src/FaserActsKalmanFilterAlg.cxx          | 157 +++++++++---------
 .../src/FaserActsKalmanFilterAlg.h            |  66 +++++---
 .../src/TrackFittingFunction.cxx              | 126 +++++++++++---
 3 files changed, 233 insertions(+), 116 deletions(-)

diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx
index 5dd831f95..d6e8bb2eb 100755
--- a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx
+++ b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.cxx
@@ -3,6 +3,7 @@
 */
 
 #include "FaserActsKalmanFilterAlg.h"
+#include "FaserActsGeometry/FASERMagneticFieldWrapper.h"
 
 // ATHENA
 #include "GaudiKernel/EventContext.h"
@@ -62,14 +63,13 @@
 #include <fstream>
 #include <cmath>
 
-using TrajectoryContainer = std::vector<FaserActsRecMultiTrajectory>;
 
 using namespace Acts::UnitLiterals;
 using Acts::VectorHelpers::eta;
 using Acts::VectorHelpers::perp;
 using Acts::VectorHelpers::phi;
 using Acts::VectorHelpers::theta;
-using ThisMeasurement = Acts::Measurement<IndexSourceLink, Acts::BoundIndices, 2>;
+using ThisMeasurement = Acts::Measurement<Acts::BoundIndices, 2>;
 using IdentifierMap = std::map<Identifier, Acts::GeometryIdentifier>;
 
 FaserActsKalmanFilterAlg::FaserActsKalmanFilterAlg(const std::string& name, ISvcLocator* pSvcLocator) :
@@ -80,12 +80,12 @@ StatusCode FaserActsKalmanFilterAlg::initialize() {
   ATH_CHECK(m_fieldCondObjInputKey.initialize());
   ATH_CHECK(m_trackingGeometryTool.retrieve());
   ATH_CHECK(m_trackFinderTool.retrieve());
-  ATH_CHECK(m_trajectoryWriterTool.retrieve());
-  ATH_CHECK(m_trajectoryStatesWriterTool.retrieve());
+//todo  ATH_CHECK(m_trajectoryWriterTool.retrieve());
+//todo  ATH_CHECK(m_trajectoryStatesWriterTool.retrieve());
 //  ATH_CHECK(m_protoTrackWriterTool.retrieve());
   ATH_CHECK(m_trackCollection.initialize());
   ATH_CHECK(detStore()->retrieve(m_idHelper,"FaserSCT_ID"));
-  m_fit = makeTrackFitterFunction(m_trackingGeometryTool->trackingGeometry());
+
   if (m_actsLogging == "VERBOSE") {
     m_logger = Acts::getDefaultLogger("KalmanFitter", Acts::Logging::VERBOSE);
   } else if (m_actsLogging == "DEBUG") {
@@ -93,6 +93,12 @@ StatusCode FaserActsKalmanFilterAlg::initialize() {
   } else {
     m_logger = Acts::getDefaultLogger("KalmanFitter", Acts::Logging::INFO);
   }
+  
+  auto magneticField = std::make_shared<FASERMagneticFieldWrapper>();
+  double reverseFilteringMomThreshold = 0.1; //@todo: needs to be validated
+  //@todo: the multiple scattering and energy loss are originallly turned off 
+  m_fit = makeTrackFitterFunction(m_trackingGeometryTool->trackingGeometry(), magneticField, false, false, reverseFilteringMomThreshold, Acts::FreeToBoundCorrection(false), *m_logger);
+
   return StatusCode::SUCCESS;
 }
 
@@ -133,62 +139,58 @@ StatusCode FaserActsKalmanFilterAlg::execute() {
 
   int n_trackSeeds = initialTrackParameters->size();
 
-  TrajectoryContainer trajectories;
-  trajectories.reserve(1);
-
-  for (int i = 0; i < n_trackSeeds; ++i) {
+  auto actsTrackContainer = std::make_shared<Acts::VectorTrackContainer>();
+  auto actsTrackStateContainer = std::make_shared<Acts::VectorMultiTrajectory>();
+  TrackContainer tracks(actsTrackContainer, actsTrackStateContainer);
 
-    Acts::KalmanFitterOptions<MeasurementCalibrator, Acts::VoidOutlierFinder, Acts::VoidReverseFilteringLogic> kfOptions(
-        geoctx,
-        magctx,
-        calctx,
-        MeasurementCalibrator(measurements->at(i)),
-        Acts::VoidOutlierFinder(),
-        Acts::VoidReverseFilteringLogic(),
-        Acts::LoggerWrapper{*m_logger},
-        Acts::PropagatorPlainOptions(),
-        &(*initialSurface)
-    );
-    kfOptions.multipleScattering = false;
-    kfOptions.energyLoss = false;
+  //@todo: the initialSurface should be targetSurface
+  FaserActsKalmanFilterAlg::GeneralFitterOptions options{
+      geoctx, magctx, calctx, &(*initialSurface),
+      Acts::PropagatorPlainOptions()};
 
+  for (int i = 0; i < n_trackSeeds; ++i) {
     ATH_MSG_DEBUG("Invoke fitter");
-    auto result = (*m_fit)(sourceLinks->at(i), initialTrackParameters->at(i), kfOptions);
+    auto sls = sourceLinks->at(i);
+    std::vector<Acts::SourceLink> actsSls;
+    for(const auto& sl: sls){
+      actsSls.push_back(Acts::SourceLink{sl}); 
+    }
+    auto result = (*m_fit)(actsSls, initialTrackParameters->at(i), options, MeasurementCalibratorAdapter(MeasurementCalibrator(), measurements->at(i)), tracks);
 
     int itrack = 0;
     if (result.ok()) {
       // Get the fit output object
-      const auto& fitOutput = result.value();
+      //const auto& fitOutput = result.value();
       std::unique_ptr<Trk::Track> track = makeTrack(geoctx, result, clusters->at(i));
        if (track) {
          outputTracks->push_back(std::move(track));
        }
 
-      // The track entry indices container. One element here.
-      std::vector<size_t> trackTips;
-      trackTips.reserve(1);
-      trackTips.emplace_back(fitOutput.lastMeasurementIndex);
-      // The fitted parameters container. One element (at most) here.
-      IndexedParams indexedParams;
-
-      if (fitOutput.fittedParameters) {
-        const auto& params = fitOutput.fittedParameters.value();
-        ATH_MSG_VERBOSE("Fitted paramemeters for track " << itrack);
-        ATH_MSG_VERBOSE("  parameters: " << params);
-        ATH_MSG_VERBOSE("  position: " << params.position(geoctx).transpose());
-        ATH_MSG_VERBOSE("  momentum: " << params.momentum().transpose());
-        // Push the fitted parameters to the container
-        indexedParams.emplace(fitOutput.lastMeasurementIndex, std::move(params));
-      } else {
-        ATH_MSG_DEBUG("No fitted paramemeters for track " << itrack);
-      }
-      // Create a SimMultiTrajectory
-      trajectories.emplace_back(std::move(fitOutput.fittedStates),
-                                std::move(trackTips), std::move(indexedParams));
+    //  // The track entry indices container. One element here.
+    //  std::vector<size_t> trackTips;
+    //  trackTips.reserve(1);
+    //  trackTips.emplace_back(fitOutput.lastMeasurementIndex);
+    //  // The fitted parameters container. One element (at most) here.
+    //  IndexedParams indexedParams;
+
+    //  if (fitOutput.fittedParameters) {
+    //    const auto& params = fitOutput.fittedParameters.value();
+    //    ATH_MSG_VERBOSE("Fitted paramemeters for track " << itrack);
+    //    ATH_MSG_VERBOSE("  parameters: " << params);
+    //    ATH_MSG_VERBOSE("  position: " << params.position(geoctx).transpose());
+    //    ATH_MSG_VERBOSE("  momentum: " << params.momentum().transpose());
+    //    // Push the fitted parameters to the container
+    //    indexedParams.emplace(fitOutput.lastMeasurementIndex, std::move(params));
+    //  } else {
+    //    ATH_MSG_DEBUG("No fitted paramemeters for track " << itrack);
+    //  }
+    //  // Create a SimMultiTrajectory
+    //  trajectories.emplace_back(std::move(fitOutput.fittedStates),
+    //                            std::move(trackTips), std::move(indexedParams));
     } else {
       ATH_MSG_WARNING("Fit failed for track " << itrack << " with error" << result.error());
       // Fit failed, but still create a empty truth fit track
-      trajectories.push_back(FaserActsRecMultiTrajectory());
+    //  trajectories.push_back(FaserActsRecMultiTrajectory());
     }
 
   }
@@ -221,86 +223,88 @@ Acts::MagneticFieldContext FaserActsKalmanFilterAlg::getMagneticFieldContext(con
 
 std::unique_ptr<Trk::Track>
 FaserActsKalmanFilterAlg::makeTrack(Acts::GeometryContext& geoCtx, TrackFitterResult& fitResult, std::vector<const Tracker::FaserSCT_Cluster*> clusters) const {
-  using ConstTrackStateProxy =
-     Acts::detail_lt::TrackStateProxy<IndexSourceLink, 6, true>;
   std::unique_ptr<Trk::Track> newtrack = nullptr;
   //Get the fit output object
-  const auto& fitOutput = fitResult.value();
-  if (fitOutput.fittedParameters) {
-    DataVector<const Trk::TrackStateOnSurface>* finalTrajectory = new DataVector<const Trk::TrackStateOnSurface>{};
+  const auto& track = fitResult.value();
+  if (track.hasReferenceSurface()) {
+    std::unique_ptr<DataVector<const Trk::TrackStateOnSurface>> finalTrajectory = std::make_unique<DataVector<const Trk::TrackStateOnSurface>>();
     std::vector<std::unique_ptr<const Acts::BoundTrackParameters>> actsSmoothedParam;
     // Loop over all the output state to create track state
-    fitOutput.fittedStates.visitBackwards(fitOutput.lastMeasurementIndex, [&](const ConstTrackStateProxy& state) {
+    for (const auto& state : track.trackStatesReversed()) { 
       auto flag = state.typeFlags();
       if (state.referenceSurface().associatedDetectorElement() != nullptr) {
         // We need to determine the type of state
         std::bitset<Trk::TrackStateOnSurface::NumberOfTrackStateOnSurfaceTypes> typePattern;
-        const Trk::TrackParameters *parm;
+        std::unique_ptr<Trk::TrackParameters> parm; 
 
         // State is a hole (no associated measurement), use predicted para meters
-        if (flag[Acts::TrackStateFlag::HoleFlag] == true) {
+        if (flag.test(Acts::TrackStateFlag::HoleFlag) == true) {
+          //@todo: ParticleHypothesis? 
           const Acts::BoundTrackParameters actsParam(state.referenceSurface().getSharedPtr(),
                                                      state.predicted(),
-                                                     state.predictedCovariance());
-          parm = ConvertActsTrackParameterToATLAS(actsParam, geoCtx);
+                                                     state.predictedCovariance(), Acts::ParticleHypothesis::pion());
+          parm = std::move(ConvertActsTrackParameterToATLAS(actsParam, geoCtx));
           // auto boundaryCheck = m_boundaryCheckTool->boundaryCheck(*p arm);
           typePattern.set(Trk::TrackStateOnSurface::Hole);
         }
           // The state was tagged as an outlier, use filtered parameters
-        else if (flag[Acts::TrackStateFlag::OutlierFlag] == true) {
+        else if (flag.test(Acts::TrackStateFlag::OutlierFlag) == true) {
+          //@todo: ParticleHypothesis? 
           const Acts::BoundTrackParameters actsParam(state.referenceSurface().getSharedPtr(),
-                                                     state.filtered(), state.filteredCovariance());
-          parm = ConvertActsTrackParameterToATLAS(actsParam, geoCtx);
+                                                     state.filtered(), state.filteredCovariance(), Acts::ParticleHypothesis::pion());
+          parm = std::move(ConvertActsTrackParameterToATLAS(actsParam, geoCtx));
           typePattern.set(Trk::TrackStateOnSurface::Outlier);
         }
           // The state is a measurement state, use smoothed parameters
         else {
+          //@todo: ParticleHypothesis? 
           const Acts::BoundTrackParameters actsParam(state.referenceSurface().getSharedPtr(),
-                                                     state.smoothed(), state.smoothedCovariance());
+                                                     state.smoothed(), state.smoothedCovariance(), Acts::ParticleHypothesis::pion());
           actsSmoothedParam.push_back(std::make_unique<const Acts::BoundTrackParameters>(Acts::BoundTrackParameters(actsParam)));
           //  const auto& psurface=actsParam.referenceSurface();
           Acts::Vector2 local(actsParam.parameters()[Acts::eBoundLoc0], actsParam.parameters()[Acts::eBoundLoc1]);
           //  const Acts::Vector3 dir = Acts::makeDirectionUnitFromPhiTheta(actsParam.parameters()[Acts::eBoundPhi], actsParam.parameters()[Acts::eBoundTheta]);
           //  auto pos=actsParam.position(tgContext);
-          parm = ConvertActsTrackParameterToATLAS(actsParam, geoCtx);
+          parm = std::move(ConvertActsTrackParameterToATLAS(actsParam, geoCtx));
           typePattern.set(Trk::TrackStateOnSurface::Measurement);
         }
-        Tracker::FaserSCT_ClusterOnTrack* measState = nullptr;
-        if (state.hasUncalibrated()) {
-          const Tracker::FaserSCT_Cluster* fitCluster = clusters.at(state.uncalibrated().index());
+        std::unique_ptr<Tracker::FaserSCT_ClusterOnTrack> measState = nullptr;
+        if (state.hasUncalibratedSourceLink()) {
+          auto sl = state.getUncalibratedSourceLink().template get<IndexSourceLink>();
+          const Tracker::FaserSCT_Cluster* fitCluster = clusters.at(sl.index());
           if (fitCluster->detectorElement() != nullptr) {
-            measState = new Tracker::FaserSCT_ClusterOnTrack{
+            measState = std::make_unique<Tracker::FaserSCT_ClusterOnTrack>(
                 fitCluster,
                 Trk::LocalParameters{
                     Trk::DefinedParameter{fitCluster->localPosition()[0], Trk::loc1},
                     Trk::DefinedParameter{fitCluster->localPosition()[1], Trk::loc2}
                 },
-                fitCluster->localCovariance(),
+                Amg::MatrixX(fitCluster->localCovariance()),
                 m_idHelper->wafer_hash(fitCluster->detectorElement()->identify())
-            };
+            );
           }
         }
         double nDoF = state.calibratedSize();
         const Trk::FitQualityOnSurface *quality = new Trk::FitQualityOnSurface(state.chi2(), nDoF);
-        const Trk::TrackStateOnSurface *perState = new Trk::TrackStateOnSurface(measState, parm, quality, nullptr, typePattern);
+        const Trk::TrackStateOnSurface *perState = new Trk::TrackStateOnSurface(*quality, std::move(measState), std::move(parm), nullptr, typePattern);
         // If a state was succesfully created add it to the trajectory
         if (perState) {
-          finalTrajectory->insert(finalTrajectory->begin(), perState);
+          (*finalTrajectory).insert((*finalTrajectory).begin(), perState);
         }
-      }
-      return;
-    });
+      } // state has referenceSurface
+    } //end loop for all states
 
     // Create the track using the states
     const Trk::TrackInfo newInfo(Trk::TrackInfo::TrackFitter::KalmanFitter, Trk::ParticleHypothesis::muon);
     // Trk::FitQuality* q = nullptr;
     // newInfo.setTrackFitter(Trk::TrackInfo::TrackFitter::KalmanFitter     ); //Mark the fitter as KalmanFitter
-    newtrack = std::make_unique<Trk::Track>(newInfo, std::move(*finalTrajectory), nullptr);
-  }
+    newtrack = std::make_unique<Trk::Track>(newInfo, std::move(finalTrajectory), nullptr);
+  } // hasFittedParameters
+
   return newtrack;
 }
 
-const Trk::TrackParameters*
+std::unique_ptr<Trk::TrackParameters>
 FaserActsKalmanFilterAlg ::ConvertActsTrackParameterToATLAS(const Acts::BoundTrackParameters &actsParameter, const Acts::GeometryContext& gctx) const      {
   using namespace Acts::UnitLiterals;
   std::optional<AmgSymMatrix(5)> cov = std::nullopt;
@@ -321,7 +325,6 @@ FaserActsKalmanFilterAlg ::ConvertActsTrackParameterToATLAS(const Acts::BoundTra
   double tqOverP=actsParameter.get<Acts::eBoundQOverP>()*1_MeV;
   double p = std::abs(1. / tqOverP);
   Amg::Vector3D tmom(p * std::cos(tphi) * std::sin(ttheta), p * std::sin(tphi) * std::sin(ttheta), p * std::cos(ttheta));
-  const Trk::CurvilinearParameters * curv = new Trk::CurvilinearParameters(pos,tmom,tqOverP>0, cov);
-  return curv;
+  return std::make_unique<Trk::CurvilinearParameters>(pos, tmom, tqOverP>0, cov);
 }
 
diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h
index a2b4c7295..cf2f254e4 100755
--- a/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h
+++ b/Tracking/Acts/FaserActsKalmanFilter/src/FaserActsKalmanFilterAlg.h
@@ -24,32 +24,36 @@
 #include "GeneratorObjects/McEventCollection.h"
 #include "TrackerSimData/TrackerSimDataCollection.h"
 #include "TrkTrack/TrackCollection.h"
-#include "TrajectoryWriterTool.h"
+//todo#include "TrajectoryWriterTool.h"
 
 // ACTS
 #include "Acts/MagneticField/ConstantBField.hpp"
 #include "Acts/MagneticField/InterpolatedBFieldMap.hpp"
-#include "Acts/MagneticField/SharedBField.hpp"
 #include "Acts/Propagator/EigenStepper.hpp"
 #include "Acts/Propagator/Propagator.hpp"
 #include "Acts/Propagator/detail/SteppingLogger.hpp"
 #include "Acts/TrackFitting/KalmanFitter.hpp"
 #include "Acts/Geometry/TrackingGeometry.hpp"
 #include "Acts/EventData/TrackParameters.hpp"
+#include "Acts/EventData/TrackContainer.hpp"
+#include "Acts/EventData/TrackProxy.hpp"
+#include "Acts/EventData/VectorTrackContainer.hpp"
+#include "Acts/EventData/VectorMultiTrajectory.hpp"
 #include "Acts/Geometry/GeometryIdentifier.hpp"
 #include "Acts/Utilities/Helpers.hpp"
 #include "Acts/Definitions/Common.hpp"
+#include "Acts/Utilities/Result.hpp"
+
 
 // PACKAGE
 #include "FaserActsGeometry/FASERMagneticFieldWrapper.h"
 #include "FaserActsGeometryInterfaces/IFaserActsTrackingGeometryTool.h"
 #include "FaserActsGeometryInterfaces/IFaserActsExtrapolationTool.h"
-#include "FaserActsRecMultiTrajectory.h"
 #include "FaserActsKalmanFilter/IndexSourceLink.h"
 #include "FaserActsKalmanFilter/Measurement.h"
 #include "FaserActsKalmanFilter/ITrackFinderTool.h"
 //#include "ProtoTrackWriterTool.h"
-#include "RootTrajectoryStatesWriterTool.h"
+//todo #include "RootTrajectoryStatesWriterTool.h"
 
 // STL
 #include <memory>
@@ -70,8 +74,9 @@ namespace TrackerDD {
   class SCT_DetectorManager;
 }
 
-using TrajectoryContainer = std::vector<FaserActsRecMultiTrajectory>;
-using BField_t = FASERMagneticFieldWrapper;
+//@toberemoved
+//using TrajectoryContainer = std::vector<FaserActsRecMultiTrajectory>;
+//using BField_t = FASERMagneticFieldWrapper;
 
 //class FaserActsKalmanFilterAlg : public AthReentrantAlgorithm {
 class FaserActsKalmanFilterAlg : public AthAlgorithm {
@@ -83,26 +88,47 @@ public:
 //  StatusCode execute(const EventContext& ctx) const override;
   StatusCode execute() override;
   StatusCode finalize() override;
+ 
+//@toberemoved
+//  using IndexedParams = std::unordered_map<size_t, Acts::BoundTrackParameters>;
+//  using TrackFitterOptions =
+//    Acts::KalmanFitterOptions<MeasurementCalibrator, Acts::VoidOutlierFinder,
+//                              Acts::VoidReverseFilteringLogic>;
+  using TrackContainer =
+      Acts::TrackContainer<Acts::VectorTrackContainer,
+                           Acts::VectorMultiTrajectory, std::shared_ptr>;
+  
+  struct GeneralFitterOptions {
+    std::reference_wrapper<const Acts::GeometryContext> geoContext;
+    std::reference_wrapper<const Acts::MagneticFieldContext> magFieldContext;
+    std::reference_wrapper<const Acts::CalibrationContext> calibrationContext;
+    const Acts::Surface* referenceSurface = nullptr;
+    Acts::PropagatorPlainOptions propOptions;
+  };
 
-  using IndexedParams = std::unordered_map<size_t, Acts::BoundTrackParameters>;
-  using TrackFitterOptions =
-    Acts::KalmanFitterOptions<MeasurementCalibrator, Acts::VoidOutlierFinder,
-                              Acts::VoidReverseFilteringLogic>;
-  using TrackFitterResult =
-    Acts::Result<Acts::KalmanFitterResult<IndexSourceLink>>;
+  using TrackFitterResult = Acts::Result<TrackContainer::TrackProxy>;
 
-  using TrackParameters = Acts::CurvilinearTrackParameters;
+  using TrackParameters = Acts::BoundTrackParameters;
 
   class TrackFitterFunction {
   public:
     virtual ~TrackFitterFunction() = default;
-    virtual TrackFitterResult operator()(const std::vector<IndexSourceLink>&,
-                                         const TrackParameters&,
-                                         const TrackFitterOptions&) const = 0;
+    virtual TrackFitterResult operator()(
+      const std::vector<Acts::SourceLink> &sourceLinks,
+      const TrackParameters &initialParameters,
+      const GeneralFitterOptions& options,
+      const MeasurementCalibratorAdapter& calibrator,
+      TrackContainer& tracks  
+    ) const = 0;
   };
 
   static std::shared_ptr<TrackFitterFunction> makeTrackFitterFunction(
-    std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry);
+    std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
+    std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
+    bool multipleScattering, bool energyLoss,
+    double reverseFilteringMomThreshold,
+    Acts::FreeToBoundCorrection freeToBoundCorrection,
+    const Acts::Logger& logger);
 
   virtual Acts::MagneticFieldContext getMagneticFieldContext(const EventContext& ctx) const;
 
@@ -113,15 +139,15 @@ private:
 
   Gaudi::Property<std::string> m_actsLogging {this, "ActsLogging", "VERBOSE"};
   std::unique_ptr<Trk::Track> makeTrack(Acts::GeometryContext& tgContext, TrackFitterResult& fitResult, std::vector<const Tracker::FaserSCT_Cluster*> clusters) const;
-  const Trk::TrackParameters* ConvertActsTrackParameterToATLAS(const Acts::BoundTrackParameters &actsParameter, const Acts::GeometryContext& gctx) const;
+  std::unique_ptr<Trk::TrackParameters> ConvertActsTrackParameterToATLAS(const Acts::BoundTrackParameters &actsParameter, const Acts::GeometryContext& gctx) const;
 
   // Read handle for conditions object to get the field cache
   SG::ReadCondHandleKey<FaserFieldCacheCondObj> m_fieldCondObjInputKey {this, "FaserFieldCacheCondObj", "fieldCondObj", "Name of the Magnetic Field conditions object key"};
 
   ToolHandle<ITrackFinderTool> m_trackFinderTool {this, "TrackFinderTool", "TruthTrackFinderTool"};
   ToolHandle<IFaserActsTrackingGeometryTool> m_trackingGeometryTool {this, "TrackingGeometryTool", "FaserActsTrackingGeometryTool"};
-  ToolHandle<TrajectoryWriterTool> m_trajectoryWriterTool {this, "OutputTool", "TrajectoryWriterTool"};
-  ToolHandle<RootTrajectoryStatesWriterTool> m_trajectoryStatesWriterTool {this, "RootTrajectoryStatesWriterTool", "RootTrajectoryStatesWriterTool"};
+//todo  ToolHandle<TrajectoryWriterTool> m_trajectoryWriterTool {this, "OutputTool", "TrajectoryWriterTool"};
+//todo  ToolHandle<RootTrajectoryStatesWriterTool> m_trajectoryStatesWriterTool {this, "RootTrajectoryStatesWriterTool", "RootTrajectoryStatesWriterTool"};
 //  ToolHandle<ProtoTrackWriterTool> m_protoTrackWriterTool {this, "ProtoTrackWriterTool", "ProtoTrackWriterTool"};
 
   SG::ReadHandleKey<McEventCollection> m_mcEventKey { this, "McEventCollection", "BeamTruthEvent" };
diff --git a/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx b/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx
index 460300bac..c8d2bbe95 100644
--- a/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx
+++ b/Tracking/Acts/FaserActsKalmanFilter/src/TrackFittingFunction.cxx
@@ -1,34 +1,102 @@
 #include "FaserActsKalmanFilterAlg.h"
-#include "FaserActsGeometry/FASERMagneticFieldWrapper.h"
-
+//#include "FaserActsGeometry/FASERMagneticFieldWrapper.h"
+//#include "FaserActsKalmanFilter/IndexSourceLink.h"
+
+#include "Acts/Definitions/Direction.hpp"
+#include "Acts/Definitions/TrackParametrization.hpp"
+#include "Acts/EventData/MultiTrajectory.hpp"
+#include "Acts/EventData/TrackContainer.hpp"
+#include "Acts/EventData/TrackStatePropMask.hpp"
+#include "Acts/EventData/VectorMultiTrajectory.hpp"
+#include "Acts/EventData/VectorTrackContainer.hpp"
+#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
+#include "Acts/Geometry/GeometryIdentifier.hpp"
+#include "Acts/Propagator/DirectNavigator.hpp"
 #include "Acts/Propagator/EigenStepper.hpp"
 #include "Acts/Propagator/Navigator.hpp"
 #include "Acts/Propagator/Propagator.hpp"
 #include "Acts/TrackFitting/GainMatrixSmoother.hpp"
 #include "Acts/TrackFitting/GainMatrixUpdater.hpp"
+#include "Acts/TrackFitting/KalmanFitter.hpp"
+#include "Acts/Utilities/Delegate.hpp"
+#include "Acts/Utilities/Logger.hpp"
 
 
 namespace {
 
-using Updater = Acts::GainMatrixUpdater;
-using Smoother = Acts::GainMatrixSmoother;
 using Stepper = Acts::EigenStepper<>;
 using Propagator = Acts::Propagator<Stepper, Acts::Navigator>;
-using Fitter = Acts::KalmanFitter<Propagator, Updater, Smoother>;
+using Fitter = Acts::KalmanFitter<Propagator, Acts::VectorMultiTrajectory>;
+
+
+struct SimpleReverseFilteringLogic {
+  double momentumThreshold = 0;
+
+  bool doBackwardFiltering(
+      Acts::VectorMultiTrajectory::ConstTrackStateProxy trackState) const {
+    auto momentum = fabs(1 / trackState.filtered()[Acts::eBoundQOverP]);
+    return (momentum <= momentumThreshold);
+  }
+};
 
 struct TrackFitterFunctionImpl
     : public FaserActsKalmanFilterAlg::TrackFitterFunction {
   Fitter trackFitter;
 
-  TrackFitterFunctionImpl(Fitter &&f) : trackFitter(std::move(f)) {}
+  Acts::GainMatrixUpdater kfUpdater;
+  Acts::GainMatrixSmoother kfSmoother;
+  SimpleReverseFilteringLogic reverseFilteringLogic;
+
+  bool multipleScattering = false;
+  bool energyLoss = false;
+  Acts::FreeToBoundCorrection freeToBoundCorrection;
+
+  IndexSourceLink::SurfaceAccessor slSurfaceAccessor;
+
+  TrackFitterFunctionImpl(Fitter &&f, const Acts::TrackingGeometry& trkGeo) : trackFitter(std::move(f)),slSurfaceAccessor{trkGeo} {}
+
+  template <typename calibrator_t>
+  auto makeKfOptions(const FaserActsKalmanFilterAlg::GeneralFitterOptions& options,
+                     const calibrator_t& calibrator) const {
+    Acts::KalmanFitterExtensions<Acts::VectorMultiTrajectory> extensions;
+    extensions.updater.connect<
+        &Acts::GainMatrixUpdater::operator()<Acts::VectorMultiTrajectory>>(
+        &kfUpdater);
+    extensions.smoother.connect<
+        &Acts::GainMatrixSmoother::operator()<Acts::VectorMultiTrajectory>>(
+        &kfSmoother);
+    extensions.reverseFilteringLogic
+        .connect<&SimpleReverseFilteringLogic::doBackwardFiltering>(
+            &reverseFilteringLogic);
+
+    Acts::KalmanFitterOptions<Acts::VectorMultiTrajectory> kfOptions(
+        options.geoContext, options.magFieldContext, options.calibrationContext,
+        extensions, options.propOptions, &(*options.referenceSurface));
+
+    kfOptions.referenceSurfaceStrategy =
+        Acts::KalmanFitterTargetSurfaceStrategy::first;
+    kfOptions.multipleScattering = multipleScattering;
+    kfOptions.energyLoss = energyLoss;
+    kfOptions.freeToBoundCorrection = freeToBoundCorrection;
+    kfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>(
+        &calibrator);
+    kfOptions.extensions.surfaceAccessor
+        .connect<&IndexSourceLink::SurfaceAccessor::operator()>(
+            &slSurfaceAccessor);
+
+    return kfOptions;
+  }
 
   FaserActsKalmanFilterAlg::TrackFitterResult operator()(
-      const std::vector<IndexSourceLink> &sourceLinks,
+      const std::vector<Acts::SourceLink> &sourceLinks,
       const FaserActsKalmanFilterAlg::TrackParameters &initialParameters,
-      const FaserActsKalmanFilterAlg::TrackFitterOptions &options)
-  const override {
-    return trackFitter.fit(sourceLinks, initialParameters, options);
-  };
+      const FaserActsKalmanFilterAlg::GeneralFitterOptions& options,
+      const MeasurementCalibratorAdapter& calibrator,
+      FaserActsKalmanFilterAlg::TrackContainer& tracks) const override { 
+    const auto kfOptions = makeKfOptions(options, calibrator);
+    return trackFitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters,
+                      kfOptions, tracks);
+  }
 };
 
 }  // namespace
@@ -36,19 +104,39 @@ struct TrackFitterFunctionImpl
 
 std::shared_ptr<FaserActsKalmanFilterAlg::TrackFitterFunction>
 FaserActsKalmanFilterAlg::makeTrackFitterFunction(
-    std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry) {
-  auto magneticField = std::make_shared<FASERMagneticFieldWrapper>();
-  auto stepper = Stepper(std::move(magneticField));
-  Acts::Navigator::Config cfg{trackingGeometry};
+    std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
+    std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
+    bool multipleScattering, bool energyLoss,
+    double reverseFilteringMomThreshold,
+    Acts::FreeToBoundCorrection freeToBoundCorrection,
+    const Acts::Logger& logger) {
+  // Stepper should be copied into the fitters
+  const Stepper stepper(std::move(magneticField));
+
+  // Standard fitter
+  const auto& geo = *trackingGeometry;
+  Acts::Navigator::Config cfg{std::move(trackingGeometry)};
   cfg.resolvePassive = false;
   cfg.resolveMaterial = true;
   cfg.resolveSensitive = true;
-  Acts::Navigator navigator(cfg);
-  Propagator propagator(std::move(stepper), std::move(navigator));
-  Fitter trackFitter(std::move(propagator));
-  return std::make_shared<TrackFitterFunctionImpl>(std::move(trackFitter));
+  Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator"));
+  Propagator propagator(stepper, std::move(navigator),
+                        logger.cloneWithSuffix("Propagator"));
+  Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter"));
+
+  // build the fitter function. owns the fitter object.
+  auto fitterFunction = std::make_shared<TrackFitterFunctionImpl>(
+      std::move(trackFitter), geo);
+  fitterFunction->multipleScattering = multipleScattering;
+  fitterFunction->energyLoss = energyLoss;
+  fitterFunction->reverseFilteringLogic.momentumThreshold =
+      reverseFilteringMomThreshold;
+  fitterFunction->freeToBoundCorrection = freeToBoundCorrection;
+
+  return fitterFunction;
 }
 
+// The following are obsolete
 /*
 
 namespace ActsExtrapolationDetail {
-- 
GitLab