diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h
index 9033f475eb45ca3ab0ca92da1214a9bacf60b3b3..87641570482e06e05ddf590d77076db8ad2f211e 100644
--- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h
+++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/DerivationFrameworkTau/TauIDDecoratorWrapper.h
@@ -12,6 +12,7 @@
 #include "StoreGate/ReadHandleKey.h"
 #include "StoreGate/WriteDecorHandleKeyArray.h"
 #include "xAODTau/TauJetContainer.h"
+#include "xAODTracking/VertexContainer.h"
 
 #include <string>
 #include <vector>
@@ -32,6 +33,7 @@ namespace DerivationFramework {
 
     private:
       SG::ReadHandleKey<xAOD::TauJetContainer> m_tauContainerKey { this, "TauContainerName", "TauJets", "Input tau container key" };
+      SG::ReadHandleKey<xAOD::VertexContainer> m_vtxContainerKey { this, "VertexContainerName", "PrimaryVertices", "Input PV container key" };
       SG::WriteDecorHandleKeyArray<xAOD::TauJetContainer> m_decorKeys{ this, "DecorationKeys", {}, "List of decorations added to the tau"};
 
       ToolHandleArray<TauRecToolBase> m_tauIDTools { this, "TauIDTools", {}, "" };
diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py
index 10873c78b1ad38086574a987e33a1135b172787d..c6a3495fd9bee888576c4bfd31b244ab14ec18ec 100644
--- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py
+++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauCommonConfig.py
@@ -103,6 +103,7 @@ def AddTauIDDecorationCfg(flags, **kwargs):
 
     kwargs.setdefault("evetoFix",         True)
     kwargs.setdefault("DeepSetID",        True)
+    kwargs.setdefault("GNNTauID",        True)
     kwargs.setdefault("TauContainerName", "TauJets")
     kwargs.setdefault("prefix",           kwargs['TauContainerName'])
 
@@ -123,6 +124,11 @@ def AddTauIDDecorationCfg(flags, **kwargs):
         tools.append( acc.popToolsAndMerge(tauTools.TauJetDeepSetEvaluatorCfg(flags, version="v2")) )
         tools.append( acc.popToolsAndMerge(tauTools.TauWPDecoratorJetDeepSetCfg(flags, version="v2")) )
 
+    if kwargs['GNNTauID']:    
+        # Add in GNTau!
+        tools.append( acc.popToolsAndMerge(tauTools.TauGNNEvaluatorCfg(flags)) )
+        tools.append( acc.popToolsAndMerge(tauTools.TauWPDecoratorGNNCfg(flags)) )
+
     if tools:
         for tool in tools:
             acc.addPublicTool(tool)
diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py
index a7067535a37e47e013e8c0a8a1dadae420fcff57..05602ed2898a3b5fa785e9f5ace959018baf1429 100644
--- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py
+++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/python/TauJetsCPContent.py
@@ -4,7 +4,7 @@
 
 TauJetsCPContent = [
     "TauJets",
-    "TauJetsAux.pt.eta.phi.m.ptFinalCalib.etaFinalCalib.ptTauEnergyScale.etaTauEnergyScale.charge.nChargedTracks.nIsolatedTracks.nAllTracks.isTauFlags.PanTau_DecayMode.NNDecayMode.NNDecayModeProb_1p0n.NNDecayModeProb_1p1n.NNDecayModeProb_1pXn.NNDecayModeProb_3p0n.RNNJetScore.RNNJetScoreSigTrans.JetDeepSetScore.JetDeepSetScoreTrans.JetDeepSetVeryLoose.JetDeepSetLoose.JetDeepSetMedium.JetDeepSetTight.JetDeepSetScore_v2.JetDeepSetScoreTrans_v2.JetDeepSetVeryLoose_v2.JetDeepSetLoose_v2.JetDeepSetMedium_v2.JetDeepSetTight_v2.RNNEleScore.RNNEleScoreSigTrans_v1.EleRNNLoose_v1.EleRNNMedium_v1.EleRNNTight_v1.tauTrackLinks.vertexLink.secondaryVertexLink.neutralPFOLinks.pi0PFOLinks.truthParticleLink.truthJetLink.trackWidth.centFrac.etOverPtLeadTrk.innerTrkAvgDist.absipSigLeadTrk.SumPtTrkFrac.EMPOverTrkSysP.ptRatioEflowApprox.mEflowApprox.dRmax.trFlightPathSig.massTrkSys.leadTrackProbNNorHT.EMFracFixed.etHotShotWinOverPtLeadTrk.hadLeakFracFixed.PSFrac.ClustersMeanCenterLambda.ClustersMeanFirstEngDens.ClustersMeanPresamplerFrac",
+    "TauJetsAux.pt.eta.phi.m.ptFinalCalib.etaFinalCalib.ptTauEnergyScale.etaTauEnergyScale.charge.nChargedTracks.nIsolatedTracks.nAllTracks.isTauFlags.PanTau_DecayMode.NNDecayMode.NNDecayModeProb_1p0n.NNDecayModeProb_1p1n.NNDecayModeProb_1pXn.NNDecayModeProb_3p0n.RNNJetScore.RNNJetScoreSigTrans.JetDeepSetScore.JetDeepSetScoreTrans.JetDeepSetVeryLoose.JetDeepSetLoose.JetDeepSetMedium.JetDeepSetTight.JetDeepSetScore_v2.JetDeepSetScoreTrans_v2.JetDeepSetVeryLoose_v2.JetDeepSetLoose_v2.JetDeepSetMedium_v2.JetDeepSetTight_v2.RNNEleScore.RNNEleScoreSigTrans_v1.EleRNNLoose_v1.EleRNNMedium_v1.EleRNNTight_v1.tauTrackLinks.vertexLink.secondaryVertexLink.neutralPFOLinks.pi0PFOLinks.truthParticleLink.truthJetLink.trackWidth.centFrac.etOverPtLeadTrk.innerTrkAvgDist.absipSigLeadTrk.SumPtTrkFrac.EMPOverTrkSysP.ptRatioEflowApprox.mEflowApprox.dRmax.trFlightPathSig.massTrkSys.leadTrackProbNNorHT.EMFracFixed.etHotShotWinOverPtLeadTrk.hadLeakFracFixed.PSFrac.ClustersMeanCenterLambda.ClustersMeanFirstEngDens.ClustersMeanPresamplerFrac.GNTauScore.GNTauScoreSigTrans_v0.GNTauVL_v0.GNTauL_v0.GNTauM_v0.GNTauT_v0",
     "TauTracks",
     "TauTracksAux.pt.eta.phi.flagSet.trackLinks.rnn_chargedScore.rnn_isolationScore.rnn_conversionScore",
     "InDetTrackParticles",
diff --git a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx
index ee95314f5f9a86879b35760ffd088eea7a625c12..4f218cf68678ca47ad3cfec87897d5a3af90d229 100644
--- a/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx
+++ b/PhysicsAnalysis/DerivationFramework/DerivationFrameworkTau/src/TauIDDecoratorWrapper.cxx
@@ -21,7 +21,7 @@ namespace DerivationFramework {
 
     // parse the properties of TauWPDecorator tools
     for (const auto& tool : m_tauIDTools) {
-      if (tool->type() != "TauWPDecorator") continue;
+      if ((tool->type() != "TauWPDecorator" )) continue;
 
       // check whether we must compute eVeto WPs, as this requires the recalculation of a variable
       BooleanProperty useAbsEta("UseAbsEta", false);
@@ -57,6 +57,7 @@ namespace DerivationFramework {
     
     // initialize read/write handle keys
     ATH_CHECK( m_tauContainerKey.initialize() );
+    ATH_CHECK( m_vtxContainerKey.initialize() );
     ATH_CHECK( m_decorKeys.initialize() );
 
     return StatusCode::SUCCESS;
@@ -77,30 +78,59 @@ namespace DerivationFramework {
     }
     const xAOD::TauJetContainer* tauContainer = tauJetsReadHandle.cptr();
 
-  static const SG::AuxElement::Decorator<float> acc_trackWidth("trackWidth");
-
-  for (const auto tau : *tauContainer) {
-    float tauTrackBasedWidth = 0;
-    // equivalent to
-    // tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged)
-    std::vector<const xAOD::TauTrack *> tauTracks = tau->tracks();
-    for (const xAOD::TauTrack *trk : tau->tracks(
-             xAOD::TauJetParameters::TauTrackFlag::classifiedIsolation)) {
-      tauTracks.push_back(trk);
-    }
-    double sumWeightedDR = 0.;
-    double ptSum = 0.;
-    for (const xAOD::TauTrack *track : tauTracks) {
-        double deltaR = tau->p4().DeltaR(track->p4());
-        sumWeightedDR += deltaR * track->pt();
-        ptSum += track->pt();
+    // retrieve PrimaryVertices container
+    SG::ReadHandle<xAOD::VertexContainer> vtxReadHandle(m_vtxContainerKey);
+    if (!vtxReadHandle.isValid()) {
+      ATH_MSG_ERROR ("Could not retrieve VertexContainer with key " << vtxReadHandle.key());
+      return StatusCode::FAILURE;
     }
-    if (ptSum > 0) {
-      tauTrackBasedWidth = sumWeightedDR / ptSum;
+    const xAOD::VertexContainer* vtxContainer = vtxReadHandle.cptr();
+    const xAOD::Vertex* pVtx = nullptr;
+
+    // Check that PV container exists and is non-empty, find the PV if possible
+    if(vtxContainer != nullptr && vtxContainer->size()>0) {
+      ATH_MSG_DEBUG("Found vtx container for decorating taus!");
+      auto itrVtx = std::find_if(vtxContainer->begin(), vtxContainer->end(),
+                                  [](const xAOD::Vertex* vtx) {
+                                      return vtx->vertexType() == xAOD::VxType::PriVtx;
+                                  });
+      pVtx = (itrVtx == vtxContainer->end() ? 0 : *itrVtx);
+      if(!pVtx){
+        ATH_MSG_WARNING("No PV found, using the first element instead!");
+        pVtx = vtxContainer->at(0);
+      }
     }
+    
+    //Create accessors  
+    static const SG::AuxElement::Decorator<float> acc_trackWidth("trackWidth");
+    static const SG::AuxElement::Accessor<float> acc_dz0_TV_PV0("dz0_TV_PV0");
+    static const SG::AuxElement::Accessor<float> acc_log_sumpt_TV("log_sumpt_TV");
+    static const SG::AuxElement::Accessor<float> acc_log_sumpt2_TV("log_sumpt2_TV");
+    static const SG::AuxElement::Accessor<float> acc_log_sumpt_PV0("log_sumpt_PV0");
+    static const SG::AuxElement::Accessor<float> acc_log_sumpt2_PV0("log_sumpt2_PV0");
+
+    for (const auto tau : *tauContainer) {
+      float tauTrackBasedWidth = 0;
+      // equivalent to
+      // tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged)
+      std::vector<const xAOD::TauTrack *> tauTracks = tau->tracks();
+      for (const xAOD::TauTrack *trk : tau->tracks(
+              xAOD::TauJetParameters::TauTrackFlag::classifiedIsolation)) {
+        tauTracks.push_back(trk);
+      }
+      double sumWeightedDR = 0.;
+      double ptSum = 0.;
+      for (const xAOD::TauTrack *track : tauTracks) {
+          double deltaR = tau->p4().DeltaR(track->p4());
+          sumWeightedDR += deltaR * track->pt();
+          ptSum += track->pt();
+      }
+      if (ptSum > 0) {
+        tauTrackBasedWidth = sumWeightedDR / ptSum;
+      }
 
-    acc_trackWidth(*tau) = tauTrackBasedWidth;
-  }
+      acc_trackWidth(*tau) = tauTrackBasedWidth;
+    }
 
     // create shallow copy
     auto shallowCopy = xAOD::shallowCopyContainer (*tauContainer);
@@ -108,6 +138,26 @@ namespace DerivationFramework {
     static const SG::AuxElement::Accessor<float> acc_absEtaLead("ABS_ETA_LEAD_TRACK");
 
     for (auto tau : *shallowCopy.first) {
+      
+      //Add in the TV/PV0 vertex variables needed for some calculators in TauGNNUtils.cxx (for GNTau)
+      float dz0_TV_PV0 = -999., sumpt_TV = 0., sumpt2_TV = 0., sumpt_PV0 = 0., sumpt2_PV0 = 0.;
+      if(pVtx!=nullptr) {
+        dz0_TV_PV0 = tau->vertex()->z() - pVtx->z();
+        for (const ElementLink<xAOD::TrackParticleContainer>& trk : pVtx->trackParticleLinks()) {
+          sumpt_PV0 += (*trk)->pt();
+          sumpt2_PV0 += pow((*trk)->pt(), 2.);
+        }
+        for (const ElementLink<xAOD::TrackParticleContainer>& trk : tau->vertex()->trackParticleLinks()) {
+          sumpt_TV += (*trk)->pt();
+          sumpt2_TV += pow((*trk)->pt(), 2.);
+        }
+      }
+      acc_dz0_TV_PV0(*tau) = dz0_TV_PV0;
+      acc_log_sumpt_TV(*tau) = (sumpt_TV>0.) ? std::log(sumpt_TV) : 0.;
+      acc_log_sumpt2_TV(*tau) = (sumpt2_TV>0.) ? std::log(sumpt2_TV) : 0.;
+      acc_log_sumpt_PV0(*tau) = (sumpt_PV0>0.) ? std::log(sumpt_PV0) : 0.;
+      acc_log_sumpt2_PV0(*tau) = (sumpt2_PV0>0.) ? std::log(sumpt2_PV0) : 0.;
+      //End of vertex variable addition block  
 
       // ABS_ETA_LEAD_TRACK is removed from the AOD content and must be redecorated when computing eVeto WPs
       // note: this redecoration is not robust against charged track thinning, but charged tracks should never be thinned      
diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py
index f4c7ff650c38985a3f7eb8936ba7d2879c78156b..d17807f2bf9667ee81c67d1a18046181b2430c6d 100644
--- a/Reconstruction/tauRec/python/TauConfigFlags.py
+++ b/Reconstruction/tauRec/python/TauConfigFlags.py
@@ -60,6 +60,9 @@ def createTauConfigFlags():
     # R22 DeepSet tau ID tune without track RNN scores, for now define a second set of flags, but ultimately we'll choose one and drop the other
     tau_cfg.addFlag("Tau.TauJetDeepSetConfig_v2", ["tauid_1p_R22_dpst_noTrackScore.json", "tauid_2p_R22_dpst_noTrackScore.json", "tauid_3p_R22_dpst_noTrackScore.json"])
     tau_cfg.addFlag("Tau.TauJetDeepSetWP_v2", ["model_1p_R22_dpst_noTrackScore.root", "model_2p_R22_dpst_noTrackScore.root", "model_3p_R22_dpst_noTrackScore.root"])
+    # GNTau ID tune file (need to add another version for noAux)
+    tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_noAux_simplified.onnx"])
+    tau_cfg.addFlag("Tau.TauGNNWP_v0", ["GNTauNA_flat_model_1p.root", "GNTauNA_flat_model_2p.root", "GNTauNA_flat_model_3p.root"])
 
     # PanTau config flags
     from PanTauAlgs.PanTauConfigFlags import createPanTauConfigFlags
diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py
index 1a9b19d3f770dda57a20555413f8fe7924be9ec7..338a1ad7a67942987cf2913c9f355d28c9324c74 100644
--- a/Reconstruction/tauRec/python/TauToolHolder.py
+++ b/Reconstruction/tauRec/python/TauToolHolder.py
@@ -858,6 +858,54 @@ def TauWPDecoratorJetDeepSetCfg(flags, version=None):
     result.setPrivateTools(myTauWPDecorator)
     return result
 
+def TauGNNEvaluatorCfg(flags):
+    result = ComponentAccumulator()
+    _name = flags.Tau.ActiveConfig.prefix + 'TauGNN'
+
+    TauGNNEvaluator = CompFactory.getComp("TauGNNEvaluator")
+    GNNConf = flags.Tau.TauGNNConfig
+    myTauGNNEvaluator = TauGNNEvaluator(name = _name,
+                                              NetworkFile = GNNConf[0],
+                                              OutputVarname = "GNTauScore",
+                                              OutputPTau = "GNTauProbTau",
+                                              OutputPJet = "GNTauProbJet",
+                                              MaxTracks = 30,
+                                              MaxClusters = 20,
+                                              MaxClusterDR = 15.0,
+                                              VertexCorrection = True,
+                                              DecorateTracks = False,
+                                              InputLayerScalar = "tau_vars",
+                                              InputLayerTracks = "track_vars",
+                                              InputLayerClusters = "cluster_vars",
+                                              NodeNameTau="GN2TauNoAux_pb",
+                                              NodeNameJet="GN2TauNoAux_pu")
+
+    result.setPrivateTools(myTauGNNEvaluator)
+    return result
+
+def TauWPDecoratorGNNCfg(flags):
+    result = ComponentAccumulator()
+    _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN'
+
+    TauWPDecorator = CompFactory.getComp("TauWPDecorator")
+    WPConf = flags.Tau.TauGNNWP_v0
+    decorWPNames = ["GNTauVL_v0", "GNTauL_v0", "GNTauM_v0", "GNTauT_v0"]
+    scoreName = "GNTauScore"
+    newScoreName = "GNTauScoreSigTrans_v0"
+    myTauWPDecorator = TauWPDecorator(name=_name,
+                                      flatteningFile1Prong = WPConf[0],
+                                      flatteningFile2Prong = WPConf[1],
+                                      flatteningFile3Prong = WPConf[2],
+                                      DecorWPNames = decorWPNames,
+                                      DecorWPCutEffs1P = [0.95, 0.85, 0.75, 0.60],
+                                      DecorWPCutEffs2P = [0.95, 0.75, 0.60, 0.45],
+                                      DecorWPCutEffs3P = [0.95, 0.75, 0.60, 0.45],
+                                      ScoreName = scoreName,
+                                      NewScoreName = newScoreName,
+                                      DefineWPs = True)
+    result.setPrivateTools(myTauWPDecorator)
+    return result
+
 def TauEleRNNEvaluatorCfg(flags):
     result = ComponentAccumulator()
     _name = flags.Tau.ActiveConfig.prefix + 'TauEleRNN'
diff --git a/Reconstruction/tauRecTools/CMakeLists.txt b/Reconstruction/tauRecTools/CMakeLists.txt
index 40a34801c23a244adb2a9be46ffa39c086df3fe2..87b3c12ef094a880e7a73652436f1f74dc33a39f 100644
--- a/Reconstruction/tauRecTools/CMakeLists.txt
+++ b/Reconstruction/tauRecTools/CMakeLists.txt
@@ -8,6 +8,7 @@ find_package( Boost )
 find_package( Eigen )
 find_package( ROOT COMPONENTS Core Tree Hist RIO )
 find_package( lwtnn )
+find_package( onnxruntime )
 
 # Optional dependencies.
 set( extra_public_libs )
@@ -28,12 +29,12 @@ atlas_add_library( tauRecToolsLib
    tauRecTools/*.h Root/*.cxx tauRecTools/lwtnn/*.h Root/lwtnn/*.cxx
    PUBLIC_HEADERS tauRecTools
    INCLUDE_DIRS ${Boost_INCLUDE_DIRS} ${EIGEN_INCLUDE_DIRS}
-   ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS}
+   ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${ONNXRUNTIME_INCLUDE_DIRS}
    LINK_LIBRARIES ${Boost_LIBRARIES} ${EIGEN_LIBRARIES} ${LWTNN_LIBRARIES}
-   ${ROOT_LIBRARIES}
+   ${ROOT_LIBRARIES} ${ONNXRUNTIME_LIBRARIES}
    CxxUtils AthLinks AsgMessagingLib AsgDataHandlesLib AsgTools xAODCaloEvent
    xAODEventInfo xAODJet xAODParticleEvent xAODPFlow xAODTau xAODTracking xAODEventShape
-   MVAUtils ${extra_public_libs}
+   MVAUtils FlavorTagDiscriminants ${extra_public_libs}
    PRIVATE_LINK_LIBRARIES CaloGeoHelpers FourMomUtils PathResolver )
 
 atlas_add_dictionary( tauRecToolsDict
@@ -46,5 +47,5 @@ if( NOT XAOD_STANDALONE )
       INCLUDE_DIRS ${Boost_INCLUDE_DIRS}
       LINK_LIBRARIES ${Boost_LIBRARIES} AsgTools AsgMessagingLib xAODBase
       xAODCaloEvent xAODJet xAODPFlow xAODTau FourMomUtils tauRecToolsLib PFlowUtilsLib
-      ${extra_private_libs} )
+      FlavorTagDiscriminants ${extra_private_libs} )
 endif()
diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..65f13c4c2f6bdde1fd3fd0901579612eb2e44bb4
--- /dev/null
+++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx
@@ -0,0 +1,206 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "tauRecTools/TauGNN.h"
+#include "FlavorTagDiscriminants/OnnxUtil.h"
+#include "PathResolver/PathResolver.h"
+
+#include <algorithm>
+#include <fstream>
+
+#include "lwtnn/LightweightGraph.hh"
+#include "lwtnn/Exceptions.hh"
+//#include "lwtnn/parse_json.hh"
+
+#include "tauRecTools/TauGNNUtils.h"
+
+TauGNN::TauGNN(const std::string &nnFile, const Config &config):
+    asg::AsgMessaging("TauGNN"),
+    m_onnxUtil(nullptr)
+  {
+    //==================================================//
+    // This part is ported from FTagDiscriminant GNN.cxx//
+    //==================================================//
+
+    m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile);
+
+    // get the configuration of the model outputs
+    FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
+    
+    //Let's see the output!
+    for (const auto& out_node: gnn_output_config) {
+        if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name);
+        if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name);
+        if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name);
+    }
+
+    //Get model config (for inputs)
+    auto lwtnn_config = m_onnxUtil->getLwtConfig();
+    
+    //===================================================//
+    // This part is ported from tauRecTools TauJetRNN.cxx//
+    //===================================================//
+
+    // Search for input layer names specified in 'config'
+    auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
+        return in_node.name == config.input_layer_scalar;
+    };
+    auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
+        return in_node.name == config.input_layer_tracks;
+    };
+    auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
+        return in_node.name == config.input_layer_clusters;
+    };
+
+    auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
+                                    lwtnn_config.inputs.cend(),
+                                    node_is_scalar);
+
+    auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
+                                   lwtnn_config.input_sequences.cend(),
+                                   node_is_track);
+
+    auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
+                                     lwtnn_config.input_sequences.cend(),
+                                     node_is_cluster);
+
+    // Check which input layers were found
+    auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
+    auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
+    auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
+    if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!");
+    if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!");
+    if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!");
+    
+    // Fill the variable names of each input layer into the corresponding vector
+    if (has_scalar_node) {
+        for (const auto &in : scalar_node->variables) {
+            std::string name = in.name;
+            m_scalarCalc_inputs.push_back(name);
+        }
+    }
+
+    if (has_track_node) {
+        for (const auto &in : track_node->variables) {
+            std::string name = in.name;
+            m_trackCalc_inputs.push_back(name);
+        }
+    }
+
+    if (has_cluster_node) {
+        for (const auto &in : cluster_node->variables) {
+            std::string name = in.name;
+            m_clusterCalc_inputs.push_back(name);
+        }
+    }
+    // Load the variable calculator
+    m_var_calc = TauGNNUtils::get_calculator(m_scalarCalc_inputs, m_trackCalc_inputs, m_clusterCalc_inputs);
+    ATH_MSG_INFO("TauGNN object initialized successfully!");
+}
+
+TauGNN::~TauGNN() {}
+
+std::tuple<
+    std::map<std::string, float>,
+    std::map<std::string, std::vector<char>>,
+    std::map<std::string, std::vector<float>> >
+TauGNN::compute(const xAOD::TauJet &tau,
+                         const std::vector<const xAOD::TauTrack *> &tracks,
+                         const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
+    InputMap scalarInputs;
+    InputSequenceMap vectorInputs;
+    std::map<std::string, input_pair> gnn_input;
+    ATH_MSG_DEBUG("Starting compute...");
+    //Prepare input variables
+    if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
+        ATH_MSG_FATAL("Failed calculateInputVariables");
+        throw StatusCode::FAILURE;
+    }
+
+    // Add TauJet-level features to the input
+    std::vector<float> tau_feats;
+    for (const auto &varname : m_scalarCalc_inputs) {
+        tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname]));
+    }
+    std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())};
+    input_pair tau_info (tau_feats, tau_feats_dim);
+    gnn_input.insert({"tau_vars", tau_info});
+
+    //Add track-level features to the input
+    std::vector<float> trk_feats;
+    int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size());
+    int num_node_vars=static_cast<int>(m_trackCalc_inputs.size());
+    trk_feats.resize(num_nodes * num_node_vars);
+    int var_idx=0;
+    for (const auto &varname : m_trackCalc_inputs) {
+        for (int node_idx=0; node_idx<num_nodes; node_idx++){    
+            trk_feats.at(node_idx*num_node_vars + var_idx)
+              = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx));
+        }
+        var_idx++;
+    }
+    std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars};
+    input_pair trk_info (trk_feats, trk_feats_dim);
+    gnn_input.insert({"track_vars", trk_info});
+    
+    //Add cluster-level features to the input
+    std::vector<float> cls_feats;
+    num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size());
+    num_node_vars=static_cast<int>(m_clusterCalc_inputs.size());
+    cls_feats.resize(num_nodes * num_node_vars);
+    var_idx=0;
+    for (const auto &varname : m_clusterCalc_inputs) {
+        for (int node_idx=0; node_idx<num_nodes; node_idx++){    
+            cls_feats.at(node_idx*num_node_vars + var_idx)
+              = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx));
+        }
+        var_idx++;
+    }
+    std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars};
+    input_pair cls_info (cls_feats, cls_feats_dim);
+    gnn_input.insert({"cluster_vars", cls_info});    
+
+    //RUN THE INFERENCE!!!
+    ATH_MSG_DEBUG("Prepared inputs, running inference...");
+    auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input);
+    ATH_MSG_DEBUG("Finished compute!");
+    return std::make_tuple(out_f, out_vc, out_vf);
+}
+
+bool TauGNN::calculateInputVariables(const xAOD::TauJet &tau,
+                  const std::vector<const xAOD::TauTrack *> &tracks,
+                  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                  std::map<std::string, std::map<std::string, double>>& scalarInputs,
+                  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
+    scalarInputs.clear();
+    vectorInputs.clear();
+    // Populate input (sequence) map with input variables
+    for (const auto &varname : m_scalarCalc_inputs) {
+        if (!m_var_calc->compute(varname, tau,
+                                 scalarInputs[m_config.input_layer_scalar][varname])) {
+            ATH_MSG_WARNING("Error computing '" << varname
+                            << "' returning default");
+            return false;
+        }
+    }
+
+    for (const auto &varname : m_trackCalc_inputs) {
+        if (!m_var_calc->compute(varname, tau, tracks,
+                                 vectorInputs[m_config.input_layer_tracks][varname])) {
+            ATH_MSG_WARNING("Error computing '" << varname
+                            << "' returning default");
+            return false;
+        }
+    }
+
+    for (const auto &varname : m_clusterCalc_inputs) {
+        if (!m_var_calc->compute(varname, tau, clusters,
+                                 vectorInputs[m_config.input_layer_clusters][varname])) {
+            ATH_MSG_WARNING("Error computing '" << varname
+                            << "' returning default");
+            return false;
+        }
+    }
+    return true;
+}
diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..0be8819d79ec52275dfb7c251872439bec964221
--- /dev/null
+++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx
@@ -0,0 +1,175 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "tauRecTools/TauGNNEvaluator.h"
+#include "tauRecTools/TauGNN.h"
+#include "tauRecTools/HelperFunctions.h"
+
+#include "PathResolver/PathResolver.h"
+
+#include <algorithm>
+
+
+TauGNNEvaluator::TauGNNEvaluator(const std::string &name): 
+  TauRecToolBase(name),
+  m_net(nullptr){
+    
+  declareProperty("NetworkFile", m_weightfile = "");
+  declareProperty("OutputVarname", m_output_varname = "GNTauScore");
+  declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau");
+  declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet");
+  declareProperty("MaxTracks", m_max_tracks = 30);
+  declareProperty("MaxClusters", m_max_clusters = 20);
+  declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f);
+  declareProperty("VertexCorrection", m_doVertexCorrection = true);
+  declareProperty("DecorateTracks", m_decorateTracks = false);
+  declareProperty("TrackClassification", m_doTrackClassification = true);
+
+  // Naming conventions for the network weight files:
+  declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars");
+  declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars");
+  declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars");
+  declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb");
+  declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu");
+  }
+
+TauGNNEvaluator::~TauGNNEvaluator() {}
+
+StatusCode TauGNNEvaluator::initialize() {
+  ATH_MSG_INFO("Initializing TauGNNEvaluator");
+  
+  std::string weightfile("");
+
+  // Use PathResolver to search for the weight files
+  if (!m_weightfile.empty()) {
+    weightfile = find_file(m_weightfile);
+    if (weightfile.empty()) {
+      ATH_MSG_ERROR("Could not find network weights: " << m_weightfile);
+      return StatusCode::FAILURE;
+    } else {
+      ATH_MSG_INFO("Using network config: " << weightfile);
+    }
+  }
+
+  // Set the layer and node names in the weight file
+  TauGNN::Config config;
+  config.input_layer_scalar = m_input_layer_scalar;
+  config.input_layer_tracks = m_input_layer_tracks;
+  config.input_layer_clusters = m_input_layer_clusters;
+  config.output_node_tau = m_outnode_tau;
+  config.output_node_jet = m_outnode_jet;
+
+  // Load the weights and create the network
+  if (!weightfile.empty()) {
+    m_net = std::make_unique<TauGNN>(weightfile, config);
+    if (!m_net) {
+      ATH_MSG_ERROR("No network configured.");
+      return StatusCode::FAILURE;
+    }
+  }
+
+  return StatusCode::SUCCESS;
+}
+
+StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const {
+  // Output variable Decorators
+  const SG::AuxElement::Accessor<float> output(m_output_varname);
+  const SG::AuxElement::Accessor<float> out_ptau(m_output_ptau);
+  const SG::AuxElement::Accessor<float> out_pjet(m_output_pjet);
+  const SG::AuxElement::Decorator<char> out_trkclass("GNTau_TrackClass");
+  // Set default score and overwrite later
+  output(tau) = -1111.0f;
+  out_ptau(tau) = -1111.0f;
+  out_pjet(tau) = -1111.0f;
+  //Skip execution for low-pT taus to save resources
+  if(tau.pt()<13000) {
+    return StatusCode::SUCCESS;
+  }
+
+  // Get input objects
+  std::vector<const xAOD::TauTrack *> tracks;
+  ATH_CHECK(get_tracks(tau, tracks));
+  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
+  ATH_CHECK(get_clusters(tau, clusters));
+
+  // Truncate tracks
+  int numTracksMax = std::min(m_max_tracks, tracks.size());
+  std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
+  // Evaluate networks
+  if (m_net) {
+    auto [out_f, out_vc, out_vf] = m_net->compute(tau, trackVec, clusters);
+    output(tau)=std::log10(1/(1-out_f.at(m_outnode_tau)));
+    out_ptau(tau)=out_f.at(m_outnode_tau);
+    out_pjet(tau)=out_f.at(m_outnode_jet);
+    if (m_decorateTracks){
+      for(unsigned int i=0;i<tracks.size();i++){
+        if(i<out_vc.at("track_class").size()){out_trkclass(*tracks.at(i))=out_vc.at("track_class").at(i);}
+        else{out_trkclass(*tracks.at(i))='9';} //Dummy value for tracks outside range of out_vc
+      }
+    }
+  }
+  
+  return StatusCode::SUCCESS;
+}
+
+const TauGNN* TauGNNEvaluator::get_gnn() const {
+  return m_net.get();
+}
+
+
+StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
+  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
+
+  // Skip unclassified tracks:
+  // - the track is a LRT and classifyLRT = false
+  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
+  // - track classification is not run (trigger)
+  if(m_doTrackClassification) {
+    std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin();
+    while(it != tracks.end()) {
+      if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
+  it = tracks.erase(it);
+      }
+      else {
+	++it;
+      }
+    }
+  }
+
+  // Sort by descending pt
+  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
+    return lhs->pt() > rhs->pt();
+  };
+  std::sort(tracks.begin(), tracks.end(), cmp_pt);
+  out = std::move(tracks);
+
+  return StatusCode::SUCCESS;
+}
+
+StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
+
+  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
+
+  std::vector<xAOD::CaloVertexedTopoCluster> vertexedClusterList = tau.vertexedClusters();
+  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : vertexedClusterList) {
+    TLorentzVector clusterP4 = vertexedCluster.p4();
+    if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
+      
+    clusters.push_back(vertexedCluster);
+  }
+
+  // Sort by descending et
+  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
+		   const xAOD::CaloVertexedTopoCluster& rhs) {
+    return lhs.p4().Et() > rhs.p4().Et();
+  };
+  std::sort(clusters.begin(), clusters.end(), et_cmp);
+
+  // Truncate clusters
+  if (clusters.size() > m_max_clusters) {
+    clusters.resize(m_max_clusters, clusters[0]);
+  }
+
+  return StatusCode::SUCCESS;
+}
diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..8da2fe5ea6ec2056d79a99a00b21908155379630
--- /dev/null
+++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx
@@ -0,0 +1,914 @@
+/*
+  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "tauRecTools/TauGNNUtils.h"
+#include "tauRecTools/HelperFunctions.h"
+#include <algorithm>
+#include <iostream>
+#define GeV 1000
+
+namespace TauGNNUtils {
+
+GNNVarCalc::GNNVarCalc() : asg::AsgMessaging("TauGNNUtils::GNNVarCalc") {
+}
+
+bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
+                      double &out) const {
+    // Retrieve calculator function
+    ScalarCalc func = nullptr;
+    try {
+        func = m_scalar_map.at(name);
+    } catch (const std::out_of_range &e) {
+        ATH_MSG_ERROR("Variable '" << name << "' not defined");
+        throw;
+    }
+
+    // Calculate variable
+    return func(tau, out);
+}
+
+bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
+                      const std::vector<const xAOD::TauTrack *> &tracks,
+                      std::vector<double> &out) const {
+    out.clear();
+
+    // Retrieve calculator function
+    TrackCalc func = nullptr;
+    try {
+        func = m_track_map.at(name);
+    } catch (const std::out_of_range &e) {
+        ATH_MSG_ERROR("Variable '" << name << "' not defined");
+        throw;
+    }
+
+    // Calculate variables for selected tracks
+    bool success = true;
+    double value;
+    for (const auto *const trk : tracks) {
+        success = success && func(tau, *trk, value);
+        out.push_back(value);
+    }
+
+    return success;
+}
+
+bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
+                      const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                      std::vector<double> &out) const {
+    out.clear();
+
+    // Retrieve calculator function
+    ClusterCalc func = nullptr;
+    try {
+        func = m_cluster_map.at(name);
+    } catch (const std::out_of_range &e) {
+        ATH_MSG_ERROR("Variable '" << name << "' not defined");
+        throw;
+    }
+
+    // Calculate variables for selected clusters
+    bool success = true;
+    double value;
+    for (const xAOD::CaloVertexedTopoCluster& cluster : clusters) {
+        success = success && func(tau, cluster, value);
+        out.push_back(value);
+    }
+
+    return success;
+}
+
+void GNNVarCalc::insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars) {
+    if (std::find(scalar_vars.begin(), scalar_vars.end(), name) == scalar_vars.end()) {
+      return;
+    }
+    if (!func) {
+        throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert");
+    }
+    m_scalar_map[name] = func;
+}
+
+void GNNVarCalc::insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars) {
+    if (std::find(track_vars.begin(), track_vars.end(), name) == track_vars.end()) {
+      return;
+    }
+    if (!func) {
+        throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert");
+    }
+    m_track_map[name] = func;
+}
+
+void GNNVarCalc::insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars) {
+    if (std::find(cluster_vars.begin(), cluster_vars.end(), name) == cluster_vars.end()) {
+      return;
+    }
+    if (!func) {
+        throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert");
+    }
+    m_cluster_map[name] = func;
+}
+
+std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars,
+					const std::vector<std::string>& track_vars,
+					const std::vector<std::string>& cluster_vars) {
+    auto calc = std::make_unique<GNNVarCalc>();
+
+    // Scalar variable calculator functions
+    calc->insert("absEta", Variables::absEta, scalar_vars);
+    calc->insert("isolFrac", Variables::isolFrac, scalar_vars);
+    calc->insert("centFrac", Variables::centFrac, scalar_vars);
+    calc->insert("etOverPtLeadTrk", Variables::etOverPtLeadTrk, scalar_vars);
+    calc->insert("innerTrkAvgDist", Variables::innerTrkAvgDist, scalar_vars);
+    calc->insert("absipSigLeadTrk", Variables::absipSigLeadTrk, scalar_vars);
+    calc->insert("SumPtTrkFrac", Variables::SumPtTrkFrac, scalar_vars);
+    calc->insert("sumEMCellEtOverLeadTrkPt", Variables::sumEMCellEtOverLeadTrkPt, scalar_vars);
+    calc->insert("EMPOverTrkSysP", Variables::EMPOverTrkSysP, scalar_vars);
+    calc->insert("ptRatioEflowApprox", Variables::ptRatioEflowApprox, scalar_vars);
+    calc->insert("mEflowApprox", Variables::mEflowApprox, scalar_vars);
+    calc->insert("dRmax", Variables::dRmax, scalar_vars);
+    calc->insert("trFlightPathSig", Variables::trFlightPathSig, scalar_vars);
+    calc->insert("massTrkSys", Variables::massTrkSys, scalar_vars);
+    calc->insert("pt", Variables::pt, scalar_vars);
+    calc->insert("pt_tau_log", Variables::pt_tau_log, scalar_vars);
+    calc->insert("ptDetectorAxis", Variables::ptDetectorAxis, scalar_vars);
+    calc->insert("ptIntermediateAxis", Variables::ptIntermediateAxis, scalar_vars);
+    //---added for the eVeto
+    calc->insert("ptJetSeed_log",              Variables::ptJetSeed_log, scalar_vars);
+    calc->insert("absleadTrackEta",            Variables::absleadTrackEta, scalar_vars);
+    calc->insert("leadTrackDeltaEta",          Variables::leadTrackDeltaEta, scalar_vars);
+    calc->insert("leadTrackDeltaPhi",          Variables::leadTrackDeltaPhi, scalar_vars);
+    calc->insert("leadTrackProbNNorHT",        Variables::leadTrackProbNNorHT, scalar_vars);
+    calc->insert("EMFracFixed",                Variables::EMFracFixed, scalar_vars);
+    calc->insert("etHotShotWinOverPtLeadTrk",  Variables::etHotShotWinOverPtLeadTrk, scalar_vars);
+    calc->insert("hadLeakFracFixed",           Variables::hadLeakFracFixed, scalar_vars);
+    calc->insert("PSFrac",                     Variables::PSFrac, scalar_vars);
+    calc->insert("ClustersMeanCenterLambda",   Variables::ClustersMeanCenterLambda, scalar_vars);
+    calc->insert("ClustersMeanFirstEngDens",   Variables::ClustersMeanFirstEngDens, scalar_vars);
+    calc->insert("ClustersMeanPresamplerFrac", Variables::ClustersMeanPresamplerFrac, scalar_vars);
+
+    // Track variable calculator functions
+    calc->insert("pt_log", Variables::Track::pt_log, track_vars);
+    calc->insert("trackPt", Variables::Track::trackPt, track_vars);
+    calc->insert("trackEta", Variables::Track::trackEta, track_vars);
+    calc->insert("trackPhi", Variables::Track::trackPhi, track_vars);
+    calc->insert("pt_tau_log", Variables::Track::pt_tau_log, track_vars);
+    calc->insert("pt_jetseed_log", Variables::Track::pt_jetseed_log, track_vars);
+    calc->insert("d0_abs_log", Variables::Track::d0_abs_log, track_vars);
+    calc->insert("z0sinThetaTJVA_abs_log", Variables::Track::z0sinThetaTJVA_abs_log, track_vars);
+    calc->insert("z0sinthetaTJVA", Variables::Track::z0sinthetaTJVA, track_vars);
+    calc->insert("z0sinthetaSigTJVA", Variables::Track::z0sinthetaSigTJVA, track_vars);
+    calc->insert("d0TJVA", Variables::Track::d0TJVA, track_vars);
+    calc->insert("d0SigTJVA", Variables::Track::d0SigTJVA, track_vars);
+    calc->insert("dEta", Variables::Track::dEta, track_vars);
+    calc->insert("dPhi", Variables::Track::dPhi, track_vars);
+    calc->insert("nInnermostPixelHits", Variables::Track::nInnermostPixelHits, track_vars);
+    calc->insert("nPixelHits", Variables::Track::nPixelHits, track_vars);
+    calc->insert("nSCTHits", Variables::Track::nSCTHits, track_vars);
+    calc->insert("nIBLHitsAndExp", Variables::Track::nIBLHitsAndExp, track_vars);
+    calc->insert("nPixelHitsPlusDeadSensors", Variables::Track::nPixelHitsPlusDeadSensors, track_vars);
+    calc->insert("nSCTHitsPlusDeadSensors", Variables::Track::nSCTHitsPlusDeadSensors, track_vars);
+    calc->insert("eProbabilityHT", Variables::Track::eProbabilityHT, track_vars);
+    calc->insert("eProbabilityNN", Variables::Track::eProbabilityNN, track_vars);
+    calc->insert("eProbabilityNNorHT", Variables::Track::eProbabilityNNorHT, track_vars);
+    calc->insert("chargedScoreRNN", Variables::Track::chargedScoreRNN, track_vars);
+    calc->insert("isolationScoreRNN", Variables::Track::isolationScoreRNN, track_vars);
+    calc->insert("conversionScoreRNN", Variables::Track::conversionScoreRNN, track_vars);
+    calc->insert("fakeScoreRNN", Variables::Track::fakeScoreRNN, track_vars);
+    //Extension - variables for GNTau
+    calc->insert("numberOfInnermostPixelLayerHits", Variables::Track::numberOfInnermostPixelLayerHits, track_vars);
+    calc->insert("numberOfPixelHits", Variables::Track::numberOfPixelHits, track_vars);
+    calc->insert("numberOfPixelSharedHits", Variables::Track::numberOfPixelSharedHits, track_vars);
+    calc->insert("numberOfPixelDeadSensors", Variables::Track::numberOfPixelDeadSensors, track_vars);
+    calc->insert("numberOfSCTHits", Variables::Track::numberOfSCTHits, track_vars);
+    calc->insert("numberOfSCTSharedHits", Variables::Track::numberOfSCTSharedHits, track_vars);
+    calc->insert("numberOfSCTDeadSensors", Variables::Track::numberOfSCTDeadSensors, track_vars);
+    calc->insert("numberOfTRTHighThresholdHits", Variables::Track::numberOfTRTHighThresholdHits, track_vars);
+    calc->insert("numberOfTRTHits", Variables::Track::numberOfTRTHits, track_vars);
+    calc->insert("nSiHits", Variables::Track::nSiHits, track_vars);
+    calc->insert("expectInnermostPixelLayerHit", Variables::Track::expectInnermostPixelLayerHit, track_vars);
+    calc->insert("expectNextToInnermostPixelLayerHit", Variables::Track::expectNextToInnermostPixelLayerHit, track_vars);
+    calc->insert("numberOfContribPixelLayers", Variables::Track::numberOfContribPixelLayers, track_vars);
+    calc->insert("numberOfPixelHoles", Variables::Track::numberOfPixelHoles, track_vars);
+    calc->insert("d0_old", Variables::Track::d0_old, track_vars);
+    calc->insert("qOverP", Variables::Track::qOverP, track_vars);
+    calc->insert("theta", Variables::Track::theta, track_vars);
+    calc->insert("z0TJVA", Variables::Track::z0TJVA, track_vars);
+    calc->insert("charge", Variables::Track::charge, track_vars);
+    calc->insert("dz0_TV_PV0", Variables::Track::dz0_TV_PV0, track_vars);
+    calc->insert("log_sumpt_TV", Variables::Track::log_sumpt_TV, track_vars);
+    calc->insert("log_sumpt2_TV", Variables::Track::log_sumpt2_TV, track_vars);
+    calc->insert("log_sumpt_PV0", Variables::Track::log_sumpt_PV0, track_vars);
+    calc->insert("log_sumpt2_PV0", Variables::Track::log_sumpt2_PV0, track_vars);
+
+    // Cluster variable calculator functions
+    calc->insert("et_log", Variables::Cluster::et_log, cluster_vars);
+    calc->insert("pt_tau_log", Variables::Cluster::pt_tau_log, cluster_vars);
+    calc->insert("pt_jetseed_log", Variables::Cluster::pt_jetseed_log, cluster_vars);
+    calc->insert("dEta", Variables::Cluster::dEta, cluster_vars);
+    calc->insert("dPhi", Variables::Cluster::dPhi, cluster_vars);
+    calc->insert("SECOND_R", Variables::Cluster::SECOND_R, cluster_vars);
+    calc->insert("SECOND_LAMBDA", Variables::Cluster::SECOND_LAMBDA, cluster_vars);
+    calc->insert("CENTER_LAMBDA", Variables::Cluster::CENTER_LAMBDA, cluster_vars);
+    //---added for the eVeto
+    calc->insert("SECOND_LAMBDAOverClustersMeanSecondLambda", Variables::Cluster::SECOND_LAMBDAOverClustersMeanSecondLambda, cluster_vars);
+    calc->insert("CENTER_LAMBDAOverClustersMeanCenterLambda", Variables::Cluster::CENTER_LAMBDAOverClustersMeanCenterLambda, cluster_vars);
+    calc->insert("FirstEngDensOverClustersMeanFirstEngDens" , Variables::Cluster::FirstEngDensOverClustersMeanFirstEngDens, cluster_vars);
+
+    //Extension - Variables for GNTau
+    calc->insert("e", Variables::Cluster::e, cluster_vars);
+    calc->insert("et", Variables::Cluster::et, cluster_vars);
+    calc->insert("FIRST_ENG_DENS", Variables::Cluster::FIRST_ENG_DENS, cluster_vars);
+    calc->insert("EM_PROBABILITY", Variables::Cluster::EM_PROBABILITY, cluster_vars);
+    calc->insert("CENTER_MAG", Variables::Cluster::CENTER_MAG, cluster_vars);
+    return calc;
+}
+
+
+namespace Variables {
+using TauDetail = xAOD::TauJetParameters::Detail;
+
+bool absEta(const xAOD::TauJet &tau, double &out) {
+    out = std::abs(tau.eta());
+    return true;
+}
+
+bool centFrac(const xAOD::TauJet &tau, double &out) {
+    float centFrac;
+    const auto success = tau.detail(TauDetail::centFrac, centFrac);
+    //out = std::min(centFrac, 1.0f);
+    out = centFrac;
+    return success;
+}
+
+bool isolFrac(const xAOD::TauJet &tau, double &out) {
+    float isolFrac;
+    const auto success = tau.detail(TauDetail::isolFrac, isolFrac);
+    //out = std::min(isolFrac, 1.0f);
+    out = isolFrac;
+    return success;
+}
+
+bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out) {
+    float etOverPtLeadTrk;
+    const auto success = tau.detail(TauDetail::etOverPtLeadTrk, etOverPtLeadTrk);
+    out = etOverPtLeadTrk;
+    return success;
+}
+
+bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out) {
+    float innerTrkAvgDist;
+    const auto success = tau.detail(TauDetail::innerTrkAvgDist, innerTrkAvgDist);
+    out = innerTrkAvgDist;
+    return success;
+}
+
+bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out) {
+    float ipSigLeadTrk = (tau.nTracks()>0) ? tau.track(0)->d0SigTJVA() : 0.;
+    //out = std::min(std::abs(ipSigLeadTrk), 30.0f);
+    out = std::abs(ipSigLeadTrk);
+    return true;
+}
+
+bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out) {
+    float sumEMCellEtOverLeadTrkPt;
+    const auto success = tau.detail(TauDetail::sumEMCellEtOverLeadTrkPt, sumEMCellEtOverLeadTrkPt);
+    out = sumEMCellEtOverLeadTrkPt;
+    return success;
+}
+
+bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out) {
+    float SumPtTrkFrac;
+    const auto success = tau.detail(TauDetail::SumPtTrkFrac, SumPtTrkFrac);
+    out = SumPtTrkFrac;
+    return success;
+}
+
+bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out) {
+    float EMPOverTrkSysP;
+    const auto success = tau.detail(TauDetail::EMPOverTrkSysP, EMPOverTrkSysP);
+    out = EMPOverTrkSysP;
+    return success;
+}
+
+bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out) {
+    float ptRatioEflowApprox;
+    const auto success = tau.detail(TauDetail::ptRatioEflowApprox, ptRatioEflowApprox);
+    //out = std::min(ptRatioEflowApprox, 4.0f);
+    out = ptRatioEflowApprox;
+    return success;
+}
+
+bool mEflowApprox(const xAOD::TauJet &tau, double &out) {
+    float mEflowApprox;
+    const auto success = tau.detail(TauDetail::mEflowApprox, mEflowApprox);
+    out = mEflowApprox;
+    return success;
+}
+
+bool dRmax(const xAOD::TauJet &tau, double &out) {
+    float dRmax;
+    const auto success = tau.detail(TauDetail::dRmax, dRmax);
+    out = dRmax;
+    return success;
+}
+
+bool trFlightPathSig(const xAOD::TauJet &tau, double &out) {
+    float trFlightPathSig;
+    const auto success = tau.detail(TauDetail::trFlightPathSig, trFlightPathSig);
+    out = trFlightPathSig;
+    return success;
+}
+
+bool massTrkSys(const xAOD::TauJet &tau, double &out) {
+    float massTrkSys;
+    const auto success = tau.detail(TauDetail::massTrkSys, massTrkSys);
+    out = massTrkSys;
+    return success;
+}
+
+bool pt(const xAOD::TauJet &tau, double &out) {
+    out = tau.pt();
+    return true;
+}
+
+bool pt_tau_log(const xAOD::TauJet &tau, double &out) {
+    out = std::log10(std::max(tau.pt() / GeV, 1e-6));
+    return true;
+}
+
+bool ptDetectorAxis(const xAOD::TauJet &tau, double &out) {
+    out = tau.ptDetectorAxis();
+    return true;
+}
+
+bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out) {
+    out = tau.ptIntermediateAxis();
+    return true;
+}
+
+bool ptJetSeed_log(const xAOD::TauJet &tau, double &out) {
+  out = std::log10(std::max(tau.ptJetSeed(), 1e-3));
+  return true;
+}
+
+bool absleadTrackEta(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_absEtaLeadTrack("ABS_ETA_LEAD_TRACK");
+  float absEtaLeadTrack = acc_absEtaLeadTrack(tau);
+  out = std::max(0.f, absEtaLeadTrack);
+  return true;
+}
+
+bool leadTrackDeltaEta(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_absDeltaEta("TAU_ABSDELTAETA");
+  float absDeltaEta = acc_absDeltaEta(tau);
+  out = std::max(0.f, absDeltaEta);
+  return true;
+}
+
+bool leadTrackDeltaPhi(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_absDeltaPhi("TAU_ABSDELTAPHI");
+  float absDeltaPhi = acc_absDeltaPhi(tau);
+  out = std::max(0.f, absDeltaPhi);
+  return true;
+}
+
+bool leadTrackProbNNorHT(const xAOD::TauJet &tau, double &out){
+  auto tracks = tau.allTracks();
+
+  // Sort tracks in descending pt order
+  if (!tracks.empty()) {
+    auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
+      return lhs->pt() > rhs->pt();
+    };
+    std::sort(tracks.begin(), tracks.end(), cmp_pt);
+
+    const xAOD::TauTrack* tauLeadTrack = tracks.at(0);
+    const xAOD::TrackParticle* xTrackParticle = tauLeadTrack->track();
+    float eProbabilityHT = xTrackParticle->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
+    static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
+    float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle);
+    out = (tauLeadTrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT;
+  }
+  else {
+    out = 0.;
+  }
+  return true;
+}
+
+bool EMFracFixed(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_emFracFixed("EMFracFixed");
+  float emFracFixed = acc_emFracFixed(tau);
+  out = std::max(emFracFixed, 0.0f);
+  return true;
+}
+
+bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_etHotShotWinOverPtLeadTrk("etHotShotWinOverPtLeadTrk");
+  float etHotShotWinOverPtLeadTrk = acc_etHotShotWinOverPtLeadTrk(tau);
+  out = std::max(etHotShotWinOverPtLeadTrk, 1e-6f);
+  return true;
+}
+
+bool hadLeakFracFixed(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_hadLeakFracFixed("hadLeakFracFixed");
+  float hadLeakFracFixed = acc_hadLeakFracFixed(tau);
+  out = std::max(0.f, hadLeakFracFixed);
+  return true;
+}
+
+bool PSFrac(const xAOD::TauJet &tau, double &out){
+  float PSFrac;
+  const auto success = tau.detail(TauDetail::PSSFraction, PSFrac);
+  out = std::max(0.f,PSFrac);  
+  return success;
+}
+
+bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanCenterLambda;
+  const auto success = tau.detail(TauDetail::ClustersMeanCenterLambda, ClustersMeanCenterLambda);
+  out = std::max(0.f, ClustersMeanCenterLambda);
+  return success;
+}
+
+bool ClustersMeanEMProbability(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanEMProbability;
+  const auto success = tau.detail(TauDetail::ClustersMeanEMProbability, ClustersMeanEMProbability);
+  out = std::max(0.f, ClustersMeanEMProbability);
+  return success;
+}
+
+bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanFirstEngDens;
+  const auto success = tau.detail(TauDetail::ClustersMeanFirstEngDens, ClustersMeanFirstEngDens);
+  out =  std::max(-10.f, ClustersMeanFirstEngDens);
+  return success;
+}
+
+bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanPresamplerFrac;
+  const auto success = tau.detail(TauDetail::ClustersMeanPresamplerFrac, ClustersMeanPresamplerFrac);
+  out = std::max(0.f, ClustersMeanPresamplerFrac);
+  return success;
+}
+
+bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanSecondLambda;
+  const auto success = tau.detail(TauDetail::ClustersMeanSecondLambda, ClustersMeanSecondLambda);
+  out = std::max(0.f, ClustersMeanSecondLambda);
+  return success;
+}
+
+namespace Track {
+
+bool pt_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = std::log10(track.pt());
+    return true;
+}
+
+bool trackPt(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.pt();
+    return true;
+}
+
+bool trackEta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.eta();
+    return true;
+}
+
+bool trackPhi(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.phi();
+    return true;
+}
+
+bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) {
+    out = std::log10(std::max(tau.pt(), 1e-6));
+    return true;
+}
+
+bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) {
+    out = std::log10(tau.ptJetSeed());
+    return true;
+}
+
+bool d0_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = std::log10(std::abs(track.d0TJVA()) + 1e-6);
+    return true;
+}
+
+bool z0sinThetaTJVA_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = std::log10(std::abs(track.z0sinthetaTJVA()) + 1e-6);
+    return true;
+}
+
+bool z0sinthetaTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.z0sinthetaTJVA();
+    return true;
+}
+
+bool z0sinthetaSigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.z0sinthetaSigTJVA();
+    return true;
+}
+
+bool d0TJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.d0TJVA();
+    return true;
+}
+
+bool d0SigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.d0SigTJVA();
+    return true;
+}
+
+bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    out = track.eta() - tau.eta();
+    return true;
+}
+
+bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    out = track.p4().DeltaPhi(tau.p4());
+    return true;
+}
+
+bool nInnermostPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t inner_pixel_hits;
+    const auto success = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits);
+    out = inner_pixel_hits;
+    return success;
+}
+
+bool nPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t pixel_hits;
+    const auto success = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits);
+    out = pixel_hits;
+    return success;
+}
+
+bool nSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t sct_hits;
+    const auto success = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits);
+    out = sct_hits;
+    return success;
+}
+
+// same as in tau track classification for trigger
+bool nIBLHitsAndExp(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t inner_pixel_hits, inner_pixel_exp;
+    const auto success1 = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits);
+    const auto success2 = track.track()->summaryValue(inner_pixel_exp, xAOD::expectInnermostPixelLayerHit);
+    out =  inner_pixel_exp ? inner_pixel_hits : 1.;
+    return success1 && success2;
+}
+
+bool nPixelHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t pixel_hits, pixel_dead;
+    const auto success1 = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits);
+    const auto success2 = track.track()->summaryValue(pixel_dead, xAOD::numberOfPixelDeadSensors);
+    out = pixel_hits + pixel_dead;
+    return success1 && success2;
+}
+
+bool nSCTHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t sct_hits, sct_dead;
+    const auto success1 = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits);
+    const auto success2 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors);
+    out = sct_hits + sct_dead;
+    return success1 && success2;
+}
+
+bool eProbabilityHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    float eProbabilityHT;
+    const auto success = track.track()->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
+    out = eProbabilityHT;
+    return success;
+}
+
+bool eProbabilityNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {  
+    static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
+    out = acc_eProbabilityNN(track);
+    return true;
+}
+
+bool eProbabilityNNorHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {  
+  auto atrack = track.track();
+  float eProbabilityHT = atrack->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
+  static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
+  float eProbabilityNN = acc_eProbabilityNN(*atrack);
+  out = (atrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT;
+  return true;
+}
+
+bool chargedScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_chargedScoreRNN("rnn_chargedScore");
+  out = acc_chargedScoreRNN(track);
+  return true;
+}
+
+bool isolationScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_isolationScoreRNN("rnn_isolationScore");
+  out = acc_isolationScoreRNN(track);
+  return true;
+}
+
+bool conversionScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_conversionScoreRNN("rnn_conversionScore");
+  out = acc_conversionScoreRNN(track);
+  return true;
+}
+
+bool fakeScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_fakeScoreRNN("rnn_fakeScore");
+  out = acc_fakeScoreRNN(track);
+  return true;
+}
+
+//Extension - variables for GNTau
+bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfInnermostPixelLayerHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelSharedHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelDeadSensors);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfSCTSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTSharedHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfSCTDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTDeadSensors);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfTRTHighThresholdHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHighThresholdHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfTRTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHits);
+    out = trk_val;
+    return success;
+}
+
+bool nSiHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t pix_hit = 0;uint8_t pix_dead = 0;uint8_t sct_hit = 0;uint8_t sct_dead = 0;
+    const auto success1 = track.track()->summaryValue(pix_hit, xAOD::numberOfPixelHits);
+    const auto success2 = track.track()->summaryValue(pix_dead, xAOD::numberOfPixelDeadSensors);
+    const auto success3 = track.track()->summaryValue(sct_hit, xAOD::numberOfSCTHits);
+    const auto success4 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors);
+    out = pix_hit + pix_dead + sct_hit + sct_dead;
+    return success1 && success2 && success3 && success4;
+}
+
+bool expectInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::expectInnermostPixelLayerHit);
+    out = trk_val;
+    return success;
+}
+
+bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::expectNextToInnermostPixelLayerHit);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfContribPixelLayers(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfContribPixelLayers);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelHoles(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHoles);
+    out = trk_val;
+    return success;
+}
+
+bool d0_old(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->d0();
+    //out = trk_val;
+    return true;
+}
+
+bool qOverP(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->qOverP();
+    return true;
+}
+
+bool theta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->theta();
+    return true;
+}
+
+bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->z0() + track.track()->vz() - tau.vertex()->z();
+    return true;
+}
+
+bool charge(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->charge();
+    return true;
+}
+
+bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out = 0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_dz0TVPV0("dz0_TV_PV0");
+    if (tau.isAvailable<float>("dz0_TV_PV0")){out = acc_dz0TVPV0(tau);}
+    return true;
+}
+
+bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumptTV("log_sumpt_TV");
+    if (tau.isAvailable<float>("log_sumpt_TV")){out=acc_logsumptTV(tau);}
+    return true;
+}
+
+bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2TV("log_sumpt2_TV");
+    if (tau.isAvailable<float>("log_sumpt2_TV")){out=acc_logsumpt2TV(tau);}
+    return true;
+}
+
+bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumptPV0("log_sumpt_PV0");
+    if (tau.isAvailable<float>("log_sumpt_PV0")){out=acc_logsumptPV0(tau);}
+    return true;
+}
+
+bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2PV0("log_sumpt2_PV0");
+    if (tau.isAvailable<float>("log_sumpt2_PV0")){out=acc_logsumpt2PV0(tau);}
+    return true;
+}
+
+} // namespace Track
+
+
+namespace Cluster {
+using MomentType = xAOD::CaloCluster::MomentType;
+
+bool et_log(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = std::log10(cluster.p4().Et());
+    return true;
+}
+
+bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) {
+    out = std::log10(std::max(tau.pt(), 1e-6));
+    return true;
+}
+
+bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) {
+    out = std::log10(tau.ptJetSeed());
+    return true;
+}
+
+bool dEta(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.eta() - tau.eta();
+    return true;
+}
+
+bool dPhi(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.p4().DeltaPhi(tau.p4());
+    return true;
+}
+
+bool SECOND_R(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_R, out);
+    return success;
+}
+
+bool SECOND_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, out);
+    return success;
+}
+
+bool CENTER_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, out);
+    return success;
+}
+
+bool SECOND_LAMBDAOverClustersMeanSecondLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanSecondLambda("ClustersMeanSecondLambda");
+  float ClustersMeanSecondLambda = acc_ClustersMeanSecondLambda(tau);
+  double secondLambda(0);
+  const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, secondLambda);
+  out = (ClustersMeanSecondLambda != 0.) ? secondLambda/ClustersMeanSecondLambda : 0.;
+  return success;
+}
+
+bool CENTER_LAMBDAOverClustersMeanCenterLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanCenterLambda("ClustersMeanCenterLambda");
+  float ClustersMeanCenterLambda = acc_ClustersMeanCenterLambda(tau);
+  double centerLambda(0);
+  const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, centerLambda);
+  if (ClustersMeanCenterLambda == 0.){
+    out = 250.;
+  }else {
+    out = centerLambda/ClustersMeanCenterLambda;
+  }
+
+  out = std::min(out, 250.);
+
+  return success;
+}
+
+
+bool FirstEngDensOverClustersMeanFirstEngDens(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+  // the ClustersMeanFirstEngDens is the log10 of the energy weighted average of the First_ENG_DENS 
+  // divided by ETot to make it dimension-less, 
+  // so we need to evaluate the difference of log10(clusterFirstEngDens/clusterTotalEnergy) and the ClustersMeanFirstEngDens
+  double clusterFirstEngDens = 0.0;
+  bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens);
+  if (clusterFirstEngDens < 1e-6) clusterFirstEngDens = 1e-6;
+
+  static const SG::AuxElement::ConstAccessor<float> acc_ClusterTotalEnergy("ClusterTotalEnergy");
+  float clusterTotalEnergy = acc_ClusterTotalEnergy(tau);
+  if (clusterTotalEnergy < 1e-6) clusterTotalEnergy = 1e-6;
+
+  static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanFirstEngDens("ClustersMeanFirstEngDens");
+  float clustersMeanFirstEngDens = acc_ClustersMeanFirstEngDens(tau);
+
+  out = std::log10(clusterFirstEngDens/clusterTotalEnergy) - clustersMeanFirstEngDens;
+  
+  return status;
+}
+
+//Extension - Variables for GNTau
+bool e(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.p4().E();
+    return true;
+}
+
+bool et(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.p4().Et();
+    return true;
+}
+
+bool FIRST_ENG_DENS(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    double clusterFirstEngDens = 0.0;
+    bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens);
+    out = clusterFirstEngDens;
+    return status;
+}
+
+bool EM_PROBABILITY(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    double clusterEMprob = 0.0;
+    bool status = cluster.clust().retrieveMoment(MomentType::EM_PROBABILITY, clusterEMprob);
+    out = clusterEMprob;
+    return status;
+}
+
+bool CENTER_MAG(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    double clusterCenterMag = 0.0;
+    bool status = cluster.clust().retrieveMoment(MomentType::CENTER_MAG, clusterCenterMag);
+    out = clusterCenterMag;
+    return status;
+}
+
+} // namespace Cluster
+} // namespace Variables
+} // namespace TauGNNUtils
diff --git a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx
index d4d5db31542ab15ccd7487f6dacb303eb48da1b3..0e78bf7d357e50f2de792c78d4ffb30d5469dab2 100644
--- a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx
+++ b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx
@@ -25,6 +25,7 @@
 #include "tauRecTools/TauWPDecorator.h"
 #include "tauRecTools/TauIDVarCalculator.h"
 #include "tauRecTools/TauJetRNNEvaluator.h"
+#include "tauRecTools/TauGNNEvaluator.h"
 #include "tauRecTools/TauDecayModeNNClassifier.h"
 #include "tauRecTools/TauVertexedClusterDecorator.h"
 #include "tauRecTools/TauAODSelector.h"
@@ -59,6 +60,7 @@ DECLARE_COMPONENT( TauPi0Selector )
 DECLARE_COMPONENT( TauWPDecorator )
 DECLARE_COMPONENT( TauIDVarCalculator )
 DECLARE_COMPONENT( TauJetRNNEvaluator )
+DECLARE_COMPONENT( TauGNNEvaluator )
 DECLARE_COMPONENT( TauDecayModeNNClassifier )
 DECLARE_COMPONENT( TauVertexedClusterDecorator )
 DECLARE_COMPONENT( TauAODSelector )
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2e39c02cbbc1444547c103d3b6dbcf8c059615d
--- /dev/null
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h
@@ -0,0 +1,110 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef TAURECTOOLS_TAUGNN_H
+#define TAURECTOOLS_TAUGNN_H
+
+#include "xAODTau/TauJet.h"
+#include "xAODCaloEvent/CaloVertexedTopoCluster.h"
+
+#include "AsgMessaging/AsgMessaging.h"
+
+#include "FlavorTagDiscriminants/OnnxUtil.h"
+
+#include <memory>
+#include <string>
+#include <map>
+
+// Forward declaration
+namespace lwt {
+    class LightweightGraph;
+}
+
+namespace TauGNNUtils {
+    class GNNVarCalc;
+}
+
+namespace FlavorTagDiscriminants{
+    class OnnxUtil;
+}
+
+/**
+ * @brief Wrapper around ONNXUtil to compute the output score of a model
+ *
+ *   Configures the network and computes the network outputs given the input
+ *   objects. Retrieval of input variables is handled internally.
+ *
+ * @author N.M. Tamir
+ *
+ */
+class TauGNN : public asg::AsgMessaging {
+public:
+    // Configuration of the weight file structure
+    struct Config {
+        std::string input_layer_scalar;
+        std::string input_layer_tracks;
+        std::string input_layer_clusters;
+        std::string output_node_tau;
+        std::string output_node_jet;
+    };
+    std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil;
+public:
+    TauGNN(const std::string &nnFile, const Config &config);
+    ~TauGNN();
+
+    // Output the OnnxUtil tuple 
+    std::tuple<
+        std::map<std::string, float>,
+        std::map<std::string, std::vector<char>>,
+        std::map<std::string, std::vector<float>> > 
+    compute(const xAOD::TauJet &tau,
+                  const std::vector<const xAOD::TauTrack *> &tracks,
+                  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const;
+
+    // Compute all input variables and store them in the maps that are passed by reference
+    bool calculateInputVariables(const xAOD::TauJet &tau,
+                  const std::vector<const xAOD::TauTrack *> &tracks,
+                  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                  std::map<std::string, std::map<std::string, double>>& scalarInputs,
+                  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const;
+
+    // Getter for the variable calculator
+    const TauGNNUtils::GNNVarCalc* variable_calculator() const {
+        return m_var_calc.get();
+    }
+
+    explicit operator bool() const {
+        return static_cast<bool>(m_graph);
+    }
+
+    //Make the output config transparent to external tools
+    FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config;
+
+private:
+    using input_pair = FlavorTagDiscriminants::input_pair;
+    // Abbreviations for lwtnn
+    using VariableMap = std::map<std::string, double>;
+    using VectorMap = std::map<std::string, std::vector<double>>;
+
+    using InputMap = std::map<std::string, VariableMap>;
+    using InputSequenceMap = std::map<std::string, VectorMap>;
+
+private:
+    const Config m_config;
+    std::unique_ptr<const lwt::LightweightGraph> m_graph;
+
+    // Names of the input variables
+    std::vector<std::string> m_scalar_inputs;
+    std::vector<std::string> m_track_inputs;
+    std::vector<std::string> m_cluster_inputs;
+    // Names passed to the variable calculator
+    std::vector<std::string> m_scalarCalc_inputs;
+    std::vector<std::string> m_trackCalc_inputs;
+    std::vector<std::string> m_clusterCalc_inputs;
+
+    // Variable calculator to calculate input variables on the fly
+    std::unique_ptr<TauGNNUtils::GNNVarCalc> m_var_calc;
+};
+
+#endif // TAURECTOOLS_TAUGNN_H
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h
new file mode 100644
index 0000000000000000000000000000000000000000..4e5cebf53ad5e99907784d96586788ed55ceec28
--- /dev/null
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h
@@ -0,0 +1,69 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef TAURECTOOLS_TAUGNNEVALUATOR_H
+#define TAURECTOOLS_TAUGNNEVALUATOR_H
+
+#include "tauRecTools/TauRecToolBase.h"
+
+#include "xAODTau/TauJet.h"
+#include "xAODCaloEvent/CaloVertexedTopoCluster.h"
+
+#include <memory>
+
+class TauGNN;
+
+/**
+ * @brief Tool to calculate tau identification score from .onnx inputs
+ *
+ *   The network configuration is supplied in .onnx format. 
+ *   Currently runs on a prongness-inclusive model 
+ *   Based off of TauJetRNNEvaluator.h format!
+ * @author N.M. Tamir
+ *
+ */
+class TauGNNEvaluator : public TauRecToolBase {
+public:
+    ASG_TOOL_CLASS2(TauGNNEvaluator, TauRecToolBase, ITauToolBase)
+
+    TauGNNEvaluator(const std::string &name = "TauGNNEvaluator");
+    virtual ~TauGNNEvaluator();
+
+    virtual StatusCode initialize() override;
+    virtual StatusCode execute(xAOD::TauJet &tau) const override;
+    // Getter for the underlying RNN implementation
+    const TauGNN* get_gnn() const;
+
+    // Selects tracks to be used as input to the network
+    StatusCode get_tracks(const xAOD::TauJet &tau,
+                          std::vector<const xAOD::TauTrack *> &out) const;
+
+    // Selects clusters to be used as input to the network
+    StatusCode get_clusters(const xAOD::TauJet &tau,
+                            std::vector<xAOD::CaloVertexedTopoCluster> &out) const;
+
+private:
+    std::string m_output_varname;
+    std::string m_output_ptau;
+    std::string m_output_pjet;
+    std::string m_weightfile;
+    std::size_t m_max_tracks;
+    std::size_t m_max_clusters;
+    float m_max_cluster_dr;
+    bool m_doVertexCorrection;
+    bool m_doTrackClassification;
+    bool m_decorateTracks;
+
+    // Configuration of the network file
+    std::string m_input_layer_scalar;
+    std::string m_input_layer_tracks;
+    std::string m_input_layer_clusters;
+    std::string m_outnode_tau;
+    std::string m_outnode_jet;
+
+    // Wrappers for lwtnn
+    std::unique_ptr<TauGNN> m_net; //!
+};
+
+#endif // TAURECTOOLS_TAUGNNEVALUATOR_H
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..28a983dd5e8a84ca17a2b61fa1b72428fc504232
--- /dev/null
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h
@@ -0,0 +1,311 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef TAURECTOOLS_TAUGNNUTILS_H
+#define TAURECTOOLS_TAUGNNUTILS_H
+
+#include "xAODTau/TauJet.h"
+#include "xAODCaloEvent/CaloVertexedTopoCluster.h"
+#include "xAODEventInfo/EventInfo.h"
+#include "xAODTracking/VertexContainer.h"
+#include "AsgTools/AsgTool.h"
+#include "AsgMessaging/AsgMessaging.h"
+#include <unordered_map>
+
+
+namespace TauGNNUtils {
+
+/**
+ * @brief Tool to calculate input variables for the GNN-based tau identification
+ *
+ *   Used to calculate input variables for (onnx)GNN-based tau identification on
+ *   the fly by providing a mapping between variable names (strings) and
+ *   functions to calculate these variables.
+ *
+ * @author C. Deutsch
+ * @author W. Davey
+ * @author N.M. Tamir
+ *
+ */
+class GNNVarCalc : public asg::AsgMessaging {
+public:
+    // Pointers to calculator functions
+    using ScalarCalc = bool (*)(const xAOD::TauJet &, double &);
+
+    using TrackCalc = bool (*)(const xAOD::TauJet &, const xAOD::TauTrack &,
+                               double &);
+
+    using ClusterCalc = bool (*)(const xAOD::TauJet &,
+                                 const xAOD::CaloVertexedTopoCluster &, double &);
+
+public:
+    GNNVarCalc();
+    ~GNNVarCalc() = default;
+
+    // Methods to compute the output (vector) based on the variable name
+
+    // Computes high-level ID variables
+    bool compute(const std::string &name, const xAOD::TauJet &tau, double &out) const;
+
+    // Computes track variables
+    bool compute(const std::string &name, const xAOD::TauJet &tau,
+                 const std::vector<const xAOD::TauTrack *> &tracks,
+                 std::vector<double> &out) const;
+
+    // Computes cluster variables
+    bool compute(const std::string &name, const xAOD::TauJet &tau,
+                 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                 std::vector<double> &out) const;
+
+    // Methods to insert calculator functions into the lookup table
+    void insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars);
+    void insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars);
+    void insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars);
+
+private:
+    // Lookup tables
+    std::unordered_map<std::string, ScalarCalc> m_scalar_map;
+    std::unordered_map<std::string, TrackCalc> m_track_map;
+    std::unordered_map<std::string, ClusterCalc> m_cluster_map;
+};
+
+// Factory function to create a variable calculator populated with default
+// variables
+std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars,
+					const std::vector<std::string>& track_vars,
+					const std::vector<std::string>& cluster_vars);
+
+
+namespace Variables {
+
+// Functions to calculate (scalar) input variables
+// Returns a status code indicating success
+bool absEta(const xAOD::TauJet &tau, double &out);
+
+bool centFrac(const xAOD::TauJet &tau, double &out);
+
+bool isolFrac(const xAOD::TauJet &tau, double &out); 
+
+bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out);
+
+bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out);
+
+bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out);
+
+bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out);
+
+bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out);
+
+bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out);
+
+bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out);
+
+bool mEflowApprox(const xAOD::TauJet &tau, double &out);
+
+bool dRmax(const xAOD::TauJet &tau, double &out);
+
+bool trFlightPathSig(const xAOD::TauJet &tau, double &out);
+
+bool massTrkSys(const xAOD::TauJet &tau, double &out);
+
+bool pt(const xAOD::TauJet &tau, double &out);
+
+bool pt_tau_log(const xAOD::TauJet &tau, double &out);
+
+bool ptDetectorAxis(const xAOD::TauJet &tau, double &out);
+
+bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out);
+
+//functions to calculate input variables needed for the eVeto RNN
+bool ptJetSeed_log             (const xAOD::TauJet &tau, double &out);
+bool absleadTrackEta           (const xAOD::TauJet &tau, double &out);
+bool leadTrackDeltaEta         (const xAOD::TauJet &tau, double &out);
+bool leadTrackDeltaPhi         (const xAOD::TauJet &tau, double &out);
+bool leadTrackProbNNorHT       (const xAOD::TauJet &tau, double &out);
+bool EMFracFixed               (const xAOD::TauJet &tau, double &out);
+bool etHotShotWinOverPtLeadTrk (const xAOD::TauJet &tau, double &out);
+bool hadLeakFracFixed          (const xAOD::TauJet &tau, double &out);
+bool PSFrac                    (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanCenterLambda  (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanEMProbability (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanFirstEngDens  (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out);
+bool ClustersMeanSecondLambda  (const xAOD::TauJet &tau, double &out);
+bool EMPOverTrkSysP            (const xAOD::TauJet &tau, double &out);
+
+
+namespace Track {
+
+// Functions to calculate input variables for each track
+// Returns a status code indicating success
+
+bool pt_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool trackPt(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool trackEta(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool trackPhi(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+    
+bool pt_tau_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool pt_jetseed_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool d0_abs_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool z0sinThetaTJVA_abs_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool z0sinthetaTJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool z0sinthetaSigTJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool d0TJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool d0SigTJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool dEta(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool dPhi(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nInnermostPixelHits(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nPixelHits(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nSCTHits(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+// trigger variants
+bool nIBLHitsAndExp (
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nPixelHitsPlusDeadSensors (
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nSCTHitsPlusDeadSensors (
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool eProbabilityHT(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool eProbabilityNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool eProbabilityNNorHT(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool chargedScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool isolationScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool conversionScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool fakeScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+//Extension - variables for GNTau
+bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfSCTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfSCTSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfSCTDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfTRTHighThresholdHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfTRTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool nSiHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool expectInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfContribPixelLayers(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelHoles(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool d0_old(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool qOverP(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool theta(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool charge(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+} // namespace Track
+
+
+namespace Cluster {
+
+// Functions to calculate input variables for each cluster
+// Returns a status code indicating success
+
+bool et_log(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool pt_tau_log(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool pt_jetseed_log(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool dEta(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool dPhi(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool SECOND_R(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool SECOND_LAMBDA(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool CENTER_LAMBDA(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool SECOND_LAMBDAOverClustersMeanSecondLambda(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool CENTER_LAMBDAOverClustersMeanCenterLambda(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool FirstEngDensOverClustersMeanFirstEngDens(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+//Extension - Variables for GNTau
+bool e(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool et(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool FIRST_ENG_DENS(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool EM_PROBABILITY(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool CENTER_MAG(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+} // namespace Cluster
+} // namespace Variables
+} // namespace TauJetGNNUtils
+
+#endif // TAURECTOOLS_TAUGNNUTILS_H
diff --git a/Tools/WorkflowTestRunner/python/References.py b/Tools/WorkflowTestRunner/python/References.py
index 8aa92ba7fa1b9c35fcad19e43224854eca676ce8..4dd63adb04203f730449d850cc0bb8a609862577 100644
--- a/Tools/WorkflowTestRunner/python/References.py
+++ b/Tools/WorkflowTestRunner/python/References.py
@@ -29,14 +29,14 @@ references_map = {
     "q452": "v9",
     "q454": "v15",
     # Derivations
-    "data_PHYS_Run2": "v20",
+    "data_PHYS_Run2": "v21",
     "data_PHYSLITE_Run2": "v2",
-    "data_PHYS_Run3": "v19",
+    "data_PHYS_Run3": "v20",
     "data_PHYSLITE_Run3": "v2",
-    "mc_PHYS_Run2": "v24",
+    "mc_PHYS_Run2": "v25",
     "mc_PHYSLITE_Run2": "v3",
-    "mc_PHYS_Run3": "v25",
+    "mc_PHYS_Run3": "v26",
     "mc_PHYSLITE_Run3": "v3",
-    "af3_PHYS_Run3": "v6",
+    "af3_PHYS_Run3": "v7",
     "af3_PHYSLITE_Run3": "v3",
 }