diff --git a/Pr/PrPixel/src/VeloKalman.cpp b/Pr/PrPixel/src/VeloKalman.cpp
index 954a57c584278e0d08e9f819afdf0f20939148e0..46fa19761d8f19591c48301fb96e76acaf7ace41 100644
--- a/Pr/PrPixel/src/VeloKalman.cpp
+++ b/Pr/PrPixel/src/VeloKalman.cpp
@@ -24,6 +24,9 @@
 #include "Event/PrVeloTracks.h"
 
 #include "PrKernel/PrPixelFastKalman.h"
+
+#include "VeloKalmanHelpers.h"
+
 /**
  * Velo only Kalman fit
  *
@@ -36,9 +39,9 @@ namespace LHCb::Pr::Velo {
     using TracksVP  = Tracks;
     using TracksFT  = Forward::Tracks;
     using TracksFit = Fitted::Forward::Tracks;
-    using dType     = SIMDWrapper::scalar::types;
-    using I         = dType::int_v;
-    using F         = dType::float_v;
+    using simd      = SIMDWrapper::avx256::types;
+    using I         = simd::int_v;
+    using F         = simd::float_v;
 
   public:
     Kalman( const std::string& name, ISvcLocator* pSvcLocator )
@@ -59,30 +62,27 @@ namespace LHCb::Pr::Velo {
       TracksFit out{&tracksFT, tracksFT.zipIdentifier(), LHCb::getMemResource( evtCtx )};
       m_nbTracksCounter += tracksFT.size();
 
-      for ( int t = 0; t < tracksFT.size(); t += dType::size ) {
-        auto loop_mask = dType::loop_mask( t, tracksFT.size() );
-
-        // TODO: v2 with less gathers (with transposition)
-        // TODO: masked gathers (for avx2/avx512)
+      for ( int t = 0; t < tracksFT.size(); t += simd::size ) {
+        auto loop_mask = simd::loop_mask( t, tracksFT.size() );
 
-        I       idxVP = tracksFT.trackVP<I>( t );
+        const I idxVP = tracksFT.trackVP<I>( t );
         const F qop   = tracksFT.stateQoP<F>( t );
 
-        auto [stateInfo, chi2, nDof] = trackFitter( hits, tracksVP, idxVP, qop );
+        auto [stateInfo, chi2, nDof] = fitBackwardWithMomentum( loop_mask, tracksVP, idxVP, qop, hits, 0 );
 
         // Store tracks in output
-        out.store_trackFT<I>( t, t ); // TODO: index vector (for avx2/avx512)
+        out.store_trackFT<I>( t, simd::indices( t ) );
 
         out.store_QoP<F>( t, qop );
-        out.store_beamStatePos<F>( t, stateInfo.pos );
-        out.store_beamStateDir<F>( t, Vec3<F>( stateInfo.tx, stateInfo.ty, 1. ) );
-        out.store_covX<F>( t, Vec3<F>( stateInfo.covXX, stateInfo.covXTx, stateInfo.covTxTx ) );
-        out.store_covY<F>( t, Vec3<F>( stateInfo.covYY, stateInfo.covYTy, stateInfo.covTyTy ) );
+        out.store_beamStatePos<F>( t, stateInfo.pos() );
+        out.store_beamStateDir<F>( t, stateInfo.dir() );
+        out.store_covX<F>( t, stateInfo.covX() );
+        out.store_covY<F>( t, stateInfo.covY() );
 
         out.store_chi2<F>( t, chi2 / F( nDof ) );
         out.store_chi2nDof<I>( t, nDof );
 
-        out.size() += dType::popcount( loop_mask );
+        out.size() += simd::popcount( loop_mask );
       }
 
       return out;
diff --git a/Pr/PrPixel/src/VeloKalmanHelpers.h b/Pr/PrPixel/src/VeloKalmanHelpers.h
index f85b2664bb5c259f41b0d0110ac2f9ebb968f44c..e4fb554133d019db35e40ba25468091b3893f5f4 100644
--- a/Pr/PrPixel/src/VeloKalmanHelpers.h
+++ b/Pr/PrPixel/src/VeloKalmanHelpers.h
@@ -24,6 +24,9 @@ namespace VeloKalmanParam {
   constexpr float err = 0.0158771324f; // 0.055f / sqrt( 12 ); //TODO: find a solution so this compile with clang
   constexpr float wx  = err * err;
   constexpr float wy  = wx;
+
+  constexpr float scatterSensorParameters[4] = {0.54772f, 1.478845f, 0.626634f, -0.78f};
+  constexpr float scatterFoilParameters[2]   = {1.67f, 20.f};
 } // namespace VeloKalmanParam
 
 template <typename T>
@@ -77,8 +80,8 @@ public:
 };
 
 template <typename M, typename F>
-inline void filter( const M& mask, const F& z, F& x, F& tx, F& covXX, F& covXTx, F& covTxTx, const F& zhit,
-                    const F& xhit, const F& winv ) {
+inline void filter( const M mask, const F z, F& x, F& tx, F& covXX, F& covXTx, F& covTxTx, const F zhit, const F xhit,
+                    const F winv ) {
   // compute prediction
   const F dz    = zhit - z;
   const F predx = x + dz * tx;
@@ -87,8 +90,7 @@ inline void filter( const M& mask, const F& z, F& x, F& tx, F& covXX, F& covXTx,
   const F predcovXTx   = covXTx + dz_t_covTxTx;
   const F dz_t_covXTx  = dz * covXTx;
 
-  const F predcovXX   = covXX + 2.f * dz_t_covXTx + dz * dz_t_covTxTx;
-  const F predcovTxTx = covTxTx;
+  const F predcovXX = covXX + 2.f * dz_t_covXTx + dz * dz_t_covTxTx;
 
   // compute the gain matrix
   const F R   = 1.0f / ( winv + predcovXX );
@@ -101,13 +103,13 @@ inline void filter( const M& mask, const F& z, F& x, F& tx, F& covXX, F& covXTx,
   tx        = select( mask, tx + KTx * r, tx );
 
   // update the covariance matrix
-  covXX   = select( mask, ( 1.f - Kx ) * predcovXX, covXX );
-  covXTx  = select( mask, ( 1.f - Kx ) * predcovXTx, covXTx );
-  covTxTx = select( mask, predcovTxTx - KTx * predcovXTx, covTxTx );
+  covTxTx = select( mask, R * ( covTxTx * ( winv + covXX ) - covXTx * covXTx ), covTxTx );
+  covXTx  = select( mask, winv * KTx, covXTx );
+  covXX   = select( mask, winv * Kx, covXX );
 }
 
 template <typename F, typename I, typename M>
-inline FittedState<F> fitBackward( const M& track_mask, LHCb::Pr::Velo::Tracks& tracks, int t,
+inline FittedState<F> fitBackward( const M track_mask, const LHCb::Pr::Velo::Tracks& tracks, int t,
                                    const LHCb::Pr::Velo::Hits& hits, const int state_id ) {
   I       nHits   = tracks.nHits<I>( t );
   int     maxHits = nHits.hmax( track_mask );
@@ -140,7 +142,7 @@ inline FittedState<F> fitBackward( const M& track_mask, LHCb::Pr::Velo::Tracks&
 }
 
 template <typename F, typename I, typename M>
-inline FittedState<F> fitForward( const M& track_mask, LHCb::Pr::Velo::Tracks& tracks, int t,
+inline FittedState<F> fitForward( const M track_mask, const LHCb::Pr::Velo::Tracks& tracks, int t,
                                   const LHCb::Pr::Velo::Hits& hits, const int state_id ) {
   I       nHits   = tracks.nHits<I>( t );
   int     maxHits = nHits.hmax( track_mask );
@@ -172,3 +174,115 @@ inline FittedState<F> fitForward( const M& track_mask, LHCb::Pr::Velo::Tracks& t
 
   return s;
 }
+
+template <typename M, typename F>
+inline F filterWithMomentum( const M mask, const F z, F& x, F& tx, F& covXX, F& covXTx, F& covTxTx, const F zhit,
+                             const F xhit, const F winv, const F qop ) {
+  // compute prediction
+  const F dz    = zhit - z;
+  const F predx = x + dz * tx;
+
+  const F dz_t_covTxTx = dz * covTxTx;
+  const F dz_t_covXTx  = dz * covXTx;
+
+  // Add noise
+  const F par1 = VeloKalmanParam::scatterSensorParameters[0];
+  const F par2 = VeloKalmanParam::scatterSensorParameters[1];
+  const F par6 = VeloKalmanParam::scatterSensorParameters[2];
+  const F par7 = VeloKalmanParam::scatterSensorParameters[3];
+
+  const F sigTx = par1 * 1e-5f + par2 * abs( qop );
+  const F sigX  = par6 * sigTx * abs( dz );
+  const F corr  = par7;
+
+  const F eXX   = sigX * sigX;
+  const F eXTx  = corr * sigX * sigTx;
+  const F eTxTx = sigTx * sigTx;
+
+  const F predcovXX  = covXX + 2.f * dz_t_covXTx + dz * dz_t_covTxTx + eXX;
+  const F predcovXTx = covXTx + dz_t_covTxTx + eXTx;
+
+  // compute the gain matrix
+  const F R   = 1.0f / ( winv + predcovXX );
+  const F Kx  = predcovXX * R;
+  const F KTx = predcovXTx * R;
+
+  // update the state vector
+  const F r = xhit - predx;
+  x         = select( mask, predx + Kx * r, x );
+  tx        = select( mask, tx + KTx * r, tx );
+
+  // update the covariance matrix
+  /*
+    Linearisation of the expression to avoid absorbtion:
+
+    covTxTx = predcovTxTx - KTx * predcovXTx
+    covTxTx = predcovTxTx - predcovXTx^2 / ( winv + predcovXX )
+    covTxTx = eTxTx + (covTxTx * ( winv + predcovXX ) - predcovXTx^2) / ( winv + predcovXX )
+    covTxTx = eTxTx + (covTxTx * ( winv + predcovXX ) - predcovXTx^2) / ( winv + predcovXX )
+    ((((((
+    predcovXTx^2 = (covXTx + dz*covTxTx + eXTx)^2
+                = covXTx^2 + (dz*covTxTx)^2 + eXTx^2 + 2*covXTx*dz*covTxTx + 2*covXTx*eXTx + 2*dz*covTxTx*eXTx
+    covTxTx * ( winv + predcovXX ) = covTxTx * ( winv + covXX + 2*dz*covXTx + dz^2*covTxTx + eXX )
+                                   = covTxTx * ( winv + covXX) + 2*dz*covXTx*covTxTx + (dz*covTxTx)^2 + eXX*covTxTx
+    ))))))
+    covTxTx = eTxTx + (covTxTx * ( winv + covXX) - covXTx^2 + eXX*covTxTx - eXTx*(eXTx + 2*(covXTx + dz*covTxTx))) / (
+    winv + predcovXX )
+   */
+  covTxTx = select( mask,
+                    eTxTx + R * ( covTxTx * ( winv + covXX ) - covXTx * covXTx + eXX * covTxTx -
+                                  eXTx * ( eXTx + 2.f * ( covXTx + dz_t_covTxTx ) ) ),
+                    covTxTx );
+  covXTx  = select( mask, winv * KTx, covXTx );
+  covXX   = select( mask, winv * Kx, covXX );
+
+  // return the chi2
+  return r * r * R;
+}
+
+template <typename F, typename I, typename M>
+inline std::tuple<FittedState<F>, F, I>
+fitBackwardWithMomentum( const M track_mask, const LHCb::Pr::Velo::Tracks& tracks, const I idxVP, const F qop,
+                         const LHCb::Pr::Velo::Hits& hits, const int state_id ) {
+  const F err = 0.0125f;
+  const F wx  = err * err;
+  const F wy  = wx;
+
+  I       nHits   = tracks.maskgather_nHits<I, I>( idxVP, track_mask, 0 );
+  int     maxHits = nHits.hmax( track_mask );
+  I       idxHit0 = tracks.maskgather_hit<I, I>( idxVP, track_mask, 0, 0 );
+  Vec3<F> dir     = tracks.maskgather_stateDir<F, I>( idxVP, track_mask, 0.f, state_id );
+  Vec3<F> pos     = hits.maskgather_pos<F, I>( idxHit0, track_mask, 0.f );
+
+  FittedState<F> s = FittedState<F>( pos, dir, 100.f, 0.f, 0.0001f, 100.f, 0.f, 0.0001f );
+
+  F chi2 = 0.f;
+
+  for ( int i = 1; i < maxHits; i++ ) {
+    auto    mask   = track_mask && ( I( i ) < nHits );
+    I       idxHit = tracks.maskgather_hit<I, I>( idxVP, mask, I( 0 ), i );
+    Vec3<F> hit    = hits.maskgather_pos<F, I>( idxHit, mask, 0.f );
+
+    chi2 = select(
+        mask,
+        chi2 + filterWithMomentum( mask, s.z, s.x, s.tx, s.covXX, s.covXTx, s.covTxTx, hit.z, hit.x, F( wx ), qop ),
+        chi2 );
+    chi2 = select(
+        mask,
+        chi2 + filterWithMomentum( mask, s.z, s.y, s.ty, s.covYY, s.covYTy, s.covTyTy, hit.z, hit.y, F( wy ), qop ),
+        chi2 );
+    s.z = select( mask, hit.z, s.z );
+  }
+
+  // Convert state at first measurement to state at closest to beam
+  const F t2 = s.dir().rho();
+
+  const F scat2RFFoil =
+      VeloKalmanParam::scatterFoilParameters[0] * ( 1.0 + VeloKalmanParam::scatterFoilParameters[1] * t2 ) * qop * qop;
+  s.covTxTx = s.covTxTx + scat2RFFoil;
+  s.covTyTy = s.covTyTy + scat2RFFoil;
+
+  s.transportTo( s.zBeam() );
+
+  return {s, chi2, 2 * nHits - 4};
+}
\ No newline at end of file