diff --git a/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx b/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx
index 1adadc66ea3637a3f1e7d51944c462132c4e110f..93832ed995735e4a3421dd4ba7450502140fc6e2 100644
--- a/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx
+++ b/Reconstruction/tauRecTools/Root/TauPi0ScoreCalculator.cxx
@@ -2,38 +2,20 @@
   Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
 */
 
-//-----------------------------------------------------------------------------
-// file:        TauPi0ScoreCalculator.cxx
-// package:     Reconstruction/tauRec
-// authors:     Benedict Winter, Will Davey
-// date:        2012-10-09
-//-----------------------------------------------------------------------------
-
 #include "tauRecTools/TauPi0ScoreCalculator.h"
 #include "tauRecTools/HelperFunctions.h"
 #include "xAODPFlow/PFO.h"
 
-//-------------------------------------------------------------------------
-// Constructor
-//-------------------------------------------------------------------------
+
 
 TauPi0ScoreCalculator::TauPi0ScoreCalculator(const std::string& name) :
-  TauRecToolBase(name),
-  m_mvaBDT(nullptr)
-{
-    declareProperty("BDTWeightFile", m_weightfile);
+    TauRecToolBase(name) {
+  declareProperty("BDTWeightFile", m_weightfile = "");
 }
 
-//-------------------------------------------------------------------------
-// Destructor
-//-------------------------------------------------------------------------
 
-TauPi0ScoreCalculator::~TauPi0ScoreCalculator() 
-{
-}
 
-StatusCode TauPi0ScoreCalculator::initialize() 
-{
+StatusCode TauPi0ScoreCalculator::initialize() {
   std::string weightFile = find_file(m_weightfile);
 
   m_mvaBDT = std::make_unique<tauRecTools::BDTHelper>();
@@ -42,124 +24,119 @@ StatusCode TauPi0ScoreCalculator::initialize()
   return StatusCode::SUCCESS;
 }
 
-StatusCode TauPi0ScoreCalculator::executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer) const
-{
-    //---------------------------------------------------------------------
-    // only run on 1-5 prong taus 
-    //---------------------------------------------------------------------
-    if (pTau.nTracks() == 0 || pTau.nTracks() >5 ) {
-        return StatusCode::SUCCESS;
-    }
-    ATH_MSG_DEBUG("ScoreCalculator: new tau. \tpt = " << pTau.pt() << "\teta = " << pTau.eta() << "\tphi = " << pTau.phi() << "\tnprongs = " << pTau.nTracks());
-
-    //---------------------------------------------------------------------
-    // retrieve neutral PFOs from tau, calculate BDT scores and store them in PFO
-    //---------------------------------------------------------------------
-    for( auto neutralPFO : neutralPFOContainer )
-    {
-      float BDTScore = calculateScore(neutralPFO);
-      neutralPFO->setBDTPi0Score(BDTScore);
-    }
-
-    ATH_MSG_DEBUG("End of TauPi0ScoreCalculator::execute");
 
+
+StatusCode TauPi0ScoreCalculator::executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer) const {
+  // Only run on 1-5 prong taus 
+  if (pTau.nTracks() == 0 || pTau.nTracks() > 5 ) {
     return StatusCode::SUCCESS;
+  }
+
+  // retrieve neutral PFOs from tau, calculate BDT scores and store them in PFO
+  for (xAOD::PFO* neutralPFO : neutralPFOContainer) {
+    float BDTScore = calculateScore(neutralPFO);
+    neutralPFO->setBDTPi0Score(BDTScore);
+  }
+
+  return StatusCode::SUCCESS;
 }
 
-float TauPi0ScoreCalculator::calculateScore(const xAOD::PFO* neutralPFO) const
-{
-    std::map<TString, float> availableVariables; // map of the variable name to its value
-    
-    float Abs_FIRST_ETA = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_FIRST_ETA, Abs_FIRST_ETA) == false) {
-        ATH_MSG_WARNING("Can't find FIRST_ETA. Set it to 0.");
-    }
-    Abs_FIRST_ETA = std::abs(Abs_FIRST_ETA);
-    availableVariables.insert(std::make_pair("Pi0Cluster_Abs_FIRST_ETA", Abs_FIRST_ETA));
-
-    float SECOND_R = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_R, SECOND_R) == false) {
-        ATH_MSG_WARNING("Can't find SECOND_R. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_SECOND_R", SECOND_R));
-
-    float Abs_DELTA_THETA = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_DELTA_THETA, Abs_DELTA_THETA) == false) {
-        ATH_MSG_WARNING("Can't find DELTA_THETA. Set it to 0.");
-    }
-    Abs_DELTA_THETA = std::abs(Abs_DELTA_THETA);
-    availableVariables.insert(std::make_pair("Pi0Cluster_Abs_DELTA_THETA", Abs_DELTA_THETA));
-
-    float CENTER_LAMBDA_helped = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_CENTER_LAMBDA, CENTER_LAMBDA_helped) == false) {
-        ATH_MSG_WARNING("Can't find CENTER_LAMBDA. Set it to 0.");
-    }
-    CENTER_LAMBDA_helped = fmin(CENTER_LAMBDA_helped, 1000.);
-    availableVariables.insert(std::make_pair("Pi0Cluster_CENTER_LAMBDA_helped", CENTER_LAMBDA_helped));
-    
-    float LONGITUDINAL = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_LONGITUDINAL, LONGITUDINAL) == false) {
-        ATH_MSG_WARNING("Can't find LONGITUDINAL. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_LONGITUDINAL", LONGITUDINAL));
-
-    float ENG_FRAC_EM = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_EM, ENG_FRAC_EM) == false) {
-        ATH_MSG_WARNING("Can't find ENG_FRAC_EM. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_ENG_FRAC_EM", ENG_FRAC_EM));
-
-    float ENG_FRAC_CORE = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_CORE, ENG_FRAC_CORE) == false) { 
-        ATH_MSG_WARNING("Can't find ENG_FRAC_CORE. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_ENG_FRAC_CORE", ENG_FRAC_CORE));
-
-    float log_SECOND_ENG_DENS = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_ENG_DENS, log_SECOND_ENG_DENS) == false) { 
-        ATH_MSG_WARNING("Can't find SECOND_ENG_DENS. Set it to 0.");
-    }
-    if(log_SECOND_ENG_DENS==0.) {
-        log_SECOND_ENG_DENS=-50.;
-    }
-    else {
-        log_SECOND_ENG_DENS = log(log_SECOND_ENG_DENS);
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_log_SECOND_ENG_DENS", log_SECOND_ENG_DENS));
-
-    float EcoreOverEEM1 = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_EM1CoreFrac, EcoreOverEEM1) == false) { 
-        ATH_MSG_WARNING("Can't find EM1CoreFrac. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_EcoreOverEEM1", EcoreOverEEM1));
-    
-    int NPosECells_EM1 = 0;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM1, NPosECells_EM1) == false) { 
-        ATH_MSG_WARNING("Can't find NPosECells_EM1. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_NPosECells_EM1", static_cast<float>(NPosECells_EM1)));
-
-    int NPosECells_EM2 = 0;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM2, NPosECells_EM2) == false) { 
-        ATH_MSG_WARNING("Can't find NPosECells_EM2. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_NPosECells_EM2", static_cast<float>(NPosECells_EM2)));
-    
-    float AbsFirstEtaWRTClusterPosition_EM1 = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1, AbsFirstEtaWRTClusterPosition_EM1) == false) { 
-        ATH_MSG_WARNING("Can't find firstEtaWRTClusterPosition_EM1. Set it to 0.");
-    }
-    AbsFirstEtaWRTClusterPosition_EM1 = std::abs(AbsFirstEtaWRTClusterPosition_EM1);
-    availableVariables.insert(std::make_pair("Pi0Cluster_AbsFirstEtaWRTClusterPosition_EM1", AbsFirstEtaWRTClusterPosition_EM1));
-
-    float secondEtaWRTClusterPosition_EM2 = 0.;
-    if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2, secondEtaWRTClusterPosition_EM2) == false) { 
-        ATH_MSG_WARNING("Can't find secondEtaWRTClusterPosition_EM2. Set it to 0.");
-    }
-    availableVariables.insert(std::make_pair("Pi0Cluster_secondEtaWRTClusterPosition_EM2", secondEtaWRTClusterPosition_EM2)); 
-
-    // Calculate BDT score, will be -999 when availableVariables lack variables
-    float score = m_mvaBDT->getGradBoostMVA(availableVariables);
-
-    return score;
+
+
+float TauPi0ScoreCalculator::calculateScore(const xAOD::PFO* neutralPFO) const {
+  
+  std::map<TString, float> availableVariables;
+  
+  float Abs_FIRST_ETA = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_FIRST_ETA, Abs_FIRST_ETA) == false) {
+    ATH_MSG_WARNING("Can't find FIRST_ETA. Set it to 0.");
+  }
+  Abs_FIRST_ETA = std::abs(Abs_FIRST_ETA);
+  availableVariables.insert(std::make_pair("Pi0Cluster_Abs_FIRST_ETA", Abs_FIRST_ETA));
+
+  float SECOND_R = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_R, SECOND_R) == false) {
+    ATH_MSG_WARNING("Can't find SECOND_R. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_SECOND_R", SECOND_R));
+
+  float Abs_DELTA_THETA = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_DELTA_THETA, Abs_DELTA_THETA) == false) {
+    ATH_MSG_WARNING("Can't find DELTA_THETA. Set it to 0.");
+  }
+  Abs_DELTA_THETA = std::abs(Abs_DELTA_THETA);
+  availableVariables.insert(std::make_pair("Pi0Cluster_Abs_DELTA_THETA", Abs_DELTA_THETA));
+
+  float CENTER_LAMBDA_helped = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_CENTER_LAMBDA, CENTER_LAMBDA_helped) == false) {
+    ATH_MSG_WARNING("Can't find CENTER_LAMBDA. Set it to 0.");
+  }
+  CENTER_LAMBDA_helped = fmin(CENTER_LAMBDA_helped, 1000.);
+  availableVariables.insert(std::make_pair("Pi0Cluster_CENTER_LAMBDA_helped", CENTER_LAMBDA_helped));
+  
+  float LONGITUDINAL = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_LONGITUDINAL, LONGITUDINAL) == false) {
+    ATH_MSG_WARNING("Can't find LONGITUDINAL. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_LONGITUDINAL", LONGITUDINAL));
+
+  float ENG_FRAC_EM = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_EM, ENG_FRAC_EM) == false) {
+    ATH_MSG_WARNING("Can't find ENG_FRAC_EM. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_ENG_FRAC_EM", ENG_FRAC_EM));
+
+  float ENG_FRAC_CORE = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_CORE, ENG_FRAC_CORE) == false) { 
+    ATH_MSG_WARNING("Can't find ENG_FRAC_CORE. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_ENG_FRAC_CORE", ENG_FRAC_CORE));
+
+  float log_SECOND_ENG_DENS = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_ENG_DENS, log_SECOND_ENG_DENS) == false) { 
+    ATH_MSG_WARNING("Can't find SECOND_ENG_DENS. Set it to 0.");
+  }
+  if(log_SECOND_ENG_DENS==0.) {
+    log_SECOND_ENG_DENS=-50.;
+  }
+  else {
+    log_SECOND_ENG_DENS = log(log_SECOND_ENG_DENS);
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_log_SECOND_ENG_DENS", log_SECOND_ENG_DENS));
+
+  float EcoreOverEEM1 = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_EM1CoreFrac, EcoreOverEEM1) == false) { 
+    ATH_MSG_WARNING("Can't find EM1CoreFrac. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_EcoreOverEEM1", EcoreOverEEM1));
+  
+  int NPosECells_EM1 = 0;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM1, NPosECells_EM1) == false) { 
+    ATH_MSG_WARNING("Can't find NPosECells_EM1. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_NPosECells_EM1", static_cast<float>(NPosECells_EM1)));
+
+  int NPosECells_EM2 = 0;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM2, NPosECells_EM2) == false) { 
+    ATH_MSG_WARNING("Can't find NPosECells_EM2. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_NPosECells_EM2", static_cast<float>(NPosECells_EM2)));
+  
+  float AbsFirstEtaWRTClusterPosition_EM1 = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1, AbsFirstEtaWRTClusterPosition_EM1) == false) { 
+    ATH_MSG_WARNING("Can't find firstEtaWRTClusterPosition_EM1. Set it to 0.");
+  }
+  AbsFirstEtaWRTClusterPosition_EM1 = std::abs(AbsFirstEtaWRTClusterPosition_EM1);
+  availableVariables.insert(std::make_pair("Pi0Cluster_AbsFirstEtaWRTClusterPosition_EM1", AbsFirstEtaWRTClusterPosition_EM1));
+
+  float secondEtaWRTClusterPosition_EM2 = 0.;
+  if(neutralPFO->attribute(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2, secondEtaWRTClusterPosition_EM2) == false) { 
+    ATH_MSG_WARNING("Can't find secondEtaWRTClusterPosition_EM2. Set it to 0.");
+  }
+  availableVariables.insert(std::make_pair("Pi0Cluster_secondEtaWRTClusterPosition_EM2", secondEtaWRTClusterPosition_EM2)); 
+
+  // Calculate BDT score, will be -999 when availableVariables lack variables
+  float score = m_mvaBDT->getGradBoostMVA(availableVariables);
+
+  return score;
 }
diff --git a/Reconstruction/tauRecTools/Root/TauPi0Selector.cxx b/Reconstruction/tauRecTools/Root/TauPi0Selector.cxx
index 6c80984f9de6337c9e235fa280909587a2196c83..592e021630fc7efbb033204a0cdcb6a724fa04c8 100644
--- a/Reconstruction/tauRecTools/Root/TauPi0Selector.cxx
+++ b/Reconstruction/tauRecTools/Root/TauPi0Selector.cxx
@@ -2,149 +2,120 @@
   Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
 */
 
-//-----------------------------------------------------------------------------
-// file:        TauPi0Selector.cxx
-// package:     Reconstruction/tauRec
-// authors:     Benedict Winter, Will Davey
-// date:        2012-10-09
-//-----------------------------------------------------------------------------
-
 #include "tauRecTools/TauPi0Selector.h"
 
-//-------------------------------------------------------------------------
-// Constructor
-//-------------------------------------------------------------------------
 
-TauPi0Selector::TauPi0Selector(const std::string& name) :
-    TauRecToolBase(name)
-{
-    declareProperty("ClusterEtCut",             m_clusterEtCut);
-    declareProperty("ClusterBDTCut_1prong",     m_clusterBDTCut_1prong);
-    declareProperty("ClusterBDTCut_mprong",     m_clusterBDTCut_mprong);
+
+TauPi0Selector::TauPi0Selector(const std::string& name) : 
+    TauRecToolBase(name) {
+  declareProperty("ClusterEtCut", m_clusterEtCut);
+  declareProperty("ClusterBDTCut_1prong", m_clusterBDTCut_1prong);
+  declareProperty("ClusterBDTCut_mprong", m_clusterBDTCut_mprong);
 }
 
-//-------------------------------------------------------------------------
-// Destructor
-//-------------------------------------------------------------------------
 
-TauPi0Selector::~TauPi0Selector() 
-{
-}
 
-StatusCode TauPi0Selector::executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer) const
-{
-    // decay mode enum
-    auto kDecayModeProto = xAOD::TauJetParameters::PanTau_DecayModeProto;
-    // Clear vector of cell-based pi0 PFO Links. Required when rerunning on xAOD level.
-    pTau.clearProtoPi0PFOLinks();
-
-    //---------------------------------------------------------------------
-    // only run on 1-5 prong taus 
-    //---------------------------------------------------------------------
-    if (pTau.nTracks() == 0 || pTau.nTracks() >5 ) {
-        // Set proto decay mode to "not set". Will be overwritten for taus with 1-5 tracks
-        pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_NotSet);
-        return StatusCode::SUCCESS;
+StatusCode TauPi0Selector::executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer) const {
+  // Clear vector of cell-based pi0 PFO Links. Required when rerunning on xAOD level.
+  pTau.clearProtoPi0PFOLinks();
+  
+  // Decay mode enum
+  auto kDecayModeProto = xAOD::TauJetParameters::PanTau_DecayModeProto;
+
+  // 0, >=5 prong taus will have Mode_NotSet 
+  if (pTau.nTracks() == 0 || pTau.nTracks() >5 ) {
+      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_NotSet);
+      return StatusCode::SUCCESS;
+  }
+
+  // 1-5 prong taus have Mode_Other by default
+  // 1, 3 prong taus will be over-written
+  pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_Other);
+
+  // Apply selection to the pi0, and count the number
+  int nRecoPi0s=0;
+  for (xAOD::PFO* neutralPFO : neutralPFOContainer) {
+    // Set number of pi0s to 0 for all neutral PFOs. Required when rerunning on xAOD level
+    neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, 0);
+
+    // Only consider PFOs within 0.2 cone of the tau axis
+    if (pTau.p4().DeltaR(neutralPFO->p4()) > 0.2) continue;
+    
+    int etaBin = getEtaBin( neutralPFO->cluster(0)->eta() );
+    
+    // Apply Et cut
+    if (neutralPFO->p4().Et() < m_clusterEtCut.at(etaBin)) continue;
+    
+    // Apply BDT score cut
+    double BDTScore = neutralPFO->bdtPi0Score();
+    if ((pTau.nTracks() == 1 && BDTScore < m_clusterBDTCut_1prong.at(etaBin)) || 
+        (pTau.nTracks() > 1 && BDTScore < m_clusterBDTCut_mprong.at(etaBin))) continue;
+
+    int nHitsInEM1 = 0;
+    if (!neutralPFO->attribute(xAOD::PFODetails::cellBased_NHitsInEM1, nHitsInEM1)) { 
+      ATH_MSG_WARNING("Couldn't retrieve nHitsInEM1. Will set it to 0.");
     }
 
-    // Set proto decay mode to "other". Will be overwritten for taus with 1 or 3 tracks
-    pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_Other);
-
-    //---------------------------------------------------------------------
-    // retrieve neutral PFOs from tau. Apply selection and create links to
-    // Pi0NeutralPFOs 
-    //---------------------------------------------------------------------
-    int nRecoPi0s=0;
-    for( auto neutralPFO : neutralPFOContainer )
-    {
-        // Set number of pi0s to 0 for all neutral PFOs. Required when rerunning on xAOD level
-        neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, 0);
-
-        // Get eta bin
-        int etaBin = getPi0Cluster_etaBin( neutralPFO->cluster(0)->eta() );
-
-        // Preselection
-        if(neutralPFO->p4().Et() < m_clusterEtCut.at(etaBin)) continue;
-        if(pTau.p4().DeltaR(neutralPFO->p4()) > 0.2) continue; // TODO: Replace by shrinking cone?
-
-        // BDT Selection
-        float BDTScore = neutralPFO->bdtPi0Score();
-        ATH_MSG_DEBUG("etaBin = " << etaBin 
-                   << ", m_clusterEtCut.at(etaBin) = " <<m_clusterEtCut.at(etaBin) 
-                   << ", m_clusterBDTCut_1prong.at(etaBin) = " << m_clusterBDTCut_1prong.at(etaBin) 
-                   << ", m_clusterBDTCut_mprong.at(etaBin) = " << m_clusterBDTCut_mprong.at(etaBin));
-        if( (pTau.nTracks()==1 && BDTScore < m_clusterBDTCut_1prong.at(etaBin)) 
-                || (pTau.nTracks()>1 && BDTScore < m_clusterBDTCut_mprong.at(etaBin)) ) continue;
-
-        // Set number of pi0s
-        int nHitsInEM1 = 0;
-        if(!neutralPFO->attribute(xAOD::PFODetails::cellBased_NHitsInEM1, nHitsInEM1)) 
-            ATH_MSG_WARNING("Couldn't retrieve nHitsInEM1. Will set it to 0.");
-        if(nHitsInEM1<3){ 
-            neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, 1);
-            nRecoPi0s++;
-        }   
-        else{ 
-            neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, 2);
-            nRecoPi0s+=2;
-        }
-
-        // Set element link to Pi0tagged PFO
-        pTau.addProtoPi0PFOLink(ElementLink< xAOD::PFOContainer > (neutralPFO, neutralPFOContainer));
+    // nHitsInEM1 < 3 --- one pi0; nHitsInEM1 >= 3 --- two pi0s
+    // FIXME: what about nHitsInEM1 == 0 ?
+    if (nHitsInEM1 < 3) { 
+      neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, 1);
+      ++nRecoPi0s;
+    }   
+    else {
+      neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, 2);
+      nRecoPi0s += 2;
     }
 
-    // Set Proto Decay Mode
-    if(pTau.nTracks()==1){
-      if(nRecoPi0s==0)      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_1p0n);
-      else if(nRecoPi0s==1) pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_1p1n);
-      else                  pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_1pXn);
+    pTau.addProtoPi0PFOLink(ElementLink< xAOD::PFOContainer > (neutralPFO, neutralPFOContainer));
+  }
+
+  // Set Proto Decay Mode based on the number charged tracks and pi0s
+  if (pTau.nTracks()==1) {
+    if (nRecoPi0s==0) {
+      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_1p0n);
     }
-    if(pTau.nTracks()==3){
-      if(nRecoPi0s==0)      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_3p0n);
-      else                  pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_3pXn);
+    else if (nRecoPi0s==1) {
+      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_1p1n);
     }
-    
-    return StatusCode::SUCCESS;
+    else {
+      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_1pXn);
+    }
+  }
+  else if (pTau.nTracks()==3) {
+    if (nRecoPi0s==0) {
+      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_3p0n);
+    }
+    else {
+      pTau.setPanTauDetail(kDecayModeProto, xAOD::TauJetParameters::DecayMode::Mode_3pXn);
+    }
+  }
+  
+  return StatusCode::SUCCESS;
 }
 
-int TauPi0Selector::getPi0Cluster_etaBin(double Pi0Cluster_eta) const {
-    int Pi0Cluster_etaBin = -1;
-    double Pi0Cluster_noCorr_ABSeta = std::abs(Pi0Cluster_eta);
 
-    if( Pi0Cluster_noCorr_ABSeta < 0.80 ) Pi0Cluster_etaBin = 0;
-    else if( Pi0Cluster_noCorr_ABSeta < 1.40 ) Pi0Cluster_etaBin = 1;
-    else if( Pi0Cluster_noCorr_ABSeta < 1.50 ) Pi0Cluster_etaBin = 2;
-    else if( Pi0Cluster_noCorr_ABSeta < 1.90 ) Pi0Cluster_etaBin = 3;
-    else Pi0Cluster_etaBin = 4;
-    return Pi0Cluster_etaBin;
-}
 
