diff --git a/PhysicsAnalysis/ElectronPhotonID/PhotonVertexSelection/Root/PhotonVertexSelectionTool.cxx b/PhysicsAnalysis/ElectronPhotonID/PhotonVertexSelection/Root/PhotonVertexSelectionTool.cxx
index ee9e67ddbbfafd746741686d8ee428670a32975b..cd417cd62ce0e5a533bd5f29bb343b45f7b91564 100644
--- a/PhysicsAnalysis/ElectronPhotonID/PhotonVertexSelection/Root/PhotonVertexSelectionTool.cxx
+++ b/PhysicsAnalysis/ElectronPhotonID/PhotonVertexSelection/Root/PhotonVertexSelectionTool.cxx
@@ -26,6 +26,24 @@
 
 namespace CP {
 
+  // helper function to get the vertex of a track
+  const xAOD::Vertex* getVertexFromTrack(const xAOD::TrackParticle* track,
+                                         const xAOD::VertexContainer* vertices)
+  {
+    const xAOD::Vertex* found_vx = nullptr;
+    for (const auto& vx : *vertices) {
+      for (const auto& tpLink : vx->trackParticleLinks()) {
+        if (*tpLink == track) {
+          found_vx = vx;
+          break;
+        }
+      }
+      if (found_vx) { break; }
+    }
+
+    return found_vx;
+  }
+
   //____________________________________________________________________________
   PhotonVertexSelectionTool::PhotonVertexSelectionTool(const std::string &name)
   : asg::AsgTool(name)
@@ -246,6 +264,12 @@ namespace CP {
     const xAOD::TrackParticle *tp = nullptr;
     size_t NumberOfTracks = 0;
 
+    const xAOD::VertexContainer* all_vertices = nullptr;
+    if (evtStore()->retrieve(all_vertices, m_vertexContainerName).isFailure()) {
+      ATH_MSG_WARNING("Couldn't retrieve " << m_vertexContainerName << " from TEvent, returning nullptr.");
+      return nullptr;
+    }
+
     for (auto photon: *photons) {
       conversionVertex = photon->vertex();
       if (conversionVertex == nullptr) continue;
@@ -261,8 +285,7 @@ namespace CP {
         tp = xAOD::EgammaHelpers::getOriginalTrackParticleFromGSF(gsfTp);
         if (tp == nullptr) continue;
 
-
-        primary = tp->vertex();
+        primary = getVertexFromTrack(tp, all_vertices);
         if (primary == nullptr) continue;
 
         if (primary->vertexType() == xAOD::VxType::VertexType::PriVtx ||