diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py
index f4c7ff650c38985a3f7eb8936ba7d2879c78156b..ea91293697d628a6ce4fec9558ef4a641fbdcc7c 100644
--- a/Reconstruction/tauRec/python/TauConfigFlags.py
+++ b/Reconstruction/tauRec/python/TauConfigFlags.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 
 import unittest
 from AthenaConfiguration.AthConfigFlags import AthConfigFlags
@@ -48,6 +48,7 @@ def createTauConfigFlags():
     tau_cfg.addFlag("Tau.MvaTESConfig", "MvaTES_R23.root")
     tau_cfg.addFlag("Tau.MinPt0p", 9.25*Units.GeV)
     tau_cfg.addFlag("Tau.MinPt", 6.75*Units.GeV)
+    tau_cfg.addFlag("Tau.MinPtDAOD", 13*Units.GeV)
     tau_cfg.addFlag("Tau.TauJetRNNConfig", ["tauid_rnn_1p_R22_v1.json", "tauid_rnn_2p_R22_v1.json", "tauid_rnn_3p_R22_v1.json"])
     tau_cfg.addFlag("Tau.TauJetRNNWPConfig", ["tauid_rnnWP_1p_R22_v0.root", "tauid_rnnWP_2p_R22_v0.root", "tauid_rnnWP_3p_R22_v0.root"])
     tau_cfg.addFlag("Tau.TauEleRNNConfig", ["taueveto_rnn_config_1P_r22.json", "taueveto_rnn_config_3P_r22.json"])
@@ -60,6 +61,25 @@ def createTauConfigFlags():
     # R22 DeepSet tau ID tune without track RNN scores, for now define a second set of flags, but ultimately we'll choose one and drop the other
     tau_cfg.addFlag("Tau.TauJetDeepSetConfig_v2", ["tauid_1p_R22_dpst_noTrackScore.json", "tauid_2p_R22_dpst_noTrackScore.json", "tauid_3p_R22_dpst_noTrackScore.json"])
     tau_cfg.addFlag("Tau.TauJetDeepSetWP_v2", ["model_1p_R22_dpst_noTrackScore.root", "model_2p_R22_dpst_noTrackScore.root", "model_3p_R22_dpst_noTrackScore.root"])
+    # GNTau ID tune file (need to add another version for noAux)
+    tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_pruned_MC23.onnx","GNTau_trunc_MC23.onnx"])
+    tau_cfg.addFlag("Tau.TauGNNWP",
+                    [ 
+                        ["GNTauNAprune_flat_model_1p.root", "GNTauNAprune_flat_model_2p.root", "GNTauNAprune_flat_model_3p.root"],
+                        ["GNTauNAtrunc_flat_model_1p.root", "GNTauNAtrunc_flat_model_2p.root", "GNTauNAtrunc_flat_model_3p.root"]
+                    ])
+    tau_cfg.addFlag("Tau.GNTauScoreName", ["GNTauScore_v0prune","GNTauScore_v1trunc"])
+    tau_cfg.addFlag("Tau.GNTauTransScoreName", ["GNTauScoreSigTrans_v0prune","GNTauScoreSigTrans_v1trunc"])
+    tau_cfg.addFlag("Tau.GNTauMaxTracks", [30,10])
+    tau_cfg.addFlag("Tau.GNTauMaxClusters", [20,6])
+    tau_cfg.addFlag("Tau.GNTauNodeNameTau", "GN2TauNoAux_pb")
+    tau_cfg.addFlag("Tau.GNTauNodeNameJet", "GN2TauNoAux_pu")
+    tau_cfg.addFlag("Tau.GNTauDecorWPNames", 
+                    [
+                        ["GNTauVL_v0prune", "GNTauL_v0prune", "GNTauM_v0prune", "GNTauT_v0prune"],
+                        ["GNTauVL_v1trunc", "GNTauL_v1trunc", "GNTauM_v1trunc", "GNTauT_v1trunc"]
+                    ])
+
 
     # PanTau config flags
     from PanTauAlgs.PanTauConfigFlags import createPanTauConfigFlags
diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py
index 679fb4c86f9fdb56745bf7b9211926cd539f3fb1..b4bafbba684bb6495bd89b21a93b7c5aac2009dc 100644
--- a/Reconstruction/tauRec/python/TauToolHolder.py
+++ b/Reconstruction/tauRec/python/TauToolHolder.py
@@ -851,6 +851,52 @@ def TauWPDecoratorJetDeepSetCfg(flags, version=None):
     result.setPrivateTools(myTauWPDecorator)
     return result
 
+def TauGNNEvaluatorCfg(flags, version=0):
+    result = ComponentAccumulator()
+    _name = flags.Tau.ActiveConfig.prefix + 'TauGNN_v' + str(version)
+
+    TauGNNEvaluator = CompFactory.getComp("TauGNNEvaluator")
+    GNNConf = flags.Tau.TauGNNConfig[version]
+    myTauGNNEvaluator = TauGNNEvaluator(name = _name,
+                                              NetworkFile = GNNConf,
+                                              OutputVarname = flags.Tau.GNTauScoreName[version],
+                                              OutputPTau = "GNTauProbTau",
+                                              OutputPJet = "GNTauProbJet",
+                                              MaxTracks = flags.Tau.GNTauMaxTracks[version], 
+                                              MaxClusters = flags.Tau.GNTauMaxClusters[version],
+                                              MaxClusterDR = 15.0,
+                                              MinTauPt = flags.Tau.MinPtDAOD,
+                                              VertexCorrection = True,
+                                              DecorateTracks = False,
+                                              InputLayerScalar = "tau_vars",
+                                              InputLayerTracks = "track_vars",
+                                              InputLayerClusters = "cluster_vars",
+                                              NodeNameTau=flags.Tau.GNTauNodeNameTau,
+                                              NodeNameJet=flags.Tau.GNTauNodeNameJet)
+
+    result.setPrivateTools(myTauGNNEvaluator)
+    return result
+
+def TauWPDecoratorGNNCfg(flags, version):
+    result = ComponentAccumulator()
+    _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN_v' + str(version)
+
+    TauWPDecorator = CompFactory.getComp("TauWPDecorator")
+    WPConf = flags.Tau.TauGNNWP[version]
+    myTauWPDecorator = TauWPDecorator(name=_name,
+                                      flatteningFile1Prong = WPConf[0],
+                                      flatteningFile2Prong = WPConf[1],
+                                      flatteningFile3Prong = WPConf[2],
+                                      DecorWPNames = flags.Tau.GNTauDecorWPNames[version],
+                                      DecorWPCutEffs1P = [0.95, 0.85, 0.75, 0.60],
+                                      DecorWPCutEffs2P = [0.95, 0.75, 0.60, 0.45],
+                                      DecorWPCutEffs3P = [0.95, 0.75, 0.60, 0.45],
+                                      ScoreName = flags.Tau.GNTauScoreName[version],
+                                      NewScoreName = flags.Tau.GNTauTransScoreName[version],
+                                      DefineWPs = True)
+    result.setPrivateTools(myTauWPDecorator)
+    return result
+
 def TauEleRNNEvaluatorCfg(flags):
     result = ComponentAccumulator()
     _name = flags.Tau.ActiveConfig.prefix + 'TauEleRNN'
diff --git a/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx b/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx
index 02ed3802ab71ebdf64260bf7fe551b1e03a3316a..5a9a2d01c0e6f769ae321782eb30a6332e15575a 100644
--- a/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx
+++ b/Reconstruction/tauRec/src/TauAODRunnerAlg.cxx
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #include "TauAODRunnerAlg.h"
@@ -66,6 +66,7 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const {
   xAOD::TauJetContainer *newTauCon = outputTauHandle.ptr();
 
   static const SG::AuxElement::Accessor<ElementLink<xAOD::TauJetContainer>> acc_ori_tau_link("originalTauJet");
+  static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD");
 
   for (const xAOD::TauJet *tau : *pTauContainer) {
     // deep copy the tau container
@@ -88,6 +89,21 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const {
       linkToTauTrack.toContainedElement(*newTauTrkCon, newTauTrk);
       newTau->addTauTrackLink(linkToTauTrack);
     }
+
+    // 'ModifiedInAOD' will be overriden by modification tools for relevant candidates
+    acc_modified(*newTau) = static_cast<char>(false);
+
+    StatusCode sc;
+    for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) {
+      ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name());
+      sc = tool->execute(*newTau);
+      if (sc.isFailure()) break;
+    }
+
+    // if tau candidate was not modified, remove it from container, track cleanup performed by thinning algorithm downstream
+    if (!acc_modified(*newTau)) {
+      newTauCon->pop_back();
+    }
   }
 
   // Read the CaloClusterContainer
@@ -123,49 +139,34 @@ StatusCode TauAODRunnerAlg::execute (const EventContext& ctx) const {
   ATH_CHECK(vertOutHandle.record(std::make_unique<xAOD::VertexContainer>(), std::make_unique<xAOD::VertexAuxContainer>()));
   xAOD::VertexContainer* pSecVtxContainer = vertOutHandle.ptr();
 
-  int n_tau_modified = 0;
-  static const SG::AuxElement::Accessor<char> acc_modified("ModifiedInAOD");
-
   for (xAOD::TauJet *pTau : *newTauCon) {
-    // Loop stops when Failure indicated by one of the tools
     StatusCode sc;
-    //add a identifier of if the tau is modifed by the mod tools
-    acc_modified(*pTau) = static_cast<char>(false);
-    // iterate over the copy
-    for (const ToolHandle<ITauToolBase> &tool : m_modificationTools) {
+    for (const ToolHandle<ITauToolBase> &tool : m_officialTools) {
       ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name());
-      sc = tool->execute(*pTau);
+      if (tool->type() == "TauPi0ClusterCreator")
+	sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer);
+      else if (tool->type() == "TauVertexVariables")
+	sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer);
+      else if (tool->type() == "TauPi0ClusterScaler")
+	sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer);
+      else if (tool->type() == "TauPi0ScoreCalculator")
+	sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
+      else if (tool->type() == "TauPi0Selector")
+	sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
+      else if (tool->type() == "PanTau::PanTauProcessor")
+	sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer);
+      else if (tool->type() == "tauRecTools::TauTrackRNNClassifier")
+	sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon);
+      else
+	sc = tool->execute(*pTau);
       if (sc.isFailure()) break;
     }