-TLorentzVector TauPi0Selector::getP4(const xAOD::TauJet& pTau) const
-{
-    TLorentzVector p4(0.,0.,0.,0.);
-    // Add charged PFOs 
-    for( auto chargedPFOLink : pTau.protoChargedPFOLinks() ){
-        if( not chargedPFOLink.isValid() ){
-            ATH_MSG_WARNING("Invalid protoChargedPFOLink");
-            continue;
-        }
-        p4+=(*chargedPFOLink)->p4();
-    }
-    // Add pi0 PFOs
-    for( auto pi0PFOLink : pTau.protoPi0PFOLinks() )
-    {
-        if( not pi0PFOLink.isValid() ){
-            ATH_MSG_WARNING("Invalid protoPi0PFOLink");
-            continue;
-        }
-        const xAOD::PFO* pi0PFO = (*pi0PFOLink);
-        // assign neutral pion mass
-        double mass = 134.9766;
-        double p  = std::sqrt(std::pow(pi0PFO->e(),2) - std::pow(mass,2));
-        double pt = p/std::cosh(pi0PFO->eta());
-        TLorentzVector pi0_corrP4;
-        pi0_corrP4.SetPtEtaPhiM(pt,pi0PFO->eta(),pi0PFO->phi(),mass);
-        p4+=pi0_corrP4;
-    }
-    return p4;
+int TauPi0Selector::getEtaBin(double eta) const {
+  int etaBin = -1;
+  
+  double absEta = std::abs(eta);
+
+  if (absEta < 0.80) {
+    etaBin = 0;
+  }
+  else if (absEta < 1.40) {
+    etaBin = 1;
+  }
+  else if (absEta < 1.50) {
+    etaBin = 2;
+  }
+  else if (absEta < 1.90) {
+    etaBin = 3;
+  }
+  else {
+    etaBin = 4;
+  }
+
+  return etaBin;
 }
diff --git a/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.cxx b/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.cxx
index 2dc44c96932cf4933f60dfc76eb1cd64c2843433..241dc39b8b00e66d38daab8109dd5a9c8517a783 100644
--- a/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.cxx
+++ b/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.cxx
@@ -3,35 +3,20 @@
 */
 
 #ifndef XAOD_ANALYSIS
-//-----------------------------------------------------------------------------
-// file:        TauPi0ClusterCreator.cxx
-// package:     Reconstruction/tauEvent
-// authors:     Benedict Winter, Will Davey, Stephanie Yuen
-// date:        2012-10-09
-//-----------------------------------------------------------------------------
+
+#include "TauPi0ClusterCreator.h"
+#include "tauRecTools/HelperFunctions.h"
 
 #include "CaloUtils/CaloClusterStoreHelper.h"
 #include "FourMomUtils/P4Helpers.h"
 #include "xAODJet/Jet.h"
 
-#include "TauPi0ClusterCreator.h"
-#include "tauRecTools/HelperFunctions.h"
 
-//-------------------------------------------------------------------------
-// Constructor
-//-------------------------------------------------------------------------
 
 TauPi0ClusterCreator::TauPi0ClusterCreator(const std::string& name) :
     TauRecToolBase(name) {
 }
 
-//-------------------------------------------------------------------------
-// Destructor
-//-------------------------------------------------------------------------
-
-TauPi0ClusterCreator::~TauPi0ClusterCreator() 
-{
-}
 
 
 StatusCode TauPi0ClusterCreator::initialize() {
@@ -39,454 +24,463 @@ StatusCode TauPi0ClusterCreator::initialize() {
   return StatusCode::SUCCESS;
 }
 
-//______________________________________________________________________________
-StatusCode TauPi0ClusterCreator::executePi0ClusterCreator(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer,
-							  xAOD::PFOContainer& hadronicClusterPFOContainer,
-							  const xAOD::CaloClusterContainer& pi0ClusterContainer) const
-{
-    // Any tau needs to have PFO vectors. Set empty vectors before nTrack cut
-    std::vector<ElementLink<xAOD::PFOContainer> > empty;
-    pTau.setProtoChargedPFOLinks(empty);
-    pTau.setProtoNeutralPFOLinks(empty);
-    pTau.setProtoPi0PFOLinks(empty);
-    pTau.setHadronicPFOLinks(empty);
-
-    // only run shower subtraction on 1-5 prong taus 
-    if (pTau.nTracks() == 0 || pTau.nTracks() >5) {
-        return StatusCode::SUCCESS;
+
+
+StatusCode TauPi0ClusterCreator::executePi0ClusterCreator(xAOD::TauJet& tau, xAOD::PFOContainer& neutralPFOContainer,
+							  xAOD::PFOContainer& hadronicPFOContainer,
+							  const xAOD::CaloClusterContainer& pi0ClusterContainer) const {
+  // Any tau needs to have PFO vectors. Set empty vectors before nTrack cut
+  std::vector<ElementLink<xAOD::PFOContainer>> empty;
+  tau.setProtoNeutralPFOLinks(empty);
+  tau.setHadronicPFOLinks(empty);
+
+  // only run shower subtraction on 1-5 prong taus 
+  if (tau.nTracks() == 0 || tau.nTracks() > 5) {
+    return StatusCode::SUCCESS;
+  }
+
+  // Retrieve Ecal1 shots and match them to clusters
+  std::vector<const xAOD::PFO*> shotPFOs;
+  
+  unsigned nShots = tau.nShotPFOs();
+  for (unsigned index=0; index<nShots; ++index) {
+    const xAOD::PFO* shotPFO = tau.shotPFO(index);
+    shotPFOs.push_back(shotPFO);
+  }
+  
+  // Map shot to the pi0 cluster 
+  std::map<unsigned, const xAOD::CaloCluster*> shotToClusterMap = getShotToClusterMap(shotPFOs, pi0ClusterContainer, tau);
+
+  // FIXME: These clusters are custom ones, so could be corrected using tau vertex directly
+  if (! tau.jetLink().isValid()) {
+    ATH_MSG_ERROR("Tau jet link is invalid.");
+    return StatusCode::FAILURE;
+  }
+  const xAOD::Jet *jetSeed = tau.jet();
+  const xAOD::Vertex* jetVertex = m_tauVertexCorrection->getJetVertex(*jetSeed);
+  const xAOD::Vertex* tauVertex = nullptr;
+  if (tau.vertexLink().isValid()) tauVertex = tau.vertex();
+  TLorentzVector tauAxis = m_tauVertexCorrection->getTauAxis(tau);
+
+  // Loop over clusters, and create neutral PFOs
+  for (const xAOD::CaloCluster* cluster: pi0ClusterContainer) {
+    TLorentzVector clusterP4 = m_tauVertexCorrection->getVertexCorrectedP4(*cluster, tauVertex, jetVertex);
+    
+    // Clusters must have enough energy, and within 0.4 cone of the tau candidate
+    if (clusterP4.Pt() < m_clusterEtCut)   continue;
+    if (clusterP4.DeltaR(tauAxis) > 0.4) continue;
+
+    // Create the neutral PFOs
+    xAOD::PFO* neutralPFO = new xAOD::PFO();
+    neutralPFOContainer.push_back(neutralPFO);
+    
+    // Add the link to the tau candidate 
+    ElementLink<xAOD::PFOContainer> PFOElementLink;
+    PFOElementLink.toContainedElement(neutralPFOContainer, neutralPFO);
+    tau.addProtoNeutralPFOLink(PFOElementLink);
+
+    ATH_CHECK(configureNeutralPFO(*cluster, pi0ClusterContainer, tau, shotPFOs, shotToClusterMap, *neutralPFO));
+  }
+
+  // Loop over clusters, and create hadronic PFOs
+  std::vector<const xAOD::CaloCluster*> clusterList;
+  ATH_CHECK(tauRecTools::GetJetClusterList(jetSeed, clusterList, m_useSubtractedCluster));
+  for (const xAOD::CaloCluster* cluster: clusterList) {
+    TLorentzVector clusterP4 = m_tauVertexCorrection->getVertexCorrectedP4(*cluster, tauVertex, jetVertex);
+       
+    // Clusters must have positive energy, and within 0.2 cone of the tau candidate 
+    if(clusterP4.E()<=0.) continue;
+    if(clusterP4.DeltaR(tauAxis) > 0.2) continue;
+
+    double clusterEnergyHad = 0.;
+    const CaloClusterCellLink* cellLinks = cluster->getCellLinks();
+    CaloClusterCellLink::const_iterator cellLink = cellLinks->begin();
+	for (; cellLink != cellLinks->end(); ++cellLink) {
+	  const CaloCell* cell = static_cast<const CaloCell*>(*cellLink);
+       
+      int sampling = cell->caloDDE()->getSampling();
+      if (sampling < 8) continue;
+
+      double cellEnergy = cell->e() * cellLink.weight();
+      clusterEnergyHad += cellEnergy;
     }
-    ATH_MSG_DEBUG("ClusterCreator: new tau. \tpt = " << pTau.pt() << "\teta = " << pTau.eta() << "\tphi = " << pTau.phi() << "\tnprongs = " << pTau.nTracks());
-
-    // Retrieve Ecal1 shots and match them to clusters
-    std::vector<const xAOD::PFO*> shotVector;
-    unsigned nShots = pTau.nShotPFOs();
-    for(unsigned iShot=0;iShot<nShots;++iShot){
-        const xAOD::PFO* thisShot = pTau.shotPFO(iShot);
-        shotVector.push_back( thisShot );
+    
+    // Energy in Had Calorimeter must be positive
+    if(clusterEnergyHad <= 0.) continue;
+  
+    // Create the hadrnic PFO
+    xAOD::PFO* hadronicPFO = new xAOD::PFO();
+    hadronicPFOContainer.push_back(hadronicPFO);
+    
+    // Add element link from tau to hadronic PFO
+    ElementLink<xAOD::PFOContainer> PFOElementLink;
+    PFOElementLink.toContainedElement( hadronicPFOContainer, hadronicPFO );
+    tau.addHadronicPFOLink( PFOElementLink );
+    
+    ATH_CHECK(configureHadronicPFO(*cluster, clusterEnergyHad, *hadronicPFO));
+  }
+
+  return StatusCode::SUCCESS;
+}
+
+
+
+std::map<unsigned, const xAOD::CaloCluster*> TauPi0ClusterCreator::getShotToClusterMap(const std::vector<const xAOD::PFO*>& shotPFOs,
+										 const xAOD::CaloClusterContainer& pi0ClusterContainer,
+										 const xAOD::TauJet &tau) const {
+  std::map<unsigned, const xAOD::CaloCluster*> shotToClusterMap;
+  for (unsigned index = 0; index < shotPFOs.size(); ++index) {
+    const xAOD::PFO* shotPFO = shotPFOs.at(index);
+
+    int seedHashInt = -1;
+    if (!shotPFO->attribute(xAOD::PFODetails::PFOAttributes::tauShots_seedHash, seedHashInt)) {
+      ATH_MSG_WARNING("Couldn't find seed hash. Set it to -1, no cluster will be associated to shot.");
     }
-    std::map<unsigned, xAOD::CaloCluster*> clusterToShotMap = getClusterToShotMap(shotVector, pi0ClusterContainer, pTau);
- 
-    if (! pTau.jetLink().isValid()) {
+    const IdentifierHash seedHash = static_cast<const IdentifierHash>(seedHashInt);
+
+    const xAOD::Jet *jetSeed = tau.jet();
+    if (!jetSeed) {
       ATH_MSG_ERROR("Tau jet link is invalid.");
-      return StatusCode::FAILURE;
+      return shotToClusterMap;
     }
-    const xAOD::Jet *jetSeed = pTau.jet();
-    
     const xAOD::Vertex* jetVertex = m_tauVertexCorrection->getJetVertex(*jetSeed);
-    
+  
     const xAOD::Vertex* tauVertex = nullptr;
-    if (pTau.vertexLink().isValid()) tauVertex = pTau.vertex();
+    if (tau.vertexLink().isValid()) tauVertex = tau.vertex();
     
-    TLorentzVector tauAxis = m_tauVertexCorrection->getTauAxis(pTau);
+    TLorentzVector tauAxis = m_tauVertexCorrection->getTauAxis(tau);
 
-    for (const xAOD::CaloCluster* cluster: pi0ClusterContainer){
-        TLorentzVector clusterP4 = m_tauVertexCorrection->getVertexCorrectedP4(*cluster, tauVertex, jetVertex);
+    float weightInCluster = -1.;
+    float weightInPreviousCluster = -1;
+    
+    for (const xAOD::CaloCluster* cluster : pi0ClusterContainer) {
+      // FIXME: cluster here is not corrected
+      TLorentzVector clusterP4 = m_tauVertexCorrection->getVertexCorrectedP4(*cluster, tauVertex, jetVertex);
+      
+      weightInCluster = -1.;
+      if (clusterP4.Et() < m_clusterEtCut) continue;
+      if (clusterP4.DeltaR(tauAxis) > 0.4)  continue;
         
-        // selection
-        if (clusterP4.Pt() < m_clusterEtCut)   continue;
-        // Cluster container has clusters for all taus.
-        // Only run on clusters that belong to this tau
-        if (clusterP4.DeltaR(tauAxis) > 0.4) continue;
-
-        // Get shots in this cluster. Need to use (CaloCluster*) (*clusterItr) 
-        // (not a copy!) since the pointer will otherwise be different than in clusterToShotMap
-        std::vector<unsigned> shotsInCluster = getShotsMatchedToCluster( shotVector, clusterToShotMap, cluster);
-
-        // Calculate input variables for fake supression. 
-        // Do this before applying the vertex correction, 
-        // since the position of the cluster in the 
-        // calorimeter is required.
-        float EM1CoreFrac = getEM1CoreFrac(cluster);
-        int NHitsInEM1 = getNPhotons(shotVector, shotsInCluster);
-	std::vector<int> NPosECellsInLayer = getNPosECells(cluster);
-	std::vector<float> firstEtaWRTClusterPositionInLayer = get1stEtaMomWRTCluster(cluster);
-	std::vector<float> secondEtaWRTClusterPositionInLayer = get2ndEtaMomWRTCluster(cluster);
-
-        // Retrieve cluster moments that are used for fake supression and that are not stored in AOD
-        // for every cluster. Do this after applying the vertex correction, since the moments 
-        // (especcially DELTA_PHI and DELTA_THETA) must be calculated WRT the tau vertex
-        double CENTER_MAG = 0.0;
-        double FIRST_ETA = 0.0;
-        double SECOND_R = 0.0;
-        double SECOND_LAMBDA = 0.0;
-        double DELTA_PHI = 0.0;
-        double DELTA_THETA = 0.0;
-        double CENTER_LAMBDA = 0.0;
-        double LATERAL = 0.0;
-        double LONGITUDINAL = 0.0;
-        double ENG_FRAC_EM = 0.0;
-        double ENG_FRAC_MAX = 0.0;
-        double ENG_FRAC_CORE = 0.0;
-        double SECOND_ENG_DENS = 0.0;
-
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::CENTER_MAG, CENTER_MAG) ) ATH_MSG_WARNING("Couldn't retrieve CENTER_MAG moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::FIRST_ETA, FIRST_ETA) ) ATH_MSG_WARNING("Couldn't retrieve FIRST_ETA moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SECOND_R, SECOND_R) ) ATH_MSG_WARNING("Couldn't retrieve SECOND_R moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SECOND_LAMBDA, SECOND_LAMBDA) ) ATH_MSG_WARNING("Couldn't retrieve SECOND_LAMBDA moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::DELTA_PHI, DELTA_PHI) ) ATH_MSG_WARNING("Couldn't retrieve DELTA_PHI moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::DELTA_THETA, DELTA_THETA) ) ATH_MSG_WARNING("Couldn't retrieve DELTA_THETA moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::CENTER_LAMBDA, CENTER_LAMBDA) ) ATH_MSG_WARNING("Couldn't retrieve CENTER_LAMBDA moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LATERAL, LATERAL) ) ATH_MSG_WARNING("Couldn't retrieve LATERAL moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LONGITUDINAL, LONGITUDINAL) ) ATH_MSG_WARNING("Couldn't retrieve LONGITUDINAL moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::ENG_FRAC_EM, ENG_FRAC_EM) ) ATH_MSG_WARNING("Couldn't retrieve ENG_FRAC_EM moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::ENG_FRAC_MAX, ENG_FRAC_MAX) ) ATH_MSG_WARNING("Couldn't retrieve ENG_FRAC_MAX moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::ENG_FRAC_CORE, ENG_FRAC_CORE) ) ATH_MSG_WARNING("Couldn't retrieve ENG_FRAC_CORE moment. Set it to 0.");
-        if( !cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SECOND_ENG_DENS, SECOND_ENG_DENS) ) ATH_MSG_WARNING("Couldn't retrieve SECOND_ENG_DENS moment. Set it to 0.");
-
-       	float E_EM1 = cluster->eSample(CaloSampling::EMB1) + cluster->eSample(CaloSampling::EME1);
-	      float E_EM2 = cluster->eSample(CaloSampling::EMB2) + cluster->eSample(CaloSampling::EME2);
+      const CaloClusterCellLink* cellLinks = cluster->getCellLinks();
+      CaloClusterCellLink::const_iterator cellLink = cellLinks->begin();
+	  for (; cellLink != cellLinks->end(); ++cellLink) {
+        const CaloCell* cell = static_cast<const CaloCell*>(*cellLink);
         
-        // create neutral PFO. Set BDTScore to dummy value <-1. The BDT score is calculated within TauPi0Selector.cxx.
-        xAOD::PFO* neutralPFO = new xAOD::PFO();
-        neutralPFOContainer.push_back( neutralPFO );
-
-        // Create element link from tau to neutral PFO
-        ElementLink<xAOD::PFOContainer> PFOElementLink;
-        PFOElementLink.toContainedElement( neutralPFOContainer, neutralPFO );
-        pTau.addProtoNeutralPFOLink( PFOElementLink );
-
-        // Set PFO variables
-        ElementLink<xAOD::CaloClusterContainer> clusElementLink;
-        clusElementLink.toContainedElement( pi0ClusterContainer, cluster );
-        neutralPFO->setClusterLink( clusElementLink );
+        // Check if seed cell is in cluster.
+        if (cell->caloDDE()->calo_hash() != seedHash) continue;
         
-        neutralPFO->setP4( (float) cluster->pt(), (float) cluster->eta(), (float) cluster->phi(), (float) cluster->m());
-        neutralPFO->setBDTPi0Score( (float) -9999. );
-        neutralPFO->setCharge( 0. );
-        neutralPFO->setCenterMag( (float) CENTER_MAG);
-        neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, -1);
-
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_FIRST_ETA,       (float) FIRST_ETA);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_R,        (float) SECOND_R);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_LAMBDA,   (float) SECOND_LAMBDA);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_DELTA_PHI,       (float) DELTA_PHI);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_DELTA_THETA,     (float) DELTA_THETA);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_CENTER_LAMBDA,   (float) CENTER_LAMBDA);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_LATERAL,         (float) LATERAL);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_LONGITUDINAL,    (float) LONGITUDINAL);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_EM,     (float) ENG_FRAC_EM);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_MAX,    (float) ENG_FRAC_MAX);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_ENG_FRAC_CORE,   (float) ENG_FRAC_CORE);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_SECOND_ENG_DENS, (float) SECOND_ENG_DENS);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_energy_EM1,      (float) E_EM1);
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_energy_EM2,      (float) E_EM2);
-
-
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_EM1CoreFrac, EM1CoreFrac);
-        neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NHitsInEM1, NHitsInEM1);
-        neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_PS,  NPosECellsInLayer.at(0));
-        neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM1, NPosECellsInLayer.at(1));
-        neutralPFO->setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM2, NPosECellsInLayer.at(2));
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1, firstEtaWRTClusterPositionInLayer.at(1));
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM2, firstEtaWRTClusterPositionInLayer.at(2));
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM1, secondEtaWRTClusterPositionInLayer.at(1));
-        neutralPFO->setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2, secondEtaWRTClusterPositionInLayer.at(2));
-
-        // Store shot element links in neutral PFO
-        std::vector<ElementLink<xAOD::IParticleContainer> > shotlinks;
-        for(unsigned iShot = 0;iShot<shotsInCluster.size();++iShot){
-            ElementLink<xAOD::PFOContainer> shotPFOElementLink = pTau.shotPFOLinks().at(shotsInCluster.at(iShot));
-            ElementLink<xAOD::IParticleContainer> shotElementLink;
-            shotPFOElementLink.toPersistent();
-            shotElementLink.resetWithKeyAndIndex( shotPFOElementLink.persKey(), shotPFOElementLink.persIndex() ); 
-            if (!shotElementLink.isValid()) ATH_MSG_WARNING("Created an invalid element link to xAOD::PFO");
-            shotlinks.push_back(shotElementLink);
+        weightInCluster = cellLink.weight();
+        // found cell, no need to loop over other cells
+        break;
+      }
+      
+      if (weightInCluster < 0) continue;
+        
+      // Check if cell was already found in a previous cluster
+      if (weightInPreviousCluster < 0) {
+        // Cell not found in a previous cluster. 
+        // Have to check whether cell is shared with other cluster
+        shotToClusterMap[index] = cluster;
+        weightInPreviousCluster = weightInCluster;
+      }
+      else {
+        // Cell has been found in a previous cluster
+        // assign shot to this cluster if it has larger weight for the cell
+        // otherwise the shots keeps assigned to the previous cluster
+        if (weightInCluster > weightInPreviousCluster) {
+            shotToClusterMap[index] = cluster;
         }
-        if(!neutralPFO->setAssociatedParticleLinks( xAOD::PFODetails::TauShot,shotlinks)) 
-            ATH_MSG_WARNING("Couldn't add shot links to neutral PFO!");
+        // FIXME: why break here ? Should loop all the cluster, and find the largest weight
+        break;
+      }
     }
+  }
+  
+  return shotToClusterMap;
+}
 
-    // Create hadronic PFOs, put them in output container and store links to tau
-    if(!setHadronicClusterPFOs(pTau, hadronicClusterPFOContainer)){
-        ATH_MSG_ERROR("Could not set hadronic PFOs");
-        return StatusCode::FAILURE;
-    }
 
-    return StatusCode::SUCCESS;
+
+std::vector<unsigned> TauPi0ClusterCreator::getShotsMatchedToCluster(const std::vector<const xAOD::PFO*>& shotPFOs,
+								                                     const std::map<unsigned, const xAOD::CaloCluster*>& shotToClusterMap, 
+								                                     const xAOD::CaloCluster& pi0Cluster) const {
+  std::vector<unsigned> shotsMatchedToCluster;
+  
+  // Loop over the shots, and select those matched to the cluster
+  for (unsigned index = 0; index < shotPFOs.size(); ++index) {
+    auto iterator = shotToClusterMap.find(index);
+    if (iterator == shotToClusterMap.end()) continue;
+    if (iterator->second != &pi0Cluster) continue;
+    
+    shotsMatchedToCluster.push_back(index);
+  }
+  
+  return shotsMatchedToCluster;
 }
 
