From 405dd5f288330dc4acaaab433f3418ba8fa29b08 Mon Sep 17 00:00:00 2001
From: Jean Yves Beaucamp <jean.yves.beaucamp@cern.ch>
Date: Tue, 5 Nov 2024 01:45:15 +0100
Subject: [PATCH] Small GNTau update

Small GNTau update
---
 .../tauRec/python/TauConfigFlags.py           | 20 +++++++++--
 Reconstruction/tauRec/python/TauToolHolder.py | 33 ++++++++----------
 Reconstruction/tauRecTools/Root/TauGNN.cxx    |  4 +--
 .../tauRecTools/Root/TauGNNEvaluator.cxx      |  9 +++--
 .../tauRecTools/Root/TauGNNUtils.cxx          | 14 ++++++++
 .../Root/lwtnn/LightweightGraph.cxx           |  4 +--
 .../tauRecTools/Root/lwtnn/Stack.cxx          | 34 +++++++++----------
 .../tauRecTools/tauRecTools/TauGNNEvaluator.h |  4 +--
 .../tauRecTools/tauRecTools/TauGNNUtils.h     |  6 ++++
 .../tauRecTools/lwtnn/LightweightGraph.h      |  4 +--
 .../tauRecTools/tauRecTools/lwtnn/Stack.h     | 26 +++++++-------
 11 files changed, 96 insertions(+), 62 deletions(-)

diff --git a/Reconstruction/tauRec/python/TauConfigFlags.py b/Reconstruction/tauRec/python/TauConfigFlags.py
index a0b3743ef4c6..ea91293697d6 100644
--- a/Reconstruction/tauRec/python/TauConfigFlags.py
+++ b/Reconstruction/tauRec/python/TauConfigFlags.py
@@ -62,8 +62,24 @@ def createTauConfigFlags():
     tau_cfg.addFlag("Tau.TauJetDeepSetConfig_v2", ["tauid_1p_R22_dpst_noTrackScore.json", "tauid_2p_R22_dpst_noTrackScore.json", "tauid_3p_R22_dpst_noTrackScore.json"])
     tau_cfg.addFlag("Tau.TauJetDeepSetWP_v2", ["model_1p_R22_dpst_noTrackScore.root", "model_2p_R22_dpst_noTrackScore.root", "model_3p_R22_dpst_noTrackScore.root"])
     # GNTau ID tune file (need to add another version for noAux)
-    tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_noAux_simplified.onnx"])
-    tau_cfg.addFlag("Tau.TauGNNWP_v0", ["GNTauNA_flat_model_1p.root", "GNTauNA_flat_model_2p.root", "GNTauNA_flat_model_3p.root"])
+    tau_cfg.addFlag("Tau.TauGNNConfig", ["GNTau_pruned_MC23.onnx","GNTau_trunc_MC23.onnx"])
+    tau_cfg.addFlag("Tau.TauGNNWP",
+                    [ 
+                        ["GNTauNAprune_flat_model_1p.root", "GNTauNAprune_flat_model_2p.root", "GNTauNAprune_flat_model_3p.root"],
+                        ["GNTauNAtrunc_flat_model_1p.root", "GNTauNAtrunc_flat_model_2p.root", "GNTauNAtrunc_flat_model_3p.root"]
+                    ])
+    tau_cfg.addFlag("Tau.GNTauScoreName", ["GNTauScore_v0prune","GNTauScore_v1trunc"])
+    tau_cfg.addFlag("Tau.GNTauTransScoreName", ["GNTauScoreSigTrans_v0prune","GNTauScoreSigTrans_v1trunc"])
+    tau_cfg.addFlag("Tau.GNTauMaxTracks", [30,10])
+    tau_cfg.addFlag("Tau.GNTauMaxClusters", [20,6])
+    tau_cfg.addFlag("Tau.GNTauNodeNameTau", "GN2TauNoAux_pb")
+    tau_cfg.addFlag("Tau.GNTauNodeNameJet", "GN2TauNoAux_pu")
+    tau_cfg.addFlag("Tau.GNTauDecorWPNames", 
+                    [
+                        ["GNTauVL_v0prune", "GNTauL_v0prune", "GNTauM_v0prune", "GNTauT_v0prune"],
+                        ["GNTauVL_v1trunc", "GNTauL_v1trunc", "GNTauM_v1trunc", "GNTauT_v1trunc"]
+                    ])
+
 
     # PanTau config flags
     from PanTauAlgs.PanTauConfigFlags import createPanTauConfigFlags