-    if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked modification tools.");
-    // if tau is not modified by the above tools, never mind running the tools afterward
-    if (static_cast<bool>(isTauModified(pTau))) {
-      n_tau_modified++;
-      for (const ToolHandle<ITauToolBase> &tool : m_officialTools) {
-	ATH_MSG_DEBUG("RunnerAlg Invoking tool " << tool->name());
-	if (tool->type() == "TauPi0ClusterCreator")
-	  sc = tool->executePi0ClusterCreator(*pTau, *neutralPFOContainer, *hadronicClusterPFOContainer, *pi0ClusterContainer);
-	else if (tool->type() == "TauVertexVariables")
-	  sc = tool->executeVertexVariables(*pTau, *pSecVtxContainer);
-	else if (tool->type() == "TauPi0ClusterScaler")
-	  sc = tool->executePi0ClusterScaler(*pTau, *neutralPFOContainer, *chargedPFOContainer);
-	else if (tool->type() == "TauPi0ScoreCalculator")
-	  sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
-	else if (tool->type() == "TauPi0Selector")
-	  sc = tool->executePi0nPFO(*pTau, *neutralPFOContainer);
-	else if (tool->type() == "PanTau::PanTauProcessor")
-	  sc = tool->executePanTau(*pTau, *pi0Container, *neutralPFOContainer);
-	else if (tool->type() == "tauRecTools::TauTrackRNNClassifier")
-	  sc = tool->executeTrackClassifier(*pTau, *newTauTrkCon);
-	else
-	  sc = tool->execute(*pTau);
-	if (sc.isFailure()) break;
-      }
-      if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools.");
-    }
+    if (sc.isSuccess()) ATH_MSG_VERBOSE("The tau candidate has been modified successfully by the invoked official tools.");
   }
+
   ATH_MSG_VERBOSE("The tau candidate container has been modified by the rest of the tools");
-  ATH_MSG_DEBUG(n_tau_modified << " / " << pTauContainer->size() <<" taus were modified");
+  ATH_MSG_DEBUG(newTauCon->size() << " / " << pTauContainer->size() <<" taus were modified");
+
   return StatusCode::SUCCESS;
 }
 
diff --git a/Reconstruction/tauRecTools/CMakeLists.txt b/Reconstruction/tauRecTools/CMakeLists.txt
index 40a34801c23a244adb2a9be46ffa39c086df3fe2..87b3c12ef094a880e7a73652436f1f74dc33a39f 100644
--- a/Reconstruction/tauRecTools/CMakeLists.txt
+++ b/Reconstruction/tauRecTools/CMakeLists.txt
@@ -8,6 +8,7 @@ find_package( Boost )
 find_package( Eigen )
 find_package( ROOT COMPONENTS Core Tree Hist RIO )
 find_package( lwtnn )
+find_package( onnxruntime )
 
 # Optional dependencies.
 set( extra_public_libs )