-//______________________________________________________________________________
-// Functions used to calculate BDT variables other than those provided by the CaloClusterMomentsMaker
-float TauPi0ClusterCreator::getEM1CoreFrac(const xAOD::CaloCluster* pi0Candidate) const
-{
-    float coreEnergy=0.;
-    float sumEPosCellsEM1=0.;
-
-    const CaloClusterCellLink* theCellLink = pi0Candidate->getCellLinks();
-    CaloClusterCellLink::const_iterator cellInClusterItr  = theCellLink->begin();
-    CaloClusterCellLink::const_iterator cellInClusterItrE = theCellLink->end();
-    for(;cellInClusterItr!=cellInClusterItrE;++cellInClusterItr){
-        CaloCell* cellInCluster = (CaloCell*) *cellInClusterItr;
-        int sampling = cellInCluster->caloDDE()->getSampling();
-        if(sampling!=1 && sampling!=5) continue;
-        float cellE = cellInCluster->e() * cellInClusterItr.weight();
-        if(cellE<=0) continue;
-        sumEPosCellsEM1 += cellE;
-        float cellEtaWRTCluster = cellInCluster->eta()-pi0Candidate->eta();
-        float cellPhiWRTCluster = P4Helpers::deltaPhi(cellInCluster->phi(), pi0Candidate->phi());
-        if(std::abs(cellPhiWRTCluster) > 0.05 || std::abs(cellEtaWRTCluster) > 2 * 0.025/8.) continue;
-        coreEnergy+=cellE;
+
+
+int TauPi0ClusterCreator::getNPhotons(const std::vector<const xAOD::PFO*>& shotPFOs,
+				                      const std::vector<unsigned>& shotsInCluster) const {
+  int totalPhotons = 0;
+  
+  for (unsigned index = 0; index < shotsInCluster.size(); ++index) {
+    int nPhotons = 0;
+    const xAOD::PFO* shotPFO = shotPFOs.at(shotsInCluster.at(index));
+    if (! shotPFO->attribute(xAOD::PFODetails::PFOAttributes::tauShots_nPhotons, nPhotons)) { 
+      ATH_MSG_WARNING("Can't find NHitsInEM1. Set it to 0.");
     }
-    if(sumEPosCellsEM1<=0.) return 0.;
-    return coreEnergy/sumEPosCellsEM1;
+    totalPhotons += nPhotons;
+  }
+  
+  return totalPhotons;
 }
 
-//______________________________________________________________________________
-// Do cluster to shot matching. 
-// A cluster is matched to a shot if the seed cell of the shot is in the cluster
-std::map<unsigned, xAOD::CaloCluster*> TauPi0ClusterCreator::getClusterToShotMap(const std::vector<const xAOD::PFO*>& shotVector,
-										 const xAOD::CaloClusterContainer& pi0ClusterContainer,
-										 const xAOD::TauJet &pTau) const
-{
-    std::map<unsigned, xAOD::CaloCluster*> clusterToShotMap;
-    for(unsigned iShot = 0;iShot<shotVector.size();++iShot){
-        int seedHash_int = -1;
-        if( shotVector.at(iShot)->attribute(xAOD::PFODetails::PFOAttributes::tauShots_seedHash, seedHash_int) == false) {
-            std::cout << "WARNING: Couldn't find seed hash. Set it to -1, no cluster will be associated to shot." << std::endl;
-        }
-        const IdentifierHash seedHash = (const IdentifierHash) seedHash_int; 
-        xAOD::CaloClusterContainer::const_iterator clusterItr   (pi0ClusterContainer.begin()),
-                                                   clusterItrEnd(pi0ClusterContainer.end());
-        float weightInCluster=-1.;
-        float weightInPreviousCluster=-1;
-    
-        const xAOD::Jet *jetSeed = pTau.jet();
-        if (!jetSeed) {
-          ATH_MSG_ERROR("Tau jet link is invalid.");
-          return clusterToShotMap;
-        }
-        const xAOD::Vertex* jetVertex = m_tauVertexCorrection->getJetVertex(*jetSeed);
+
+
+std::vector<int> TauPi0ClusterCreator::getNPosECells(const xAOD::CaloCluster& cluster) const {
+  std::vector<int> nPosECells(3, 0);
+
+  const CaloClusterCellLink* cellLinks = cluster.getCellLinks();
+  CaloClusterCellLink::const_iterator cellLink = cellLinks->begin();
+  for (; cellLink != cellLinks->end(); ++cellLink) {
+    const CaloCell* cell = static_cast<const CaloCell*>(*cellLink);
+    int sampling = cell->caloDDE()->getSampling();
     
-        const xAOD::Vertex* tauVertex = nullptr;
-        if (pTau.vertexLink().isValid()) tauVertex = pTau.vertex();
-      
-        TLorentzVector tauAxis = m_tauVertexCorrection->getTauAxis(pTau);
-
-        for (; clusterItr != clusterItrEnd; ++clusterItr){
-            xAOD::CaloCluster* cluster = (xAOD::CaloCluster*) (*clusterItr);
-            TLorentzVector clusterP4 = m_tauVertexCorrection->getVertexCorrectedP4(*cluster, tauVertex, jetVertex);
-            
-            weightInCluster=-1.;
-            if (clusterP4.Et() < m_clusterEtCut) continue; // Not interested in clusters that fail the Et cut
-            // Cluster container has clusters for all taus.
-            // Only run on clusters that belong to this tau
-            if (clusterP4.DeltaR(tauAxis) > 0.4)  continue;
-            const CaloClusterCellLink* theCellLink = cluster->getCellLinks();
-            CaloClusterCellLink::const_iterator cellItr  = theCellLink->begin();
-            CaloClusterCellLink::const_iterator cellItrE = theCellLink->end();
-            for(;cellItr!=cellItrE; ++cellItr){
-                CaloCell* cellInCluster = (CaloCell*) *cellItr;
-                // Check if seed cell is in cluster.
-                if(cellInCluster->caloDDE()->calo_hash()!=seedHash) continue;
-                weightInCluster = cellItr.weight();
-                // found cell, no need to loop over other cells
-                break;
-            }
-            if(weightInCluster<0) continue;
-            // Check if cell was already found in a previous cluster
-            if(weightInPreviousCluster<0){
-                // Cell not found in a previous cluster. 
-                // Have to check whether cell is shared with other cluster
-                clusterToShotMap[iShot] = cluster;
-                weightInPreviousCluster = weightInCluster;
-            }
-            else{
-                // Cell has been found in a previous cluster
-                // assign shot to this cluster if it has larger weight for the cell
-                // otherwise the shots keeps assigned to the previous cluster
-                if(weightInCluster>weightInPreviousCluster){
-                    clusterToShotMap[iShot] = cluster;
-                }
-                // No need to loop over other clusters as cells can not be shared by more than two clusters
-                break;
-            }
-        }
+    // layer0: PS, layer1: EM1, layer2: EM2
+    int layer = sampling%4;  
+    if (layer < 3 && cell->e() > 0) {
+      ++nPosECells[layer];
     }
-    return clusterToShotMap;
-}
+  }
 
-//______________________________________________________________________________
-std::vector<unsigned> TauPi0ClusterCreator::getShotsMatchedToCluster(const std::vector<const xAOD::PFO*>& shotVector,
-								     const std::map<unsigned, xAOD::CaloCluster*>& clusterToShotMap, 
-								     const xAOD::CaloCluster* pi0Cluster) const
-{
-    std::vector<unsigned> shotsMatchedToCluster;
-    for(unsigned iShot = 0;iShot<shotVector.size();++iShot){
-        auto itr = clusterToShotMap.find(iShot);
-        if(itr==clusterToShotMap.end()) continue;
-        if(itr->second!=pi0Cluster) continue;
-        shotsMatchedToCluster.push_back(iShot);
-    }
-    return shotsMatchedToCluster;
+  return nPosECells;
 }
 
-//______________________________________________________________________________
-int TauPi0ClusterCreator::getNPhotons(const std::vector<const xAOD::PFO*>& shotVector,
-				      const std::vector<unsigned>& shotsInCluster ) const
-{
-    int nPhotons = 0;
-    for(unsigned iShot = 0;iShot<shotsInCluster.size();++iShot){
-        int curNPhotons=0;
-        if(shotVector.at(shotsInCluster.at(iShot))->attribute(xAOD::PFODetails::PFOAttributes::tauShots_nPhotons,curNPhotons) == false)
-            ATH_MSG_WARNING("Can't find NHitsInEM1. Set it to 0.");
-        nPhotons+=curNPhotons;
-    }
-    return nPhotons;
+
+
+float TauPi0ClusterCreator::getEM1CoreFrac(const xAOD::CaloCluster& cluster) const {
+  float coreEnergyEM1 = 0.;
+  float totalEnergyEM1 = 0.;
+  
+  const CaloClusterCellLink* cellLinks = cluster.getCellLinks();
+  CaloClusterCellLink::const_iterator cellLink = cellLinks->begin();
+  for (; cellLink != cellLinks->end(); ++cellLink) {
+    const CaloCell* cell = static_cast<const CaloCell*>(*cellLink);
+    
+    // Only consider EM1
+    int sampling = cell->caloDDE()->getSampling();
+    if (sampling != 1 && sampling != 5) continue;
+    
+    // Only consider positive cells
+    // FIXME: is the weight needed ?
+    float cellEnergy = cell->e() * cellLink.weight();
+    if (cellEnergy <= 0) continue;
+    
+    totalEnergyEM1 += cellEnergy;
+
+    float deltaEta = cell->eta() - cluster.eta();
+    float deltaPhi = P4Helpers::deltaPhi(cell->phi(), cluster.phi());
+    
+    // Core region: [0.05, 0.05/8]
+    if(std::abs(deltaPhi) > 0.05 || std::abs(deltaEta) > 2 * 0.025/8.) continue;
+    
+    coreEnergyEM1 += cellEnergy;
+  }
+  
+  if (totalEnergyEM1 <= 0.) return 0.;
+  return coreEnergyEM1/totalEnergyEM1;
 }
 
-//______________________________________________________________________________
-std::vector<int> TauPi0ClusterCreator::getNPosECells(const xAOD::CaloCluster* pi0Candidate) const
-{
-    std::vector<int> nPosECellsInLayer(3,0); // 3 layers initialised with 0 +ve cells
-
-    const CaloClusterCellLink* theCellLink = pi0Candidate->getCellLinks();
-    CaloClusterCellLink::const_iterator cellInClusterItr  = theCellLink->begin();
-    CaloClusterCellLink::const_iterator cellInClusterItrE = theCellLink->end();
-
-    for(;cellInClusterItr!=cellInClusterItrE; ++cellInClusterItr){
-        const CaloCell* cellInCluster = static_cast<const CaloCell*>( *cellInClusterItr);
-        int sampling = cellInCluster->caloDDE()->getSampling();
-        // Get cell layer: PSB and PSE belong to layer 0,  
-        // EMB1 and EME1 to layer 1, EMB2 and EME2 to layer 2. 
-        int cellLayer = sampling%4;  
-        if(cellLayer < 3 && cellInCluster->e() > 0) nPosECellsInLayer[cellLayer]++;
+
+
+std::vector<float> TauPi0ClusterCreator::get1stEtaMomWRTCluster(const xAOD::CaloCluster& cluster) const {
+  std::vector<float> deltaEtaFirstMom (3, 0.);
+  std::vector<float> totalEnergy (3, 0.);
+
+  const CaloClusterCellLink* cellLinks = cluster.getCellLinks();
+  CaloClusterCellLink::const_iterator cellLink = cellLinks->begin();
+  for (; cellLink != cellLinks->end(); ++cellLink) {
+    const CaloCell* cell = static_cast<const CaloCell*>(*cellLink);
+    
+    // Only consider PS, EM1, and EM2
+    int sampling = cell->caloDDE()->getSampling();
+    int layer = sampling%4;
+    if (layer >= 3) continue;
+
+    // Only consider positive cells
+    float cellEnergy = cell->e();
+    if (cellEnergy <= 0) continue;
+
+    float deltaEta = cell->eta() - cluster.eta();
+    deltaEtaFirstMom[layer] += deltaEta * cellEnergy;
+    totalEnergy[layer] += cellEnergy;
+  }
+
+  for (int layer=0; layer < 3; ++layer) {
+    if (totalEnergy[layer] != 0.) {
+      deltaEtaFirstMom[layer]/=std::abs(totalEnergy[layer]);
+    }
+    else {
+      deltaEtaFirstMom[layer]=0.;
     }
-    return nPosECellsInLayer;
+  }
+  
+  return deltaEtaFirstMom;
 }
 
-//______________________________________________________________________________
-std::vector<float> TauPi0ClusterCreator::get1stEtaMomWRTCluster(const xAOD::CaloCluster* pi0Candidate) const
-{
-    std::vector<float> firstEtaWRTClusterPositionInLayer (4, 0.);  //init with 0. for 0-3 layers
-    std::vector<float> sumEInLayer (4, 0.); //init with 0. for 0-3 layers
-
-    const CaloClusterCellLink* theCellLink = pi0Candidate->getCellLinks();
-    CaloClusterCellLink::const_iterator cellInClusterItr  = theCellLink->begin();
-    CaloClusterCellLink::const_iterator cellInClusterItrE = theCellLink->end();
-
-    for(;cellInClusterItr!=cellInClusterItrE;++cellInClusterItr){
-        CaloCell* cellInCluster = (CaloCell*) *cellInClusterItr;
-        int sampling = cellInCluster->caloDDE()->getSampling();
-        // Get cell layer: PSB and PSE belong to layer 0,  
-        // EMB1 and EME1 to layer 1, EMB2 and EME2 to layer 2. 
-        int cellLayer = sampling%4;
-        
-        float cellEtaWRTClusterPos=cellInCluster->eta()-pi0Candidate->eta();
-        float cellE=cellInCluster->e();
-        if(cellE<=0  || cellLayer>=3) continue;
-        firstEtaWRTClusterPositionInLayer[cellLayer]+=cellEtaWRTClusterPos*cellE;
-        sumEInLayer[cellLayer]+=cellE;
-    }
 
-    for(int iLayer=0;iLayer<4;++iLayer){
-        if(sumEInLayer[iLayer]!=0) 
-            firstEtaWRTClusterPositionInLayer[iLayer]/=std::abs(sumEInLayer[iLayer]);
-        else firstEtaWRTClusterPositionInLayer[iLayer]=0.;
+
+std::vector<float> TauPi0ClusterCreator::get2ndEtaMomWRTCluster(const xAOD::CaloCluster& cluster) const {
+  std::vector<float> deltaEtaSecondMom (3, 0.);
+  std::vector<float> totalEnergy (3, 0.);
+
+  const CaloClusterCellLink* cellLinks = cluster.getCellLinks();
+  CaloClusterCellLink::const_iterator cellLink = cellLinks->begin();
+  for (; cellLink != cellLinks->end(); ++cellLink) {
+    const CaloCell* cell = static_cast<const CaloCell*>(*cellLink);
+    
+    // Only consider PS, EM1, and EM2
+    int sampling = cell->caloDDE()->getSampling();
+    int layer = sampling%4;
+    if (layer >= 3) continue;
+
+    // Only consider positive cells
+    float cellEnergy=cell->e();
+    if (cellEnergy <= 0) continue;
+
+    float deltaEta = cell->eta() - cluster.eta();
+    deltaEtaSecondMom[layer] += deltaEta * deltaEta * cellEnergy;
+    totalEnergy[layer] += cellEnergy;
+  }
+
+  for (int layer=0; layer < 3; ++layer) {
+    if (totalEnergy[layer] != 0.) {
+      deltaEtaSecondMom[layer]/=std::abs(totalEnergy[layer]);
+    }
+    else {
+      deltaEtaSecondMom[layer]=0.;
     }
-    return firstEtaWRTClusterPositionInLayer;
+  }
+  
+  return deltaEtaSecondMom;
 }
 
-//______________________________________________________________________________
-std::vector<float> TauPi0ClusterCreator::get2ndEtaMomWRTCluster( const xAOD::CaloCluster* pi0Candidate) const
-{
-      std::vector<float> secondEtaWRTClusterPositionInLayer (4, 0.); //init with 0. for 0-3 layers
-      std::vector<float> sumEInLayer (4, 0.); //init with 0. for 0-3 layers
-
-      const CaloClusterCellLink* theCellLinks = pi0Candidate->getCellLinks();
-
-      for(const CaloCell* cellInCluster: *theCellLinks){
-            int sampling = cellInCluster->caloDDE()->getSampling();
-            // Get cell layer: PSB and PSE belong to layer 0,  
-            // EMB1 and EME1 to layer 1, EMB2 and EME2 to layer 2. 
-            int cellLayer = sampling%4;
-
-            float cellEtaWRTClusterPos=cellInCluster->eta()-pi0Candidate->eta();
-            float cellE=cellInCluster->e();
-            if(cellE<=0  || cellLayer>=3) continue;
-            secondEtaWRTClusterPositionInLayer[cellLayer]+=cellEtaWRTClusterPos*cellEtaWRTClusterPos*cellE;
-            sumEInLayer[cellLayer]+=cellE;
-      }
 
-      for(int iLayer=0;iLayer<4;++iLayer){
-            if(sumEInLayer[iLayer]!=0) 
-                secondEtaWRTClusterPositionInLayer[iLayer]/=std::abs(sumEInLayer[iLayer]);
-            else secondEtaWRTClusterPositionInLayer[iLayer]=0.;
-      }
-      return secondEtaWRTClusterPositionInLayer;
-}
 
-//______________________________________________________________________________
-bool TauPi0ClusterCreator::setHadronicClusterPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pHadronPFOContainer) const
-{
-    if (! pTau.jetLink().isValid()) {
-      ATH_MSG_ERROR("Tau jet link is invalid.");
-      return false;
+StatusCode TauPi0ClusterCreator::configureNeutralPFO(const xAOD::CaloCluster& cluster,
+                                                     const xAOD::CaloClusterContainer& pi0ClusterContainer,
+                                                     const xAOD::TauJet& tau,
+                                                     const std::vector<const xAOD::PFO*>& shotPFOs, 
+                                                     const std::map<unsigned, const xAOD::CaloCluster*>& shotToClusterMap,
+                                                     xAOD::PFO& neutralPFO) const {
+  // Set the property of the PFO
+  // -- Four momentum: not corrected yet
+  neutralPFO.setP4(cluster.pt(), cluster.eta(), cluster.phi(), cluster.m());
+  
+  // -- Default value
+  neutralPFO.setBDTPi0Score(-9999.);
+  neutralPFO.setCharge(0);
+  neutralPFO.setAttribute<int>(xAOD::PFODetails::PFOAttributes::nPi0Proto, -1);
+
+  // -- CENTER_MAG
+  double CENTER_MAG = 0.0;
+  if (!cluster.retrieveMoment(xAOD::CaloCluster::MomentType::CENTER_MAG, CENTER_MAG)) {
+    ATH_MSG_WARNING("Couldn't retrieve CENTER_MAG moment. Set it to 0.");
+  }
+  neutralPFO.setCenterMag( (float) CENTER_MAG);
+  
+  // -- Number of photons 
+  std::vector<unsigned> shotsInCluster = getShotsMatchedToCluster(shotPFOs, shotToClusterMap, cluster);
+  int NHitsInEM1 = getNPhotons(shotPFOs, shotsInCluster);
+  neutralPFO.setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NHitsInEM1, NHitsInEM1);
+  
+  // -- Energy at each layer
+  float eEM1 = cluster.eSample(CaloSampling::EMB1) + cluster.eSample(CaloSampling::EME1);
+  neutralPFO.setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_energy_EM1, eEM1);
+  
+  float eEM2 = cluster.eSample(CaloSampling::EMB2) + cluster.eSample(CaloSampling::EME2);
+  neutralPFO.setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_energy_EM2, eEM2);
+  
+  // -- Number of positive cells in each layer
+  std::vector<int> nPosECells = getNPosECells(cluster);
+  neutralPFO.setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_PS,  nPosECells.at(0));
+  neutralPFO.setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM1, nPosECells.at(1));
+  neutralPFO.setAttribute<int>(xAOD::PFODetails::PFOAttributes::cellBased_NPosECells_EM2, nPosECells.at(2));
+ 
+  // -- Core Fraction of the energy in EM1 
+  float EM1CoreFrac = getEM1CoreFrac(cluster);
+  neutralPFO.setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_EM1CoreFrac, EM1CoreFrac);
+
+  // -- First moment of deltaEta(cluster, cell) in EM1 and EM2 
+  std::vector<float> deltaEtaFirstMom = get1stEtaMomWRTCluster(cluster);
+  neutralPFO.setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM1, deltaEtaFirstMom.at(1));
+  neutralPFO.setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_firstEtaWRTClusterPosition_EM2, deltaEtaFirstMom.at(2));
+  
+  // -- Second moment of deltaEta(cluster, cell) in EM1 and EM2
+  std::vector<float> secondEtaWRTClusterPositionInLayer = get2ndEtaMomWRTCluster(cluster);
+  neutralPFO.setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM1, secondEtaWRTClusterPositionInLayer.at(1));
+  neutralPFO.setAttribute<float>(xAOD::PFODetails::PFOAttributes::cellBased_secondEtaWRTClusterPosition_EM2, secondEtaWRTClusterPositionInLayer.at(2));
+
+  // -- Retrieve cluster moments
+  using Moment = xAOD::CaloCluster::MomentType;
+  using Attribute = xAOD::PFODetails::PFOAttributes; 
+  const std::array< std::pair<Moment, Attribute>, 12> momentAttributePairs {{
+      {Moment::FIRST_ETA, Attribute::cellBased_FIRST_ETA},
+      {Moment::SECOND_R, Attribute::cellBased_SECOND_R}, 
+      {Moment::SECOND_LAMBDA, Attribute::cellBased_SECOND_LAMBDA},
+      {Moment::DELTA_PHI, Attribute::cellBased_DELTA_PHI},
+      {Moment::DELTA_THETA, Attribute::cellBased_DELTA_THETA},
+      {Moment::CENTER_LAMBDA, Attribute::cellBased_CENTER_LAMBDA},
+      {Moment::LATERAL, Attribute::cellBased_LATERAL},
+      {Moment::LONGITUDINAL, Attribute::cellBased_LONGITUDINAL},
+      {Moment::ENG_FRAC_EM, Attribute::cellBased_ENG_FRAC_EM},
+      {Moment::ENG_FRAC_MAX, Attribute::cellBased_ENG_FRAC_MAX},
+      {Moment::ENG_FRAC_CORE, Attribute::cellBased_ENG_FRAC_CORE},
+      {Moment::SECOND_ENG_DENS, Attribute::cellBased_SECOND_ENG_DENS}
+  }};
+  
+  for (const auto& [moment, attribute] : momentAttributePairs) {
+    double value = 0.0;
+    if (! cluster.retrieveMoment(moment, value)) {
+      ATH_MSG_WARNING("Cound not retrieve " << moment);
     }
-    const xAOD::Jet *jetSeed = pTau.jet();
-    
-    const xAOD::Vertex* jetVertex = m_tauVertexCorrection->getJetVertex(*jetSeed);
-    
-    const xAOD::Vertex* tauVertex = nullptr;
-    if (pTau.vertexLink().isValid()) tauVertex = pTau.vertex();
-    
-    TLorentzVector tauAxis = m_tauVertexCorrection->getTauAxis(pTau);
-    
-    std::vector<const xAOD::CaloCluster*> clusterList;
-    StatusCode sc = tauRecTools::GetJetClusterList(jetSeed, clusterList, m_useSubtractedCluster);
-    if (!sc) return false;
-
-    for (const xAOD::CaloCluster* cluster : clusterList){
-        // Procedure: 
-        // - Calculate cluster energy in Hcal. This is to treat -ve energy cells correctly
-        // - Then set 4momentum via setP4(E/cosh(eta), eta, phi, m). This forces the PFO to have the correct energy and mass
-        // - Ignore clusters outside 0.2 cone and those with overall negative energy or negative energy in Hcal
-
-        // Don't create PFOs for clusters with overall (Ecal+Hcal) negative energy (noise)
-        TLorentzVector clusterP4 = m_tauVertexCorrection->getVertexCorrectedP4(*cluster, tauVertex, jetVertex);
-        
-        if(clusterP4.E()<=0.) continue;
+    neutralPFO.setAttribute(attribute, static_cast<float>(value));
+  }
 
-        // Only need clusters in core cone. Others are not needed for subtraction
-        if(tauAxis.DeltaR(clusterP4) > 0.2) continue;
+  // -- Element link to the cluster 
+  ElementLink<xAOD::CaloClusterContainer> clusElementLink;
+  clusElementLink.toContainedElement(pi0ClusterContainer, &cluster);
+  neutralPFO.setClusterLink( clusElementLink );
+ 
+  // -- Element link to the shots
+  std::vector<ElementLink<xAOD::IParticleContainer>> shotlinks;
+  for (unsigned index = 0; index < shotsInCluster.size(); ++index) {
+    ElementLink<xAOD::PFOContainer> shotPFOElementLink = tau.shotPFOLinks().at(shotsInCluster.at(index));
+    ElementLink<xAOD::IParticleContainer> shotElementLink;
+    shotPFOElementLink.toPersistent();
+    shotElementLink.resetWithKeyAndIndex(shotPFOElementLink.persKey(), shotPFOElementLink.persIndex()); 
+    if (!shotElementLink.isValid()) {
+      ATH_MSG_WARNING("Created an invalid element link to xAOD::PFO");
+    }
+    shotlinks.push_back(shotElementLink);
+  }
+  if(!neutralPFO.setAssociatedParticleLinks( xAOD::PFODetails::TauShot,shotlinks)) { 
+    ATH_MSG_WARNING("Couldn't add shot links to neutral PFO!");
+  }
 
-        // Loop over cells to calculate cluster energy in Hcal
-        double clusterE_Hcal=0.;
-	const CaloClusterCellLink* theCellLink = cluster->getCellLinks();
-	CaloClusterCellLink::const_iterator cellInClusterItr  = theCellLink->begin();
-	CaloClusterCellLink::const_iterator cellInClusterItrE = theCellLink->end();
+  return StatusCode::SUCCESS;
+}
 
-	for(; cellInClusterItr != cellInClusterItrE; ++cellInClusterItr){
-	   const CaloCell* cellInCluster = static_cast<const CaloCell*> (*cellInClusterItr);
 
-            //Get only HCAL cells
-            int sampling = cellInCluster->caloDDE()->getSampling();
-            if (sampling < 8) continue;
 
-            double cellE = cellInCluster->e()*cellInClusterItr.weight();
-            clusterE_Hcal+=cellE;
-        }
-        // Don't save PFOs for clusters with negative energy in Hcal 
-        if(clusterE_Hcal<=0.) continue;
-
-        // Create hadronic PFO
-        xAOD::PFO* hadronicPFO = new xAOD::PFO();
-        pHadronPFOContainer.push_back( hadronicPFO );
-
-        // Set 4mom. Eta and phi are taken from cluster
-        double cluster_Pt_Hcal = clusterE_Hcal/std::cosh(cluster->eta());
-        hadronicPFO->setP4( (float) cluster_Pt_Hcal, (float) cluster->eta(), (float) cluster->phi(), (float) 0.);
-
-        // TODO: May want to set element link to the cluster the PFO is originating from
-        // ElementLink<xAOD::CaloClusterContainer> clusElementLink;
-        // clusElementLink.toContainedElement( CLUSTERCONTAINER, cluster );
-        // hadronicPFO->setClusterLink( clusElementLink );
-
-        // Create element link from tau to hadronic PFO
-        ElementLink<xAOD::PFOContainer> PFOElementLink;
-        PFOElementLink.toContainedElement( pHadronPFOContainer, hadronicPFO );
-        pTau.addHadronicPFOLink( PFOElementLink );
-    }
-    return true;
+StatusCode TauPi0ClusterCreator::configureHadronicPFO(const xAOD::CaloCluster& cluster, 
+                                                      double clusterEnergyHad, 
+                                                      xAOD::PFO& hadronicPFO) const {
+  double clusterPtHad = clusterEnergyHad/std::cosh(cluster.eta());
+  hadronicPFO.setP4(clusterPtHad, cluster.eta(), cluster.phi(), 0.);
+
+  return StatusCode::SUCCESS;
 }
 
 #endif
diff --git a/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.h b/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.h
index 4ac1db7f16326dbf39eb13075398e3cfaef5d078..fcb8c415b532cb25f71b4ad9c741b6735de82fcc 100644
--- a/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.h
+++ b/Reconstruction/tauRecTools/src/TauPi0ClusterCreator.h
@@ -27,50 +27,64 @@
  */
 
 class TauPi0ClusterCreator : public TauRecToolBase {
+
 public:
-    TauPi0ClusterCreator(const std::string& name) ;
-    ASG_TOOL_CLASS2(TauPi0ClusterCreator, TauRecToolBase, ITauToolBase);
-    virtual ~TauPi0ClusterCreator();
-
-    virtual StatusCode initialize() override;
-    virtual StatusCode executePi0ClusterCreator(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer, 
-						xAOD::PFOContainer& hadronicClusterPFOContainer,
-						const xAOD::CaloClusterContainer& pi0CaloClusContainer) const override;
-    
+  
+  ASG_TOOL_CLASS2(TauPi0ClusterCreator, TauRecToolBase, ITauToolBase);
+  
+  TauPi0ClusterCreator(const std::string& name) ;
+  virtual ~TauPi0ClusterCreator() = default;
+
+  virtual StatusCode initialize() override;
+  virtual StatusCode executePi0ClusterCreator(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer, 
+  					xAOD::PFOContainer& hadronicClusterPFOContainer,
+  					const xAOD::CaloClusterContainer& pi0CaloClusContainer) const override;
+  
 private:
-    /** @brief fraction of cluster enegry in central EM1 cells */
-    float getEM1CoreFrac( const xAOD::CaloCluster* pi0Candidate) const;
-    
-    /** @brief number of cells from cluster with positive energy in PS, EM1 and EM2 */
-    std::vector<int> getNPosECells( const xAOD::CaloCluster* pi0Candidate) const;
-
-    std::map<unsigned, xAOD::CaloCluster*> getClusterToShotMap(
-        const std::vector<const xAOD::PFO*>& shotVector,
-        const xAOD::CaloClusterContainer& pi0ClusterContainer,
-        const xAOD::TauJet &pTau) const;
-
-    std::vector<unsigned> getShotsMatchedToCluster(
-        const std::vector<const xAOD::PFO*>& shotVector,
-        const std::map<unsigned, xAOD::CaloCluster*>& clusterToShotMap,
-        const xAOD::CaloCluster* pi0Cluster) const;
-
-    int getNPhotons( const std::vector<const xAOD::PFO*>& shotVector,
-                     const std::vector<unsigned>& shotsInCluster) const;
-
-    /** @brief first eta moment in PS, EM1 and EM2 w.r.t cluster eta: (eta_i - eta_cluster) */
-    std::vector<float> get1stEtaMomWRTCluster( const xAOD::CaloCluster* pi0Candidate) const;
-
-    /** @brief second eta moment in PS, EM1 and EM2 w.r.t cluster eta: (eta_i - eta_cluster)^2 */ 
-    std::vector<float> get2ndEtaMomWRTCluster(const xAOD::CaloCluster* pi0Candidate) const;
-
-    /** @brief get hadronic cluster PFOs*/
-    bool setHadronicClusterPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pHadronicClusterContainer) const;
-
-    Gaudi::Property<double> m_clusterEtCut {this, "ClusterEtCut", 0.5 * Gaudi::Units::GeV, "Et threshould for pi0 candidate clusters"};
-    Gaudi::Property<bool> m_useSubtractedCluster {this, "UseSubtractedCluster", true, "use shower subtracted clusters in calo calculations"};
-
-    ToolHandle<ITauVertexCorrection> m_tauVertexCorrection { this, 
-      "TauVertexCorrection", "TauVertexCorrection", "Tool to perform the vertex correction"};
+  
+  /** @brief Configure the neutral PFO*/
+  StatusCode configureNeutralPFO(const xAOD::CaloCluster& cluster,
+                                 const xAOD::CaloClusterContainer& pi0ClusterContainer,
+                                 const xAOD::TauJet& tau,
+                                 const std::vector<const xAOD::PFO*>& shotPFOs,
+                                 const std::map<unsigned, const xAOD::CaloCluster*>& shotsInCluster,
+                                 xAOD::PFO& neutralPFO) const;
+
+  /** @brief Configure the haronic PFO*/
+  StatusCode configureHadronicPFO(const xAOD::CaloCluster& cluster,
+                                  double clusterEnergyHad,
+                                  xAOD::PFO& hadronicPFO) const;
+
+  std::map<unsigned, const xAOD::CaloCluster*> getShotToClusterMap(
+      const std::vector<const xAOD::PFO*>& shotVector,
+      const xAOD::CaloClusterContainer& pi0ClusterContainer,
+      const xAOD::TauJet &pTau) const;
+
+  std::vector<unsigned> getShotsMatchedToCluster(
+      const std::vector<const xAOD::PFO*>& shotVector,
+      const std::map<unsigned, const xAOD::CaloCluster*>& clusterToShotMap,
+      const xAOD::CaloCluster& pi0Cluster) const;
+
+  int getNPhotons( const std::vector<const xAOD::PFO*>& shotVector,
+                   const std::vector<unsigned>& shotsInCluster) const;
+
+  /** @brief fraction of cluster enegry in central EM1 cells */
+  float getEM1CoreFrac(const xAOD::CaloCluster& cluster) const;
+  
+  /** @brief number of cells from cluster with positive energy in PS, EM1 and EM2 */
+  std::vector<int> getNPosECells(const xAOD::CaloCluster& cluster) const;
+
+  /** @brief first eta moment in PS, EM1 and EM2 w.r.t cluster eta */
+  std::vector<float> get1stEtaMomWRTCluster(const xAOD::CaloCluster& cluster) const;
+
+  /** @brief second eta moment in PS, EM1 and EM2 w.r.t cluster eta */ 
+  std::vector<float> get2ndEtaMomWRTCluster(const xAOD::CaloCluster& cluster) const;
+
+  Gaudi::Property<double> m_clusterEtCut {this, "ClusterEtCut", 0.5 * Gaudi::Units::GeV, "Et threshould for pi0 candidate clusters"};
+  Gaudi::Property<bool> m_useSubtractedCluster {this, "UseSubtractedCluster", true, "use shower subtracted clusters in calo calculations"};
+
+  ToolHandle<ITauVertexCorrection> m_tauVertexCorrection { this, 
+    "TauVertexCorrection", "TauVertexCorrection", "Tool to perform the vertex correction"};
 };
 
 #endif	/* TAUPI0CLUSTERCREATOR_H */