diff --git a/Reconstruction/tauRec/python/TauToolHolder.py b/Reconstruction/tauRec/python/TauToolHolder.py
index fd96be80c82e..b4bafbba684b 100644
--- a/Reconstruction/tauRec/python/TauToolHolder.py
+++ b/Reconstruction/tauRec/python/TauToolHolder.py
@@ -851,19 +851,19 @@ def TauWPDecoratorJetDeepSetCfg(flags, version=None):
     result.setPrivateTools(myTauWPDecorator)
     return result
 
-def TauGNNEvaluatorCfg(flags):
+def TauGNNEvaluatorCfg(flags, version=0):
     result = ComponentAccumulator()
-    _name = flags.Tau.ActiveConfig.prefix + 'TauGNN'
+    _name = flags.Tau.ActiveConfig.prefix + 'TauGNN_v' + str(version)
 
     TauGNNEvaluator = CompFactory.getComp("TauGNNEvaluator")
-    GNNConf = flags.Tau.TauGNNConfig
+    GNNConf = flags.Tau.TauGNNConfig[version]
     myTauGNNEvaluator = TauGNNEvaluator(name = _name,
-                                              NetworkFile = GNNConf[0],
-                                              OutputVarname = "GNTauScore",
+                                              NetworkFile = GNNConf,
+                                              OutputVarname = flags.Tau.GNTauScoreName[version],
                                               OutputPTau = "GNTauProbTau",
                                               OutputPJet = "GNTauProbJet",
-                                              MaxTracks = 30,
-                                              MaxClusters = 20,
+                                              MaxTracks = flags.Tau.GNTauMaxTracks[version], 
+                                              MaxClusters = flags.Tau.GNTauMaxClusters[version],
                                               MaxClusterDR = 15.0,
                                               MinTauPt = flags.Tau.MinPtDAOD,
                                               VertexCorrection = True,
@@ -871,31 +871,28 @@ def TauGNNEvaluatorCfg(flags):
                                               InputLayerScalar = "tau_vars",
                                               InputLayerTracks = "track_vars",
                                               InputLayerClusters = "cluster_vars",
-                                              NodeNameTau="GN2TauNoAux_pb",
-                                              NodeNameJet="GN2TauNoAux_pu")
+                                              NodeNameTau=flags.Tau.GNTauNodeNameTau,
+                                              NodeNameJet=flags.Tau.GNTauNodeNameJet)
 
     result.setPrivateTools(myTauGNNEvaluator)
     return result
 
-def TauWPDecoratorGNNCfg(flags):
+def TauWPDecoratorGNNCfg(flags, version):
     result = ComponentAccumulator()
-    _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN'
+    _name = flags.Tau.ActiveConfig.prefix + 'TauWPDecoratorGNN_v' + str(version)
 
     TauWPDecorator = CompFactory.getComp("TauWPDecorator")
-    WPConf = flags.Tau.TauGNNWP_v0
-    decorWPNames = ["GNTauVL_v0", "GNTauL_v0", "GNTauM_v0", "GNTauT_v0"]
-    scoreName = "GNTauScore"
-    newScoreName = "GNTauScoreSigTrans_v0"
+    WPConf = flags.Tau.TauGNNWP[version]
     myTauWPDecorator = TauWPDecorator(name=_name,
                                       flatteningFile1Prong = WPConf[0],
                                       flatteningFile2Prong = WPConf[1],
                                       flatteningFile3Prong = WPConf[2],
-                                      DecorWPNames = decorWPNames,
+                                      DecorWPNames = flags.Tau.GNTauDecorWPNames[version],
                                       DecorWPCutEffs1P = [0.95, 0.85, 0.75, 0.60],
                                       DecorWPCutEffs2P = [0.95, 0.75, 0.60, 0.45],
                                       DecorWPCutEffs3P = [0.95, 0.75, 0.60, 0.45],
-                                      ScoreName = scoreName,
-                                      NewScoreName = newScoreName,
+                                      ScoreName = flags.Tau.GNTauScoreName[version],
+                                      NewScoreName = flags.Tau.GNTauTransScoreName[version],
                                       DefineWPs = True)
     result.setPrivateTools(myTauWPDecorator)
     return result