@@ -28,12 +29,12 @@ atlas_add_library( tauRecToolsLib
    tauRecTools/*.h Root/*.cxx tauRecTools/lwtnn/*.h Root/lwtnn/*.cxx
    PUBLIC_HEADERS tauRecTools
    INCLUDE_DIRS ${Boost_INCLUDE_DIRS} ${EIGEN_INCLUDE_DIRS}
-   ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS}
+   ${LWTNN_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${ONNXRUNTIME_INCLUDE_DIRS}
    LINK_LIBRARIES ${Boost_LIBRARIES} ${EIGEN_LIBRARIES} ${LWTNN_LIBRARIES}
-   ${ROOT_LIBRARIES}
+   ${ROOT_LIBRARIES} ${ONNXRUNTIME_LIBRARIES}
    CxxUtils AthLinks AsgMessagingLib AsgDataHandlesLib AsgTools xAODCaloEvent
    xAODEventInfo xAODJet xAODParticleEvent xAODPFlow xAODTau xAODTracking xAODEventShape
-   MVAUtils ${extra_public_libs}
+   MVAUtils FlavorTagDiscriminants ${extra_public_libs}
    PRIVATE_LINK_LIBRARIES CaloGeoHelpers FourMomUtils PathResolver )
 
 atlas_add_dictionary( tauRecToolsDict
@@ -46,5 +47,5 @@ if( NOT XAOD_STANDALONE )
       INCLUDE_DIRS ${Boost_INCLUDE_DIRS}
       LINK_LIBRARIES ${Boost_LIBRARIES} AsgTools AsgMessagingLib xAODBase
       xAODCaloEvent xAODJet xAODPFlow xAODTau FourMomUtils tauRecToolsLib PFlowUtilsLib
-      ${extra_private_libs} )
+      FlavorTagDiscriminants ${extra_private_libs} )
 endif()
diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..a0603fa2957bba17acef8230b14702c59b3fe3be
--- /dev/null
+++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx
@@ -0,0 +1,202 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "tauRecTools/TauGNN.h"
+#include "FlavorTagDiscriminants/OnnxUtil.h"
+#include "lwtnn/parse_json.hh"
+#include "PathResolver/PathResolver.h"
+
+#include <algorithm>
+#include <fstream>
+
+#include "tauRecTools/TauGNNUtils.h"
+
+TauGNN::TauGNN(const std::string &nnFile, const Config &config):
+    asg::AsgMessaging("TauGNN"),
+    m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile)),
+    m_config{config}
+  {
+    //==================================================//
+    // This part is ported from FTagDiscriminant GNN.cxx//
+    //==================================================//
+
+    // get the configuration of the model outputs
+    FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
+    
+    //Let's see the output!
+    for (const auto& out_node: gnn_output_config) {
+        if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::FLOAT) ATH_MSG_INFO("Found output FLOAT node named:" << out_node.name);
+        if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECCHAR) ATH_MSG_INFO("Found output VECCHAR node named:" << out_node.name);
+        if(out_node.type==FlavorTagDiscriminants::OnnxOutput::OutputType::VECFLOAT) ATH_MSG_INFO("Found output VECFLOAT node named:" << out_node.name);
+    }
+
+    //Get model config (for inputs)
+    auto lwtnn_config = m_onnxUtil->getLwtConfig();
+    
+    //===================================================//
+    // This part is ported from tauRecTools TauJetRNN.cxx//
+    //===================================================//
+
+    // Search for input layer names specified in 'config'
+    auto node_is_scalar = [&config](const lwt::InputNodeConfig &in_node) {
+        return in_node.name == config.input_layer_scalar;
+    };
+    auto node_is_track = [&config](const lwt::InputNodeConfig &in_node) {
+        return in_node.name == config.input_layer_tracks;
+    };
+    auto node_is_cluster = [&config](const lwt::InputNodeConfig &in_node) {
+        return in_node.name == config.input_layer_clusters;
+    };
+
+    auto scalar_node = std::find_if(lwtnn_config.inputs.cbegin(),
+                                    lwtnn_config.inputs.cend(),
+                                    node_is_scalar);
+
+    auto track_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
+                                   lwtnn_config.input_sequences.cend(),
+                                   node_is_track);
+
+    auto cluster_node = std::find_if(lwtnn_config.input_sequences.cbegin(),
+                                     lwtnn_config.input_sequences.cend(),
+                                     node_is_cluster);
+
+    // Check which input layers were found
+    auto has_scalar_node = scalar_node != lwtnn_config.inputs.cend();
+    auto has_track_node = track_node != lwtnn_config.input_sequences.cend();
+    auto has_cluster_node = cluster_node != lwtnn_config.input_sequences.cend();
+    if(!has_scalar_node) ATH_MSG_WARNING("No scalar node with name "<<config.input_layer_scalar<<" found!");
+    if(!has_track_node) ATH_MSG_WARNING("No track node with name "<<config.input_layer_tracks<<" found!");
+    if(!has_cluster_node) ATH_MSG_WARNING("No cluster node with name "<<config.input_layer_clusters<<" found!");
+    
+    // Fill the variable names of each input layer into the corresponding vector
+    if (has_scalar_node) {
+        for (const auto &in : scalar_node->variables) {
+            std::string name = in.name;
+            m_scalarCalc_inputs.push_back(name);
+        }
+    }
+
+    if (has_track_node) {
+        for (const auto &in : track_node->variables) {
+            std::string name = in.name;
+            m_trackCalc_inputs.push_back(name);
+        }
+    }
+
+    if (has_cluster_node) {
+        for (const auto &in : cluster_node->variables) {
+            std::string name = in.name;
+            m_clusterCalc_inputs.push_back(name);
+        }
+    }
+    // Load the variable calculator
+    m_var_calc = TauGNNUtils::get_calculator(m_scalarCalc_inputs, m_trackCalc_inputs, m_clusterCalc_inputs);
+    ATH_MSG_INFO("TauGNN object initialized successfully!");
+}
+
+TauGNN::~TauGNN() {}
+
+std::tuple<
+    std::map<std::string, float>,
+    std::map<std::string, std::vector<char>>,
+    std::map<std::string, std::vector<float>> >
+TauGNN::compute(const xAOD::TauJet &tau,
+		const std::vector<const xAOD::TauTrack *> &tracks,
+		const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
+    InputMap scalarInputs;
+    InputSequenceMap vectorInputs;
+    std::map<std::string, Inputs> gnn_input;
+    ATH_MSG_DEBUG("Starting compute...");
+    //Prepare input variables
+    if (!calculateInputVariables(tau, tracks, clusters, scalarInputs, vectorInputs)) {
+        ATH_MSG_FATAL("Failed calculateInputVariables");
+        throw StatusCode::FAILURE;
+    }
+
+    // Add TauJet-level features to the input
+    std::vector<float> tau_feats;
+    for (const auto &varname : m_scalarCalc_inputs) {
+        tau_feats.push_back(static_cast<float>(scalarInputs[m_config.input_layer_scalar][varname]));
+    }
+    std::vector<int64_t> tau_feats_dim = {1, static_cast<int64_t>(tau_feats.size())};
+    Inputs tau_info (tau_feats, tau_feats_dim);
+    gnn_input.insert({"tau_vars", tau_info});
+
+    //Add track-level features to the input
+    std::vector<float> trk_feats;
+    int num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_tracks][m_trackCalc_inputs.at(0)].size());
+    int num_node_vars=static_cast<int>(m_trackCalc_inputs.size());
+    trk_feats.resize(num_nodes * num_node_vars);
+    int var_idx=0;
+    for (const auto &varname : m_trackCalc_inputs) {
+        for (int node_idx=0; node_idx<num_nodes; node_idx++){    
+            trk_feats.at(node_idx*num_node_vars + var_idx)
+              = static_cast<float>(vectorInputs[m_config.input_layer_tracks][varname].at(node_idx));
+        }
+        var_idx++;
+    }
+    std::vector<int64_t> trk_feats_dim = {num_nodes, num_node_vars};
+    Inputs trk_info (trk_feats, trk_feats_dim);
+    gnn_input.insert({"track_vars", trk_info});
+    
+    //Add cluster-level features to the input
+    std::vector<float> cls_feats;
+    num_nodes=static_cast<int>(vectorInputs[m_config.input_layer_clusters][m_clusterCalc_inputs.at(0)].size());
+    num_node_vars=static_cast<int>(m_clusterCalc_inputs.size());
+    cls_feats.resize(num_nodes * num_node_vars);
+    var_idx=0;
+    for (const auto &varname : m_clusterCalc_inputs) {
+        for (int node_idx=0; node_idx<num_nodes; node_idx++){    
+            cls_feats.at(node_idx*num_node_vars + var_idx)
+              = static_cast<float>(vectorInputs[m_config.input_layer_clusters][varname].at(node_idx));
+        }
+        var_idx++;
+    }
+    std::vector<int64_t> cls_feats_dim = {num_nodes, num_node_vars};
+    Inputs cls_info (cls_feats, cls_feats_dim);
+    gnn_input.insert({"cluster_vars", cls_info});    
+
+    //RUN THE INFERENCE!!!
+    ATH_MSG_DEBUG("Prepared inputs, running inference...");
+    auto [out_f, out_vc, out_vf] = m_onnxUtil->runInference(gnn_input);
+    ATH_MSG_DEBUG("Finished compute!");
+    return std::make_tuple(out_f, out_vc, out_vf);
+}
+
+bool TauGNN::calculateInputVariables(const xAOD::TauJet &tau,
+                  const std::vector<const xAOD::TauTrack *> &tracks,
+                  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                  std::map<std::string, std::map<std::string, double>>& scalarInputs,
+                  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const {
+    scalarInputs.clear();
+    vectorInputs.clear();
+    // Populate input (sequence) map with input variables
+    for (const auto &varname : m_scalarCalc_inputs) {
+        if (!m_var_calc->compute(varname, tau,
+                                 scalarInputs[m_config.input_layer_scalar][varname])) {
+            ATH_MSG_WARNING("Error computing '" << varname
+                            << "' returning default");
+            return false;
+        }
+    }
+
+    for (const auto &varname : m_trackCalc_inputs) {
+        if (!m_var_calc->compute(varname, tau, tracks,
+                                 vectorInputs[m_config.input_layer_tracks][varname])) {
+            ATH_MSG_WARNING("Error computing '" << varname
+                            << "' returning default");
+            return false;
+        }
+    }
+
+    for (const auto &varname : m_clusterCalc_inputs) {
+        if (!m_var_calc->compute(varname, tau, clusters,
+                                 vectorInputs[m_config.input_layer_clusters][varname])) {
+            ATH_MSG_WARNING("Error computing '" << varname
+                            << "' returning default");
+            return false;
+        }
+    }
+    return true;
+}
diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..bc6b27d0b346277438be6cf56e730589fac9be3f
--- /dev/null
+++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx
@@ -0,0 +1,179 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "tauRecTools/TauGNNEvaluator.h"
+#include "tauRecTools/TauGNN.h"
+#include "tauRecTools/HelperFunctions.h"
+
+#include "PathResolver/PathResolver.h"
+
+#include <algorithm>
+
+
+TauGNNEvaluator::TauGNNEvaluator(const std::string &name): 
+  TauRecToolBase(name),
+  m_net(nullptr){
+    
+  declareProperty("NetworkFile", m_weightfile = "");
+  declareProperty("OutputVarname", m_output_varname = "GNTauScore");
+  declareProperty("OutputPTau", m_output_ptau = "GNTauProbTau");
+  declareProperty("OutputPJet", m_output_pjet = "GNTauProbJet");
+  declareProperty("MaxTracks", m_max_tracks = 30);
+  declareProperty("MaxClusters", m_max_clusters = 20);
+  declareProperty("MaxClusterDR", m_max_cluster_dr = 1.0f);
+  declareProperty("VertexCorrection", m_doVertexCorrection = true);
+  declareProperty("DecorateTracks", m_decorateTracks = false);
+  declareProperty("TrackClassification", m_doTrackClassification = true);
+  declareProperty("MinTauPt", m_minTauPt = 0.);
+
+  // Naming conventions for the network weight files:
+  declareProperty("InputLayerScalar", m_input_layer_scalar = "tau_vars");
+  declareProperty("InputLayerTracks", m_input_layer_tracks = "track_vars");
+  declareProperty("InputLayerClusters", m_input_layer_clusters = "cluster_vars");
+  declareProperty("NodeNameTau", m_outnode_tau = "GN2TauNoAux_pb");
+  declareProperty("NodeNameJet", m_outnode_jet = "GN2TauNoAux_pu");
+  }
+
+TauGNNEvaluator::~TauGNNEvaluator() {}
+
+StatusCode TauGNNEvaluator::initialize() {
+  ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks<<" tracks and "<<m_max_clusters<<" clusters...");
+  
+  std::string weightfile("");
+
+  // Use PathResolver to search for the weight files
+  if (!m_weightfile.empty()) {
+    weightfile = find_file(m_weightfile);
+    if (weightfile.empty()) {
+      ATH_MSG_ERROR("Could not find network weights: " << m_weightfile);
+      return StatusCode::FAILURE;
+    } else {
+      ATH_MSG_INFO("Using network config: " << weightfile);
+    }
+  }
+
+  // Set the layer and node names in the weight file
+  TauGNN::Config config;
+  config.input_layer_scalar = m_input_layer_scalar;
+  config.input_layer_tracks = m_input_layer_tracks;
+  config.input_layer_clusters = m_input_layer_clusters;
+  config.output_node_tau = m_outnode_tau;
+  config.output_node_jet = m_outnode_jet;
+
+  // Load the weights and create the network
+  if (!weightfile.empty()) {
+    m_net = std::make_unique<TauGNN>(weightfile, config);
+    if (!m_net) {
+      ATH_MSG_ERROR("No network configured.");
+      return StatusCode::FAILURE;
+    }
+  }
+
+  return StatusCode::SUCCESS;
+}
+
+StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const {
+  // Output variable Decorators
+  const SG::AuxElement::Accessor<float> output(m_output_varname);
+  const SG::AuxElement::Accessor<float> out_ptau(m_output_ptau);
+  const SG::AuxElement::Accessor<float> out_pjet(m_output_pjet);
+  const SG::AuxElement::Decorator<char> out_trkclass("GNTau_TrackClass");
+  // Set default score and overwrite later
+  output(tau) = -1111.0f;
+  out_ptau(tau) = -1111.0f;
+  out_pjet(tau) = -1111.0f;
+
+  //Skip execution for low-pT taus to save resources
+  if (tau.pt() < m_minTauPt) {
+    return StatusCode::SUCCESS;
+  }
+
+  // Get input objects
+  ATH_MSG_DEBUG("Fetching Tracks");
+  std::vector<const xAOD::TauTrack *> tracks;
+  ATH_CHECK(get_tracks(tau, tracks));
+  ATH_MSG_DEBUG("Fetching clusters");
+  std::vector<xAOD::CaloVertexedTopoCluster> clusters;
+  ATH_CHECK(get_clusters(tau, clusters));
+  ATH_MSG_DEBUG("Constituent fetching done...");
+
+  // Truncate tracks
+  int numTracksMax = std::min(m_max_tracks, static_cast<int>(tracks.size()));
+  std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
+  // Evaluate networks
+  if (m_net) {
+    auto [out_f, out_vc, out_vf] = m_net->compute(tau, trackVec, clusters);
+    output(tau)=std::log10(1/(1-out_f.at(m_outnode_tau)));
+    out_ptau(tau)=out_f.at(m_outnode_tau);
+    out_pjet(tau)=out_f.at(m_outnode_jet);
+    if (m_decorateTracks){
+      for(unsigned int i=0;i<tracks.size();i++){
+        if(i<out_vc.at("track_class").size()){out_trkclass(*tracks.at(i))=out_vc.at("track_class").at(i);}
+        else{out_trkclass(*tracks.at(i))='9';} //Dummy value for tracks outside range of out_vc
+      }
+    }
+  }
+  
+  return StatusCode::SUCCESS;
+}
+
+const TauGNN* TauGNNEvaluator::get_gnn() const {
+  return m_net.get();
+}
+
+
+StatusCode TauGNNEvaluator::get_tracks(const xAOD::TauJet &tau, std::vector<const xAOD::TauTrack *> &out) const {
+  std::vector<const xAOD::TauTrack*> tracks = tau.allTracks();
+
+  // Skip unclassified tracks:
+  // - the track is a LRT and classifyLRT = false
+  // - the track is not among the MaxNtracks highest-pt tracks in the track classifier
+  // - track classification is not run (trigger)
+  if(m_doTrackClassification) {
+    std::vector<const xAOD::TauTrack*>::iterator it = tracks.begin();
+    while(it != tracks.end()) {
+      if((*it)->flag(xAOD::TauJetParameters::unclassified)) {
+	it = tracks.erase(it);
+      }
+      else {
+	++it;
+      }
+    }
+  }
+
+  // Sort by descending pt
+  auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
+    return lhs->pt() > rhs->pt();
+  };
+  std::sort(tracks.begin(), tracks.end(), cmp_pt);
+  out = std::move(tracks);
+
+  return StatusCode::SUCCESS;
+}
+
+StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const {
+
+  TLorentzVector tauAxis = tauRecTools::getTauAxis(tau, m_doVertexCorrection);
+
+  for (const xAOD::CaloVertexedTopoCluster& vertexedCluster : tau.vertexedClusters()) {
+    TLorentzVector clusterP4 = vertexedCluster.p4();
+    if (clusterP4.DeltaR(tauAxis) > m_max_cluster_dr) continue;
+      
+    clusters.push_back(vertexedCluster);
+  }
+
+  // Sort by descending et
+  auto et_cmp = [](const xAOD::CaloVertexedTopoCluster& lhs,
+		   const xAOD::CaloVertexedTopoCluster& rhs) {
+    return lhs.p4().Et() > rhs.p4().Et();
+  };
+  std::sort(clusters.begin(), clusters.end(), et_cmp);
+
+  // Truncate clusters
+  if (static_cast<int>(clusters.size()) > m_max_clusters) {
+    clusters.resize(m_max_clusters, clusters[0]);
+  }
+
+  return StatusCode::SUCCESS;
+}
diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..32b4bc2afb2b815683f91c246ffc3ad276111fcc
--- /dev/null
+++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx
@@ -0,0 +1,930 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#include "tauRecTools/TauGNNUtils.h"
+#include "tauRecTools/HelperFunctions.h"
+#include <algorithm>
+#include <iostream>
+#define GeV 1000
+
+namespace TauGNNUtils {
+
+GNNVarCalc::GNNVarCalc() : asg::AsgMessaging("TauGNNUtils::GNNVarCalc") {
+}
+
+bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
+                      double &out) const {
+    // Retrieve calculator function
+    ScalarCalc func = nullptr;
+    try {
+        func = m_scalar_map.at(name);
+    } catch (const std::out_of_range &e) {
+        ATH_MSG_ERROR("Variable '" << name << "' not defined");
+        throw;
+    }
+
+    // Calculate variable
+    return func(tau, out);
+}
+
+bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
+                      const std::vector<const xAOD::TauTrack *> &tracks,
+                      std::vector<double> &out) const {
+    out.clear();
+    out.reserve(tracks.size());
+
+    // Retrieve calculator function
+    TrackCalc func = nullptr;
+    try {
+        func = m_track_map.at(name);
+    } catch (const std::out_of_range &e) {
+        ATH_MSG_ERROR("Variable '" << name << "' not defined");
+        throw;
+    }
+
+    // Calculate variables for selected tracks
+    bool success = true;
+    double value;
+    for (const auto *const trk : tracks) {
+        success = success && func(tau, *trk, value);
+        out.push_back(value);
+    }
+
+    return success;
+}
+
+bool GNNVarCalc::compute(const std::string &name, const xAOD::TauJet &tau,
+                      const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                      std::vector<double> &out) const {
+    out.clear();
+    out.reserve(clusters.size());
+
+    // Retrieve calculator function
+    ClusterCalc func = nullptr;
+    try {
+        func = m_cluster_map.at(name);
+    } catch (const std::out_of_range &e) {
+        ATH_MSG_ERROR("Variable '" << name << "' not defined");
+        throw;
+    }
+
+    // Calculate variables for selected clusters
+    bool success = true;
+    double value;
+    for (const xAOD::CaloVertexedTopoCluster& cluster : clusters) {
+        success = success && func(tau, cluster, value);
+        out.push_back(value);
+    }
+
+    return success;
+}
+
+void GNNVarCalc::insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars) {
+    if (std::find(scalar_vars.begin(), scalar_vars.end(), name) == scalar_vars.end()) {
+      return;
+    }
+    if (!func) {
+        throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert");
+    }
+    m_scalar_map[name] = func;
+}
+
+void GNNVarCalc::insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars) {
+    if (std::find(track_vars.begin(), track_vars.end(), name) == track_vars.end()) {
+      return;
+    }
+    if (!func) {
+        throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert");
+    }
+    m_track_map[name] = func;
+}
+
+void GNNVarCalc::insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars) {
+    if (std::find(cluster_vars.begin(), cluster_vars.end(), name) == cluster_vars.end()) {
+      return;
+    }
+    if (!func) {
+        throw std::invalid_argument("Nullptr passed to GNNVarCalc::insert");
+    }
+    m_cluster_map[name] = func;
+}
+
+std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars,
+					const std::vector<std::string>& track_vars,
+					const std::vector<std::string>& cluster_vars) {
+    auto calc = std::make_unique<GNNVarCalc>();
+
+    // Scalar variable calculator functions
+    calc->insert("absEta", Variables::absEta, scalar_vars);
+    calc->insert("isolFrac", Variables::isolFrac, scalar_vars);
+    calc->insert("centFrac", Variables::centFrac, scalar_vars);
+    calc->insert("etOverPtLeadTrk", Variables::etOverPtLeadTrk, scalar_vars);
+    calc->insert("innerTrkAvgDist", Variables::innerTrkAvgDist, scalar_vars);
+    calc->insert("absipSigLeadTrk", Variables::absipSigLeadTrk, scalar_vars);
+    calc->insert("SumPtTrkFrac", Variables::SumPtTrkFrac, scalar_vars);
+    calc->insert("sumEMCellEtOverLeadTrkPt", Variables::sumEMCellEtOverLeadTrkPt, scalar_vars);
+    calc->insert("EMPOverTrkSysP", Variables::EMPOverTrkSysP, scalar_vars);
+    calc->insert("ptRatioEflowApprox", Variables::ptRatioEflowApprox, scalar_vars);
+    calc->insert("mEflowApprox", Variables::mEflowApprox, scalar_vars);
+    calc->insert("dRmax", Variables::dRmax, scalar_vars);
+    calc->insert("trFlightPathSig", Variables::trFlightPathSig, scalar_vars);
+    calc->insert("massTrkSys", Variables::massTrkSys, scalar_vars);
+    calc->insert("pt", Variables::pt, scalar_vars);
+    calc->insert("pt_tau_log", Variables::pt_tau_log, scalar_vars);
+    calc->insert("ptDetectorAxis", Variables::ptDetectorAxis, scalar_vars);
+    calc->insert("ptIntermediateAxis", Variables::ptIntermediateAxis, scalar_vars);
+    //---added for the eVeto
+    calc->insert("ptJetSeed_log",              Variables::ptJetSeed_log, scalar_vars);
+    calc->insert("absleadTrackEta",            Variables::absleadTrackEta, scalar_vars);
+    calc->insert("leadTrackDeltaEta",          Variables::leadTrackDeltaEta, scalar_vars);
+    calc->insert("leadTrackDeltaPhi",          Variables::leadTrackDeltaPhi, scalar_vars);
+    calc->insert("leadTrackProbNNorHT",        Variables::leadTrackProbNNorHT, scalar_vars);
+    calc->insert("EMFracFixed",                Variables::EMFracFixed, scalar_vars);
+    calc->insert("etHotShotWinOverPtLeadTrk",  Variables::etHotShotWinOverPtLeadTrk, scalar_vars);
+    calc->insert("hadLeakFracFixed",           Variables::hadLeakFracFixed, scalar_vars);
+    calc->insert("PSFrac",                     Variables::PSFrac, scalar_vars);
+    calc->insert("ClustersMeanCenterLambda",   Variables::ClustersMeanCenterLambda, scalar_vars);
+    calc->insert("ClustersMeanFirstEngDens",   Variables::ClustersMeanFirstEngDens, scalar_vars);
+    calc->insert("ClustersMeanPresamplerFrac", Variables::ClustersMeanPresamplerFrac, scalar_vars);
+
+    // Track variable calculator functions
+    calc->insert("pt_log", Variables::Track::pt_log, track_vars);
+    calc->insert("trackPt", Variables::Track::trackPt, track_vars);
+    calc->insert("trackEta", Variables::Track::trackEta, track_vars);
+    calc->insert("trackPhi", Variables::Track::trackPhi, track_vars);
+    calc->insert("pt_tau_log", Variables::Track::pt_tau_log, track_vars);
+    calc->insert("pt_jetseed_log", Variables::Track::pt_jetseed_log, track_vars);
+    calc->insert("d0_abs_log", Variables::Track::d0_abs_log, track_vars);
+    calc->insert("z0sinThetaTJVA_abs_log", Variables::Track::z0sinThetaTJVA_abs_log, track_vars);
+    calc->insert("z0sinthetaTJVA", Variables::Track::z0sinthetaTJVA, track_vars);
+    calc->insert("z0sinthetaSigTJVA", Variables::Track::z0sinthetaSigTJVA, track_vars);
+    calc->insert("d0TJVA", Variables::Track::d0TJVA, track_vars);
+    calc->insert("d0SigTJVA", Variables::Track::d0SigTJVA, track_vars);
+    calc->insert("dEta", Variables::Track::dEta, track_vars);
+    calc->insert("dEtaJetSeedAxis", Variables::Track::dEtaJetSeedAxis, track_vars);
+    calc->insert("dPhi", Variables::Track::dPhi, track_vars);
+    calc->insert("dPhiJetSeedAxis", Variables::Track::dPhiJetSeedAxis, track_vars);
+    calc->insert("nInnermostPixelHits", Variables::Track::nInnermostPixelHits, track_vars);
+    calc->insert("nPixelHits", Variables::Track::nPixelHits, track_vars);
+    calc->insert("nSCTHits", Variables::Track::nSCTHits, track_vars);
+    calc->insert("nIBLHitsAndExp", Variables::Track::nIBLHitsAndExp, track_vars);
+    calc->insert("nPixelHitsPlusDeadSensors", Variables::Track::nPixelHitsPlusDeadSensors, track_vars);
+    calc->insert("nSCTHitsPlusDeadSensors", Variables::Track::nSCTHitsPlusDeadSensors, track_vars);
+    calc->insert("eProbabilityHT", Variables::Track::eProbabilityHT, track_vars);
+    calc->insert("eProbabilityNN", Variables::Track::eProbabilityNN, track_vars);
+    calc->insert("eProbabilityNNorHT", Variables::Track::eProbabilityNNorHT, track_vars);
+    calc->insert("chargedScoreRNN", Variables::Track::chargedScoreRNN, track_vars);
+    calc->insert("isolationScoreRNN", Variables::Track::isolationScoreRNN, track_vars);
+    calc->insert("conversionScoreRNN", Variables::Track::conversionScoreRNN, track_vars);
+    calc->insert("fakeScoreRNN", Variables::Track::fakeScoreRNN, track_vars);
+    //Extension - variables for GNTau
+    calc->insert("numberOfInnermostPixelLayerHits", Variables::Track::numberOfInnermostPixelLayerHits, track_vars);
+    calc->insert("numberOfPixelHits", Variables::Track::numberOfPixelHits, track_vars);
+    calc->insert("numberOfPixelSharedHits", Variables::Track::numberOfPixelSharedHits, track_vars);
+    calc->insert("numberOfPixelDeadSensors", Variables::Track::numberOfPixelDeadSensors, track_vars);
+    calc->insert("numberOfSCTHits", Variables::Track::numberOfSCTHits, track_vars);
+    calc->insert("numberOfSCTSharedHits", Variables::Track::numberOfSCTSharedHits, track_vars);
+    calc->insert("numberOfSCTDeadSensors", Variables::Track::numberOfSCTDeadSensors, track_vars);
+    calc->insert("numberOfTRTHighThresholdHits", Variables::Track::numberOfTRTHighThresholdHits, track_vars);
+    calc->insert("numberOfTRTHits", Variables::Track::numberOfTRTHits, track_vars);
+    calc->insert("nSiHits", Variables::Track::nSiHits, track_vars);
+    calc->insert("expectInnermostPixelLayerHit", Variables::Track::expectInnermostPixelLayerHit, track_vars);
+    calc->insert("expectNextToInnermostPixelLayerHit", Variables::Track::expectNextToInnermostPixelLayerHit, track_vars);
+    calc->insert("numberOfContribPixelLayers", Variables::Track::numberOfContribPixelLayers, track_vars);
+    calc->insert("numberOfPixelHoles", Variables::Track::numberOfPixelHoles, track_vars);
+    calc->insert("d0_old", Variables::Track::d0_old, track_vars);
+    calc->insert("qOverP", Variables::Track::qOverP, track_vars);
+    calc->insert("theta", Variables::Track::theta, track_vars);
+    calc->insert("z0TJVA", Variables::Track::z0TJVA, track_vars);
+    calc->insert("charge", Variables::Track::charge, track_vars);
+    calc->insert("dz0_TV_PV0", Variables::Track::dz0_TV_PV0, track_vars);
+    calc->insert("log_sumpt_TV", Variables::Track::log_sumpt_TV, track_vars);
+    calc->insert("log_sumpt2_TV", Variables::Track::log_sumpt2_TV, track_vars);
+    calc->insert("log_sumpt_PV0", Variables::Track::log_sumpt_PV0, track_vars);
+    calc->insert("log_sumpt2_PV0", Variables::Track::log_sumpt2_PV0, track_vars);
+
+    // Cluster variable calculator functions
+    calc->insert("et_log", Variables::Cluster::et_log, cluster_vars);
+    calc->insert("pt_tau_log", Variables::Cluster::pt_tau_log, cluster_vars);
+    calc->insert("pt_jetseed_log", Variables::Cluster::pt_jetseed_log, cluster_vars);
+    calc->insert("dEta", Variables::Cluster::dEta, cluster_vars);
+    calc->insert("dPhi", Variables::Cluster::dPhi, cluster_vars);
+    calc->insert("SECOND_R", Variables::Cluster::SECOND_R, cluster_vars);
+    calc->insert("SECOND_LAMBDA", Variables::Cluster::SECOND_LAMBDA, cluster_vars);
+    calc->insert("CENTER_LAMBDA", Variables::Cluster::CENTER_LAMBDA, cluster_vars);
+    //---added for the eVeto
+    calc->insert("SECOND_LAMBDAOverClustersMeanSecondLambda", Variables::Cluster::SECOND_LAMBDAOverClustersMeanSecondLambda, cluster_vars);
+    calc->insert("CENTER_LAMBDAOverClustersMeanCenterLambda", Variables::Cluster::CENTER_LAMBDAOverClustersMeanCenterLambda, cluster_vars);
+    calc->insert("FirstEngDensOverClustersMeanFirstEngDens" , Variables::Cluster::FirstEngDensOverClustersMeanFirstEngDens, cluster_vars);
+
+    //Extension - Variables for GNTau
+    calc->insert("e", Variables::Cluster::e, cluster_vars);
+    calc->insert("et", Variables::Cluster::et, cluster_vars);
+    calc->insert("FIRST_ENG_DENS", Variables::Cluster::FIRST_ENG_DENS, cluster_vars);
+    calc->insert("EM_PROBABILITY", Variables::Cluster::EM_PROBABILITY, cluster_vars);
+    calc->insert("CENTER_MAG", Variables::Cluster::CENTER_MAG, cluster_vars);
+    return calc;
+}
+
+
+namespace Variables {
+using TauDetail = xAOD::TauJetParameters::Detail;
+
+bool absEta(const xAOD::TauJet &tau, double &out) {
+    out = std::abs(tau.eta());
+    return true;
+}
+
+bool centFrac(const xAOD::TauJet &tau, double &out) {
+    float centFrac;
+    const auto success = tau.detail(TauDetail::centFrac, centFrac);
+    //out = std::min(centFrac, 1.0f);
+    out = centFrac;
+    return success;
+}
+
+bool isolFrac(const xAOD::TauJet &tau, double &out) {
+    float isolFrac;
+    const auto success = tau.detail(TauDetail::isolFrac, isolFrac);
+    //out = std::min(isolFrac, 1.0f);
+    out = isolFrac;
+    return success;
+}
+
+bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out) {
+    float etOverPtLeadTrk;
+    const auto success = tau.detail(TauDetail::etOverPtLeadTrk, etOverPtLeadTrk);
+    out = etOverPtLeadTrk;
+    return success;
+}
+
+bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out) {
+    float innerTrkAvgDist;
+    const auto success = tau.detail(TauDetail::innerTrkAvgDist, innerTrkAvgDist);
+    out = innerTrkAvgDist;
+    return success;
+}
+
+bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out) {
+    float ipSigLeadTrk = (tau.nTracks()>0) ? tau.track(0)->d0SigTJVA() : 0.;
+    //out = std::min(std::abs(ipSigLeadTrk), 30.0f);
+    out = std::abs(ipSigLeadTrk);
+    return true;
+}
+
+bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out) {
+    float sumEMCellEtOverLeadTrkPt;
+    const auto success = tau.detail(TauDetail::sumEMCellEtOverLeadTrkPt, sumEMCellEtOverLeadTrkPt);
+    out = sumEMCellEtOverLeadTrkPt;
+    return success;
+}
+
+bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out) {
+    float SumPtTrkFrac;
+    const auto success = tau.detail(TauDetail::SumPtTrkFrac, SumPtTrkFrac);
+    out = SumPtTrkFrac;
+    return success;
+}
+
+bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out) {
+    float EMPOverTrkSysP;
+    const auto success = tau.detail(TauDetail::EMPOverTrkSysP, EMPOverTrkSysP);
+    out = EMPOverTrkSysP;
+    return success;
+}
+
+bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out) {
+    float ptRatioEflowApprox;
+    const auto success = tau.detail(TauDetail::ptRatioEflowApprox, ptRatioEflowApprox);
+    //out = std::min(ptRatioEflowApprox, 4.0f);
+    out = ptRatioEflowApprox;
+    return success;
+}
+
+bool mEflowApprox(const xAOD::TauJet &tau, double &out) {
+    float mEflowApprox;
+    const auto success = tau.detail(TauDetail::mEflowApprox, mEflowApprox);
+    out = mEflowApprox;
+    return success;
+}
+
+bool dRmax(const xAOD::TauJet &tau, double &out) {
+    float dRmax;
+    const auto success = tau.detail(TauDetail::dRmax, dRmax);
+    out = dRmax;
+    return success;
+}
+
+bool trFlightPathSig(const xAOD::TauJet &tau, double &out) {
+    float trFlightPathSig;
+    const auto success = tau.detail(TauDetail::trFlightPathSig, trFlightPathSig);
+    out = trFlightPathSig;
+    return success;
+}
+
+bool massTrkSys(const xAOD::TauJet &tau, double &out) {
+    float massTrkSys;
+    const auto success = tau.detail(TauDetail::massTrkSys, massTrkSys);
+    out = massTrkSys;
+    return success;
+}
+
+bool pt(const xAOD::TauJet &tau, double &out) {
+    out = tau.pt();
+    return true;
+}
+
+bool pt_tau_log(const xAOD::TauJet &tau, double &out) {
+    out = std::log10(std::max(tau.pt() / GeV, 1e-6));
+    return true;
+}
+
+bool ptDetectorAxis(const xAOD::TauJet &tau, double &out) {
+    out = tau.ptDetectorAxis();
+    return true;
+}
+
+bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out) {
+    out = tau.ptIntermediateAxis();
+    return true;
+}
+
+bool ptJetSeed_log(const xAOD::TauJet &tau, double &out) {
+  out = std::log10(std::max(tau.ptJetSeed(), 1e-3));
+  return true;
+}
+
+bool absleadTrackEta(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_absEtaLeadTrack("ABS_ETA_LEAD_TRACK");
+  float absEtaLeadTrack = acc_absEtaLeadTrack(tau);
+  out = std::max(0.f, absEtaLeadTrack);
+  return true;
+}
+
+bool leadTrackDeltaEta(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_absDeltaEta("TAU_ABSDELTAETA");
+  float absDeltaEta = acc_absDeltaEta(tau);
+  out = std::max(0.f, absDeltaEta);
+  return true;
+}
+
+bool leadTrackDeltaPhi(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_absDeltaPhi("TAU_ABSDELTAPHI");
+  float absDeltaPhi = acc_absDeltaPhi(tau);
+  out = std::max(0.f, absDeltaPhi);
+  return true;
+}
+
+bool leadTrackProbNNorHT(const xAOD::TauJet &tau, double &out){
+  auto tracks = tau.allTracks();
+
+  // Sort tracks in descending pt order
+  if (!tracks.empty()) {
+    auto cmp_pt = [](const xAOD::TauTrack *lhs, const xAOD::TauTrack *rhs) {
+      return lhs->pt() > rhs->pt();
+    };
+    std::sort(tracks.begin(), tracks.end(), cmp_pt);
+
+    const xAOD::TauTrack* tauLeadTrack = tracks.at(0);
+    const xAOD::TrackParticle* xTrackParticle = tauLeadTrack->track();
+    float eProbabilityHT = xTrackParticle->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
+    static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
+    float eProbabilityNN = acc_eProbabilityNN(*xTrackParticle);
+    out = (tauLeadTrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT;
+  }
+  else {
+    out = 0.;
+  }
+  return true;
+}
+
+bool EMFracFixed(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_emFracFixed("EMFracFixed");
+  float emFracFixed = acc_emFracFixed(tau);
+  out = std::max(emFracFixed, 0.0f);
+  return true;
+}
+
+bool etHotShotWinOverPtLeadTrk(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_etHotShotWinOverPtLeadTrk("etHotShotWinOverPtLeadTrk");
+  float etHotShotWinOverPtLeadTrk = acc_etHotShotWinOverPtLeadTrk(tau);
+  out = std::max(etHotShotWinOverPtLeadTrk, 1e-6f);
+  return true;
+}
+
+bool hadLeakFracFixed(const xAOD::TauJet &tau, double &out){
+  static const SG::AuxElement::ConstAccessor<float> acc_hadLeakFracFixed("hadLeakFracFixed");
+  float hadLeakFracFixed = acc_hadLeakFracFixed(tau);
+  out = std::max(0.f, hadLeakFracFixed);
+  return true;
+}
+
+bool PSFrac(const xAOD::TauJet &tau, double &out){
+  float PSFrac;
+  const auto success = tau.detail(TauDetail::PSSFraction, PSFrac);
+  out = std::max(0.f,PSFrac);  
+  return success;
+}
+
+bool ClustersMeanCenterLambda(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanCenterLambda;
+  const auto success = tau.detail(TauDetail::ClustersMeanCenterLambda, ClustersMeanCenterLambda);
+  out = std::max(0.f, ClustersMeanCenterLambda);
+  return success;
+}
+
+bool ClustersMeanEMProbability(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanEMProbability;
+  const auto success = tau.detail(TauDetail::ClustersMeanEMProbability, ClustersMeanEMProbability);
+  out = std::max(0.f, ClustersMeanEMProbability);
+  return success;
+}
+
+bool ClustersMeanFirstEngDens(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanFirstEngDens;
+  const auto success = tau.detail(TauDetail::ClustersMeanFirstEngDens, ClustersMeanFirstEngDens);
+  out =  std::max(-10.f, ClustersMeanFirstEngDens);
+  return success;
+}
+
+bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanPresamplerFrac;
+  const auto success = tau.detail(TauDetail::ClustersMeanPresamplerFrac, ClustersMeanPresamplerFrac);
+  out = std::max(0.f, ClustersMeanPresamplerFrac);
+  return success;
+}
+
+bool ClustersMeanSecondLambda(const xAOD::TauJet &tau, double &out){
+  float ClustersMeanSecondLambda;
+  const auto success = tau.detail(TauDetail::ClustersMeanSecondLambda, ClustersMeanSecondLambda);
+  out = std::max(0.f, ClustersMeanSecondLambda);
+  return success;
+}
+
+namespace Track {
+
+bool pt_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = std::log10(track.pt());
+    return true;
+}
+
+bool trackPt(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.pt();
+    return true;
+}
+
+bool trackEta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.eta();
+    return true;
+}
+
+bool trackPhi(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.phi();
+    return true;
+}
+
+bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) {
+    out = std::log10(std::max(tau.pt(), 1e-6));
+    return true;
+}
+
+bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::TauTrack& /*track*/, double &out) {
+    out = std::log10(tau.ptJetSeed());
+    return true;
+}
+
+bool d0_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = std::log10(std::abs(track.d0TJVA()) + 1e-6);
+    return true;
+}
+
+bool z0sinThetaTJVA_abs_log(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = std::log10(std::abs(track.z0sinthetaTJVA()) + 1e-6);
+    return true;
+}
+
+bool z0sinthetaTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.z0sinthetaTJVA();
+    return true;
+}
+
+bool z0sinthetaSigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.z0sinthetaSigTJVA();
+    return true;
+}
+
+bool d0TJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.d0TJVA();
+    return true;
+}
+
+bool d0SigTJVA(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.d0SigTJVA();
+    return true;
+}
+
+bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    out = track.eta() - tau.eta();
+    return true;
+}
+
+bool dEtaJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed);
+    out = std::abs(tlvSeedJet.Eta() - track.eta());
+    return true;
+}
+
+bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    out = track.p4().DeltaPhi(tau.p4());
+    return true;
+}
+
+bool dPhiJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed);
+    out = tlvSeedJet.DeltaPhi(track.p4());
+    return true;
+}
+
+bool nInnermostPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t inner_pixel_hits;
+    const auto success = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits);
+    out = inner_pixel_hits;
+    return success;
+}
+
+bool nPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t pixel_hits;
+    const auto success = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits);
+    out = pixel_hits;
+    return success;
+}
+
+bool nSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t sct_hits;
+    const auto success = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits);
+    out = sct_hits;
+    return success;
+}
+
+// same as in tau track classification for trigger
+bool nIBLHitsAndExp(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t inner_pixel_hits, inner_pixel_exp;
+    const auto success1 = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits);
+    const auto success2 = track.track()->summaryValue(inner_pixel_exp, xAOD::expectInnermostPixelLayerHit);
+    out =  inner_pixel_exp ? inner_pixel_hits : 1.;
+    return success1 && success2;
+}
+
+bool nPixelHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t pixel_hits, pixel_dead;
+    const auto success1 = track.track()->summaryValue(pixel_hits, xAOD::numberOfPixelHits);
+    const auto success2 = track.track()->summaryValue(pixel_dead, xAOD::numberOfPixelDeadSensors);
+    out = pixel_hits + pixel_dead;
+    return success1 && success2;
+}
+
+bool nSCTHitsPlusDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t sct_hits, sct_dead;
+    const auto success1 = track.track()->summaryValue(sct_hits, xAOD::numberOfSCTHits);
+    const auto success2 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors);
+    out = sct_hits + sct_dead;
+    return success1 && success2;
+}
+
+bool eProbabilityHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    float eProbabilityHT;
+    const auto success = track.track()->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
+    out = eProbabilityHT;
+    return success;
+}
+
+bool eProbabilityNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {  
+    static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
+    out = acc_eProbabilityNN(track);
+    return true;
+}
+
+bool eProbabilityNNorHT(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {  
+  auto atrack = track.track();
+  float eProbabilityHT = atrack->summaryValue(eProbabilityHT, xAOD::eProbabilityHT);
+  static const SG::AuxElement::ConstAccessor<float> acc_eProbabilityNN("eProbabilityNN");
+  float eProbabilityNN = acc_eProbabilityNN(*atrack);
+  out = (atrack->pt()>2000.) ? eProbabilityNN : eProbabilityHT;
+  return true;
+}
+
+bool chargedScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_chargedScoreRNN("rnn_chargedScore");
+  out = acc_chargedScoreRNN(track);
+  return true;
+}
+
+bool isolationScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_isolationScoreRNN("rnn_isolationScore");
+  out = acc_isolationScoreRNN(track);
+  return true;
+}
+
+bool conversionScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_conversionScoreRNN("rnn_conversionScore");
+  out = acc_conversionScoreRNN(track);
+  return true;
+}
+
+bool fakeScoreRNN(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_fakeScoreRNN("rnn_fakeScore");
+  out = acc_fakeScoreRNN(track);
+  return true;
+}
+
+//Extension - variables for GNTau
+bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfInnermostPixelLayerHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelSharedHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelDeadSensors);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfSCTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfSCTSharedHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTSharedHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfSCTDeadSensors(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfSCTDeadSensors);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfTRTHighThresholdHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHighThresholdHits);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfTRTHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfTRTHits);
+    out = trk_val;
+    return success;
+}
+
+bool nSiHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t pix_hit = 0;uint8_t pix_dead = 0;uint8_t sct_hit = 0;uint8_t sct_dead = 0;
+    const auto success1 = track.track()->summaryValue(pix_hit, xAOD::numberOfPixelHits);
+    const auto success2 = track.track()->summaryValue(pix_dead, xAOD::numberOfPixelDeadSensors);
+    const auto success3 = track.track()->summaryValue(sct_hit, xAOD::numberOfSCTHits);
+    const auto success4 = track.track()->summaryValue(sct_dead, xAOD::numberOfSCTDeadSensors);
+    out = pix_hit + pix_dead + sct_hit + sct_dead;
+    return success1 && success2 && success3 && success4;
+}
+
+bool expectInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::expectInnermostPixelLayerHit);
+    out = trk_val;
+    return success;
+}
+
+bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::expectNextToInnermostPixelLayerHit);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfContribPixelLayers(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfContribPixelLayers);
+    out = trk_val;
+    return success;
+}
+
+bool numberOfPixelHoles(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    uint8_t trk_val = 0;
+    const auto success = track.track()->summaryValue(trk_val, xAOD::numberOfPixelHoles);
+    out = trk_val;
+    return success;
+}
+
+bool d0_old(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->d0();
+    //out = trk_val;
+    return true;
+}
+
+bool qOverP(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->qOverP();
+    return true;
+}
+
+bool theta(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->theta();
+    return true;
+}
+
+bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->z0() + track.track()->vz() - tau.vertex()->z();
+    return true;
+}
+
+bool charge(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
+    out = track.track()->charge();
+    return true;
+}
+
+bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out = 0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_dz0TVPV0("dz0_TV_PV0");
+    if (tau.isAvailable<float>("dz0_TV_PV0")){out = acc_dz0TVPV0(tau);}
+    return true;
+}
+
+bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumptTV("log_sumpt_TV");
+    if (tau.isAvailable<float>("log_sumpt_TV")){out=acc_logsumptTV(tau);}
+    return true;
+}
+
+bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2TV("log_sumpt2_TV");
+    if (tau.isAvailable<float>("log_sumpt2_TV")){out=acc_logsumpt2TV(tau);}
+    return true;
+}
+
+bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumptPV0("log_sumpt_PV0");
+    if (tau.isAvailable<float>("log_sumpt_PV0")){out=acc_logsumptPV0(tau);}
+    return true;
+}
+
+bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &/*track*/, double &out) {
+    out=0.;
+    static const SG::AuxElement::ConstAccessor<float> acc_logsumpt2PV0("log_sumpt2_PV0");
+    if (tau.isAvailable<float>("log_sumpt2_PV0")){out=acc_logsumpt2PV0(tau);}
+    return true;
+}
+
+} // namespace Track
+
+
+namespace Cluster {
+using MomentType = xAOD::CaloCluster::MomentType;
+
+bool et_log(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = std::log10(cluster.p4().Et());
+    return true;
+}
+
+bool pt_tau_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) {
+    out = std::log10(std::max(tau.pt(), 1e-6));
+    return true;
+}
+
+bool pt_jetseed_log(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster& /*cluster*/, double &out) {
+    out = std::log10(tau.ptJetSeed());
+    return true;
+}
+
+bool dEta(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.eta() - tau.eta();
+    return true;
+}
+
+bool dPhi(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.p4().DeltaPhi(tau.p4());
+    return true;
+}
+
+bool SECOND_R(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_R, out);
+    return success;
+}
+
+bool SECOND_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, out);
+    return success;
+}
+
+bool CENTER_LAMBDA(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, out);
+    return success;
+}
+
+bool SECOND_LAMBDAOverClustersMeanSecondLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanSecondLambda("ClustersMeanSecondLambda");
+  float ClustersMeanSecondLambda = acc_ClustersMeanSecondLambda(tau);
+  double secondLambda(0);
+  const auto success = cluster.clust().retrieveMoment(MomentType::SECOND_LAMBDA, secondLambda);
+  out = (ClustersMeanSecondLambda != 0.) ? secondLambda/ClustersMeanSecondLambda : 0.;
+  return success;
+}
+
+bool CENTER_LAMBDAOverClustersMeanCenterLambda(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+  static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanCenterLambda("ClustersMeanCenterLambda");
+  float ClustersMeanCenterLambda = acc_ClustersMeanCenterLambda(tau);
+  double centerLambda(0);
+  const auto success = cluster.clust().retrieveMoment(MomentType::CENTER_LAMBDA, centerLambda);
+  if (ClustersMeanCenterLambda == 0.){
+    out = 250.;
+  }else {
+    out = centerLambda/ClustersMeanCenterLambda;
+  }
+
+  out = std::min(out, 250.);
+
+  return success;
+}
+
+
+bool FirstEngDensOverClustersMeanFirstEngDens(const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+  // the ClustersMeanFirstEngDens is the log10 of the energy weighted average of the First_ENG_DENS 
+  // divided by ETot to make it dimension-less, 
+  // so we need to evaluate the difference of log10(clusterFirstEngDens/clusterTotalEnergy) and the ClustersMeanFirstEngDens
+  double clusterFirstEngDens = 0.0;
+  bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens);
+  if (clusterFirstEngDens < 1e-6) clusterFirstEngDens = 1e-6;
+
+  static const SG::AuxElement::ConstAccessor<float> acc_ClusterTotalEnergy("ClusterTotalEnergy");
+  float clusterTotalEnergy = acc_ClusterTotalEnergy(tau);
+  if (clusterTotalEnergy < 1e-6) clusterTotalEnergy = 1e-6;
+
+  static const SG::AuxElement::ConstAccessor<float> acc_ClustersMeanFirstEngDens("ClustersMeanFirstEngDens");
+  float clustersMeanFirstEngDens = acc_ClustersMeanFirstEngDens(tau);
+
+  out = std::log10(clusterFirstEngDens/clusterTotalEnergy) - clustersMeanFirstEngDens;
+  
+  return status;
+}
+
+//Extension - Variables for GNTau
+bool e(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.p4().E();
+    return true;
+}
+
+bool et(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    out = cluster.p4().Et();
+    return true;
+}
+
+bool FIRST_ENG_DENS(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    double clusterFirstEngDens = 0.0;
+    bool status = cluster.clust().retrieveMoment(MomentType::FIRST_ENG_DENS, clusterFirstEngDens);
+    out = clusterFirstEngDens;
+    return status;
+}
+
+bool EM_PROBABILITY(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    double clusterEMprob = 0.0;
+    bool status = cluster.clust().retrieveMoment(MomentType::EM_PROBABILITY, clusterEMprob);
+    out = clusterEMprob;
+    return status;
+}
+
+bool CENTER_MAG(const xAOD::TauJet& /*tau*/, const xAOD::CaloVertexedTopoCluster &cluster, double &out) {
+    double clusterCenterMag = 0.0;
+    bool status = cluster.clust().retrieveMoment(MomentType::CENTER_MAG, clusterCenterMag);
+    out = clusterCenterMag;
+    return status;
+}
+
+} // namespace Cluster
+} // namespace Variables
+} // namespace TauGNNUtils
diff --git a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx
index 94a4fd2a6c110f946a77a00a5901c8682cf14f99..d6b83c9a7c56f9c3c0ba1e6dd9acbc114707e299 100644
--- a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx
+++ b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #include "tauRecTools/lwtnn/LightweightGraph.h"
@@ -66,7 +66,7 @@ namespace lwtDev {
 
   typedef LightweightGraph::NodeMap NodeMap;
   LightweightGraph::LightweightGraph(const GraphConfig& config,
-                                     std::string default_output):
+                                     const std::string& default_output):
     m_graph(new Graph(config.nodes, config.layers))
   {
     for (const auto& node: config.inputs) {
diff --git a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx
index ad895f07649f9c1193633baa283bf5815a1cd5d0..fff78982eb2f275a44617fc2881e1f8b616169e9 100644
--- a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx
+++ b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #include "tauRecTools/lwtnn/Stack.h"
@@ -424,7 +424,7 @@ namespace lwtDev {
   // __________________________________________________________________
   // Recurrent layers
 
-  EmbeddingLayer::EmbeddingLayer(int var_row_index, MatrixXd W):
+  EmbeddingLayer::EmbeddingLayer(int var_row_index, const MatrixXd & W):
     m_var_row_index(var_row_index),
     m_W(W)
   {
@@ -470,14 +470,14 @@ namespace lwtDev {
 
 
   // LSTM layer
-  LSTMLayer::LSTMLayer(ActivationConfig activation,
-                       ActivationConfig inner_activation,
-                       MatrixXd W_i, MatrixXd U_i, VectorXd b_i,
-                       MatrixXd W_f, MatrixXd U_f, VectorXd b_f,
-                       MatrixXd W_o, MatrixXd U_o, VectorXd b_o,
-                       MatrixXd W_c, MatrixXd U_c, VectorXd b_c,
-                       bool go_backwards,
-                       bool return_sequence):
+  LSTMLayer::LSTMLayer(const ActivationConfig & activation,
+              const ActivationConfig & inner_activation,
+              const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i,
+              const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f,
+              const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o,
+              const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c,
+              bool go_backwards,
+              bool return_sequence):
     m_W_i(W_i),
     m_U_i(U_i),
     m_b_i(b_i),
@@ -547,11 +547,11 @@ namespace lwtDev {
 
 
   // GRU layer
-  GRULayer::GRULayer(ActivationConfig activation,
-                     ActivationConfig inner_activation,
-                     MatrixXd W_z, MatrixXd U_z, VectorXd b_z,
-                     MatrixXd W_r, MatrixXd U_r, VectorXd b_r,
-                     MatrixXd W_h, MatrixXd U_h, VectorXd b_h):
+  GRULayer::GRULayer(const ActivationConfig & activation,
+                     const ActivationConfig & inner_activation,
+                     const MatrixXd & W_z, const  MatrixXd & U_z, const VectorXd & b_z,
+                     const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r,
+                     const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h):
     m_W_z(W_z),
     m_U_z(U_z),
     m_b_z(b_z),
@@ -621,8 +621,8 @@ namespace lwtDev {
   }
 
   MatrixXd BidirectionalLayer::scan( const MatrixXd& x) const{
-    MatrixXd forward = m_forward_layer->scan(x);
-    MatrixXd backward = m_backward_layer->scan(x);
+    const MatrixXd & forward = m_forward_layer->scan(x);
+    const MatrixXd & backward = m_backward_layer->scan(x);
     MatrixXd backward_rev;
     if (m_return_sequence){
       backward_rev = backward.rowwise().reverse();
diff --git a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx
index d4d5db31542ab15ccd7487f6dacb303eb48da1b3..0e78bf7d357e50f2de792c78d4ffb30d5469dab2 100644
--- a/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx
+++ b/Reconstruction/tauRecTools/src/components/tauRecTools_entries.cxx
@@ -25,6 +25,7 @@
 #include "tauRecTools/TauWPDecorator.h"
 #include "tauRecTools/TauIDVarCalculator.h"
 #include "tauRecTools/TauJetRNNEvaluator.h"
+#include "tauRecTools/TauGNNEvaluator.h"
 #include "tauRecTools/TauDecayModeNNClassifier.h"
 #include "tauRecTools/TauVertexedClusterDecorator.h"
 #include "tauRecTools/TauAODSelector.h"
@@ -59,6 +60,7 @@ DECLARE_COMPONENT( TauPi0Selector )
 DECLARE_COMPONENT( TauWPDecorator )
 DECLARE_COMPONENT( TauIDVarCalculator )
 DECLARE_COMPONENT( TauJetRNNEvaluator )
+DECLARE_COMPONENT( TauGNNEvaluator )
 DECLARE_COMPONENT( TauDecayModeNNClassifier )
 DECLARE_COMPONENT( TauVertexedClusterDecorator )
 DECLARE_COMPONENT( TauAODSelector )
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNN.h b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h
new file mode 100644
index 0000000000000000000000000000000000000000..e318762111a0edc19e187a7fddab22032a6738a4
--- /dev/null
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNN.h
@@ -0,0 +1,100 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef TAURECTOOLS_TAUGNN_H
+#define TAURECTOOLS_TAUGNN_H
+
+#include "xAODTau/TauJet.h"
+#include "xAODCaloEvent/CaloVertexedTopoCluster.h"
+
+#include "AsgMessaging/AsgMessaging.h"
+
+#include "FlavorTagDiscriminants/OnnxUtil.h"
+
+#include <memory>
+#include <string>
+#include <map>
+
+namespace TauGNNUtils {
+    class GNNVarCalc;
+}
+
+namespace FlavorTagDiscriminants{
+    class OnnxUtil;
+}
+
+/**
+ * @brief Wrapper around ONNXUtil to compute the output score of a model
+ *
+ *   Configures the network and computes the network outputs given the input
+ *   objects. Retrieval of input variables is handled internally.
+ *
+ * @author N.M. Tamir
+ *
+ */
+class TauGNN : public asg::AsgMessaging {
+public:
+    // Configuration of the weight file structure
+    struct Config {
+        std::string input_layer_scalar;
+        std::string input_layer_tracks;
+        std::string input_layer_clusters;
+        std::string output_node_tau;
+        std::string output_node_jet;
+    };
+    std::shared_ptr<const FlavorTagDiscriminants::OnnxUtil> m_onnxUtil;
+public:
+    TauGNN(const std::string &nnFile, const Config &config);
+    ~TauGNN();
+
+    // Output the OnnxUtil tuple 
+    std::tuple<
+        std::map<std::string, float>,
+        std::map<std::string, std::vector<char>>,
+        std::map<std::string, std::vector<float>> > 
+    compute(const xAOD::TauJet &tau,
+                  const std::vector<const xAOD::TauTrack *> &tracks,
+                  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters) const;
+
+    // Compute all input variables and store them in the maps that are passed by reference
+    bool calculateInputVariables(const xAOD::TauJet &tau,
+                  const std::vector<const xAOD::TauTrack *> &tracks,
+                  const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                  std::map<std::string, std::map<std::string, double>>& scalarInputs,
+                  std::map<std::string, std::map<std::string, std::vector<double>>>& vectorInputs) const;
+
+    // Getter for the variable calculator
+    const TauGNNUtils::GNNVarCalc* variable_calculator() const {
+        return m_var_calc.get();
+    }
+
+    //Make the output config transparent to external tools
+    FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config;
+
+private:
+    using Inputs = FlavorTagDiscriminants::Inputs;
+    // Abbreviations for lwtnn
+    using VariableMap = std::map<std::string, double>;
+    using VectorMap = std::map<std::string, std::vector<double>>;
+
+    using InputMap = std::map<std::string, VariableMap>;
+    using InputSequenceMap = std::map<std::string, VectorMap>;
+
+private:
+    const Config m_config;
+
+    // Names of the input variables
+    std::vector<std::string> m_scalar_inputs;
+    std::vector<std::string> m_track_inputs;
+    std::vector<std::string> m_cluster_inputs;
+    // Names passed to the variable calculator
+    std::vector<std::string> m_scalarCalc_inputs;
+    std::vector<std::string> m_trackCalc_inputs;
+    std::vector<std::string> m_clusterCalc_inputs;
+
+    // Variable calculator to calculate input variables on the fly
+    std::unique_ptr<TauGNNUtils::GNNVarCalc> m_var_calc;
+};
+
+#endif // TAURECTOOLS_TAUGNN_H
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h
new file mode 100644
index 0000000000000000000000000000000000000000..776a6a96bd700081c8b5e7df17f0eba8c7c98040
--- /dev/null
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h
@@ -0,0 +1,70 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef TAURECTOOLS_TAUGNNEVALUATOR_H
+#define TAURECTOOLS_TAUGNNEVALUATOR_H
+
+#include "tauRecTools/TauRecToolBase.h"
+
+#include "xAODTau/TauJet.h"
+#include "xAODCaloEvent/CaloVertexedTopoCluster.h"
+
+#include <memory>
+
+class TauGNN;
+
+/**
+ * @brief Tool to calculate tau identification score from .onnx inputs
+ *
+ *   The network configuration is supplied in .onnx format. 
+ *   Currently runs on a prongness-inclusive model 
+ *   Based off of TauJetRNNEvaluator.h format!
+ * @author N.M. Tamir
+ *
+ */
+class TauGNNEvaluator : public TauRecToolBase {
+public:
+    ASG_TOOL_CLASS2(TauGNNEvaluator, TauRecToolBase, ITauToolBase)
+
+    TauGNNEvaluator(const std::string &name = "TauGNNEvaluator");
+    virtual ~TauGNNEvaluator();
+
+    virtual StatusCode initialize() override;
+    virtual StatusCode execute(xAOD::TauJet &tau) const override;
+    // Getter for the underlying RNN implementation
+    const TauGNN* get_gnn() const;
+
+    // Selects tracks to be used as input to the network
+    StatusCode get_tracks(const xAOD::TauJet &tau,
+                          std::vector<const xAOD::TauTrack *> &out) const;
+
+    // Selects clusters to be used as input to the network
+    StatusCode get_clusters(const xAOD::TauJet &tau,
+                            std::vector<xAOD::CaloVertexedTopoCluster> &out) const;
+
+private:
+    std::string m_output_varname;
+    std::string m_output_ptau;
+    std::string m_output_pjet;
+    std::string m_weightfile;
+    int m_max_tracks;
+    int m_max_clusters;
+    float m_max_cluster_dr;
+    float m_minTauPt;
+    bool m_doVertexCorrection;
+    bool m_doTrackClassification;
+    bool m_decorateTracks;
+
+    // Configuration of the network file
+    std::string m_input_layer_scalar;
+    std::string m_input_layer_tracks;
+    std::string m_input_layer_clusters;
+    std::string m_outnode_tau;
+    std::string m_outnode_jet;
+
+    // Wrappers for lwtnn
+    std::unique_ptr<TauGNN> m_net; //!
+};
+
+#endif // TAURECTOOLS_TAUGNNEVALUATOR_H
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h
new file mode 100644
index 0000000000000000000000000000000000000000..7e4913018524c332970af7c8c55d3be1f64c6e98
--- /dev/null
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h
@@ -0,0 +1,317 @@
+/*
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef TAURECTOOLS_TAUGNNUTILS_H
+#define TAURECTOOLS_TAUGNNUTILS_H
+
+#include "xAODTau/TauJet.h"
+#include "xAODCaloEvent/CaloVertexedTopoCluster.h"
+#include "xAODEventInfo/EventInfo.h"
+#include "xAODTracking/VertexContainer.h"
+#include "AsgTools/AsgTool.h"
+#include "AsgMessaging/AsgMessaging.h"
+#include <unordered_map>
+
+
+namespace TauGNNUtils {
+
+/**
+ * @brief Tool to calculate input variables for the GNN-based tau identification
+ *
+ *   Used to calculate input variables for (onnx)GNN-based tau identification on
+ *   the fly by providing a mapping between variable names (strings) and
+ *   functions to calculate these variables.
+ *
+ * @author C. Deutsch
+ * @author W. Davey
+ * @author N.M. Tamir
+ *
+ */
+class GNNVarCalc : public asg::AsgMessaging {
+public:
+    // Pointers to calculator functions
+    using ScalarCalc = bool (*)(const xAOD::TauJet &, double &);
+
+    using TrackCalc = bool (*)(const xAOD::TauJet &, const xAOD::TauTrack &,
+                               double &);
+
+    using ClusterCalc = bool (*)(const xAOD::TauJet &,
+                                 const xAOD::CaloVertexedTopoCluster &, double &);
+
+public:
+    GNNVarCalc();
+    ~GNNVarCalc() = default;
+
+    // Methods to compute the output (vector) based on the variable name
+
+    // Computes high-level ID variables
+    bool compute(const std::string &name, const xAOD::TauJet &tau, double &out) const;
+
+    // Computes track variables
+    bool compute(const std::string &name, const xAOD::TauJet &tau,
+                 const std::vector<const xAOD::TauTrack *> &tracks,
+                 std::vector<double> &out) const;
+
+    // Computes cluster variables
+    bool compute(const std::string &name, const xAOD::TauJet &tau,
+                 const std::vector<xAOD::CaloVertexedTopoCluster> &clusters,
+                 std::vector<double> &out) const;
+
+    // Methods to insert calculator functions into the lookup table
+    void insert(const std::string &name, ScalarCalc func, const std::vector<std::string>& scalar_vars);
+    void insert(const std::string &name, TrackCalc func, const std::vector<std::string>& track_vars);
+    void insert(const std::string &name, ClusterCalc func, const std::vector<std::string>& cluster_vars);
+
+private:
+    // Lookup tables
+    std::unordered_map<std::string, ScalarCalc> m_scalar_map;
+    std::unordered_map<std::string, TrackCalc> m_track_map;
+    std::unordered_map<std::string, ClusterCalc> m_cluster_map;
+};
+
+// Factory function to create a variable calculator populated with default
+// variables
+std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scalar_vars,
+					const std::vector<std::string>& track_vars,
+					const std::vector<std::string>& cluster_vars);
+
+
+namespace Variables {
+
+// Functions to calculate (scalar) input variables
+// Returns a status code indicating success
+bool absEta(const xAOD::TauJet &tau, double &out);
+
+bool centFrac(const xAOD::TauJet &tau, double &out);
+
+bool isolFrac(const xAOD::TauJet &tau, double &out); 
+
+bool etOverPtLeadTrk(const xAOD::TauJet &tau, double &out);
+
+bool innerTrkAvgDist(const xAOD::TauJet &tau, double &out);
+
+bool absipSigLeadTrk(const xAOD::TauJet &tau, double &out);
+
+bool sumEMCellEtOverLeadTrkPt(const xAOD::TauJet &tau, double &out);
+
+bool SumPtTrkFrac(const xAOD::TauJet &tau, double &out);
+
+bool EMPOverTrkSysP(const xAOD::TauJet &tau, double &out);
+
+bool ptRatioEflowApprox(const xAOD::TauJet &tau, double &out);
+
+bool mEflowApprox(const xAOD::TauJet &tau, double &out);
+
+bool dRmax(const xAOD::TauJet &tau, double &out);
+
+bool trFlightPathSig(const xAOD::TauJet &tau, double &out);
+
+bool massTrkSys(const xAOD::TauJet &tau, double &out);
+
+bool pt(const xAOD::TauJet &tau, double &out);
+
+bool pt_tau_log(const xAOD::TauJet &tau, double &out);
+
+bool ptDetectorAxis(const xAOD::TauJet &tau, double &out);
+
+bool ptIntermediateAxis(const xAOD::TauJet &tau, double &out);
+
+//functions to calculate input variables needed for the eVeto RNN
+bool ptJetSeed_log             (const xAOD::TauJet &tau, double &out);
+bool absleadTrackEta           (const xAOD::TauJet &tau, double &out);
+bool leadTrackDeltaEta         (const xAOD::TauJet &tau, double &out);
+bool leadTrackDeltaPhi         (const xAOD::TauJet &tau, double &out);
+bool leadTrackProbNNorHT       (const xAOD::TauJet &tau, double &out);
+bool EMFracFixed               (const xAOD::TauJet &tau, double &out);
+bool etHotShotWinOverPtLeadTrk (const xAOD::TauJet &tau, double &out);
+bool hadLeakFracFixed          (const xAOD::TauJet &tau, double &out);
+bool PSFrac                    (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanCenterLambda  (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanEMProbability (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanFirstEngDens  (const xAOD::TauJet &tau, double &out);
+bool ClustersMeanPresamplerFrac(const xAOD::TauJet &tau, double &out);
+bool ClustersMeanSecondLambda  (const xAOD::TauJet &tau, double &out);
+bool EMPOverTrkSysP            (const xAOD::TauJet &tau, double &out);
+
+
+namespace Track {
+
+// Functions to calculate input variables for each track
+// Returns a status code indicating success
+
+bool pt_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool trackPt(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool trackEta(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool trackPhi(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+    
+bool pt_tau_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool pt_jetseed_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool d0_abs_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool z0sinThetaTJVA_abs_log(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool z0sinthetaTJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool z0sinthetaSigTJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool d0TJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool d0SigTJVA(
+    const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+bool dEta(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool dEtaJetSeedAxis(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool dPhi(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool dPhiJetSeedAxis(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nInnermostPixelHits(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nPixelHits(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nSCTHits(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+// trigger variants
+bool nIBLHitsAndExp (
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nPixelHitsPlusDeadSensors (
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool nSCTHitsPlusDeadSensors (
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool eProbabilityHT(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool eProbabilityNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool eProbabilityNNorHT(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool chargedScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool isolationScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool conversionScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+bool fakeScoreRNN(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
+//Extension - variables for GNTau
+bool numberOfInnermostPixelLayerHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfSCTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfSCTSharedHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfSCTDeadSensors(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfTRTHighThresholdHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfTRTHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool nSiHits(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool expectInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool expectNextToInnermostPixelLayerHit(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfContribPixelLayers(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool numberOfPixelHoles(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool d0_old(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool qOverP(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool theta(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool z0TJVA(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool charge(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool dz0_TV_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt2_TV(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+bool log_sumpt2_PV0(const xAOD::TauJet& tau, const xAOD::TauTrack &track, double &out);
+
+} // namespace Track
+
+
+namespace Cluster {
+
+// Functions to calculate input variables for each cluster
+// Returns a status code indicating success
+
+bool et_log(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool pt_tau_log(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool pt_jetseed_log(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool dEta(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool dPhi(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool SECOND_R(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool SECOND_LAMBDA(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool CENTER_LAMBDA(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool SECOND_LAMBDAOverClustersMeanSecondLambda(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool CENTER_LAMBDAOverClustersMeanCenterLambda(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool FirstEngDensOverClustersMeanFirstEngDens(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+//Extension - Variables for GNTau
+bool e(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool et(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool FIRST_ENG_DENS(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool EM_PROBABILITY(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+
+bool CENTER_MAG(
+    const xAOD::TauJet &tau, const xAOD::CaloVertexedTopoCluster &cluster, double &out);
+} // namespace Cluster
+} // namespace Variables
+} // namespace TauJetGNNUtils
+
+#endif // TAURECTOOLS_TAUGNNUTILS_H
diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h
index 802d17df2b021d1cc6ca38a73bb9a6c8a340bdc2..16d1b29d7acd4b28009e00826a76b2b902052d20 100644
--- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h
+++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #ifndef LIGHTWEIGHT_GRAPH_HH_TAURECTOOLS
@@ -72,7 +72,7 @@ namespace lwtDev {
     // define a "default" output, so that calling "compute" with no
     // output specified doesn't lead to ambiguity.
     LightweightGraph(const GraphConfig& config,
-                     std::string default_output = "");
+                     const std::string& default_output = "");
 
     ~LightweightGraph();
     LightweightGraph(LightweightGraph&) = delete;
diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h
index 69ea569fb0f3addeb7f7d4c7a5f3fe4bfd852036..84d55e46c9ad5f1a6875c595e6fb590e7ff2b08b 100644
--- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h
+++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #ifndef STACK_HH_TAURECTOOLS
@@ -222,7 +222,7 @@ namespace lwtDev {
   class EmbeddingLayer : public IRecurrentLayer
   {
   public:
-    EmbeddingLayer(int var_row_index, MatrixXd W);
+    EmbeddingLayer(int var_row_index, const MatrixXd & W);
     virtual ~EmbeddingLayer() {};
     virtual MatrixXd scan( const MatrixXd&) const override;
 
@@ -236,12 +236,12 @@ namespace lwtDev {
   class LSTMLayer : public IRecurrentLayer
   {
   public:
-    LSTMLayer(ActivationConfig activation,
-              ActivationConfig inner_activation,
-              MatrixXd W_i, MatrixXd U_i, VectorXd b_i,
-              MatrixXd W_f, MatrixXd U_f, VectorXd b_f,
-              MatrixXd W_o, MatrixXd U_o, VectorXd b_o,
-              MatrixXd W_c, MatrixXd U_c, VectorXd b_c,
+    LSTMLayer(const ActivationConfig & activation,
+              const ActivationConfig & inner_activation,
+              const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i,
+              const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f,
+              const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o,
+              const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c,
               bool go_backwards,
               bool return_sequence);
 
@@ -277,11 +277,11 @@ namespace lwtDev {
   class GRULayer : public IRecurrentLayer
   {
   public:
-    GRULayer(ActivationConfig activation,
-             ActivationConfig inner_activation,
-             MatrixXd W_z, MatrixXd U_z, VectorXd b_z,
-             MatrixXd W_r, MatrixXd U_r, VectorXd b_r,
-             MatrixXd W_h, MatrixXd U_h, VectorXd b_h);
+    GRULayer(const ActivationConfig & activation,
+                     const ActivationConfig & inner_activation,
+                     const MatrixXd & W_z, const  MatrixXd & U_z, const VectorXd & b_z,
+                     const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r,
+                     const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h);
 
     virtual ~GRULayer() {};
     virtual MatrixXd scan( const MatrixXd&) const override;