diff --git a/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.cxx b/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.cxx
index 45aeb323e2db670626a9ac6cf8ebded2293c96d9..9bd8f013719507e97ad7ee69db98af774cf1b064 100644
--- a/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.cxx
+++ b/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.cxx
@@ -2,323 +2,325 @@
   Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
 */
 
-//-----------------------------------------------------------------------------
-// file:        TauPi0ClusterScaler.cxx
-// package:     Reconstruction/tauRec
-// authors:     Stephanie Yuen, Benedict Winter, Will Davey
-// date:        2014-08-04
-//-----------------------------------------------------------------------------
-
-#include <vector>
 
 #include "TauPi0ClusterScaler.h"
+
 #include "xAODTau/TauJet.h"
 #include "xAODPFlow/PFO.h"
-#include "tauRecTools/ITauToolBase.h"
-#include "FourMomUtils/xAODP4Helpers.h"
 #include "xAODCaloEvent/CaloVertexedTopoCluster.h"
+#include "FourMomUtils/xAODP4Helpers.h"
+
+#include <vector>
+
 
-//-------------------------------------------------------------------------
-// Constructor
-//-------------------------------------------------------------------------
 
 TauPi0ClusterScaler::TauPi0ClusterScaler(const std::string& name) :
-    TauRecToolBase(name)
-{
+    TauRecToolBase(name) {
 }
 
-//-------------------------------------------------------------------------
-// Destructor
-//-------------------------------------------------------------------------
 
-TauPi0ClusterScaler::~TauPi0ClusterScaler()
-{
+
+StatusCode TauPi0ClusterScaler::executePi0ClusterScaler(xAOD::TauJet& tau, 
+                                                        xAOD::PFOContainer& neutralPFOContainer, 
+                                                        xAOD::PFOContainer& chargedPFOContainer) const {
+  // Clear vector of cell-based charged PFO Links, which are required when running xAOD 
+  tau.clearProtoChargedPFOLinks();
+ 
+  // Only run on 1-5 prong taus 
+  if (tau.nTracks() == 0 or tau.nTracks() >5) { 
+    return StatusCode::SUCCESS;
+  }
+ 
+  ATH_MSG_DEBUG("Process a new tau candidate, addreess " << &tau
+                  << ", e: " << tau.pt()
+                  << ", eta: " << tau.eta()
+                  << ", pt: " << tau.pt());
+
+  // Correct neutral PFO kinematics to point at tau vertex, this is needed since the 
+  // charged shower subtraction is performed several times for each neutral PFO
+  correctNeutralPFOs(tau, neutralPFOContainer);
+  
+  // Create new proto charged PFOs
+  createChargedPFOs(tau, chargedPFOContainer);
+  
+  // Associate hadronic PFOs to charged PFOs using extrapolated positions in HCal
+  associateHadronicToChargedPFOs(tau, chargedPFOContainer);
+  
+  // Associate charged PFOs to neutral PFOs using extrapolated positions in ECal
+  associateChargedToNeutralPFOs(tau, neutralPFOContainer);
+  
+  // Estimate charged PFO EM energy and subtract from neutral PFOs
+  subtractChargedEnergyFromNeutralPFOs(neutralPFOContainer);
+
+  for (xAOD::PFO* pfo : neutralPFOContainer) {
+    ATH_MSG_DEBUG("Final Neutral PFO, address " << pfo
+                  << ", e: " << pfo->pt()
+                  << ", eta: " << pfo->eta()
+                  << ", pt: " << pfo->pt());
+  }
+
+  return StatusCode::SUCCESS;
 }
 
-//______________________________________________________________________________
-StatusCode TauPi0ClusterScaler::executePi0ClusterScaler(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer, xAOD::PFOContainer& chargedPFOContainer) const
-{
-    // Clear vector of cell-based charged PFO Links. 
-    // Required when rerunning on xAOD level.
-    pTau.clearProtoChargedPFOLinks();
 
-    // Only run on 1-5 prong taus 
-    if (pTau.nTracks() == 0 or pTau.nTracks() >5 ) 
-        return StatusCode::SUCCESS;
-    
-    ATH_MSG_DEBUG("new tau pt = " << pTau.pt() 
-                  << ", eta = " << pTau.eta() 
-                  << ", phi = " << pTau.phi() 
-                  << ", nprongs = " << pTau.nTracks());
-
-    // reset neutral PFO kinematics (incase re-run on AOD)
-    resetNeutralPFOs(pTau, neutralPFOContainer);
-    // create new proto charged PFOs, extrapolate tracks, add to tau 
-    createChargedPFOs(pTau, chargedPFOContainer);
-    // associate hadronic PFOs to charged PFOs using extrapolated positions in HCal
-    associateHadronicToChargedPFOs(pTau, chargedPFOContainer);
-    // associate charged PFOs to neutral PFOs using extrapolated positions in ECal
-    associateChargedToNeutralPFOs(pTau, neutralPFOContainer);
-    // estimate charged PFO EM energy and subtract from neutral PFOs
-    subtractChargedEnergyFromNeutralPFOs(neutralPFOContainer);
-
-    ATH_MSG_DEBUG("End of TauPi0ClusterScaler::execute");
 
-    return StatusCode::SUCCESS;
+void TauPi0ClusterScaler::clearAssociatedParticleLinks(xAOD::PFOContainer& pfoContainer, xAOD::PFODetails::PFOParticleType type) const {
+  std::vector<ElementLink<xAOD::IParticleContainer>> emptyLinks;
+  
+  for (xAOD::PFO* pfo : pfoContainer) {
+    pfo->setAssociatedParticleLinks(type, emptyLinks);
+  }
 }
 
-//______________________________________________________________________________
-void TauPi0ClusterScaler::resetNeutralPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer) const
-{
-    // Set neutral PFO kinematics to vertex corrected cluster
-    ATH_MSG_DEBUG("Resetting neutral PFO kinematics");
-    for( auto pfo : neutralPFOContainer )
-    {
-        const xAOD::CaloCluster* cl = pfo->cluster(0);
-
-        // apply cluster vertex correction 
-        if(pTau.vertexLink().isValid()){
-            auto clcorr = xAOD::CaloVertexedTopoCluster(*cl, pTau.vertex()->position());
-            pfo->setP4(clcorr.pt(), clcorr.eta(), clcorr.phi(), 0.0);
-        }
-        else{
-            pfo->setP4(cl->pt(), cl->eta(), cl->phi(), 0.0);
-        }
-
-        ATH_MSG_DEBUG("Neutral PFO, ptr: " <<  cl
-                        << ", e: " << pfo->e() 
-                        << ", pt: " << pfo->pt()
-                        << ", eta: " << pfo->eta()
-                        << ", eta(unorr): " << cl->eta());
+
+
+void TauPi0ClusterScaler::correctNeutralPFOs(xAOD::TauJet& tau, xAOD::PFOContainer& neutralPFOContainer) const {
+  // FIXME: Loop over existing neutral PFOs, this may include those not associated to the tau candidate
+  // What if two taus have different vertex ??? Seems rare.
+  for (xAOD::PFO* pfo : neutralPFOContainer ) {
+    const xAOD::CaloCluster* cluster = pfo->cluster(0);
+
+    // apply cluster vertex correction 
+    if(tau.vertexLink().isValid()) {
+      auto clusterAtTauVertx = xAOD::CaloVertexedTopoCluster(*cluster, tau.vertex()->position());
+      pfo->setP4(clusterAtTauVertx.pt(), clusterAtTauVertx.eta(), clusterAtTauVertx.phi(), 0.0);
+    }
+    else{
+      pfo->setP4(cluster->pt(), cluster->eta(), cluster->phi(), 0.0);
     }
+    
+    ATH_MSG_DEBUG("Original Neutral PFO" 
+                  << ", e: " << cluster->pt() 
+                  << ", eta: " << cluster->eta() 
+                  << ", pt: " << cluster->pt());
+
+    ATH_MSG_DEBUG("Corrected Neutral PFO" 
+                  << ", e: " << pfo->pt()
+                  << ", eta: " << pfo->eta()
+                  << ", pt: " << pfo->pt());
+  }
 }
 
-//______________________________________________________________________________
-void TauPi0ClusterScaler::createChargedPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& cPFOContainer) const
-{
-    ATH_MSG_DEBUG("Creating charged PFOs");
-    for(auto tauTrackLink : pTau.tauTrackLinks(xAOD::TauJetParameters::classifiedCharged)){
-        if( not tauTrackLink.isValid() ){
-            ATH_MSG_WARNING("Invalid tauTrackLink");
-            continue;
-        }
-        const xAOD::TauTrack* tauTrack = (*tauTrackLink);
-        // create pfo
-        xAOD::PFO* chargedPFO = new xAOD::PFO();
-        cPFOContainer.push_back(chargedPFO);
-        // set properties
-        chargedPFO->setCharge(tauTrack->track()->charge());
-        chargedPFO->setP4(tauTrack->p4());
-        // link to track
-	if(not chargedPFO->setTrackLink((*tauTrackLink)->trackLinks().at(0)))
-	  ATH_MSG_WARNING("Could not add Track to PFO");
-	// now directly using tau track link from above
-        if(not chargedPFO->setAssociatedParticleLink(xAOD::PFODetails::CaloCluster,tauTrackLink))
-	  ATH_MSG_WARNING("Could not add TauTrack to PFO");
-
-	// link from tau
-        pTau.addProtoChargedPFOLink(ElementLink< xAOD::PFOContainer >
-                                    (chargedPFO, cPFOContainer));
+
+
+void TauPi0ClusterScaler::createChargedPFOs(xAOD::TauJet& tau, xAOD::PFOContainer& chargedPFOContainer) const {
+  for (auto tauTrackLink : tau.tauTrackLinks(xAOD::TauJetParameters::classifiedCharged)) {
+    if (not tauTrackLink.isValid()) {
+      ATH_MSG_WARNING("Invalid tauTrackLink");
+      continue;
+    }
+    const xAOD::TauTrack* tauTrack = (*tauTrackLink);
+    
+    // Create charged PFO
+    xAOD::PFO* chargedPFO = new xAOD::PFO();
+    chargedPFOContainer.push_back(chargedPFO);
+
+    // Set properties
+    chargedPFO->setCharge(tauTrack->track()->charge());
+    chargedPFO->setP4(tauTrack->p4());
+    
+    // Link to track
+    if (not chargedPFO->setTrackLink(tauTrack->trackLinks().at(0))) {
+      ATH_MSG_WARNING("Could not add Track to PFO");
+    }
+
+    // FIXME: Better to change xAOD::PFODetails::CaloCluster, it is confusing 
+    if (not chargedPFO->setAssociatedParticleLink(xAOD::PFODetails::CaloCluster, tauTrackLink)) {
+      ATH_MSG_WARNING("Could not add TauTrack to PFO");
     }
+
+    tau.addProtoChargedPFOLink(ElementLink<xAOD::PFOContainer>(chargedPFO, chargedPFOContainer));
+  }
+}
+
+
+
+float TauPi0ClusterScaler::getExtrapolatedPosition(const xAOD::PFO& chargedPFO, xAOD::TauJetParameters::TrackDetail detail) const {
+  float position = -10.0;
+  
+  // Obtain the associated TauTrack  
+  std::vector<const xAOD::IParticle*> tauTrackParticles;
+  // FIXME: The type here is confusing
+  chargedPFO.associatedParticles(xAOD::PFODetails::CaloCluster, tauTrackParticles);
+  if (tauTrackParticles.empty()) {
+    ATH_MSG_WARNING("ChargedPFO has no associated TauTrack, will set -10.0 to " << detail);
+    return -10.0;
+  }
+
+  const xAOD::TauTrack* tauTrack = dynamic_cast<const xAOD::TauTrack*>(tauTrackParticles.at(0));
+  if (not tauTrack) {
+    ATH_MSG_WARNING("Failed to retrieve TauTrack from ChargedPFO, will set -10.0 to " << detail);
+    return -10.0;
+  }
+  
+  if( not tauTrack->detail(detail, position)) {
+    ATH_MSG_WARNING("Failed to retrieve extrapolated chargedPFO position, will set -10.0 to " << detail);
+    return -10.0;
+  }
+  
+  return position; 
 }
 
-//______________________________________________________________________________
-void TauPi0ClusterScaler::associateHadronicToChargedPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& chargedPFOContainer) const
-{
-    ATH_MSG_DEBUG("Associating hadronic PFOs to charged PFOs");
-
-    // Will: I'm ashamed of this link-map, but its necessary until the 
-    // PFO EDM is improved to allow sequential addition of particle links
-    std::map< xAOD::PFO*,std::vector< ElementLink< xAOD::IParticleContainer > > > linkMap;
-    ATH_MSG_DEBUG("nHadPFOs: " << pTau.nHadronicPFOs() );
-    for( auto hadPFOLink : pTau.hadronicPFOLinks() ){
-        if( not hadPFOLink.isValid() ){
-            ATH_MSG_WARNING("Invalid hadPFOLink");
-            continue;
-        }
-        ATH_MSG_DEBUG("hadPFO " << hadPFOLink.index() 
-                      << ", eta: " << (*hadPFOLink)->eta() 
-                      << ", phi: " << (*hadPFOLink)->phi() );
-        xAOD::PFO* chargedPFOMatch = nullptr;
-        // assign hadPFO to closest extrapolated chargedPFO track within dR<0.4
-        float dRmin = 0.4; 
-        for( auto chargedPFO : chargedPFOContainer ){
-            // get extrapolated positions from tau-track
-            std::vector<const xAOD::IParticle*> tauTrackPcleVec;
-            chargedPFO->associatedParticles(xAOD::PFODetails::CaloCluster, tauTrackPcleVec);
-            if( tauTrackPcleVec.empty() ){
-                ATH_MSG_WARNING("ChargedPFO has no associated TauTrack");
-                continue;
-            }
-
-	    auto tauTrack = dynamic_cast<const xAOD::TauTrack*>(tauTrackPcleVec.at(0));
-	    if( not tauTrack ){
-                ATH_MSG_WARNING("Failed to retrieve TauTrack from ChargedPFO");
-                continue;
-	    }
-	    float etaCalo = -10.0;
-	    float phiCalo = -10.0;
-	    if( not tauTrack->detail(xAOD::TauJetParameters::CaloSamplingEtaHad, etaCalo))
-	      ATH_MSG_WARNING("Failed to retrieve extrapolated chargedPFO eta");
-	    if( not tauTrack->detail(xAOD::TauJetParameters::CaloSamplingPhiHad, phiCalo))
-	      ATH_MSG_WARNING("Failed to retrieve extrapolated chargedPFO phi");
-            // calculate dR (false means use eta instead of rapidity)
-            float dR = xAOD::P4Helpers::deltaR((**hadPFOLink), etaCalo, phiCalo, false);
-            ATH_MSG_DEBUG("chargedPFO, pt: " << chargedPFO->pt()
-			  << ", type: " << tauTrack->flagSet()
-			  << ", eta: " << etaCalo
-			  << ", phi: " << phiCalo
-			  << ", dR: " << dR );
-            if (dR < dRmin){
-                dRmin = dR;
-                chargedPFOMatch = chargedPFO;
-            }
-        }
-        if( not chargedPFOMatch ){
-            ATH_MSG_DEBUG("Unassigned Hadronic PFO");
-            continue; 
-        }
-
-        // create link to had PFO (add to chargedPFO later)
-        ElementLink< xAOD::IParticleContainer > newHadLink;
-        newHadLink.toPersistent();
-        newHadLink.resetWithKeyAndIndex( hadPFOLink.persKey(), hadPFOLink.persIndex() );
-        if (not newHadLink.isValid()){
-            ATH_MSG_WARNING("Created an invalid element link to xAOD::PFO");
-            continue;
-        }
-
-        // temporarily store in linkMap since we can't sequentially add to chargedPFOMatch
-        if( not linkMap.count(chargedPFOMatch) )
-            linkMap[chargedPFOMatch] = std::vector< ElementLink< xAOD::IParticleContainer > >();
-        linkMap[chargedPFOMatch].push_back(newHadLink);
+
+
+void TauPi0ClusterScaler::associateHadronicToChargedPFOs(xAOD::TauJet& tau, xAOD::PFOContainer& chargedPFOContainer) const {
+  std::map< xAOD::PFO*,std::vector< ElementLink< xAOD::IParticleContainer > > > linkMap;
+ 
+  // For each hadronic PFO, associate it to the cloest charged PFO. It assumes that one hadronic PFO comes from at 
+  // most one charged PFO.
+  for (auto hadPFOLink : tau.hadronicPFOLinks()) {
+    if (not hadPFOLink.isValid()) {
+      ATH_MSG_WARNING("Invalid hadPFOLink");
+      continue;
+    }
+    ATH_MSG_DEBUG("hadPFO " << hadPFOLink.index() << ", eta: " << (*hadPFOLink)->eta() << ", phi: " << (*hadPFOLink)->phi() );
+    
+    // Assign hadPFO to closest extrapolated chargedPFO track within dR < 0.4
+    xAOD::PFO* chargedPFOMatch = nullptr;
+    float dRmin = 0.4;
+    
+    // FIXME: This loops over the existing charged PFO container, and could contain PFO not associated to this tau.
+    // It could make the association depending on the order of the tau candidate, but the point is that 
+    // hadronic PFO in one tau candidate is unlikely to be associated to charged PFO in another tau candidate
+    for (xAOD::PFO* chargedPFO : chargedPFOContainer) {
+      
+      float etaCalo = getExtrapolatedPosition(*chargedPFO, xAOD::TauJetParameters::CaloSamplingEtaHad);
+      float phiCalo = getExtrapolatedPosition(*chargedPFO, xAOD::TauJetParameters::CaloSamplingPhiHad);
+    
+      float dR = xAOD::P4Helpers::deltaR((**hadPFOLink), etaCalo, phiCalo, false);
+      if (dR < dRmin){
+        dRmin = dR;
+        chargedPFOMatch = chargedPFO;
+      }
+    }
+    
+    if( not chargedPFOMatch ){
+      ATH_MSG_DEBUG("Unassigned Hadronic PFO");
+      continue; 
     }
 
-    // finally set hadronic PFO links (note: we use existing TauShot enum)
-    for( auto [k,v] : linkMap ){
-        if(not k->setAssociatedParticleLinks(xAOD::PFODetails::TauShot, v))
-            ATH_MSG_WARNING("Couldn't add hadronic PFO links to charged PFO!");
+    // create link to had PFO (add to chargedPFO later)
+    ElementLink< xAOD::IParticleContainer > newHadLink;
+    newHadLink.toPersistent();
+    newHadLink.resetWithKeyAndIndex( hadPFOLink.persKey(), hadPFOLink.persIndex() );
+    if (not newHadLink.isValid()){
+        ATH_MSG_WARNING("Created an invalid element link to xAOD::PFO");
+        continue;
+    }
+    
+    if( not linkMap.count(chargedPFOMatch) ) {
+        linkMap[chargedPFOMatch] = std::vector< ElementLink< xAOD::IParticleContainer > >();
     }
+
+    linkMap[chargedPFOMatch].push_back(newHadLink);
+  }
+
+  // finally set hadronic PFO links (note: we use existing TauShot enum)
+  for (auto [k,v] : linkMap) {
+    if(not k->setAssociatedParticleLinks(xAOD::PFODetails::TauShot, v))
+      ATH_MSG_WARNING("Couldn't add hadronic PFO links to charged PFO!");
+  }
 }
 
-//______________________________________________________________________________
-void TauPi0ClusterScaler::associateChargedToNeutralPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& neutralPFOContainer) const
-{
-    ATH_MSG_DEBUG("Associating charged PFOs to neutral PFOs");
-    // Will: I'm ashamed of this link-map, but its necessary until the 
-    // PFO EDM is improved to allow sequential addition of particle links
-    std::map< xAOD::PFO*,std::vector< ElementLink< xAOD::IParticleContainer > > > linkMap;
-    ATH_MSG_DEBUG("nChargedPFOs: " << pTau.nProtoChargedPFOs() );
-    for( auto chargedPFOLink : pTau.protoChargedPFOLinks() ){
-        if( not chargedPFOLink.isValid() ){
-            ATH_MSG_WARNING("Invalid protoChargedPFOLink");
-            continue;
-        }
-        const xAOD::PFO* chargedPFO = (*chargedPFOLink);
-        
-        // get extrapolated positions from tau-track
-        std::vector<const xAOD::IParticle*> tauTrackPcleVec;
-        chargedPFO->associatedParticles(xAOD::PFODetails::CaloCluster, tauTrackPcleVec);
-        if( tauTrackPcleVec.empty() ){
-            ATH_MSG_WARNING("ChargedPFO has no associated TauTrack");
-            continue;
-        }
-        auto tauTrack = dynamic_cast<const xAOD::TauTrack*>(tauTrackPcleVec.at(0));
-        if( not tauTrack ){
-            ATH_MSG_WARNING("Failed to retrieve TauTrack from ChargedPFO");
-            continue;
-        } 
-        float etaCalo = -10.0;
-        float phiCalo = -10.0;
-        if( not tauTrack->detail(xAOD::TauJetParameters::CaloSamplingEtaEM, etaCalo))
-            ATH_MSG_WARNING("Failed to retrieve extrapolated chargedPFO eta");
-        if( not tauTrack->detail(xAOD::TauJetParameters::CaloSamplingPhiEM, phiCalo))
-            ATH_MSG_WARNING("Failed to retrieve extrapolated chargedPFO phi");
-        ATH_MSG_DEBUG("chargedPFO " << chargedPFOLink.index() 
-                      << ", eta: " << etaCalo 
-                      << ", phi: " << phiCalo );
+
+
+void TauPi0ClusterScaler::associateChargedToNeutralPFOs(xAOD::TauJet& tau, xAOD::PFOContainer& neutralPFOContainer) const {
+  std::map< xAOD::PFO*,std::vector< ElementLink< xAOD::IParticleContainer > > > linkMap;
+  for (auto chargedPFOLink : tau.protoChargedPFOLinks()) {
+    if (not chargedPFOLink.isValid()) {
+      ATH_MSG_WARNING("Invalid protoChargedPFOLink");
+      continue;
+    }
+    const xAOD::PFO* chargedPFO = (*chargedPFOLink);
+    
+    float etaCalo = getExtrapolatedPosition(*chargedPFO, xAOD::TauJetParameters::CaloSamplingEtaEM);
+    float phiCalo = getExtrapolatedPosition(*chargedPFO, xAOD::TauJetParameters::CaloSamplingPhiEM);
+    
+    // Assign extrapolated chargedPFO to closest neutralPFO within dR<0.04
+    xAOD::PFO* neutralPFOMatch = nullptr;
+    
+    // FIXME: This loops over the existing neutral PFO container, and could contain PFO not associated to this tau.
+    // It could make the association depending on the order of the tau candidate. but the point is that 
+    // charged PFO in one tau candidate is unlikely to be associated to the neutral PFO in another tau candidate
+    float dRmin = 0.04; 
+    for (xAOD::PFO* neutralPFO : neutralPFOContainer) {
+      // FIXME: cluster p4 is not corrected to the tau axis 
+      float dR = xAOD::P4Helpers::deltaR((*neutralPFO->cluster(0)), etaCalo, phiCalo, false);
+      if (dR < dRmin){
+        dRmin = dR;
+        neutralPFOMatch = neutralPFO;
+      }
+    }
+    
+    if (not neutralPFOMatch){
+      ATH_MSG_DEBUG("Unassigned Charged PFO");
+      continue; 
+    }
+
+    // create link to charged PFO 
+    ElementLink<xAOD::IParticleContainer> newChargedLink;
+    newChargedLink.toPersistent();
+    newChargedLink.resetWithKeyAndIndex(chargedPFOLink.persKey(), chargedPFOLink.persIndex());
+    if (not newChargedLink.isValid()){
+      ATH_MSG_WARNING("Created an invalid element link to xAOD::PFO");
+      continue;
+    }
         
-        // assign extrapolated chargedPFO to closest neutralPFO within dR<0.04
-        xAOD::PFO* neutralPFOMatch = nullptr;
-        float dRmin = 0.04; 
-        for( auto neutralPFO : neutralPFOContainer ){
-            // calculate dR (false means use eta instead of rapidity)
-            float dR = xAOD::P4Helpers::deltaR((*neutralPFO->cluster(0)), etaCalo, phiCalo, false);
-            ATH_MSG_DEBUG("neutralPFO, eta: " << neutralPFO->cluster(0)->eta()
-                            << ", phi: " << neutralPFO->cluster(0)->phi()
-                            << ", dR: " << dR );
-            if (dR < dRmin){
-                dRmin = dR;
-                neutralPFOMatch = neutralPFO;
-            }
-        }
-        if( not neutralPFOMatch ){
-            ATH_MSG_DEBUG("Unassigned Charged PFO");
-            continue; 
-        }
-        else ATH_MSG_DEBUG("Assigned Charged PFO");
-
-        // create link to charged PFO 
-        ElementLink< xAOD::IParticleContainer > newChargedLink;
-        newChargedLink.toPersistent();
-        newChargedLink.resetWithKeyAndIndex( chargedPFOLink.persKey(), chargedPFOLink.persIndex() );
-        if (not newChargedLink.isValid()){
-            ATH_MSG_WARNING("Created an invalid element link to xAOD::PFO");
-            continue;
-        }
-
-        // temporarily store in linkMap since we can't sequentially add to neutralPFOMatch
-        if( not linkMap.count(neutralPFOMatch) )
-            linkMap[neutralPFOMatch] = std::vector< ElementLink< xAOD::IParticleContainer > >();
-        linkMap[neutralPFOMatch].push_back(newChargedLink);
+    if( not linkMap.count(neutralPFOMatch) ) {
+      linkMap[neutralPFOMatch] = std::vector< ElementLink< xAOD::IParticleContainer > >();
     }
 
-    // finally set charged PFO links
-    for( auto [k,v] : linkMap ){
-        if(not k->setAssociatedParticleLinks(xAOD::PFODetails::Track,v))
-            ATH_MSG_WARNING("Couldn't add charged PFO links to neutral PFO!");
+    linkMap[neutralPFOMatch].push_back(newChargedLink);
+  }
+    
+  // Finally set charged PFO links,
+  for (auto [k,v] : linkMap) {
+    if(not k->setAssociatedParticleLinks(xAOD::PFODetails::Track,v)) {
+      ATH_MSG_WARNING("Couldn't add charged PFO links to neutral PFO!");
     }
+  }
 }
 
-//______________________________________________________________________________
-void TauPi0ClusterScaler::subtractChargedEnergyFromNeutralPFOs(xAOD::PFOContainer& neutralPFOContainer) const
-{
-    ATH_MSG_DEBUG("Subtracting charged energy from neutral PFOs");
-    for( auto neutralPFO : neutralPFOContainer )
-    {
-        // get associated charged PFOs
-        std::vector<const xAOD::IParticle*> chargedPFOs;
-        neutralPFO->associatedParticles(xAOD::PFODetails::Track, chargedPFOs);
-        if( chargedPFOs.empty() ){
-            ATH_MSG_DEBUG("No associated charged to subtract"); 
-            continue;
-        }
-        ATH_MSG_DEBUG("Associated charged PFOs: " << chargedPFOs.size() );
-
-        // estimate charged EM energy and subtract
-        float neutralEnergy = neutralPFO->e();
-        for( auto chargedPcle : chargedPFOs )
-        {
-            // since PFO stores element links as IParticle, need to cast back
-            const xAOD::PFO* chargedPFO = dynamic_cast<const xAOD::PFO*>(chargedPcle);
-            if( not chargedPFO ){
-                ATH_MSG_WARNING("Failed to downcast IParticle ptr: " << chargedPcle << ", to ChargedPFO! " );
-                continue;
-            }
-            float chargedEMEnergy = chargedPFO->e();
-            std::vector<const xAOD::IParticle*> hadPFOs;
-            chargedPFO->associatedParticles(xAOD::PFODetails::TauShot, hadPFOs);
-            for( auto hadPFO : hadPFOs )
-                chargedEMEnergy -= hadPFO->e();
-            
-            if( chargedEMEnergy < 0.0 ) chargedEMEnergy = 0.0;
-            neutralEnergy -= chargedEMEnergy;
-            ATH_MSG_DEBUG("Subtracting charged energy: " << chargedEMEnergy );
-        } 
-        float neutralPt = neutralEnergy / std::cosh(neutralPFO->eta());
-        if(neutralPt <= 100.) neutralPt = 100.0;
-        
-        ATH_MSG_DEBUG("Neutral PFO pt, orig: " << neutralPFO->pt() << "  new: " << neutralPt); 
-        neutralPFO->setP4(neutralPt , neutralPFO->eta(), neutralPFO->phi(), neutralPFO->m());
+
+
+void TauPi0ClusterScaler::subtractChargedEnergyFromNeutralPFOs(xAOD::PFOContainer& neutralPFOContainer) const {
+  // FIXME: It loops all the exsiting PFOs, will make the PFO kinematic depend on the current 
+  // tau candidate. The kinematics written to the xAOD is the one for the last tau candidate.
+  
+  for (xAOD::PFO* neutralPFO : neutralPFOContainer) {
+    // Get associated charged PFOs
+    std::vector<const xAOD::IParticle*> chargedPFOs;
+    neutralPFO->associatedParticles(xAOD::PFODetails::Track, chargedPFOs);
+    if (chargedPFOs.empty()) {
+      ATH_MSG_DEBUG("No associated charged to subtract"); 
+      continue;
     }
+    ATH_MSG_DEBUG("Associated charged PFOs: " << chargedPFOs.size() );
+
+    // estimate charged EM energy and subtract
+    float neutralEnergy = neutralPFO->e();
+    for (const xAOD::IParticle* chargedParticle : chargedPFOs) {
+      const xAOD::PFO* chargedPFO = dynamic_cast<const xAOD::PFO*>(chargedParticle);
+      if( not chargedPFO ){
+          ATH_MSG_WARNING("Failed to downcast IParticle ptr: " << chargedParticle << ", to ChargedPFO! " );
+          continue;
+      }
+      float chargedEMEnergy = chargedPFO->e();
+      
+      std::vector<const xAOD::IParticle*> hadPFOs;
+      chargedPFO->associatedParticles(xAOD::PFODetails::TauShot, hadPFOs);
+      for (auto hadPFO : hadPFOs) {
+          chargedEMEnergy -= hadPFO->e();
+      }
+
+      if( chargedEMEnergy < 0.0 ) chargedEMEnergy = 0.0;
+      neutralEnergy -= chargedEMEnergy;
+      ATH_MSG_DEBUG("Subtracting charged energy: " << chargedEMEnergy );
+    } 
+    float neutralPt = neutralEnergy / std::cosh(neutralPFO->eta());
+    if (neutralPt <= 100.) neutralPt = 100.0;
+    
+    ATH_MSG_DEBUG("Neutral PFO pt, original: " << neutralPFO->pt() << "  subtracted: " << neutralPt); 
+    neutralPFO->setP4(neutralPt , neutralPFO->eta(), neutralPFO->phi(), neutralPFO->m());
+  }
 }
diff --git a/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.h b/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.h
index bb05d81d099398cdb304e1982152bafd45383e1c..90e9083ea8336bf9a178c12ce69eafe60ab8fd6b 100644
--- a/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.h
+++ b/Reconstruction/tauRecTools/src/TauPi0ClusterScaler.h
@@ -19,51 +19,39 @@
  * @author Will Davey <will.davey@cern.ch> 
  */
 
-//namespace Trk {
-//    class IParticleCaloExtensionTool;
-//}
 class TauPi0ClusterScaler : virtual public TauRecToolBase {
-public:
-    TauPi0ClusterScaler(const std::string& name);
-    ASG_TOOL_CLASS2(TauPi0ClusterScaler, TauRecToolBase, ITauToolBase)
-    virtual ~TauPi0ClusterScaler();
-
-    virtual StatusCode executePi0ClusterScaler(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer, xAOD::PFOContainer& pChargedPFOContainer) const override; 
-
-private:
-
-    /** @brief tool handles */
-    //ToolHandle<Trk::IParticleCaloExtensionTool> m_caloExtensionTool;
 
-    /** @brief reset neutral PFO kinematics (for AOD running) */
-    void resetNeutralPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const;
-
-    /** @brief create charged PFOs */
-    void createChargedPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pChargedPFOContainer) const;
-
-    /** @brief extrapolate charged PFO tracks to EM and HAD layers */
-    //void extrapolateChargedPFOs(xAOD::TauJet& pTau);
-    
-    /** @brief associate hadronic PFOs to charged PFOs */
-    void associateHadronicToChargedPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pChargedPFOContainer) const;
-    
-    /** @brief associate charged PFOs to neutral PFOs */
-    void associateChargedToNeutralPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const;
-    
-    /** @brief associate charged PFOs to neutral PFOs */
-    void subtractChargedEnergyFromNeutralPFOs(xAOD::PFOContainer& pNeutralPFOContainer) const;
+public:
+  
+  ASG_TOOL_CLASS2(TauPi0ClusterScaler, TauRecToolBase, ITauToolBase)
 
-    /** @brief sets of EM/Had samplings for track extrapolation */
-    //std::set<CaloSampling::CaloSample> m_EMSamplings;
-    //std::set<CaloSampling::CaloSample> m_HadSamplings;
+  TauPi0ClusterScaler(const std::string& name);
+  virtual ~TauPi0ClusterScaler() = default;
 
-    /** dodgy re-purposed PFOAttributes enums */
-    //xAOD::PFODetails::PFOAttributes ETAECAL; 
-    //xAOD::PFODetails::PFOAttributes PHIECAL;
-    //xAOD::PFODetails::PFOAttributes ETAHCAL;
-    //xAOD::PFODetails::PFOAttributes PHIHCAL;
+  virtual StatusCode executePi0ClusterScaler(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer, xAOD::PFOContainer& pChargedPFOContainer) const override; 
 
+private:
+  
+  /** @brief Clear accosicated partcle links for the pfo container */
+  void clearAssociatedParticleLinks(xAOD::PFOContainer& pfoContainer, xAOD::PFODetails::PFOParticleType type) const;
+
+  /** @brief Get extrapolated position to the CAL */
+  float getExtrapolatedPosition(const xAOD::PFO& chargedPFO, xAOD::TauJetParameters::TrackDetail detail) const; 
+
+  /** @brief Correct neutral PFO kinematics to point at the current tau vertex */
+  void correctNeutralPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const;
+
+  /** @brief create charged PFOs */
+  void createChargedPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pChargedPFOContainer) const;
+
+  /** @brief associate hadronic PFOs to charged PFOs */
+  void associateHadronicToChargedPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pChargedPFOContainer) const;
+  
+  /** @brief associate charged PFOs to neutral PFOs */
+  void associateChargedToNeutralPFOs(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const;
+  
+  /** @brief associate charged PFOs to neutral PFOs */
+  void subtractChargedEnergyFromNeutralPFOs(xAOD::PFOContainer& pNeutralPFOContainer) const;
 };
 
 #endif  /* TAUPI0CLUSTERSCALER_H */
-
diff --git a/Reconstruction/tauRecTools/src/TauPi0CreateROI.cxx b/Reconstruction/tauRecTools/src/TauPi0CreateROI.cxx
index 586d7b68d0d155b12ec07c856b90803de7b40f8d..e2cf019bdbd38632ba65797eef143136047b62d8 100644
--- a/Reconstruction/tauRecTools/src/TauPi0CreateROI.cxx
+++ b/Reconstruction/tauRecTools/src/TauPi0CreateROI.cxx
@@ -3,33 +3,20 @@
 */
 
 #ifndef XAOD_ANALYSIS
-//-----------------------------------------------------------------------------
-// file:        TauPi0CreateROI.cxx
-// package:     Reconstruction/tauEvent
-// authors:     Will Davey, Benedict Winter, Stephanie Yuen
-// date:        2012-10-09
-//-----------------------------------------------------------------------------
 
-#include "CaloUtils/CaloCellList.h"
 #include "TauPi0CreateROI.h"
 
+#include "CaloUtils/CaloCellList.h"
+
 #include <boost/scoped_ptr.hpp>
 
-//-------------------------------------------------------------------------
-// Constructor
-//-------------------------------------------------------------------------
+
 
 TauPi0CreateROI::TauPi0CreateROI(const std::string& name) :
-     TauRecToolBase(name)
-{
+     TauRecToolBase(name) {
 }
    
-//-------------------------------------------------------------------------
-// Destructor
-//-------------------------------------------------------------------------
 
-TauPi0CreateROI::~TauPi0CreateROI() {
-}
 
 StatusCode TauPi0CreateROI::initialize() {
     
@@ -38,58 +25,45 @@ StatusCode TauPi0CreateROI::initialize() {
     return StatusCode::SUCCESS;
 }
 
-//______________________________________________________________________________
-StatusCode TauPi0CreateROI::executePi0CreateROI(xAOD::TauJet& pTau, CaloCellContainer& pPi0CellContainer, boost::dynamic_bitset<>& addedCellsMap) const {
 
-    //---------------------------------------------------------------------
-    // only run on 1-5 prong taus 
-    //---------------------------------------------------------------------
-    if (pTau.nTracks() == 0 || pTau.nTracks() >5 ) {
-        return StatusCode::SUCCESS;
-    }
-    ATH_MSG_DEBUG("new tau. \tpt = " << pTau.pt() << "\teta = " << pTau.eta() << "\tphi = " << pTau.phi() << "\tnprongs = " << pTau.nTracks());
-
-    //---------------------------------------------------------------------
-    // retrieve cells around tau 
-    //---------------------------------------------------------------------
-    // get all calo cell container
-    SG::ReadHandle<CaloCellContainer> caloCellInHandle( m_caloCellInputContainer );
-    if (!caloCellInHandle.isValid()) {
-      ATH_MSG_ERROR ("Could not retrieve HiveDataObj with key " << caloCellInHandle.key());
-      return StatusCode::FAILURE;
-    }
-    const CaloCellContainer *pCellContainer = NULL;
-    pCellContainer = caloCellInHandle.cptr();
-    
-    // get only EM cells within dR<0.4
-    std::vector<CaloCell_ID::SUBCALO> emSubCaloBlocks;
-    emSubCaloBlocks.push_back(CaloCell_ID::LAREM);
-    boost::scoped_ptr<CaloCellList> pCells(new CaloCellList(pCellContainer,emSubCaloBlocks)); 
-    pCells->select(pTau.eta(), pTau.phi(), 0.4); // TODO: change hardcoded 0.4 to tau cone variable, (or func. from TauJet)?
-
-    //---------------------------------------------------------------------
-    // Put Ecal cells in output container
-    //---------------------------------------------------------------------
-
-    CaloCellList::list_iterator cellItr(pCells->begin()), cellItrE(pCells->end());
-    for(; cellItr != cellItrE; ++cellItr) {
-        const CaloCell* cell = (*cellItr);
-
-        // only keep cells that are in Ecal (PS, EM1, EM2 and EM3, both barrel and endcap).
-        int samp = cell->caloDDE()->getSampling();
-        if(samp>7) continue;
-
-        // Store cell in output container
-        const IdentifierHash cellHash = cell->caloDDE()->calo_hash();
-
-	if(!addedCellsMap.test(cellHash)) {
-            CaloCell* copyCell = cell->clone();
-            pPi0CellContainer.push_back(copyCell);
-	    addedCellsMap.set(cellHash);
-        }
-    }
 
+StatusCode TauPi0CreateROI::executePi0CreateROI(xAOD::TauJet& tau, CaloCellContainer& pi0CellContainer, boost::dynamic_bitset<>& addedCellsMap) const {
+  // only run on 1-5 prong taus 
+  if (tau.nTracks() == 0 || tau.nTracks() >5 ) {
     return StatusCode::SUCCESS;
+  }
+
+  SG::ReadHandle<CaloCellContainer> caloCellInHandle( m_caloCellInputContainer );
+  if (!caloCellInHandle.isValid()) {
+    ATH_MSG_ERROR ("Could not retrieve HiveDataObj with key " << caloCellInHandle.key());
+    return StatusCode::FAILURE;
+  }
+  const CaloCellContainer *cellContainer = caloCellInHandle.cptr();;
+  
+  // get only EM cells within dR < 0.4
+  // TODO: change hardcoded 0.4 to meaningful variable
+  std::vector<CaloCell_ID::SUBCALO> emSubCaloBlocks;
+  emSubCaloBlocks.push_back(CaloCell_ID::LAREM);
+  boost::scoped_ptr<CaloCellList> cellList(new CaloCellList(cellContainer,emSubCaloBlocks)); 
+  // FIXME: tau p4 is corrected to point at tau vertex, but the cells are not
+  cellList->select(tau.eta(), tau.phi(), 0.4);
+
+  for (const CaloCell* cell : *cellList) {
+    // only keep cells that are in Ecal (PS, EM1, EM2 and EM3, both barrel and endcap).
+    int sampling = cell->caloDDE()->getSampling();
+    if (sampling > 7) continue;
+
+    // Store cell in output container
+    const IdentifierHash cellHash = cell->caloDDE()->calo_hash();
+
+    if (!addedCellsMap.test(cellHash)) {
+      CaloCell* newCell = cell->clone();
+      pi0CellContainer.push_back(newCell);
+      addedCellsMap.set(cellHash);
+    }
+  }
+
+  return StatusCode::SUCCESS;
 }
 
 #endif
diff --git a/Reconstruction/tauRecTools/src/TauPi0CreateROI.h b/Reconstruction/tauRecTools/src/TauPi0CreateROI.h
index ee9e0ea640ae615e95a55282a14045012bd9fdf4..e3c6b54f1aaf3d86fc351139228d26e521742865 100644
--- a/Reconstruction/tauRecTools/src/TauPi0CreateROI.h
+++ b/Reconstruction/tauRecTools/src/TauPi0CreateROI.h
@@ -18,7 +18,7 @@
 #include "xAODTau/TauJet.h"
 
 /**
- * @brief Create ROIs for the Pi0 finder.
+ * @brief Find the cells used to create pi0 cluster
  * 
  * @author Will Davey <will.davey@cern.ch> 
  * @author Benedict Winter <benedict.tobias.winter@cern.ch> 
@@ -26,17 +26,21 @@
  */
 
 class TauPi0CreateROI : public TauRecToolBase {
+
 public:
-    TauPi0CreateROI(const std::string& name);
-    ASG_TOOL_CLASS2(TauPi0CreateROI, TauRecToolBase, ITauToolBase);
-    virtual ~TauPi0CreateROI();
 
-    virtual StatusCode initialize() override;
-    virtual StatusCode executePi0CreateROI(xAOD::TauJet& pTau, CaloCellContainer& Pi0CellContainer, boost::dynamic_bitset<>& map) const override;
+  ASG_TOOL_CLASS2(TauPi0CreateROI, TauRecToolBase, ITauToolBase);
+  
+  TauPi0CreateROI(const std::string& name);
+  virtual ~TauPi0CreateROI() = default;
+
+  virtual StatusCode initialize() override;
+  virtual StatusCode executePi0CreateROI(xAOD::TauJet& pTau, CaloCellContainer& Pi0CellContainer, boost::dynamic_bitset<>& map) const override;
 
 private:
-    SG::ReadHandleKey<CaloCellContainer> m_caloCellInputContainer{this,"Key_caloCellInputContainer", "AllCalo", "input vertex container key"};
+    
+  SG::ReadHandleKey<CaloCellContainer> m_caloCellInputContainer{this,"Key_caloCellInputContainer", "AllCalo", "input vertex container key"};
+
 };
 
 #endif	/* TAUPI0CREATEROI_H */
-
diff --git a/Reconstruction/tauRecTools/src/TauShotFinder.cxx b/Reconstruction/tauRecTools/src/TauShotFinder.cxx
index 1ebb673a60f559186064b3904bd4600e691c6f5b..04c9964045d9a5f5f3154ff07d0d98a51f4f0567 100644
--- a/Reconstruction/tauRecTools/src/TauShotFinder.cxx
+++ b/Reconstruction/tauRecTools/src/TauShotFinder.cxx
@@ -3,346 +3,401 @@
 */
 
 #ifndef XAOD_ANALYSIS
-//-----------------------------------------------------------------------------
-// file:        TauShotFinder.cxx
-// package:     Reconstruction/tauRec
-// authors:     Will Davey, Benedict Winter, Stephanie Yuen
-// date:        2013-05-22
-//-----------------------------------------------------------------------------
 
-#include <boost/scoped_ptr.hpp>
+#include "TauShotFinder.h"
+#include "TauShotVariableHelpers.h"
 
 #include "xAODCaloEvent/CaloClusterContainer.h"
 #include "xAODCaloEvent/CaloClusterKineHelper.h"
 #include "CaloUtils/CaloClusterStoreHelper.h"
 #include "CaloUtils/CaloCellList.h"
-#include "TauShotFinder.h"
-#include "TauShotVariableHelpers.h"
 #include "xAODPFlow/PFOContainer.h"
 #include "xAODPFlow/PFOAuxContainer.h"
 #include "xAODPFlow/PFO.h"
 
-//-------------------------------------------------------------------------
-// Constructor
-//-------------------------------------------------------------------------
+#include <boost/scoped_ptr.hpp>
+
+
 
 TauShotFinder::TauShotFinder(const std::string& name) :
     TauRecToolBase(name) {
 }
 
-//-------------------------------------------------------------------------
-// Destructor
-//-------------------------------------------------------------------------
 
-TauShotFinder::~TauShotFinder() {
-}
 
-//______________________________________________________________________________
 StatusCode TauShotFinder::initialize() {
-    
-    // retrieve tools
-    ATH_MSG_DEBUG( "Retrieving tools" );
-    CHECK( m_caloWeightTool.retrieve() );
+  
+  ATH_CHECK(m_caloWeightTool.retrieve());
+  ATH_CHECK(m_caloCellInputContainer.initialize());
+  ATH_CHECK(detStore()->retrieve (m_calo_id, "CaloCell_ID"));
 
-    ATH_CHECK( m_caloCellInputContainer.initialize() );
+  return StatusCode::SUCCESS;
+}
 
-    // initialize calo cell geo
-    ATH_CHECK( detStore()->retrieve (m_calo_id, "CaloCell_ID") );
 
-    return StatusCode::SUCCESS;
-}
 
-//______________________________________________________________________________
-StatusCode TauShotFinder::executeShotFinder(xAOD::TauJet& pTau, xAOD::CaloClusterContainer& tauShotClusterContainer,
-					    xAOD::PFOContainer& tauShotPFOContainer) const {
+StatusCode TauShotFinder::executeShotFinder(xAOD::TauJet& tau, xAOD::CaloClusterContainer& shotClusterContainer,
+					    xAOD::PFOContainer& shotPFOContainer) const {
 
-    ATH_MSG_DEBUG("execute");
-    // Any tau needs to have shot PFO vectors. Set empty vectors before nTrack cut
-    std::vector<ElementLink<xAOD::PFOContainer> > empty;
-    pTau.setShotPFOLinks(empty);
+  // Any tau needs to have shot PFO vectors. Set empty vectors before nTrack cut
+  std::vector<ElementLink<xAOD::PFOContainer>> empty;
+  tau.setShotPFOLinks(empty);
+  
+  // Only run on 1-5 prong taus 
+  if (tau.nTracks() == 0 || tau.nTracks() >5 ) {
+     return StatusCode::SUCCESS;
+  }
+    
+  SG::ReadHandle<CaloCellContainer> caloCellInHandle( m_caloCellInputContainer );
+  if (!caloCellInHandle.isValid()) {
+    ATH_MSG_ERROR ("Could not retrieve HiveDataObj with key " << caloCellInHandle.key());
+    return StatusCode::FAILURE;
+  }
+  const CaloCellContainer *cellContainer = caloCellInHandle.cptr();;
+    
+  // Select seed cells:
+  // -- dR < 0.4, EM1, pt > 100
+  // -- largest pt among the neighbours in eta direction 
+  // -- no other seed cell as neighbour in eta direction 
+  std::vector<const CaloCell*> seedCells = selectSeedCells(tau, *cellContainer);
+  ATH_MSG_DEBUG("seedCells.size() = " << seedCells.size());
     
-    //---------------------------------------------------------------------
-    // only run shower subtraction on 1-5 prong taus 
-    //---------------------------------------------------------------------
-    if (pTau.nTracks() == 0 || pTau.nTracks() >5 ) {
-       return StatusCode::SUCCESS;
+  // Construt shot by merging neighbour cells in phi direction 
+  while (seedCells.size()) {
+    // Find the neighbour in phi direction, and choose the one with highest pt
+    const CaloCell* cell = seedCells.front(); 
+    const CaloCell* phiNeigCell = getPhiNeighbour(*cell, seedCells);
+    
+    // Construct shot PFO candidate
+    xAOD::PFO* shot = new xAOD::PFO();
+    shotPFOContainer.push_back(shot);
+
+    // -- Construct the shot cluster 
+    xAOD::CaloCluster* shotCluster = createShotCluster(cell, phiNeigCell, *cellContainer);
+    shotClusterContainer.push_back(shotCluster);
+   
+    ElementLink<xAOD::CaloClusterContainer> clusElementLink;
+    clusElementLink.toContainedElement( shotClusterContainer, shotCluster );
+    shot->setClusterLink( clusElementLink );
+   
+    // -- Calculate the four momentum
+    // TODO: simplify the calculation 
+    if (phiNeigCell) {
+      // interpolate position
+      double dPhi = TVector2::Phi_mpi_pi( phiNeigCell->phi() - cell->phi());
+      double ratio = phiNeigCell->pt()*m_caloWeightTool->wtCell(phiNeigCell)/(cell->pt()*m_caloWeightTool->wtCell(cell) + phiNeigCell->pt()*m_caloWeightTool->wtCell(phiNeigCell));
+      float phi = cell->phi()+dPhi*ratio;
+      float pt = cell->pt()*m_caloWeightTool->wtCell(cell)+phiNeigCell->pt()*m_caloWeightTool->wtCell(phiNeigCell);
+
+      shot->setP4( (float) pt, (float) cell->eta(), (float) phi, (float) cell->m());
     }
-    //---------------------------------------------------------------------
-    // retrieve cells around tau 
-    //---------------------------------------------------------------------
-    // get all calo cell container
-    SG::ReadHandle<CaloCellContainer> caloCellInHandle( m_caloCellInputContainer );
-    if (!caloCellInHandle.isValid()) {
-      ATH_MSG_ERROR ("Could not retrieve HiveDataObj with key " << caloCellInHandle.key());
-      return StatusCode::FAILURE;
+    else {
+      shot->setP4( (float) cell->pt()*m_caloWeightTool->wtCell(cell), (float) cell->eta(), (float) cell->phi(), (float) cell->m()); 
     }
-    const CaloCellContainer *pCellContainer = caloCellInHandle.cptr();;
+
+    // -- Set the Attribute 
+    shot->setBDTPi0Score(-9999.);
+    shot->setCharge(0);
+    shot->setCenterMag(0.0);
     
-    // get only EM cells within dR<0.4
-    std::vector<CaloCell_ID::SUBCALO> emSubCaloBlocks;
-    emSubCaloBlocks.push_back(CaloCell_ID::LAREM);
-    boost::scoped_ptr<CaloCellList> pCells(new CaloCellList(pCellContainer,emSubCaloBlocks)); 
-    pCells->select(pTau.eta(), pTau.phi(), 0.4); 
-
-    // Dump cells into a std::vector since CaloCellList wont allow sorting
-    // Also apply very basic preselection
-    std::vector<const CaloCell*> cells;
-    CaloCellList::list_iterator cellItr = pCells->begin();
-    for(; cellItr!=pCells->end();++cellItr){
-        // require cells above 100 MeV
-        if( (*cellItr)->pt()*m_caloWeightTool->wtCell(*cellItr) < 100. ) continue;
-        // require cells in EM1 
-        int samp = (*cellItr)->caloDDE()->getSampling();
-        if( !( samp == CaloCell_ID::EMB1 || samp == CaloCell_ID::EME1 ) ) continue;
-        cells.push_back(*cellItr);
-    }
-    // sort cells in descending pt    
-    std::sort(cells.begin(),cells.end(),ptSort(*this));
+    shot->setAttribute<int>(xAOD::PFODetails::PFOAttributes::tauShots_nCellsInEta, m_nCellsInEta);
     
-    //---------------------------------------------------------------------
-    // shot seeding 
-    //---------------------------------------------------------------------
-    // get seed cells
-    std::vector<const CaloCell*> seedCells; 
-    std::set<IdentifierHash> seedCellHashes;
-    cellItr = cells.begin();
-    for(; cellItr != cells.end(); ++cellItr) {
-        const CaloCell* cell = (*cellItr);
-        const IdentifierHash cellHash = cell->caloDDE()->calo_hash();
-
-        // apply seed selection on nearest neighbours
-        std::vector<IdentifierHash> nextEta, prevEta;
-        m_calo_id->get_neighbours(cellHash,LArNeighbours::nextInEta,nextEta);
-        m_calo_id->get_neighbours(cellHash,LArNeighbours::prevInEta,prevEta);
-        std::vector<IdentifierHash> neighbours = nextEta;
-        neighbours.insert(neighbours.end(),prevEta.begin(),prevEta.end()); 
-        bool status = true;
-        std::vector<IdentifierHash>::iterator hashItr = neighbours.begin();
-        for(;hashItr!=neighbours.end();++hashItr){
-            // must not be next to seed cell (TODO: maybe this requirement can be removed)
-            if( seedCellHashes.find(*hashItr) != seedCellHashes.end() ){
-                status = false;
-                break;
-            }
-            // must be maximum
-            const CaloCell* neigCell = pCellContainer->findCell(*hashItr);
-            if( !neigCell ) continue;
-            if( neigCell->pt()*m_caloWeightTool->wtCell(neigCell) >= cell->pt()*m_caloWeightTool->wtCell(cell) ){
-                status = false;
-                break;
-            }
-        }
-        if( !status ) continue;        
-        seedCells.push_back(cell); 
-        seedCellHashes.insert(cellHash);
-    } // preselected cells
-    ATH_MSG_DEBUG("seedCells.size() = " << seedCells.size());
+    const IdentifierHash seedHash = cell->caloDDE()->calo_hash();
+    shot->setAttribute<int>(xAOD::PFODetails::PFOAttributes::tauShots_seedHash, seedHash);
+
+    std::vector<std::vector<const CaloCell*>> cellBlock = TauShotVariableHelpers::getCellBlock(*shot, m_calo_id);
+
+    float pt1 = TauShotVariableHelpers::ptWindow(cellBlock, 1, m_caloWeightTool);
+    shot->setAttribute<float>(xAOD::PFODetails::PFOAttributes::tauShots_pt1, pt1);
     
-    // merge across phi and construct shots
-    while( seedCells.size() ){
-        
-        const CaloCell* cell = seedCells.front(); 
-        const IdentifierHash seedHash = cell->caloDDE()->calo_hash();
-
-        // look for match across phi in current seeds
-        const CaloCell* nextPhi = NULL;
-        const CaloCell* prevPhi = NULL;
-        for( cellItr = seedCells.begin(); cellItr!=seedCells.end(); ++cellItr){
-            if( (*cellItr) == cell ) continue;
-            IdentifierHash shotCellHash = (*cellItr)->caloDDE()->calo_hash();
-            if( this->isPhiNeighbour(seedHash,shotCellHash,true) )       nextPhi = (*cellItr);
-            else if( this->isPhiNeighbour(seedHash,shotCellHash,false) ) prevPhi = (*cellItr);
-        }
-       
-        const CaloCell* mergePhi = NULL;
-        if( nextPhi && prevPhi ){
-            // take higest-pt if merged up and down
-            if( nextPhi->pt()*m_caloWeightTool->wtCell(nextPhi) > prevPhi->pt()*m_caloWeightTool->wtCell(prevPhi) ) mergePhi = nextPhi;
-            else                                mergePhi = prevPhi;
-        }
-        else if (nextPhi) mergePhi = nextPhi;
-        else if (prevPhi) mergePhi = prevPhi;
-
-        // get neighbours in 5x1 window
-        std::vector<const CaloCell*> windowNeighbours = this->getNeighbours(pCellContainer,cell,2);
-        if( mergePhi ){
-            std::vector<const CaloCell*> mergeCells = this->getNeighbours(pCellContainer,mergePhi,2);
-            windowNeighbours.push_back(mergePhi);
-            windowNeighbours.insert(windowNeighbours.end(),mergeCells.begin(),mergeCells.end());
-        }
-
-        
-        // create seed cluster
-        xAOD::CaloCluster* shotCluster = CaloClusterStoreHelper::makeCluster(pCellContainer);
-        shotCluster->getOwnCellLinks()->reserve(windowNeighbours.size()+1);
-        shotCluster->addCell(pCellContainer->findIndex(seedHash), 1.);
-        cellItr = windowNeighbours.begin();
-        for( ; cellItr!=windowNeighbours.end(); ++cellItr)
-            shotCluster->addCell(pCellContainer->findIndex((*cellItr)->caloDDE()->calo_hash()),1.0);
-        CaloClusterKineHelper::calculateKine(shotCluster,true,true);
-        tauShotClusterContainer.push_back(shotCluster);
-        
-        // create shot PFO and store it in output container
-        xAOD::PFO* shot = new xAOD::PFO();
-        tauShotPFOContainer.push_back( shot );
-
-        // Create element link from tau to shot
-        ElementLink<xAOD::PFOContainer> PFOElementLink;
-        PFOElementLink.toContainedElement( tauShotPFOContainer, shot );
-        pTau.addShotPFOLink( PFOElementLink );
-       
-        if( mergePhi ){
-            // interpolate position
-            double dPhi = TVector2::Phi_mpi_pi( mergePhi->phi() - cell->phi());
-            double ratio = mergePhi->pt()*m_caloWeightTool->wtCell(mergePhi)/(cell->pt()*m_caloWeightTool->wtCell(cell) + mergePhi->pt()*m_caloWeightTool->wtCell(mergePhi));
-            float phi = cell->phi()+dPhi*ratio;
-            float pt = cell->pt()*m_caloWeightTool->wtCell(cell)+mergePhi->pt()*m_caloWeightTool->wtCell(mergePhi);
-
-            shot->setP4( (float) pt, (float) cell->eta(), (float) phi, (float) cell->m());
-        }
-        else shot->setP4( (float) cell->pt()*m_caloWeightTool->wtCell(cell), (float) cell->eta(), (float) cell->phi(), (float) cell->m());
-        
-        shot->setBDTPi0Score( (float) -9999. );
-        shot->setCharge( 0. );
-        double center_mag = 0.0;
-        shot->setCenterMag( (float) center_mag);
-        
-        ElementLink<xAOD::CaloClusterContainer> clusElementLink;
-        clusElementLink.toContainedElement( tauShotClusterContainer, shotCluster );
-        shot->setClusterLink( clusElementLink );
-        shot->setAttribute<int>(xAOD::PFODetails::PFOAttributes::tauShots_nCellsInEta, m_nCellsInEta);
-        shot->setAttribute<int>(xAOD::PFODetails::PFOAttributes::tauShots_seedHash, seedHash);
-
-        // Get cell block for variable calculations
-        std::vector<std::vector<const CaloCell*> > cellBlock = TauShotVariableHelpers::getCellBlock(shot, m_calo_id);
-
-        // Some DEBUG statements
-        if (msgLvl(MSG::DEBUG)) { 
-          if(cell->pt()*m_caloWeightTool->wtCell(cell)>300){
-            ATH_MSG_DEBUG("New shot. \t block size phi = " << cellBlock.size() << " \t block size eta = " << cellBlock.at(0).size() << "\t shot->pt() = " << shot->pt());
-            for(unsigned iCellPhi = 0; iCellPhi<cellBlock.size();++iCellPhi){
-              for(unsigned iCellEta = 0; iCellEta<cellBlock.at(iCellPhi).size();++iCellEta){
-                const CaloCell* cell = cellBlock.at(iCellPhi).at(iCellEta);
-                if( cell==NULL ) ATH_MSG_DEBUG("Cell" << iCellPhi << iCellEta << ": \t NULL" );
-                else            ATH_MSG_DEBUG("Cell"<<iCellPhi<<iCellEta<<":\tPt = "<< cell->pt()*m_caloWeightTool->wtCell(cell)<<"\teta = "<<cell->eta()<<"\tphi = "<<cell->phi());
-              }
-            }
-          }
-        }
-        // Get eta bin
-        int etaBin = getEtaBin(cell->eta());
-
-        // set variables used for photon counting
-        float pt1=TauShotVariableHelpers::ptWindow(cellBlock,1,m_caloWeightTool);
-        float pt3=TauShotVariableHelpers::ptWindow(cellBlock,3,m_caloWeightTool);
-        float pt5=TauShotVariableHelpers::ptWindow(cellBlock,5,m_caloWeightTool);
-
-        // Calculate number of photons in shot
-        int nPhotons = getNPhotons(etaBin, pt1);
-
-        // Set variables in shot PFO
-        shot->setAttribute<float>(xAOD::PFODetails::PFOAttributes::tauShots_pt1, pt1);
-        shot->setAttribute<float>(xAOD::PFODetails::PFOAttributes::tauShots_pt3, pt3);
-        shot->setAttribute<float>(xAOD::PFODetails::PFOAttributes::tauShots_pt5, pt5);
-        shot->setAttribute<int>(xAOD::PFODetails::PFOAttributes::tauShots_nPhotons, nPhotons);
-
-        // remove shot(s) from list
-	std::vector<const CaloCell*>::iterator cellItrNonConst;
-        cellItrNonConst = std::find(seedCells.begin(),seedCells.end(),cell);
-        seedCells.erase(cellItrNonConst);
-        if( mergePhi ){
-            cellItrNonConst = std::find(seedCells.begin(),seedCells.end(),mergePhi);
-            seedCells.erase(cellItrNonConst);
-        }
-    } // seed cells
+    float pt3 = TauShotVariableHelpers::ptWindow(cellBlock, 3, m_caloWeightTool);
+    shot->setAttribute<float>(xAOD::PFODetails::PFOAttributes::tauShots_pt3, pt3);
     
+    float pt5 = TauShotVariableHelpers::ptWindow(cellBlock, 5, m_caloWeightTool);
+    shot->setAttribute<float>(xAOD::PFODetails::PFOAttributes::tauShots_pt5, pt5);
     
-    return StatusCode::SUCCESS;
+    int nPhotons = getNPhotons(cell->eta(), pt1);
+    shot->setAttribute<int>(xAOD::PFODetails::PFOAttributes::tauShots_nPhotons, nPhotons);
+
+    // Add Element link to the shot PFO container
+    ElementLink<xAOD::PFOContainer> PFOElementLink;
+    PFOElementLink.toContainedElement(shotPFOContainer, shot);
+    tau.addShotPFOLink(PFOElementLink);
+    
+    // Remove used cells from list
+    std::vector<const CaloCell*>::iterator cellItrNonConst;
+    auto cellIndex = std::find(seedCells.begin(), seedCells.end(), cell);
+    seedCells.erase(cellIndex);
+    if (phiNeigCell) {
+      cellIndex = std::find(seedCells.begin(), seedCells.end(), phiNeigCell);
+      seedCells.erase(cellIndex);
+    }
+  } // Loop over seed cells
+  
+  return StatusCode::SUCCESS;
 }
 
-//______________________________________________________________________________
-std::vector<const CaloCell*> TauShotFinder::getNeighbours(const CaloCellContainer* pCellContainer, 
-							  const CaloCell* cell, 
-							  int maxDepth) const
-{
-    std::vector<const CaloCell*> cells;
-    this->addNeighbours(pCellContainer,cell,cells,0,maxDepth,true);  //next
-    this->addNeighbours(pCellContainer,cell,cells,0,maxDepth,false); //prev
-    return cells; 
+
+
+int TauShotFinder::getEtaBin(float eta) const {
+  float absEta=std::abs(eta);
+  
+  if (absEta < 0.80) {
+    return 0; // Central Barrel
+  }
+  if (absEta<1.39) {
+    return 1; // Outer Barrel
+  }
+  if (absEta<1.51) {
+    return 2; // Crack region
+  }
+  if (absEta<1.80) {
+    return 3; // Endcap, fine granularity
+  }
+  return 4; // ndcap, coarse granularity
 }
 
-//______________________________________________________________________________
-void TauShotFinder::addNeighbours(const CaloCellContainer* pCellContainer,
-                                  const CaloCell* cell, 
-                                  std::vector<const CaloCell*>& cells,
-                                  int depth,
-                                  int maxDepth,
-                                  bool next) const
-{
-    depth++; 
-    if( depth > maxDepth ) return;
 
+
+int TauShotFinder::getNPhotons(float eta, float energy) const {
+  int etaBin = getEtaBin(eta);
+  
+  // No photons in crack region
+  if(etaBin==2) return 0;
+
+  const std::vector<float>& minPtCut = m_minPtCut.value();
+  const std::vector<float>& doubleShotCut = m_doubleShotCut.value();
+  ATH_MSG_DEBUG("etaBin = " << etaBin  << ", energy = " << energy);
+  ATH_MSG_DEBUG("MinPtCut: " << minPtCut.at(etaBin) << "DoubleShotCut: " << doubleShotCut.at(etaBin));
+
+  if (energy < minPtCut.at(etaBin)) return 0;
+  if (energy > doubleShotCut.at(etaBin)) return 2;
+  return 1;
+}
+
+
+
+std::vector<const CaloCell*> TauShotFinder::selectCells(const xAOD::TauJet& tau,
+                                                        const CaloCellContainer& cellContainer) const {
+  // Get only cells within dR < 0.4
+  // -- TODO: change the hardcoded 0.4
+  std::vector<CaloCell_ID::SUBCALO> emSubCaloBlocks;
+  emSubCaloBlocks.push_back(CaloCell_ID::LAREM);
+  boost::scoped_ptr<CaloCellList> cellList(new CaloCellList(&cellContainer,emSubCaloBlocks)); 
+  // -- FIXME: tau p4 is corrected to point at tau vertex, but the cells are not 
+  cellList->select(tau.eta(), tau.phi(), 0.4); 
+
+  std::vector<const CaloCell*> cells;
+  for (const CaloCell* cell : *cellList) {
+    // Require cells above 100 MeV
+    // FIXME: cells are not corrected to point at tau vertex
+    if (cell->pt() * m_caloWeightTool->wtCell(cell) < 100.) continue;
+    
+    // Require cells in EM1 
+    int sampling = cell->caloDDE()->getSampling();
+    if( !( sampling == CaloCell_ID::EMB1 || sampling == CaloCell_ID::EME1 ) ) continue;
+    
+    cells.push_back(cell);
+  }
+
+  return cells;
+} 
+
+
+
+std::vector<const CaloCell*> TauShotFinder::selectSeedCells(const xAOD::TauJet& tau,
+                                                            const CaloCellContainer& cellContainer) const {
+
+  // Apply pre-selection of the cells
+  std::vector<const CaloCell*> cells = selectCells(tau, cellContainer);
+  std::sort(cells.begin(),cells.end(),ptSort(*this));
+
+  std::vector<const CaloCell*> seedCells;  
+  std::set<IdentifierHash> seedCellHashes;
+
+  // Loop the pt sorted cells, and select the seed cells
+  for (const CaloCell* cell: cells) {
     const IdentifierHash cellHash = cell->caloDDE()->calo_hash();
-    std::vector<IdentifierHash> neigHashes;
-    if( next )
-        m_calo_id->get_neighbours(cellHash,LArNeighbours::nextInEta,neigHashes);
-    else
-        m_calo_id->get_neighbours(cellHash,LArNeighbours::prevInEta,neigHashes);
+
+    std::vector<IdentifierHash> nextEtaHashes;
+    m_calo_id->get_neighbours(cellHash, LArNeighbours::nextInEta, nextEtaHashes);
+    std::vector<IdentifierHash> prevEtaHashes;
+    m_calo_id->get_neighbours(cellHash, LArNeighbours::prevInEta, prevEtaHashes);
+   
+    std::vector<IdentifierHash> neighHashes = nextEtaHashes; 
+    neighHashes.insert(neighHashes.end(),prevEtaHashes.begin(),prevEtaHashes.end()); 
+   
+    // Check whether it is a seed cell
+    bool status = true;
+    for (const IdentifierHash& neighHash : neighHashes) {
+      // Seed cells must not have seed cells as neighbours
+      // TODO: maybe this requirement can be removed
+      if (seedCellHashes.find(neighHash) != seedCellHashes.end()) {
+        status = false;
+        break;
+      }
+      
+      // Pt of seed cells must be larger than neighbours'
+      const CaloCell* neighCell = cellContainer.findCell(neighHash);
+      if (!neighCell) continue;
+      if (neighCell->pt() * m_caloWeightTool->wtCell(neighCell) >= cell->pt() * m_caloWeightTool->wtCell(cell)) {
+        status = false;
+        break;
+      }
+    } // End of the loop of neighbour cells
     
-    std::vector<IdentifierHash>::iterator hashItr = neigHashes.begin();
-    for( ; hashItr!=neigHashes.end(); ++hashItr ){
-        const CaloCell* newCell = pCellContainer->findCell(*hashItr);
-        if(!newCell)continue;
-        cells.push_back(newCell);
-        this->addNeighbours(pCellContainer,newCell,cells,depth,maxDepth,next);
-        // no EM1 cell should have more than one neighbor. Just add this neigbor for now
-        // FIXME: Check whether it happens that a cell has > 1 neighbors
-        break; 
-    } 
+    if (!status) continue; 
+    
+    seedCells.push_back(cell); 
+    seedCellHashes.insert(cellHash);
+  } // End of the loop of cells
+
+  return seedCells; 
+} 
+
+
+
+bool TauShotFinder::isPhiNeighbour(IdentifierHash cell1Hash, IdentifierHash cell2Hash) const {
+  std::vector<IdentifierHash> neigHashes;
+ 
+  // Next cell in phi direction 
+  m_calo_id->get_neighbours(cell1Hash,LArNeighbours::nextInPhi,neigHashes);
+  if (neigHashes.size() > 1) {
+    ATH_MSG_WARNING(cell1Hash << " has " << neigHashes.size()  <<  " neighbours in the next phi direction !"); 
+  }
+  if (std::find(neigHashes.begin(), neigHashes.end(), cell2Hash) != neigHashes.end()) {
+    return true;
+  }
+  
+  // Previous cell in phi direction
+  m_calo_id->get_neighbours(cell1Hash,LArNeighbours::prevInPhi,neigHashes);
+  if (neigHashes.size() > 1) {
+    ATH_MSG_WARNING(cell1Hash << " has " << neigHashes.size()  <<  " neighbours in the previous phi direction !"); 
+  }
+  if (std::find(neigHashes.begin(), neigHashes.end(), cell2Hash) != neigHashes.end()) {
+    return true;
+  }
+
+  return false;
 }
 
-//______________________________________________________________________________
-bool TauShotFinder::isPhiNeighbour(IdentifierHash cell1Hash, IdentifierHash cell2Hash, bool next) const{
-    std::vector<IdentifierHash> neigHashes;
-    if( next ) m_calo_id->get_neighbours(cell1Hash,LArNeighbours::nextInPhi,neigHashes);
-    else       m_calo_id->get_neighbours(cell1Hash,LArNeighbours::prevInPhi,neigHashes);
-    std::vector<IdentifierHash>::iterator itr = neigHashes.begin();
-    for( ; itr!=neigHashes.end(); ++itr ){
-        if(cell2Hash == (*itr)) return true;
-    } 
-    return false;
+
+
+const CaloCell* TauShotFinder::getPhiNeighbour(const CaloCell& seedCell, 
+                                               const std::vector<const CaloCell*>& seedCells) const {
+
+  const IdentifierHash seedHash = seedCell.caloDDE()->calo_hash();
+ 
+  // Obtain the neighbour cells in the phi direction 
+  std::vector<const CaloCell*> neighCells;
+  for (const CaloCell* neighCell : seedCells) {
+    if (neighCell == &seedCell) continue;
+    
+    IdentifierHash neighHash = neighCell->caloDDE()->calo_hash();
+    if (this->isPhiNeighbour(seedHash, neighHash)) {
+      neighCells.push_back(neighCell);
+    }
+  }
+  std::sort(neighCells.begin(),neighCells.end(),ptSort(*this)); 
+
+  // Select the one with largest pt
+  const CaloCell* phiNeigCell = nullptr;
+  if (neighCells.size() >= 1) {
+    phiNeigCell = neighCells[0];
+  } 
+
+  return phiNeigCell;
 }
 
-//______________________________________________________________________________
-float TauShotFinder::getEtaBin(float seedEta) const {
-    float absSeedEta=std::abs(seedEta);
-    if(absSeedEta < 0.80)      return 0; // Central Barrel
-    else if(absSeedEta<1.39) return 1; // Outer Barrel
-    else if(absSeedEta<1.51) return 2; // crack
-    else if(absSeedEta<1.80) return 3; // endcap, fine granularity
-    else return 4;                           // endcap, coarse granularity
+
+
+std::vector<const CaloCell*> TauShotFinder::getEtaNeighbours(const CaloCell& cell,
+                                                             const CaloCellContainer& cellContainer, 
+							                                 int maxDepth) const {
+    std::vector<const CaloCell*> cells;
+    
+    // Add neighbours in next eta direction
+    this->addEtaNeighbours(cell, cellContainer, cells, 0, maxDepth, true);
+    // Add neighbours in previous eta direction
+    this->addEtaNeighbours(cell, cellContainer, cells, 0, maxDepth, false);
+
+    return cells; 
 }
 
-//______________________________________________________________________________
-float TauShotFinder::getNPhotons(int etaBin, float seedEnergy) const {
-    // no photon counting in crack region, e.g. [1.39, 1.51]
-    if(etaBin==2) return 0;
 
-    const std::vector<float>& minPtCut = m_minPtCut.value();
-    const std::vector<float>& autoDoubleShotCut = m_autoDoubleShotCut.value();
-    ATH_MSG_DEBUG("etaBin = " << etaBin  << ", seedEnergy = " << seedEnergy);
-    ATH_MSG_DEBUG("MinPtCut: " << minPtCut.at(etaBin) << "DoubleShotCut: " << autoDoubleShotCut.at(etaBin));
 
-    if( seedEnergy < minPtCut.at(etaBin) ) return 0;
-    if( seedEnergy > autoDoubleShotCut.at(etaBin) ) return 2;
-    return 1;
+void TauShotFinder::addEtaNeighbours(const CaloCell& cell,
+                                     const CaloCellContainer& cellContainer,
+                                     std::vector<const CaloCell*>& cells,
+                                     int depth,
+                                     int maxDepth,
+                                     bool next) const {
+  ++depth; 
+  
+  if (depth > maxDepth) return;
+
+  const IdentifierHash cellHash = cell.caloDDE()->calo_hash();
+  
+  std::vector<IdentifierHash> neigHashes;
+  if (next) {
+    m_calo_id->get_neighbours(cellHash,LArNeighbours::nextInEta,neigHashes);
+  }
+  else {
+    m_calo_id->get_neighbours(cellHash,LArNeighbours::prevInEta,neigHashes);
+  }
+
+  for (const IdentifierHash& hash : neigHashes) {
+    const CaloCell* newCell = cellContainer.findCell(hash);
+    
+    if (!newCell) continue;
+    
+    cells.push_back(newCell);
+    this->addEtaNeighbours(*newCell, cellContainer, cells, depth, maxDepth, next);
+  
+    if (neigHashes.size() > 1) {
+      ATH_MSG_WARNING(cellHash << " has " << neigHashes.size()  <<  " neighbours in the eta direction !"); 
+      break; 
+    }
+  } 
 }
 
-//______________________________________________________________________________
-// some really slick c++ way of doing sort (since we need to use the member m_caloWeightTool)
+
+
+xAOD::CaloCluster* TauShotFinder::createShotCluster(const CaloCell* cell, 
+                                                    const CaloCell* phiNeigCell, 
+                                                    const CaloCellContainer& cellContainer) const {
+    
+  xAOD::CaloCluster* shotCluster = CaloClusterStoreHelper::makeCluster(&cellContainer);
+  
+  int maxDepth = (m_nCellsInEta - 1) / 2;
+
+  std::vector<const CaloCell*> windowNeighbours = this->getEtaNeighbours(*cell, cellContainer, maxDepth);
+  if (phiNeigCell) {
+    std::vector<const CaloCell*> mergeCells = this->getEtaNeighbours(*phiNeigCell, cellContainer, maxDepth);
+    windowNeighbours.push_back(phiNeigCell);
+    windowNeighbours.insert(windowNeighbours.end(), mergeCells.begin(), mergeCells.end());
+  }
+
+  shotCluster->getOwnCellLinks()->reserve(windowNeighbours.size()+1);
+  const IdentifierHash seedHash = cell->caloDDE()->calo_hash();
+  shotCluster->addCell(cellContainer.findIndex(seedHash), 1.);
+  
+  for (const CaloCell* cell : windowNeighbours) {
+    shotCluster->addCell(cellContainer.findIndex(cell->caloDDE()->calo_hash()), 1.0);
+  }
+
+  CaloClusterKineHelper::calculateKine(shotCluster,true,true);
+
+  return shotCluster;
+} 
+
+
+
 TauShotFinder::ptSort::ptSort( const TauShotFinder& info ) : m_info(info) { } 
-bool TauShotFinder::ptSort::operator()( const CaloCell* c1, const CaloCell* c2 ){
-     return  c1->pt()*m_info.m_caloWeightTool->wtCell(c1) > c2->pt()*m_info.m_caloWeightTool->wtCell(c2);  
+bool TauShotFinder::ptSort::operator()( const CaloCell* cell1, const CaloCell* cell2 ){
+  double pt1 = cell1->pt()*m_info.m_caloWeightTool->wtCell(cell1);
+  double pt2 = cell2->pt()*m_info.m_caloWeightTool->wtCell(cell2);
+  return pt1 > pt2;  
 }
 
 #endif
diff --git a/Reconstruction/tauRecTools/src/TauShotFinder.h b/Reconstruction/tauRecTools/src/TauShotFinder.h
index 6273fc619692164dbf2885e079f0f6b2b258ac24..fde52480403aa80dc4e9941b7eb382d32cb59b3d 100644
--- a/Reconstruction/tauRecTools/src/TauShotFinder.h
+++ b/Reconstruction/tauRecTools/src/TauShotFinder.h
@@ -5,69 +5,104 @@
 #ifndef TAUREC_TAUSHOTFINDER_H
 #define	TAUREC_TAUSHOTFINDER_H
 
-#include "GaudiKernel/ToolHandle.h"
 #include "tauRecTools/TauRecToolBase.h"
+
 #include "xAODPFlow/PFOAuxContainer.h"
 #include "xAODCaloEvent/CaloClusterAuxContainer.h"
 #include "CaloInterface/IHadronicCalibrationTool.h"
 
-class CaloDetDescrManager;
+#include "GaudiKernel/ToolHandle.h"
+
+/**
+ * @brief Construct the shot candidates
+ *        1. select seed cells used to construct the shot candidates 
+ *        2. create the shot PFOs by merging the neighbour seed cells in phi direction
+ *        3. the cluster of the shot PFO contains cells in a window of 2 x NCellsInEta
+ *
+ * @author Will Davey <will.davey@cern.ch> 
+ * @author Benedict Winter <benedict.tobias.winter@cern.ch>
+ * @author Stephanie Yuen <stephanie.yuen@cern.ch> 
+ */
+
 class CaloCell_ID;
 
 class TauShotFinder : public TauRecToolBase {
+
 public:
-    TauShotFinder(const std::string& name);
-    ASG_TOOL_CLASS2(TauShotFinder, TauRecToolBase, ITauToolBase);
-    virtual ~TauShotFinder();
+  
+  ASG_TOOL_CLASS2(TauShotFinder, TauRecToolBase, ITauToolBase);
+
+  TauShotFinder(const std::string& name);
+  virtual ~TauShotFinder() = default;
 
-    virtual StatusCode initialize() override;
-    virtual StatusCode executeShotFinder(xAOD::TauJet& pTau, xAOD::CaloClusterContainer& tauShotCaloClusContainer, xAOD::PFOContainer& tauShotPFOContainer) const override;
+  virtual StatusCode initialize() override;
+  virtual StatusCode executeShotFinder(xAOD::TauJet& pTau, xAOD::CaloClusterContainer& tauShotCaloClusContainer, xAOD::PFOContainer& tauShotPFOContainer) const override;
 
 private:
 
-    /** @brief tool handles */
-    ToolHandle<IHadronicCalibrationTool> m_caloWeightTool {this, "CaloWeightTool", "H1WeightToolCSC12Generic"};
-    
-    /** @brief new shot PFO container and name */
-    /** @brief calo cell navigation */
-    const CaloCell_ID* m_calo_id = NULL;
-
-    /** @brief Thanks C++ for ruining my day */
-    struct ptSort
-    { 
-         ptSort( const TauShotFinder& info );
-         const TauShotFinder& m_info;
-         bool operator()( const CaloCell* c1, const CaloCell* c2 );
-    };
-
-    /** @brief get neighbour cells */
-    std::vector<const CaloCell*> getNeighbours(const CaloCellContainer*,const CaloCell*,int /*maxDepth*/) const;
-
-    void addNeighbours(const CaloCellContainer*,
-                       const CaloCell* cell,
-                       std::vector<const CaloCell*>& cells,
-                       int depth,
-                       int maxDepth,
-                       bool next) const;
-
-    bool isPhiNeighbour(IdentifierHash cell1Hash, IdentifierHash cell2Hash, bool next) const;
-
-    /** @brief get eta bin */
-    float getEtaBin(float /*seedEta*/) const;
+  struct ptSort
+  { 
+    ptSort( const TauShotFinder& info );
+    const TauShotFinder& m_info;
+    bool operator()( const CaloCell* c1, const CaloCell* c2 );
+  };
+
+  /** @brief Apply preselection of the cells 
+   *         Cells within dR < 0.4, in EM1, and pt > 100 MeV are selected
+   */
+  std::vector<const CaloCell*> selectCells(const xAOD::TauJet& tau, const CaloCellContainer& cellContainer) const;
+
+  /** @brief Select the seed cells used to construct the shot 
+   *         Cells must sastisfy:
+   *         1. pre-selction: dR < 0.4, in EM1, and pt > 100 MeV
+   *         2. have largest pt among the neighbours in the eta direction 
+   *         3. no other seed cells as neighbors in the eta direction
+   */
+  std::vector<const CaloCell*> selectSeedCells(const xAOD::TauJet& tau, const CaloCellContainer& cellContainer) const;
+
+  /** @brief Check whether two cells are neighbours in the phi direction */
+  bool isPhiNeighbour(IdentifierHash cell1Hash, IdentifierHash cell2Hash) const;
+
+  /** @brief Get the hottest neighbour cell in the phi direction */ 
+  const CaloCell* getPhiNeighbour(const CaloCell& seedCell, const std::vector<const CaloCell*>& seedCells) const;
+  
+  /** @brief Get neighbour cells in the eta direction */
+  std::vector<const CaloCell*> getEtaNeighbours(const CaloCell& cell, const CaloCellContainer& cellContainer, int maxDepth) const;
+
+  /** @brief Get neighbour cells in the eta direction */
+  void addEtaNeighbours(const CaloCell& cell,
+                        const CaloCellContainer& cellContainer,
+                        std::vector<const CaloCell*>& cells,
+                        int depth,
+                        int maxDepth,
+                        bool next) const;
+
+  /** @brief Create the shot cluster 
+   *         Shot cluster contains 5x1 cells from the seed cell and hottestneighbour
+   *         cell in the phi direction
+   */
+  xAOD::CaloCluster* createShotCluster(const CaloCell* cell,
+                                       const CaloCell* phiNeighCell,
+                                       const CaloCellContainer& cellContainer) const;
+
+  /** @brief Get eta bin */
+  int getEtaBin(float eta) const;
+
+  /** @brief Get NPhotons in shot */
+  int getNPhotons(float eta, float energy) const;
+
+  Gaudi::Property<int> m_nCellsInEta {this, "NCellsInEta"};
+  Gaudi::Property<std::vector<float>> m_minPtCut {this, "MinPtCut"};
+  Gaudi::Property<std::vector<float>> m_doubleShotCut {this, "AutoDoubleShotCut"};
+
+  SG::ReadHandleKey<CaloCellContainer> m_caloCellInputContainer{this,"Key_caloCellInputContainer", "AllCalo", "input vertex container key"};
+  SG::WriteHandleKey<xAOD::PFOContainer> m_tauPFOOutputContainer{this,"Key_tauPFOOutputContainer", "TauShotParticleFlowObjects", "tau pfo out key"};
   
-    /** @brief get NPhotons in shot */
-    float getNPhotons(int /*etaBin*/, 
-                      float /*seedEnergy*/) const;
-
-    // number of cells in eta
-    Gaudi::Property<int> m_nCellsInEta {this, "NCellsInEta"};
-    Gaudi::Property<std::vector<float>> m_minPtCut {this, "MinPtCut"};
-    Gaudi::Property<std::vector<float>> m_autoDoubleShotCut {this, "AutoDoubleShotCut"};
+  ToolHandle<IHadronicCalibrationTool> m_caloWeightTool {this, "CaloWeightTool", "H1WeightToolCSC12Generic"};
   
-    SG::ReadHandleKey<CaloCellContainer> m_caloCellInputContainer{this,"Key_caloCellInputContainer", "AllCalo", "input vertex container key"};
-    SG::WriteHandleKey<xAOD::PFOContainer> m_tauPFOOutputContainer{this,"Key_tauPFOOutputContainer", "TauShotParticleFlowObjects", "tau pfo out key"};
-    
+  /// calo cell navigation
+  const CaloCell_ID* m_calo_id = nullptr;
+
 };
 
 #endif	/* TAUSHOTFINDER_H */
-
diff --git a/Reconstruction/tauRecTools/src/TauShotVariableHelpers.cxx b/Reconstruction/tauRecTools/src/TauShotVariableHelpers.cxx
index 3b2ef8c1dc49c3eacaa6f04b1d7e45d8362d5d03..d2e1021f2efa20e3a2b0eab39d609772e29e2146 100644
--- a/Reconstruction/tauRecTools/src/TauShotVariableHelpers.cxx
+++ b/Reconstruction/tauRecTools/src/TauShotVariableHelpers.cxx
@@ -3,6 +3,9 @@
 */
 
 #ifndef XAOD_ANALYSIS
+
+#include "TauShotVariableHelpers.h"
+
 /**
  * @brief implementation of photon shot variable calculation 
  * 
@@ -11,350 +14,174 @@
  * @author Stephanie Yuen <stephanie.yuen@cern.ch> 
  */
 
-#include "TauShotVariableHelpers.h"
+namespace TauShotVariableHelpers {
 
-using xAOD::PFO;
-using std::vector;
+ANA_MSG_SOURCE(msgTauShotVariableHelpers, "TauShotVariableHelpers")
 
-namespace TauShotVariableHelpers {
-    std::vector<std::vector<const CaloCell*> > getCellBlock(xAOD::PFO* shot, const CaloCell_ID* calo_id){
-        std::vector<std::vector<const CaloCell*> > cellVector;
-        std::vector<const CaloCell*> oneEtaLayer;
-        int nCellsInEta = 0;
-        if( shot->attribute(xAOD::PFODetails::PFOAttributes::tauShots_nCellsInEta, nCellsInEta) == false) {
-            std::cout << "WARNING: Couldn't find nCellsInEta. Return empty cell block." << std::endl;
-            return cellVector;
-        }
-        int seedHash = 0;
-        if( shot->attribute(xAOD::PFODetails::PFOAttributes::tauShots_seedHash, seedHash) == false) {
-            std::cout << "WARNING: Couldn't find seed hash. Return empty cell block." << std::endl;
-            return cellVector;
-        }
-        for(int iCell=0;iCell<nCellsInEta;++iCell) oneEtaLayer.push_back(NULL);
-        // have two layers in phi
-        cellVector.push_back(oneEtaLayer);
-        cellVector.push_back(oneEtaLayer);
-        // get cluster from shot
-        const xAOD::CaloCluster* cluster = shot->cluster(0);
-        const CaloClusterCellLink* theCellLink = cluster->getCellLinks();
-        CaloClusterCellLink::const_iterator cellItr  = theCellLink->begin();
-        CaloClusterCellLink::const_iterator cellItrE = theCellLink->end();
 
-        // get seed cell from shot cluster
-        const CaloCell* seedCell=NULL;
-        for(;cellItr!=cellItrE;++cellItr){
-            if((*cellItr)->caloDDE()->calo_hash()!=(unsigned) seedHash) continue;
-            seedCell = *cellItr;
-            break;
-        }
-        if(seedCell==NULL){
-          std::cout << "WARNING: Couldn't find seed cell in shot cluster. Return empty cell block." << std::endl;
-          return cellVector;
-        }
-        
-        // get merged cell in phi. Keep NULL if shot is not merged across phi
-        const CaloCell* mergedCell = NULL;
-        std::vector<IdentifierHash> nextInPhi;
-        std::vector<IdentifierHash> prevInPhi;
-        calo_id->get_neighbours(seedCell->caloDDE()->calo_hash(),LArNeighbours::nextInPhi,nextInPhi);
-        calo_id->get_neighbours(seedCell->caloDDE()->calo_hash(),LArNeighbours::prevInPhi,prevInPhi);
-        for(cellItr=theCellLink->begin();cellItr!=cellItrE;++cellItr){
-            std::vector<IdentifierHash>::iterator itr = nextInPhi.begin();
-            for( ; itr!=nextInPhi.end(); ++itr ){
-                if((*cellItr)->caloDDE()->calo_hash() != (*itr)) continue;
-                mergedCell = (*cellItr);
-                break;
-            }
-            if(mergedCell!=NULL) break;
-            itr = prevInPhi.begin();
-            for( ; itr!=prevInPhi.end(); ++itr ){
-                if((*cellItr)->caloDDE()->calo_hash() != (*itr)) continue;
-                mergedCell = (*cellItr);
-                break;
-            }
-            if(mergedCell!=NULL) break;
-        }
-        // store cells in the eta layer, which contains the seed cell
-        int nCellsFromSeed = 1;
-        const CaloCell* lastCell = seedCell;
-        cellVector.at(0).at(nCellsInEta/2) = seedCell; // store seed cell
-        std::vector<IdentifierHash> next;
-        while(lastCell!=NULL && nCellsFromSeed<nCellsInEta/2+1){
-            calo_id->get_neighbours(lastCell->caloDDE()->calo_hash(),LArNeighbours::nextInEta,next);
-            lastCell = NULL;
-            for(cellItr=theCellLink->begin();cellItr!=cellItrE;++cellItr){
-                std::vector<IdentifierHash>::iterator itr = next.begin();
-                for( ; itr!=next.end(); ++itr ){
-                    if((*cellItr)->caloDDE()->calo_hash() != (*itr)) continue;
-                    cellVector.at(0).at(nCellsInEta/2+nCellsFromSeed) = (*cellItr);
-                    lastCell = (*cellItr);
-                }
-            }
-            nCellsFromSeed++;
-        }
-        nCellsFromSeed = 1;
-        lastCell = seedCell;
-        while(lastCell!=NULL && nCellsFromSeed<nCellsInEta/2+1){
-            calo_id->get_neighbours(lastCell->caloDDE()->calo_hash(),LArNeighbours::prevInEta,next);
-            lastCell = NULL;
-            for(cellItr=theCellLink->begin();cellItr!=cellItrE;++cellItr){
-                std::vector<IdentifierHash>::iterator itr = next.begin();
-                for( ; itr!=next.end(); ++itr ){
-                    if((*cellItr)->caloDDE()->calo_hash() != (*itr)) continue;
-                    cellVector.at(0).at(nCellsInEta/2-nCellsFromSeed) = (*cellItr);
-                    lastCell = (*cellItr);
-                }
-            }
-            nCellsFromSeed++;
-        }
-        // store cells in the eta layer, which contains the merged cell
-        int nCellsFromMerged = 1;
-        lastCell = mergedCell; // is NULL if shot is not merged
-        cellVector.at(1).at(nCellsInEta/2) = mergedCell; // store merged cell
-        while(lastCell!=NULL && nCellsFromMerged<nCellsInEta/2+1){
-            calo_id->get_neighbours(lastCell->caloDDE()->calo_hash(),LArNeighbours::nextInEta,next);
-            lastCell = NULL;
-            for(cellItr=theCellLink->begin();cellItr!=cellItrE;++cellItr){
-                std::vector<IdentifierHash>::iterator itr = next.begin();
-                for( ; itr!=next.end(); ++itr ){
-                    if((*cellItr)->caloDDE()->calo_hash() != (*itr)) continue;
-                    cellVector.at(1).at(nCellsInEta/2+nCellsFromMerged) = (*cellItr);
-                    lastCell = (*cellItr);
-                }
-            }
-            nCellsFromMerged++;
-        }
-        nCellsFromMerged = 1;
-        lastCell = mergedCell;
-        while(lastCell!=NULL && nCellsFromMerged<nCellsInEta/2+1){
-            calo_id->get_neighbours(lastCell->caloDDE()->calo_hash(),LArNeighbours::prevInEta,next);
-            lastCell = NULL;
-            for(cellItr=theCellLink->begin();cellItr!=cellItrE;++cellItr){
-                std::vector<IdentifierHash>::iterator itr = next.begin();
-                for( ; itr!=next.end(); ++itr ){
-                    if((*cellItr)->caloDDE()->calo_hash() != (*itr)) continue;
-                    cellVector.at(1).at(nCellsInEta/2-nCellsFromMerged) = (*cellItr);
-                    lastCell = (*cellItr);
-                }
-            }
-            nCellsFromMerged++;
-        }
-        return cellVector;
-    
-    }
+const CaloCell* getNeighbour(const CaloCell* cell, 
+                             const CaloClusterCellLink& links, 
+                             const CaloCell_ID* calo_id, 
+                             const LArNeighbours::neighbourOption& option) {
+  const CaloCell* neigCell = nullptr;
 
+  std::vector<IdentifierHash> neighHashes;
+  calo_id->get_neighbours(cell->caloDDE()->calo_hash(), option, neighHashes);
+  
+  // Loop all the cells, and find the required neighbour cell
+  for (const auto& cell : links) {
+    const IdentifierHash& cellHash = cell->caloDDE()->calo_hash();
 
-    float mean_eta(vector<vector<const CaloCell*> > shotCells, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        float sumEta=0.;
-        float sumWeight=0.;
-        vector<vector<const CaloCell*> >::iterator itrPhi = shotCells.begin();
-        for( ; itrPhi!=shotCells.end(); ++itrPhi ){
-            vector<const CaloCell*>::iterator itrEta = itrPhi->begin();
-            for( ; itrEta!=itrPhi->end(); ++itrEta ){
-                if((*itrEta) == NULL) continue;
-                sumWeight += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta);
-                sumEta    += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta) * (*itrEta)->eta();
-            }
-        }
-        if(sumWeight<=0.) return -99999.;
-        return sumEta/sumWeight;
+    // Check whether the cell is a neighbour cell 
+    for (const IdentifierHash& neighHash : neighHashes) {
+      if (cellHash == neighHash) {
+        neigCell = cell;
+        return neigCell;
+      }
     }
+  }
 
-    float mean_pt(vector<vector<const CaloCell*> > shotCells, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        float sumPt=0.;
-        int nCells = 0;
-        vector<vector<const CaloCell*> >::iterator itrPhi = shotCells.begin();
-        for( ; itrPhi!=shotCells.end(); ++itrPhi ){
-            vector<const CaloCell*>::iterator itrEta = itrPhi->begin();
-            for( ; itrEta!=itrPhi->end(); ++itrEta ){
-                if((*itrEta) == NULL) continue;
-                sumPt  += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta);
-                nCells ++;
-            }
-        }
-        if(nCells==0) return -99999.;
-        return sumPt/nCells;
-    }
+  return neigCell;
+}
 
-    float ptWindow(vector<vector<const CaloCell*> > shotCells, int windowSize, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        // window size should be odd and noti be larger than eta window of shotCells
-        int nCells_eta = shotCells.at(0).size();
-        int seedIndex = nCells_eta/2;
-        if( windowSize%2!=1 )        return 0.;
-        if( windowSize > nCells_eta) return 0.;
-        float ptWindow  = 0.;
-        for(int iCell = 0; iCell != nCells_eta; ++iCell ){
-	  if(std::abs(iCell-seedIndex)>windowSize/2) continue;
-            if(shotCells.at(0).at(iCell) != NULL) ptWindow+=shotCells.at(0).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(iCell));
-            if(shotCells.at(1).at(iCell) != NULL) ptWindow+=shotCells.at(1).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(iCell));
-        }
-        return ptWindow;
-    }
 
-    float ws5(vector<vector<const CaloCell*> > shotCells, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        int nCells_eta = shotCells.at(0).size();
-        int seedIndex = nCells_eta/2;
-        float sumWeight=0.;
-        float sumDev2=0.;
-        vector<vector<const CaloCell*> >::iterator itrPhi = shotCells.begin();
-        for( ; itrPhi!=shotCells.end(); ++itrPhi ){
-            for(unsigned iCell = 0; iCell != itrPhi->size(); ++iCell ){
-                if(itrPhi->at(iCell) == NULL) continue;
-                sumWeight += itrPhi->at(iCell)->pt()*caloWeightTool->wtCell(itrPhi->at(iCell));
-                sumDev2   += itrPhi->at(iCell)->pt()*caloWeightTool->wtCell(itrPhi->at(iCell)) * pow(iCell-seedIndex,2);
-            }
-        }
-        if(sumWeight<=0. || sumDev2 <0.) return -99999.;
-        return sqrt( sumDev2 / sumWeight );
+std::vector<std::vector<const CaloCell*> > getCellBlock(const xAOD::PFO& shot, const CaloCell_ID* calo_id) {
+  using namespace TauShotVariableHelpers::msgTauShotVariableHelpers; 
+
+  std::vector<std::vector<const CaloCell*>> cellBlock;
+
+  int etaSize = 0;
+  if (shot.attribute(xAOD::PFODetails::PFOAttributes::tauShots_nCellsInEta, etaSize) == false) {
+    ANA_MSG_WARNING("Couldn't find nCellsInEta. Return empty cell block.");
+    return cellBlock;
+  }
+  
+  int seedHash = 0;
+  if (shot.attribute(xAOD::PFODetails::PFOAttributes::tauShots_seedHash, seedHash) == false) {
+    ANA_MSG_WARNING("Couldn't find seed hash. Return empty cell block.");
+    return cellBlock;
+  }
+  
+  // Initialize the cell block
+  std::vector<const CaloCell*> etaLayer;
+  for (int etaIndex = 0; etaIndex < etaSize; ++ etaIndex) {
+    etaLayer.push_back(nullptr);
+  }
+  int phiSize = 2;
+  for (int phiIndex = 0; phiIndex < phiSize; ++phiIndex) {
+    cellBlock.push_back(etaLayer);
+  }
+
+  // Get seed cell from shot cluster
+  const xAOD::CaloCluster* cluster = shot.cluster(0);
+  const CaloClusterCellLink* cellLinks = cluster->getCellLinks();
+
+  const CaloCell* seedCell = nullptr;
+  for (const auto& cell : *cellLinks) {
+     if (cell->caloDDE()->calo_hash() != (unsigned) seedHash) continue;
+     seedCell = cell;
+     break;
+  }
+  if (seedCell==nullptr) {
+    ANA_MSG_WARNING("Couldn't find seed cell in shot cluster. Return empty cell block.");
+    return cellBlock;
+  }
+  int mediumEtaIndex = etaSize/2;
+  cellBlock.at(0).at(mediumEtaIndex) = seedCell;
+
+  // Obtain the neighbour cells in the eta direction
+  // -- Next in eta
+  const CaloCell* lastCell = seedCell;
+  int maxDepth = etaSize - mediumEtaIndex - 1;
+  for (int depth = 1; depth < maxDepth + 1; ++depth) {
+    lastCell = getNeighbour(lastCell, *cellLinks, calo_id, LArNeighbours::nextInEta);
+    if (lastCell != nullptr) {
+      cellBlock.at(0).at(mediumEtaIndex + depth) = lastCell;
     }
-
-    float sdevEta_WRTmean(vector<vector<const CaloCell*> > shotCells, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        float mean = mean_eta(shotCells, caloWeightTool); 
-        float sumWeight=0.;
-        float sumDev2=0.;
-        vector<vector<const CaloCell*> >::iterator itrPhi = shotCells.begin();
-        for( ; itrPhi!=shotCells.end(); ++itrPhi ){
-            vector<const CaloCell*>::iterator itrEta = itrPhi->begin();
-            for( ; itrEta!=itrPhi->end(); ++itrEta ){
-                if((*itrEta) == NULL) continue;
-                sumWeight += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta);
-                sumDev2   += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta) * pow((*itrEta)->eta() - mean,2);
-            }
-        }
-        if(sumWeight<=0. || sumDev2 <0.) return -99999.;
-        return sqrt( sumDev2 / sumWeight );
+    else {
+      break;
     }
-
-    float sdevEta_WRTmode(vector<vector<const CaloCell*> > shotCells, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        int nCells_eta = shotCells.at(0).size();
-        int seedIndex = nCells_eta/2;
-        float mode = shotCells.at(0).at(seedIndex)->eta();
-        float sumWeight=0.;
-        float sumDev2=0.;
-        vector<vector<const CaloCell*> >::iterator itrPhi = shotCells.begin();
-        for( ; itrPhi!=shotCells.end(); ++itrPhi ){
-            vector<const CaloCell*>::iterator itrEta = itrPhi->begin();
-            for( ; itrEta!=itrPhi->end(); ++itrEta ){
-                if((*itrEta) == NULL) continue;
-                sumWeight += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta);
-                sumDev2   += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta) * pow((*itrEta)->eta() - mode,2);
-            }
-        }
-        if(sumWeight<=0. || sumDev2 <0.) return -99999.;
-        return sqrt( sumDev2 / sumWeight );
+  }
+
+  // -- Previous in eta
+  lastCell = seedCell;
+  for (int depth = 1; depth < maxDepth + 1; ++depth) {
+    lastCell = getNeighbour(lastCell, *cellLinks, calo_id, LArNeighbours::prevInEta);
+    if (lastCell != nullptr) {
+      cellBlock.at(0).at(mediumEtaIndex - depth) = lastCell;
     }
-
-    float sdevPt(vector<vector<const CaloCell*> > shotCells, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        float mean = mean_pt(shotCells, caloWeightTool);
-        float sumWeight=0.;
-        float sumDev2=0.;
-        vector<vector<const CaloCell*> >::iterator itrPhi = shotCells.begin();
-        for( ; itrPhi!=shotCells.end(); ++itrPhi ){
-            vector<const CaloCell*>::iterator itrEta = itrPhi->begin();
-            for( ; itrEta!=itrPhi->end(); ++itrEta ){
-                if((*itrEta) == NULL) continue;
-                sumWeight += (*itrEta)->pt()*caloWeightTool->wtCell(*itrEta);
-                sumDev2   += pow((*itrEta)->pt()*caloWeightTool->wtCell(*itrEta) - mean,2);
-            }
-        }
-        if(sumWeight<=0. || sumDev2 <0.) return -99999.;
-        return sqrt(sumDev2)/sumWeight;
+    else {
+      break;
     }
-
-    float deltaPt12_min(vector<vector<const CaloCell*> > shotCells, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        int nCells_eta = shotCells.at(0).size();
-        int seedIndex = nCells_eta/2;
-        bool haveLeft  = false;
-        bool haveRight = false;
-        float deltaPt_left  = 0.;
-        float deltaPt_right = 0.;
-        if(shotCells.at(0).at(seedIndex-1)!=NULL && shotCells.at(0).at(seedIndex-2)!=NULL){
-            haveLeft  = true;
-            deltaPt_left =  shotCells.at(0).at(seedIndex-1)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(seedIndex-1))
-                           -shotCells.at(0).at(seedIndex-2)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(seedIndex-2));
-            if(shotCells.at(1).at(seedIndex-1)!=NULL && shotCells.at(1).at(seedIndex-2)!=NULL){
-                deltaPt_left += shotCells.at(1).at(seedIndex-1)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(seedIndex-1))
-                               -shotCells.at(1).at(seedIndex-2)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(seedIndex-2));
-            }
-        }
-        if(shotCells.at(0).at(seedIndex+1)!=NULL && shotCells.at(0).at(seedIndex+2)!=NULL){
-            haveRight = true;
-            deltaPt_right =  shotCells.at(0).at(seedIndex+1)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(seedIndex+1))
-                            -shotCells.at(0).at(seedIndex+2)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(seedIndex+2));
-            if(shotCells.at(1).at(seedIndex+1)!=NULL && shotCells.at(1).at(seedIndex+2)!=NULL){
-                deltaPt_right += shotCells.at(1).at(seedIndex+1)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(seedIndex+1))
-                                -shotCells.at(1).at(seedIndex+2)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(seedIndex+1));
-            }
-        }
-        if(haveLeft && haveRight) return fmin(deltaPt_left,deltaPt_right);
-        if(haveLeft)              return deltaPt_left;
-        if(haveRight)             return deltaPt_right;
-        else                      return -1.;
+  }
+
+  // Merged cell
+  const CaloCell* mergedCell = getNeighbour(seedCell, *cellLinks, calo_id, LArNeighbours::nextInPhi);
+  if (mergedCell == nullptr) {
+    mergedCell = getNeighbour(seedCell, *cellLinks, calo_id, LArNeighbours::prevInPhi);
+  }
+
+  if (mergedCell != nullptr) {
+    cellBlock.at(1).at(mediumEtaIndex) = mergedCell;
+  
+    // Obtain the neighbour cells in the eta direction
+    // -- Next in eta
+    lastCell = mergedCell;
+    for (int depth = 1; depth < maxDepth + 1; ++depth) {
+      lastCell = getNeighbour(lastCell, *cellLinks, calo_id, LArNeighbours::nextInEta);
+      if (lastCell != nullptr) {
+        cellBlock.at(1).at(mediumEtaIndex + depth) = lastCell;
+      }
+      else {
+        break;
+      }
     }
+  
+    // -- Previous in eta
+    lastCell = mergedCell;
+    for (int depth = 1; depth < maxDepth + 1; ++depth) {
+      lastCell = getNeighbour(lastCell, *cellLinks, calo_id, LArNeighbours::prevInEta);
+      if (lastCell != nullptr) {
+        cellBlock.at(1).at(mediumEtaIndex - depth) = lastCell;
+      }
+      else {
+        break;
+      }
+    }
+  } // End of mergedCell != nullptr
 
+  return cellBlock;
+}
 
-    float Fside(vector<vector<const CaloCell*> > shotCells, int largerWindow, int smallerWindow, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        // window sizes should be odd and windows should be not larger than eta window of shotCells
-        int nCells_eta = shotCells.at(0).size();
-        int seedIndex = nCells_eta/2;
-        if( largerWindow%2!=1 || smallerWindow%2!=1) return 0.;
-        if( largerWindow <= smallerWindow)           return 0.;
-        if( largerWindow > nCells_eta)   return 0.;
-        float pt_largerWindow  = 0.;
-        float pt_smallerWindow = 0.;
-        for(int iCell = 0; iCell != nCells_eta; ++iCell ){
-	    if(std::abs(iCell-seedIndex)>largerWindow/2) continue;
-            if(shotCells.at(0).at(iCell)!=NULL) pt_largerWindow+=shotCells.at(0).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(iCell));
-            if(shotCells.at(1).at(iCell)!=NULL) pt_largerWindow+=shotCells.at(1).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(iCell));
-            if(std::abs(iCell-seedIndex)>smallerWindow/2) continue;
-            if(shotCells.at(0).at(iCell)!=NULL) pt_smallerWindow+=shotCells.at(0).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(iCell));
-            if(shotCells.at(1).at(iCell)!=NULL) pt_smallerWindow+=shotCells.at(1).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(iCell));
-        }
-        if(pt_smallerWindow==0.) return -99999.;
-        return (pt_largerWindow-pt_smallerWindow)/pt_smallerWindow;
-    }
 
-    float fracSide(vector<vector<const CaloCell*> > shotCells, int largerWindow, int smallerWindow, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        // window sizes should be odd and windows should be not larger than eta window of shotCells
-        int nCells_eta = shotCells.at(0).size();
-        int seedIndex = nCells_eta/2;
-        if( largerWindow%2!=1 || smallerWindow%2!=1) return 0.;
-        if( largerWindow <= smallerWindow)           return 0.;
-        if( largerWindow > nCells_eta)   return 0.;
-        float pt_largerWindow  = 0.;
-        float pt_smallerWindow = 0.;
-        for(int iCell = 0; iCell != nCells_eta; ++iCell ){
-            if(std::abs(iCell-seedIndex)>largerWindow/2) continue;
-            if(shotCells.at(0).at(iCell)!=NULL) pt_largerWindow+=shotCells.at(0).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(iCell));
-            if(shotCells.at(1).at(iCell)!=NULL) pt_largerWindow+=shotCells.at(1).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(iCell));
-            if(std::abs(iCell-seedIndex)>smallerWindow/2) continue;
-            if(shotCells.at(0).at(iCell)!=NULL) pt_smallerWindow+=shotCells.at(0).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(iCell));
-            if(shotCells.at(1).at(iCell)!=NULL) pt_smallerWindow+=shotCells.at(1).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(iCell));
-        }
-        if(pt_largerWindow==0.) return -99999.;
-        return (pt_largerWindow-pt_smallerWindow)/pt_largerWindow;
-    }
 
-    float ptWindowFrac(vector<vector<const CaloCell*> > shotCells, int largerWindow, int smallerWindow, const ToolHandle<IHadronicCalibrationTool>& caloWeightTool){
-        // window sizes should be odd and windows should be not larger than eta window of shotCells
-        int nCells_eta = shotCells.at(0).size();
-        int seedIndex = nCells_eta/2;
-        if( largerWindow%2!=1 || smallerWindow%2!=1) return 0.;
-        if( largerWindow <= smallerWindow)           return 0.;
-        if( largerWindow > nCells_eta)   return 0.;
-        float pt_largerWindow  = 0.;
-        float pt_smallerWindow = 0.;
-        for(int iCell = 0; iCell != nCells_eta; ++iCell ){
-            if(std::abs(iCell-seedIndex)>largerWindow/2) continue;
-            if(shotCells.at(0).at(iCell)!=NULL) pt_largerWindow+=shotCells.at(0).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(iCell));
-            if(shotCells.at(1).at(iCell)!=NULL) pt_largerWindow+=shotCells.at(1).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(iCell));
-            if(std::abs(iCell-seedIndex)>smallerWindow/2) continue;
-            if(shotCells.at(0).at(iCell)!=NULL) pt_smallerWindow+=shotCells.at(0).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(0).at(iCell));
-            if(shotCells.at(1).at(iCell)!=NULL) pt_smallerWindow+=shotCells.at(1).at(iCell)->pt()*caloWeightTool->wtCell(shotCells.at(1).at(iCell));
-        }
-        if(pt_largerWindow==0.) return -99999.;
-        return pt_smallerWindow/pt_largerWindow;
+float ptWindow(const std::vector<std::vector<const CaloCell*>>& shotCells, 
+               int windowSize, 
+               const ToolHandle<IHadronicCalibrationTool>& caloWeightTool) {
+  // window size should be odd and smaller than eta window of shotCells
+  if (windowSize%2 != 1) return 0.;
+  
+  int etaSize = shotCells.at(0).size();
+  if (windowSize > etaSize) return 0.;
+  
+  int seedIndex = etaSize/2;
+  int phiSize = shotCells.size();
+
+  float ptWindow  = 0.;
+  for (int etaIndex = 0; etaIndex != etaSize; ++etaIndex) {
+    if (std::abs(etaIndex-seedIndex) > windowSize/2) continue;
+    
+    for (int phiIndex = 0; phiIndex != phiSize; ++phiIndex) {
+      const CaloCell* cell = shotCells.at(phiIndex).at(etaIndex);
+      if (cell != nullptr) {
+        ptWindow += cell->pt() * caloWeightTool->wtCell(cell);
+      }     
     }
+  }
+
+  return ptWindow;
 }
 