diff --git a/Reconstruction/tauRecTools/Root/TauGNN.cxx b/Reconstruction/tauRecTools/Root/TauGNN.cxx
index 971372e865b0..08dfdc7fac4e 100644
--- a/Reconstruction/tauRecTools/Root/TauGNN.cxx
+++ b/Reconstruction/tauRecTools/Root/TauGNN.cxx
@@ -14,14 +14,12 @@
 
 TauGNN::TauGNN(const std::string &nnFile, const Config &config):
     asg::AsgMessaging("TauGNN"),
-    m_onnxUtil(nullptr)
+    m_onnxUtil(std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile))
   {
     //==================================================//
     // This part is ported from FTagDiscriminant GNN.cxx//
     //==================================================//
 
-    m_onnxUtil = std::make_shared<FlavorTagDiscriminants::OnnxUtil>(nnFile);
-
     // get the configuration of the model outputs
     FlavorTagDiscriminants::OnnxUtil::OutputConfig gnn_output_config = m_onnxUtil->getOutputConfig();
     
diff --git a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx
index a1346d67f22b..bc6b27d0b346 100644
--- a/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx
+++ b/Reconstruction/tauRecTools/Root/TauGNNEvaluator.cxx
@@ -38,7 +38,7 @@ TauGNNEvaluator::TauGNNEvaluator(const std::string &name):
 TauGNNEvaluator::~TauGNNEvaluator() {}
 
 StatusCode TauGNNEvaluator::initialize() {
-  ATH_MSG_INFO("Initializing TauGNNEvaluator");
+  ATH_MSG_INFO("Initializing TauGNNEvaluator with "<<m_max_tracks<<" tracks and "<<m_max_clusters<<" clusters...");
   
   std::string weightfile("");
 
@@ -90,13 +90,16 @@ StatusCode TauGNNEvaluator::execute(xAOD::TauJet &tau) const {
   }
 
   // Get input objects
+  ATH_MSG_DEBUG("Fetching Tracks");
   std::vector<const xAOD::TauTrack *> tracks;
   ATH_CHECK(get_tracks(tau, tracks));
+  ATH_MSG_DEBUG("Fetching clusters");
   std::vector<xAOD::CaloVertexedTopoCluster> clusters;
   ATH_CHECK(get_clusters(tau, clusters));
+  ATH_MSG_DEBUG("Constituent fetching done...");
 
   // Truncate tracks
-  int numTracksMax = std::min(m_max_tracks, tracks.size());
+  int numTracksMax = std::min(m_max_tracks, static_cast<int>(tracks.size()));
   std::vector<const xAOD::TauTrack *> trackVec(tracks.begin(), tracks.begin()+numTracksMax);
   // Evaluate networks
   if (m_net) {
@@ -168,7 +171,7 @@ StatusCode TauGNNEvaluator::get_clusters(const xAOD::TauJet &tau, std::vector<xA
   std::sort(clusters.begin(), clusters.end(), et_cmp);
 
   // Truncate clusters
-  if (clusters.size() > m_max_clusters) {
+  if (static_cast<int>(clusters.size()) > m_max_clusters) {
     clusters.resize(m_max_clusters, clusters[0]);
   }
 
diff --git a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx
index a372e0c019bf..32b4bc2afb2b 100644
--- a/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx
+++ b/Reconstruction/tauRecTools/Root/TauGNNUtils.cxx
@@ -162,7 +162,9 @@ std::unique_ptr<GNNVarCalc> get_calculator(const std::vector<std::string>& scala
     calc->insert("d0TJVA", Variables::Track::d0TJVA, track_vars);
     calc->insert("d0SigTJVA", Variables::Track::d0SigTJVA, track_vars);
     calc->insert("dEta", Variables::Track::dEta, track_vars);
+    calc->insert("dEtaJetSeedAxis", Variables::Track::dEtaJetSeedAxis, track_vars);
     calc->insert("dPhi", Variables::Track::dPhi, track_vars);
+    calc->insert("dPhiJetSeedAxis", Variables::Track::dPhiJetSeedAxis, track_vars);
     calc->insert("nInnermostPixelHits", Variables::Track::nInnermostPixelHits, track_vars);
     calc->insert("nPixelHits", Variables::Track::nPixelHits, track_vars);
     calc->insert("nSCTHits", Variables::Track::nSCTHits, track_vars);
@@ -527,11 +529,23 @@ bool dEta(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
     return true;
 }
 
+bool dEtaJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed);
+    out = std::abs(tlvSeedJet.Eta() - track.eta());
+    return true;
+}
+
 bool dPhi(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
     out = track.p4().DeltaPhi(tau.p4());
     return true;
 }
 
+bool dPhiJetSeedAxis(const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out) {
+    TLorentzVector tlvSeedJet = tau.p4(xAOD::TauJetParameters::JetSeed);
+    out = tlvSeedJet.DeltaPhi(track.p4());
+    return true;
+}
+
 bool nInnermostPixelHits(const xAOD::TauJet& /*tau*/, const xAOD::TauTrack &track, double &out) {
     uint8_t inner_pixel_hits;
     const auto success = track.track()->summaryValue(inner_pixel_hits, xAOD::numberOfInnermostPixelLayerHits);
diff --git a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx
index 94a4fd2a6c11..d6b83c9a7c56 100644
--- a/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx
+++ b/Reconstruction/tauRecTools/Root/lwtnn/LightweightGraph.cxx
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #include "tauRecTools/lwtnn/LightweightGraph.h"
@@ -66,7 +66,7 @@ namespace lwtDev {
 
   typedef LightweightGraph::NodeMap NodeMap;
   LightweightGraph::LightweightGraph(const GraphConfig& config,
-                                     std::string default_output):
+                                     const std::string& default_output):
     m_graph(new Graph(config.nodes, config.layers))
   {
     for (const auto& node: config.inputs) {
diff --git a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx
index ad895f07649f..fff78982eb2f 100644
--- a/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx
+++ b/Reconstruction/tauRecTools/Root/lwtnn/Stack.cxx
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #include "tauRecTools/lwtnn/Stack.h"
@@ -424,7 +424,7 @@ namespace lwtDev {
   // __________________________________________________________________
   // Recurrent layers
 
-  EmbeddingLayer::EmbeddingLayer(int var_row_index, MatrixXd W):
+  EmbeddingLayer::EmbeddingLayer(int var_row_index, const MatrixXd & W):
     m_var_row_index(var_row_index),
     m_W(W)
   {
@@ -470,14 +470,14 @@ namespace lwtDev {
 
 
   // LSTM layer
-  LSTMLayer::LSTMLayer(ActivationConfig activation,
-                       ActivationConfig inner_activation,
-                       MatrixXd W_i, MatrixXd U_i, VectorXd b_i,
-                       MatrixXd W_f, MatrixXd U_f, VectorXd b_f,
-                       MatrixXd W_o, MatrixXd U_o, VectorXd b_o,
-                       MatrixXd W_c, MatrixXd U_c, VectorXd b_c,
-                       bool go_backwards,
-                       bool return_sequence):
+  LSTMLayer::LSTMLayer(const ActivationConfig & activation,
+              const ActivationConfig & inner_activation,
+              const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i,
+              const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f,
+              const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o,
+              const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c,
+              bool go_backwards,
+              bool return_sequence):
     m_W_i(W_i),
     m_U_i(U_i),
     m_b_i(b_i),
@@ -547,11 +547,11 @@ namespace lwtDev {
 
 
   // GRU layer
-  GRULayer::GRULayer(ActivationConfig activation,
-                     ActivationConfig inner_activation,
-                     MatrixXd W_z, MatrixXd U_z, VectorXd b_z,
-                     MatrixXd W_r, MatrixXd U_r, VectorXd b_r,
-                     MatrixXd W_h, MatrixXd U_h, VectorXd b_h):
+  GRULayer::GRULayer(const ActivationConfig & activation,
+                     const ActivationConfig & inner_activation,
+                     const MatrixXd & W_z, const  MatrixXd & U_z, const VectorXd & b_z,
+                     const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r,
+                     const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h):
     m_W_z(W_z),
     m_U_z(U_z),
     m_b_z(b_z),
@@ -621,8 +621,8 @@ namespace lwtDev {
   }
 
   MatrixXd BidirectionalLayer::scan( const MatrixXd& x) const{
-    MatrixXd forward = m_forward_layer->scan(x);
-    MatrixXd backward = m_backward_layer->scan(x);
+    const MatrixXd & forward = m_forward_layer->scan(x);
+    const MatrixXd & backward = m_backward_layer->scan(x);
     MatrixXd backward_rev;
     if (m_return_sequence){
       backward_rev = backward.rowwise().reverse();
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h
index baa5d54a181b..776a6a96bd70 100644
--- a/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNEvaluator.h
@@ -48,8 +48,8 @@ private:
     std::string m_output_ptau;
     std::string m_output_pjet;
     std::string m_weightfile;
-    std::size_t m_max_tracks;
-    std::size_t m_max_clusters;
+    int m_max_tracks;
+    int m_max_clusters;
     float m_max_cluster_dr;
     float m_minTauPt;
     bool m_doVertexCorrection;
diff --git a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h
index 28a983dd5e8a..7e4913018524 100644
--- a/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h
+++ b/Reconstruction/tauRecTools/tauRecTools/TauGNNUtils.h
@@ -179,9 +179,15 @@ bool d0SigTJVA(
 bool dEta(
     const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
 
+bool dEtaJetSeedAxis(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
 bool dPhi(
     const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
 
+bool dPhiJetSeedAxis(
+    const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
+
 bool nInnermostPixelHits(
     const xAOD::TauJet &tau, const xAOD::TauTrack &track, double &out);
 
diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h
index 802d17df2b02..16d1b29d7acd 100644
--- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h
+++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/LightweightGraph.h
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #ifndef LIGHTWEIGHT_GRAPH_HH_TAURECTOOLS
@@ -72,7 +72,7 @@ namespace lwtDev {
     // define a "default" output, so that calling "compute" with no
     // output specified doesn't lead to ambiguity.
     LightweightGraph(const GraphConfig& config,
-                     std::string default_output = "");
+                     const std::string& default_output = "");
 
     ~LightweightGraph();
     LightweightGraph(LightweightGraph&) = delete;
diff --git a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h
index 69ea569fb0f3..84d55e46c9ad 100644
--- a/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h
+++ b/Reconstruction/tauRecTools/tauRecTools/lwtnn/Stack.h
@@ -1,5 +1,5 @@
 /*
-  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
+  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
 */
 
 #ifndef STACK_HH_TAURECTOOLS
@@ -222,7 +222,7 @@ namespace lwtDev {
   class EmbeddingLayer : public IRecurrentLayer
   {
   public:
-    EmbeddingLayer(int var_row_index, MatrixXd W);
+    EmbeddingLayer(int var_row_index, const MatrixXd & W);
     virtual ~EmbeddingLayer() {};
     virtual MatrixXd scan( const MatrixXd&) const override;
 
@@ -236,12 +236,12 @@ namespace lwtDev {
   class LSTMLayer : public IRecurrentLayer
   {
   public:
-    LSTMLayer(ActivationConfig activation,
-              ActivationConfig inner_activation,
-              MatrixXd W_i, MatrixXd U_i, VectorXd b_i,
-              MatrixXd W_f, MatrixXd U_f, VectorXd b_f,
-              MatrixXd W_o, MatrixXd U_o, VectorXd b_o,
-              MatrixXd W_c, MatrixXd U_c, VectorXd b_c,
+    LSTMLayer(const ActivationConfig & activation,
+              const ActivationConfig & inner_activation,
+              const MatrixXd & W_i, const MatrixXd & U_i, const VectorXd & b_i,
+              const MatrixXd & W_f, const MatrixXd & U_f, const VectorXd & b_f,
+              const MatrixXd & W_o, const MatrixXd & U_o, const VectorXd & b_o,
+              const MatrixXd & W_c, const MatrixXd & U_c, const VectorXd & b_c,
               bool go_backwards,
               bool return_sequence);
 
@@ -277,11 +277,11 @@ namespace lwtDev {
   class GRULayer : public IRecurrentLayer
   {
   public:
-    GRULayer(ActivationConfig activation,
-             ActivationConfig inner_activation,
-             MatrixXd W_z, MatrixXd U_z, VectorXd b_z,
-             MatrixXd W_r, MatrixXd U_r, VectorXd b_r,
-             MatrixXd W_h, MatrixXd U_h, VectorXd b_h);
+    GRULayer(const ActivationConfig & activation,
+                     const ActivationConfig & inner_activation,
+                     const MatrixXd & W_z, const  MatrixXd & U_z, const VectorXd & b_z,
+                     const MatrixXd & W_r, const MatrixXd & U_r, const VectorXd & b_r,
+                     const MatrixXd & W_h, const MatrixXd & U_h, const VectorXd & b_h);
 
     virtual ~GRULayer() {};
     virtual MatrixXd scan( const MatrixXd&) const override;
-- 
GitLab