+} // End of namespace TauShotVariableHelpers 
+
 #endif
diff --git a/Reconstruction/tauRecTools/src/TauShotVariableHelpers.h b/Reconstruction/tauRecTools/src/TauShotVariableHelpers.h
index fdc21ce85d0941f815d48d1e9de3e75c2aaa44ca..8e5f867fa6d222cb949b53ba8e38097c32a39aab 100644
--- a/Reconstruction/tauRecTools/src/TauShotVariableHelpers.h
+++ b/Reconstruction/tauRecTools/src/TauShotVariableHelpers.h
@@ -14,65 +14,33 @@
 #define TAUSHOTVARIABLEHELPERS_H
 
 #include "xAODPFlow/PFO.h"
-#include "GaudiKernel/ToolHandle.h"
 #include "CaloInterface/IHadronicCalibrationTool.h"
+#include "CaloIdentifier/LArNeighbours.h"
 
-namespace TauShotVariableHelpers {
-
-    /** @brief get cell block with (currently) 5x2 cells in correct order for variable calculations */
-    std::vector<std::vector<const CaloCell*> > getCellBlock(xAOD::PFO* shot,
-                                                            const CaloCell_ID* calo_id);
-
-    /** @brief mean eta, used by other functions */
-    float mean_eta(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                   const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
-
-    /** @brief mean pt, used by other functions */ 
-    float mean_pt(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                  const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
-
-    /** @brief pt in windows */
-    float ptWindow(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                   int /*windowSize*/, 
-                   const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
-
-    /** @brief ws5 variable (egamma) */
-    float ws5(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-	      const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
-
-    /** @brief standard deviation in eta WRT mean */
-    float sdevEta_WRTmean(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                          const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
+#include "AsgMessaging/MessageCheck.h"
+#include "GaudiKernel/ToolHandle.h"
 
-    /** @brief standard deviation in eta WRT mode */
-    float sdevEta_WRTmode(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                          const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
+class CaloCell_ID;
 
-    /** @brief normalized standard deviation in pt */
-    float sdevPt(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-		 const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
+namespace TauShotVariableHelpers {
+  
+  ANA_MSG_HEADER(msgHelperFunction)
 
-    /** @brief pT diff b/w lead and sub-lead cell */
-    float deltaPt12_min(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                        const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
+  /** @brief Obtain the required neighbour cell */
+  const CaloCell* getNeighbour(const CaloCell* cell,
+                               const CaloClusterCellLink& links,
+                               const CaloCell_ID* calo_id,
+                               const LArNeighbours::neighbourOption& option);
 
-    /** @brief Fside variable (egamma) */
-    float Fside(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                int /*largerWindow*/, 
-                int /*smallerWindow*/, 
-                const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
+  /** @brief Get cell block with (currently) 2 x 5 cells in correct order for variable calculations */
+  std::vector<std::vector<const CaloCell*>> getCellBlock(const xAOD::PFO& shot,
+                                                         const CaloCell_ID* calo_id);
 
-    /** @brief similar than Fside but in unit of eta instead of number of cells */
-    float fracSide(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                   int /*largerWindow*/, 
-                   int /*smallerWindow*/, 
-                   ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
+  /** @brief pt in a window of (currently) 2 x windowSize cells */
+  float ptWindow(const std::vector<std::vector<const CaloCell*>>& shotCells, 
+                 int windowSize, 
+                 const ToolHandle<IHadronicCalibrationTool>& caloWeightTool);
 
-    /** @brief pt window fraction */
-    float ptWindowFrac(std::vector<std::vector<const CaloCell*> > /*shotCells*/, 
-                       int /*largerWindow*/, 
-                       int /*smallerWindow*/, 
-                       const ToolHandle<IHadronicCalibrationTool>& /*caloWeightTool*/);
 }
 
 #endif // TAUSHOTVARIABLEHELPERS_H
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h b/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h
index ba8b377e9a6fba52417b92dfe37cc0160df8533f..b57c013c456aedaa1b7a7e4a78db2ecd12817317 100644
--- a/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h
+++ b/Reconstruction/tauRecTools/tauRecTools/TauPi0ScoreCalculator.h
@@ -5,12 +5,13 @@
 #ifndef TAUREC_TAUPI0SCORECALCULATOR_H
 #define	TAUREC_TAUPI0SCORECALCULATOR_H
 
-#include <string>
-#include <map>
 #include "tauRecTools/TauRecToolBase.h"
-#include "xAODPFlow/PFO.h"
 #include "tauRecTools/BDTHelper.h"
 
+#include "xAODPFlow/PFO.h"
+
+#include <string>
+
 /**
  * @brief Selectes pi0Candidates (Pi0 Finder).
  * 
@@ -21,20 +22,24 @@
  */
 
 class TauPi0ScoreCalculator : public TauRecToolBase {
+
 public:
-    TauPi0ScoreCalculator(const std::string& name);
-    ASG_TOOL_CLASS2(TauPi0ScoreCalculator, TauRecToolBase, ITauToolBase)
-    virtual ~TauPi0ScoreCalculator();
 
-    virtual StatusCode initialize() override;
-    virtual StatusCode executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const override;
+  ASG_TOOL_CLASS2(TauPi0ScoreCalculator, TauRecToolBase, ITauToolBase)
+  
+  TauPi0ScoreCalculator(const std::string& name);
+  virtual ~TauPi0ScoreCalculator() = default;
+
+  virtual StatusCode initialize() override;
+  virtual StatusCode executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const override;
 
 private:
-    /** @brief function used to calculate BDT score */
-    float calculateScore(const xAOD::PFO* neutralPFO) const;
+  
+  /** @brief Calculate pi0 BDT score */
+  float calculateScore(const xAOD::PFO* neutralPFO) const;
 
-    std::string m_weightfile;    
-    std::unique_ptr<tauRecTools::BDTHelper> m_mvaBDT;
+  std::string m_weightfile = "";
+  std::unique_ptr<tauRecTools::BDTHelper> m_mvaBDT = nullptr;
 };
 
 #endif	/* TAUPI0SCORECALCULATOR_H */
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauPi0Selector.h b/Reconstruction/tauRecTools/tauRecTools/TauPi0Selector.h
index 16506a71f5ff79b3a3ddd5793e51e152b648b367..79c12f6eccf0a62fcd2ed313a1e23b449c622b29 100644
--- a/Reconstruction/tauRecTools/tauRecTools/TauPi0Selector.h
+++ b/Reconstruction/tauRecTools/tauRecTools/TauPi0Selector.h
@@ -5,11 +5,12 @@
 #ifndef TAUREC_TAUPI0SELECTOR_H
 #define	TAUREC_TAUPI0SELECTOR_H
 
-#include <string>
 #include "tauRecTools/TauRecToolBase.h"
 
+#include <string>
+
 /**
- * @brief Selects pi0s
+ * @brief Apply Et and BDT score cut to pi0s
  * 
  * @author Will Davey <will.davey@cern.ch> 
  * @author Benedict Winter <benedict.tobias.winter@cern.ch> 
@@ -17,21 +18,23 @@
  */
 
 class TauPi0Selector : public TauRecToolBase {
+
 public:
-    TauPi0Selector(const std::string& name);
-    ASG_TOOL_CLASS2(TauPi0Selector, TauRecToolBase, ITauToolBase)
-    virtual ~TauPi0Selector();
-    virtual StatusCode executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const override;
+  
+  ASG_TOOL_CLASS2(TauPi0Selector, TauRecToolBase, ITauToolBase)
+  
+  TauPi0Selector(const std::string& name);
+  virtual ~TauPi0Selector() = default;
+  
+  virtual StatusCode executePi0nPFO(xAOD::TauJet& pTau, xAOD::PFOContainer& pNeutralPFOContainer) const override;
 
 private:
+  /** @brief Get eta bin of Pi0Cluster */
+  int getEtaBin(double eta) const;
 
-    std::vector<float> m_clusterEtCut;
-    std::vector<float> m_clusterBDTCut_1prong;
-    std::vector<float> m_clusterBDTCut_mprong;
-    /** @brief function used to get eta bin of Pi0Cluster */
-    int getPi0Cluster_etaBin(double Pi0Cluster_eta) const;
-    /** @brief function used to calculate the visible tau 4 momentum */
-    TLorentzVector getP4(const xAOD::TauJet& tauJet) const;
+  std::vector<double> m_clusterEtCut;
+  std::vector<double> m_clusterBDTCut_1prong;
+  std::vector<double> m_clusterBDTCut_mprong;
 };
 
 #endif	/* TAUPI0SELECTOR_H */