diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt b/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt
index cce538a2ded0d823f829e60ba43fb513e2e6dd0f..8f2ad316714b4248e5ed7b1dd3a3462082b4f62a 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt
@@ -5,6 +5,7 @@ atlas_subdir( SiClusterizationTool )
 
 # External dependencies:
 find_package( lwtnn )
+find_package( Eigen )
 find_package( CLHEP )
 find_package( ROOT COMPONENTS Core MathCore Hist )
 find_package( COOL COMPONENTS CoolKernel CoolApplication )
@@ -13,10 +14,10 @@ find_package( COOL COMPONENTS CoolKernel CoolApplication )
 atlas_add_library( SiClusterizationToolLib
    SiClusterizationTool/*.h src/*.cxx
    PUBLIC_HEADERS SiClusterizationTool
-   INCLUDE_DIRS ${ROOT_INCLUDE_DIRS} ${LWTNN_INCLUDE_DIRS}
+   INCLUDE_DIRS ${ROOT_INCLUDE_DIRS} ${LWTNN_INCLUDE_DIRS} ${EIGEN_INCLUDE_DIRS}
    PRIVATE_INCLUDE_DIRS ${CLHEP_INCLUDE_DIRS} ${COOL_INCLUDE_DIRS}
    PRIVATE_DEFINITIONS ${CLHEP_DEFINITIONS}
-   LINK_LIBRARIES ${CLHEP_LIBRARIES} ${LWTNN_LIBRARIES} ${ROOT_LIBRARIES} AthenaBaseComps AthenaKernel BeamSpotConditionsData EventPrimitives GaudiKernel GeoPrimitives Identifier InDetCondTools InDetConditionsSummaryService InDetIdentifier InDetPrepRawData InDetRawData InDetReadoutGeometry InDetRecToolInterfaces InDetSimData PixelCablingLib PixelConditionsData PixelGeoModelLib PoolSvcLib StoreGateLib TrkNeuralNetworkUtilsLib TrkParameters TrkSurfaces
+   LINK_LIBRARIES ${CLHEP_LIBRARIES} ${LWTNN_LIBRARIES} ${EIGEN_LIBRARIES} ${ROOT_LIBRARIES} AthenaBaseComps AthenaKernel BeamSpotConditionsData EventPrimitives GaudiKernel GeoPrimitives Identifier InDetCondTools InDetConditionsSummaryService InDetIdentifier InDetPrepRawData InDetRawData InDetReadoutGeometry InDetRecToolInterfaces InDetSimData PixelCablingLib PixelConditionsData PixelGeoModelLib PoolSvcLib StoreGateLib TrkNeuralNetworkUtilsLib TrkParameters TrkSurfaces LwtnnUtils
    PRIVATE_LINK_LIBRARIES ${Boost_LIBRARIES} ${COOL_LIBRARIES} AthenaPoolUtilities AtlasDetDescr AtlasHepMCLib DetDescrCondToolsLib FileCatalog PixelReadoutGeometry SCT_ReadoutGeometry TrkEventPrimitives VxVertex )
 
 atlas_add_component( SiClusterizationTool
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h
index 885b24a055afa75a0856b938d9c8b7c4ea294759..fb0e97ed20946f23d861a4074ef88ee03b0622a3 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h
@@ -6,10 +6,10 @@
 #define _LWTNNCollection_H_
 
 //#include <vector>
-#include "lwtnn/LightweightGraph.hh"
+#include "LwtnnUtils/FastGraph.h"
 
 class LWTNNCollection 
-  :  public std::map<int, std::unique_ptr<lwt::LightweightGraph> >
+  :  public std::map<int, std::unique_ptr<lwt::atlas::FastGraph> >
 {
 public:
 
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
index 2a84f885ef5f658a3da0cc6bbf41f31a2a62c812..8c375be5135ba492db3aa4e1dc4a4cd339c31f19 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
@@ -43,6 +43,9 @@
 #include "PixelConditionsData/PixelChargeCalibCondData.h"
 #include "StoreGate/ReadCondHandleKey.h"
 
+#include <Eigen/Dense>
+
+
  class TTrainedNetwork;
  class TH1;
  class ICoolHistSvc;
@@ -102,37 +105,29 @@ namespace InDet {
     virtual StatusCode finalize() { return StatusCode::SUCCESS; };
 
     std::vector<double> estimateNumberOfParticles(const InDet::PixelCluster& pCluster,
-                                                  Amg::Vector3D & beamSpotPosition,
-                                                  int sizeX=7,
-                                                  int sizeY=7) const;
+                                                  Amg::Vector3D & beamSpotPosition) const;
 
     std::vector<double> estimateNumberOfParticles(const InDet::PixelCluster& pCluster,
                                                   const Trk::Surface& pixelSurface,
-                                                  const Trk::TrackParameters& trackParsAtSurface,
-                                                  int sizeX=7,
-                                                  int sizeY=7) const;
+                                                  const Trk::TrackParameters& trackParsAtSurface) const;
 
     /* Public-facing method 1: no track parameters */
     std::vector<Amg::Vector2D> estimatePositions(const InDet::PixelCluster& pCluster,
                                                       Amg::Vector3D & beamSpotPosition,
                                                       std::vector<Amg::MatrixX> & errors,
-                                                      int numberSubClusters,
-                                                      int sizeX=7,
-                                                      int sizeY=7) const;
+                                                      int numberSubClusters) const;
 
     /* Public-facing method 1: with track parameters */
     std::vector<Amg::Vector2D> estimatePositions(const InDet::PixelCluster& pCluster,
                                                       const Trk::Surface& pixelSurface,
                                                       const Trk::TrackParameters& trackParsAtSurface,
                                                       std::vector<Amg::MatrixX> & errors,
-                                                      int numberSubClusters,
-                                                      int sizeX=7,
-                                                      int sizeY=7) const;
+                                                      int numberSubClusters) const;
 
    private:
 
     // Handling lwtnn inputs
-    typedef std::map<std::string, std::map<std::string, double> > InputMap;
+    typedef std::vector<Eigen::VectorXd> InputVector;
 
     /* Estimate number of particles for both with and w/o tracks */
     /* Method 1: using older TTrainedNetworks */
@@ -140,8 +135,9 @@ namespace InDet {
                                                      std::vector<double> inputData) const;
 
     /* Estimate number of particles for both with and w/o tracks */
-    /* Method 2: using lwtnn for more flexible interfacing */
-    std::vector<double> estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const;
+    /* Method 2: using lwtnn for more flexible interfacing with an ordered vector
+     * Vector order MUST match variable order. */
+    std::vector<double> estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputVector & input) const;
 
     /* Estimate position for both with and w/o tracks */
     /* Method 1: using older TTrainedNetworks */
@@ -150,15 +146,14 @@ namespace InDet {
                                                  const std::vector<double>& inputData,
                                                  const NNinput& input,
                                                  const InDet::PixelCluster& pCluster,
-                                                 int sizeX,
-                                                 int sizeY,
                                                  int numberSubClusters,
                                                  std::vector<Amg::MatrixX> & errors) const;
 
     /* Estimate position for both with and w/o tracks */
-    /* Method 2: using lwtnn for more flexible interfacing */
+    /* Method 2: using lwtnn for more flexible interfacing with an ordered vector
+     * Vector order MUST match variable order. */
     std::vector<Amg::Vector2D> estimatePositionsLWTNN(
-                                                NnClusterizationFactory::InputMap & input, 
+                                                NnClusterizationFactory::InputVector & input, 
                                                 NNinput& rawInput,
                                                 const InDet::PixelCluster& pCluster,
                                                 int numberSubClusters,
@@ -167,14 +162,12 @@ namespace InDet {
     // For error formatting in lwtnn cases
     double correctedRMSX(double posPixels) const;
 
-    double correctedRMSY(double posPixels, double sizeY, std::vector<float>& pitches) const; 
+    double correctedRMSY(double posPixels, std::vector<float>& pitches) const; 
 
      /* algorithmic component */
     NNinput createInput(const InDet::PixelCluster& pCluster,
                         Amg::Vector3D & beamSpotPosition,
-                        double & tanl,
-                        int sizeX=7,
-                        int sizeY=7) const;
+                        double & tanl) const;
 
     void addTrackInfoToInput(NNinput& input,
                              const Trk::Surface& pixelSurface,
@@ -182,23 +175,16 @@ namespace InDet {
                              const double tanl) const;
 
 
-  std::vector<double> assembleInputRunI(NNinput& input,
-                                      int sizeX,
-                                      int sizeY) const;
+  std::vector<double> assembleInputRunI(NNinput& input) const;
 
 
+  std::vector<double> assembleInputRunII(NNinput& input) const;
 
-  std::vector<double> assembleInputRunII(NNinput& input,
-                                      int sizeX,
-                                      int sizeY) const;
-
-    InputMap flattenInput(NNinput & input) const;
+    InputVector eigenInput(NNinput & input) const;
 
     std::vector<Amg::Vector2D> getPositionsFromOutput(std::vector<double> & output,
                                                       const NNinput & input,
-                                                      const InDet::PixelCluster& pCluster,
-                                                      int sizeX=7,
-                                                      int sizeY=7) const;
+                                                      const InDet::PixelCluster& pCluster) const;
 
 
     void getErrorMatrixFromOutput(std::vector<double>& outputX,
@@ -239,7 +225,7 @@ namespace InDet {
     std::vector< std::vector<unsigned int> > m_NNId;
 
     // Function to be called to assemble the inputs
-    std::vector<double> (InDet::NnClusterizationFactory:: *m_assembleInput)(NNinput& input,int sizeX, int sizeY) const {&NnClusterizationFactory::assembleInputRunII};
+    std::vector<double> (InDet::NnClusterizationFactory:: *m_assembleInput)(NNinput& input) const {&NnClusterizationFactory::assembleInputRunII};
 
     // Function to be called to compute the output
     std::vector<Double_t> (::TTrainedNetwork:: *m_calculateOutput)(const std::vector<Double_t> &input) const {&TTrainedNetwork::calculateNormalized};
@@ -264,6 +250,22 @@ namespace InDet {
        {this, "NnCollectionJSONReadKey", "PixelClusterNNJSON",
         "The conditions key for the pixel cluster NNs configured via JSON file and accessed with lwtnn"};
 
+    //  this is written into the JSON config "node_index"
+    //  this can be found from the LWTNN GraphConfig object used to initalize the collection objects
+    //     opiton size_t index = graph_config.outputs.at("output_node_name").node_index
+    //   
+    Gaudi::Property< std::size_t > m_outputNodesPos1
+    {this, "OutputNodePos1", 7,
+        "Output node for the 1 position networks (LWTNN)"};
+
+    Gaudi::Property< std::vector<std::size_t> > m_outputNodesPos2
+    {this, "OutputNodePos2", { 10, 11 },
+        "List of output nodes for the 2 position network (LWTNN)"};
+
+    Gaudi::Property< std::vector<std::size_t> > m_outputNodesPos3
+    {this, "OutputNodePos3", { 13, 14, 15 },
+        "List of output nodes for the 3 position networks (LWTNN)"};
+
     Gaudi::Property<unsigned int> m_maxSubClusters
        {this, "MaxSubClusters", 3, "Maximum number of sub cluster supported by the networks." };
 
@@ -291,6 +293,11 @@ namespace InDet {
     Gaudi::Property<bool> m_useRecenteringNNWithTracks
        {this, "useRecenteringNNWithTracks",false,"Recenter x position when evaluating NN with track input."};
 
+    Gaudi::Property<int> m_sizeX
+       {this, "sizeX",7,"Size of pixel matrix along X"};
+
+    Gaudi::Property<int> m_sizeY
+       {this, "sizeY",7,"Size of pixel matrix along Y"};
 
    };
 
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx
index 994403f4fc4d075d43080eeb1edb4a3bb4f05100..43d1533eadae55eba743b814e487d43e1cf93a5d 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx
@@ -5,6 +5,7 @@
  *   */
 
 #include "LWTNNCondAlg.h"
+#include "LwtnnUtils/InputOrder.h"
 
 #include "AthenaPoolUtilities/CondAttrListCollection.h"
 #include "AthenaPoolUtilities/AthenaAttributeList.h"
@@ -53,7 +54,7 @@ namespace InDet {
     return StatusCode::SUCCESS;
   }
 
-  StatusCode LWTNNCondAlg::configureLwtnn(std::unique_ptr<lwt::LightweightGraph> & thisNN, 
+  StatusCode LWTNNCondAlg::configureLwtnn(std::unique_ptr<lwt::atlas::FastGraph> & thisNN, 
                                         const std::string& thisJson) {
 
     // Read DNN weights from input json config
@@ -66,9 +67,15 @@ namespace InDet {
       return StatusCode::FAILURE;
     }
 
+    // pass the input order for the FastGraph
+    lwt::atlas::InputOrder order;
+    order.scalar.push_back( std::make_pair("NNinputs", m_variableOrder) );
+    // sequence not needed for NN (more for RNN, but set anyway)
+    order.sequence.push_back( std::make_pair("NNinputs", m_variableOrder) );
+
     // Build the network
     try {
-      thisNN.reset(new lwt::LightweightGraph(config, "merge_1"));
+      thisNN.reset(new lwt::atlas::FastGraph(config, order, "merge_1"));
     } catch (lwt::NNConfigurationException& exc) {
       ATH_MSG_ERROR("NN configuration problem: " << exc.what());
       return StatusCode::FAILURE;
@@ -128,7 +135,7 @@ namespace InDet {
 
     // First, extract configuration for the number network.
     pt::ptree subtreeNumberNetwork = parentTree.get_child("NumberNetwork");
-    writeCdo->insert(std::make_pair(0,std::unique_ptr<lwt::LightweightGraph>(nullptr)));
+    writeCdo->insert(std::make_pair(0,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
     // If this json is empty, just fill a null pointer.
     if(subtreeNumberNetwork.empty()) {
       ATH_MSG_INFO("Not using lwtnn for number network.");
@@ -153,7 +160,7 @@ namespace InDet {
       std::string posNetworkConfig = configStream.str();
       
       // Put a lwt network into the map
-      writeCdo->insert(std::make_pair(i,std::unique_ptr<lwt::LightweightGraph>(nullptr)));
+      writeCdo->insert(std::make_pair(i,std::unique_ptr<lwt::atlas::FastGraph>(nullptr)));
 
       // Now do empty check: if any one of these is empty we won't use lwtnn
       if(subtreePosNetwork.empty()) {
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h
index 2678f02c10395a614100d4c96eeac5aca60d08d4..515fba6eb94dbdeb0ce9c4992a2f7cf5fa39f670 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h
@@ -21,7 +21,8 @@ class IPoolSvc;
 
 namespace lwt {
   class NanReplacer;    
-  class LightweightGraph;
+  //class LightweightGraph;
+  namespace atlas { class FastGraph; }
 }
 
 namespace InDet {
@@ -44,12 +45,9 @@ class LWTNNCondAlg : public AthAlgorithm {
 
   ServiceHandle<ICondSvc> m_condSvc
     {this, "CondSvc", "CondSvc", "The conditions service to register new conditions data."};
-/*  ServiceHandle<IPoolSvc> m_poolsvc
-    {this, "PoolSvc", "PoolSvc", "The service to retrieve files by GUID."};
-  ToolHandle<Trk::NeuralNetworkToHistoTool> m_networkToHistoTool
-    {this,"NetworkToHistoTool", "Trk::NeuralNetworkToHistoTool/NeuralNetworkToHistoTool", "Tool to create a neural network from a set of histograms." };
-*/
-  StatusCode configureLwtnn(std::unique_ptr<lwt::LightweightGraph> & thisNN, const std::string& thisJson);
+
+  //StatusCode configureLwtnn(std::unique_ptr<lwt::LightweightGraph> & thisNN, const std::string& thisJson);
+  StatusCode configureLwtnn(std::unique_ptr<lwt::atlas::FastGraph> & thisNN, const std::string& thisJson);
 
   SG::ReadCondHandleKey<CondAttrListCollection> m_readKey
     {this, "ReadKey", "/PIXEL/PixelClustering/PixelNNCalibJSON", "Cool folder name for the cluster NN input histogram file."};
@@ -57,13 +55,72 @@ class LWTNNCondAlg : public AthAlgorithm {
   SG::WriteCondHandleKey<LWTNNCollection> m_writeKey
     {this, "WriteKey", "PixelClusterNNJSON", "The conditions statore key for the pixel cluster NNs"};
 
-  Gaudi::Property< std::vector<std::string> > m_nnOrder
-    {this, "NetworkNames", {
-          "NumberNetwork",
-          "PositionNetwork_N1",
-          "PositionNetwork_N2",
-          "PositionNetwork_N3"},
-        "List of network names, which are indexe in map in this order"};
+  // as of now, the number and position networks all use the same variables
+  // only need one of these
+  Gaudi::Property< std::vector<std::string> > m_variableOrder
+    {this, "VariableOrder", {
+          "NN_matrix0",
+          "NN_matrix1",
+          "NN_matrix2",
+          "NN_matrix3",
+          "NN_matrix4",
+          "NN_matrix5",
+          "NN_matrix6",
+          "NN_matrix7",
+          "NN_matrix8",
+          "NN_matrix9",
+          "NN_matrix10",
+          "NN_matrix11",
+          "NN_matrix12",
+          "NN_matrix13",
+          "NN_matrix14",
+          "NN_matrix15",
+          "NN_matrix16",
+          "NN_matrix17",
+          "NN_matrix18",
+          "NN_matrix19",
+          "NN_matrix20",
+          "NN_matrix21",
+          "NN_matrix22",
+          "NN_matrix23",
+          "NN_matrix24",
+          "NN_matrix25",
+          "NN_matrix26",
+          "NN_matrix27",
+          "NN_matrix28",
+          "NN_matrix29",
+          "NN_matrix30",
+          "NN_matrix31",
+          "NN_matrix32",
+          "NN_matrix33",
+          "NN_matrix34",
+          "NN_matrix35",
+          "NN_matrix36",
+          "NN_matrix37",
+          "NN_matrix38",
+          "NN_matrix39",
+          "NN_matrix40",
+          "NN_matrix41",
+          "NN_matrix42",
+          "NN_matrix43",
+          "NN_matrix44",
+          "NN_matrix45",
+          "NN_matrix46",
+          "NN_matrix47",
+          "NN_matrix48",
+          "NN_pitches0",
+          "NN_pitches1",
+          "NN_pitches2",
+          "NN_pitches3",
+          "NN_pitches4",
+          "NN_pitches5",
+          "NN_pitches6",
+          "NN_layer",
+          "NN_barrelEC",
+          "NN_phi",
+          "NN_theta"},
+        "List of training variables for the LWTNN networks in the order they are fed to evaluate the networks"};
+
 /*
   Gaudi::Property<std::string> m_layerInfoHistogram
   {this, "LayerInfoHistogram",      "LayersInfo","Name about the layer info histogram."};
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
index 716f562270135c53d979161dfe6a907a3e56bf0c..470836d3e58d58f33e84b35a9c3599745724c4fd 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
@@ -159,27 +159,18 @@ namespace InDet {
   }
 
 
-  std::vector<double> NnClusterizationFactory::assembleInputRunII(NNinput& input,
-                                                                  int sizeX,
-                                                                  int sizeY) const
+  std::vector<double> NnClusterizationFactory::assembleInputRunII(NNinput& input) const
 {
 
     std::vector<double> inputData;
-    for (int u=0;u<sizeX;u++)
+    for (int u=0;u<m_sizeX;u++)
     {
-      for (int s=0;s<sizeY;s++)
+      for (int s=0;s<m_sizeY;s++)
       {
-        if (m_useToT)
-        {
-          inputData.push_back(input.matrixOfToT[u][s]);
-        }
-        else
-        {
-          inputData.push_back(input.matrixOfToT[u][s]);
-        }
+        inputData.push_back(input.matrixOfToT[u][s]);
       }
     }
-    for (int s=0;s<sizeY;s++)
+    for (int s=0;s<m_sizeY;s++)
     {
       inputData.push_back(input.vectorOfPitchesY[s]);
     }
@@ -200,7 +191,6 @@ namespace InDet {
 
 
 
-
     return inputData;
 
 }
@@ -208,14 +198,12 @@ namespace InDet {
 
 
 
-  std::vector<double> NnClusterizationFactory::assembleInputRunI(NNinput& input,
-                                                                 int sizeX,
-                                                                 int sizeY) const
+  std::vector<double> NnClusterizationFactory::assembleInputRunI(NNinput& input) const
   {
     std::vector<double> inputData;
-    for (int u=0;u<sizeX;u++)
+    for (int u=0;u<m_sizeX;u++)
     {
-      for (int s=0;s<sizeY;s++)
+      for (int s=0;s<m_sizeY;s++)
       {
         if (m_useToT)
         {
@@ -227,7 +215,7 @@ namespace InDet {
         }
       }
     }
-    for (int s=0;s<sizeY;s++)
+    for (int s=0;s<m_sizeY;s++)
     {
       const double rawPitch(input.vectorOfPitchesY[s]);
         const double normPitch(norm_pitch(rawPitch,m_addIBL));
@@ -253,57 +241,68 @@ namespace InDet {
     return inputData;
   }
 
-  NnClusterizationFactory::InputMap NnClusterizationFactory::flattenInput(NNinput & input) const
+  NnClusterizationFactory::InputVector NnClusterizationFactory::eigenInput(NNinput & input) const
   {
 
-    // Format for use with lwtnn
-    std::map<std::string, std::map<std::string, double> > flattened;
+    // we know the size to be
+    //  - m_sizeX x m_sizeY pixel ToT values
+    //  - m_sizeY pitch sizes in y
+    //  - 2 values: detector location 
+    //  - 2 values: track incidence angles 
+    //  - optional: eta module
+    int vecSize = m_sizeX*m_sizeY + m_sizeY + 4;
+    if(!input.useTrackInfo) { vecSize += 1; }
+    Eigen::VectorXd valuesVector( vecSize );
 
     // Fill it!
-    // Variable names here need to match the ones in the configuration.    
+    // Variable names here need to match the ones in the configuration...
+    // ...IN THE SAME ORDER!!!
+    // location in eigen matrix object where next element goes
+    int location(0);
 
-    std::map<std::string, double> simpleInputs;
     for (unsigned int x = 0; x < input.matrixOfToT.size(); x++) {
       for (unsigned int y = 0; y < input.matrixOfToT.at(0).size(); y++) {
-        unsigned int index = x*input.matrixOfToT.at(0).size()+y;
-        std::string varname = "NN_matrix"+std::to_string(index);
-        simpleInputs[varname] = input.matrixOfToT.at(x).at(y);
+        valuesVector[location] = input.matrixOfToT.at(x).at(y);
+        location++;
       }
     }
 
     for (unsigned int p = 0; p < input.vectorOfPitchesY.size(); p++) {
-      std::string varname = "NN_pitches" + std::to_string(p);
-      simpleInputs[varname] = input.vectorOfPitchesY.at(p);
+      valuesVector[location] = input.vectorOfPitchesY.at(p);
+      location++;
     }
 
-    simpleInputs["NN_layer"] = input.ClusterPixLayer;
-    simpleInputs["NN_barrelEC"] = input.ClusterPixBarrelEC;
-    simpleInputs["NN_phi"] = input.phi;
-    simpleInputs["NN_theta"] = input.theta;
-
-    if (input.useTrackInfo) simpleInputs["NN_etaModule"] = input.etaModule;
+    valuesVector[location] = input.ClusterPixLayer;
+    location++;
+    valuesVector[location] = input.ClusterPixBarrelEC;
+    location++;
+    valuesVector[location] = input.phi;
+    location++;
+    valuesVector[location] = input.theta;
+    location++;
+
+    if (!input.useTrackInfo) { 
+      valuesVector[location] = input.etaModule;
+      location++;
+    }
 
     // We have only one node for now, so we just store things there.
-    flattened["NNinputs"] = simpleInputs;
-
-    return flattened;
-
+    // Format for use with lwtnn
+    std::vector<Eigen::VectorXd> vectorOfEigen;
+    vectorOfEigen.push_back(valuesVector);
 
+    return vectorOfEigen;
   }
 
   std::vector<double> NnClusterizationFactory::estimateNumberOfParticles(const InDet::PixelCluster& pCluster,
-                                                                         Amg::Vector3D & beamSpotPosition,
-                                                                         int sizeX,
-                                                                         int sizeY) const
+                                                                         Amg::Vector3D & beamSpotPosition) const
   {
 
     double tanl=0;
 
     NNinput input( createInput(pCluster,
                                beamSpotPosition,
-                               tanl,
-                               sizeX,
-                               sizeY) );
+                               tanl) );
 
     if (!input)
     {
@@ -311,11 +310,12 @@ namespace InDet {
     }
 
 
-    std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
-
     // If using old TTrainedNetworks, fetch correct ones for the
     // without-track situation and call them now.
     if (m_useTTrainedNetworks) {
+
+      std::vector<double> inputData=(this->*m_assembleInput)(input);
+
       SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
       if (!nn_collection.isValid()) {
 	ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
@@ -324,16 +324,14 @@ namespace InDet {
     }
 
     // Otherwise, prepare lwtnn input map and use new networks.
-    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
-    return estimateNumberOfParticlesLWTNN(nnInputData);
+    NnClusterizationFactory::InputVector nnInputVector = eigenInput(input);
+    return estimateNumberOfParticlesLWTNN(nnInputVector);
 
   }
 
   std::vector<double> NnClusterizationFactory::estimateNumberOfParticles(const InDet::PixelCluster& pCluster,
                                                                          const Trk::Surface& pixelSurface,
-                                                                         const Trk::TrackParameters& trackParsAtSurface,
-                                                                         int sizeX,
-                                                                         int sizeY) const
+                                                                         const Trk::TrackParameters& trackParsAtSurface) const
   {
 
     Amg::Vector3D dummyBS(0,0,0);
@@ -342,9 +340,7 @@ namespace InDet {
 
     NNinput input( createInput(pCluster,
                                dummyBS,
-                               tanl,
-                               sizeX,
-                               sizeY) );
+                               tanl) );
 
     if (!input)
     {
@@ -353,7 +349,7 @@ namespace InDet {
 
     addTrackInfoToInput(input,pixelSurface,trackParsAtSurface,tanl);
 
-    std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
+    std::vector<double> inputData=(this->*m_assembleInput)(input);
 
     // If using old TTrainedNetworks, fetch correct ones for the
     // with-track situation and call them now.
@@ -366,8 +362,8 @@ namespace InDet {
     }
 
     // Otherwise, prepare lwtnn input map and use new networks.
-    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
-    return estimateNumberOfParticlesLWTNN(nnInputData);
+    NnClusterizationFactory::InputVector nnInputVector = eigenInput(input);
+    return estimateNumberOfParticlesLWTNN(nnInputVector);
 
   }
 
@@ -385,18 +381,20 @@ namespace InDet {
     return resultNN_TTN;
   }
 
-  std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const
+
+  std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputVector & input) const
   {
     SG::ReadCondHandle<LWTNNCollection> lwtnn_collection(m_readKeyJSON) ;
     if (!lwtnn_collection.isValid()) {
       ATH_MSG_ERROR( "Failed to get LWTNN network collection with key " << m_readKeyJSON.key() );
     }
     ATH_MSG_DEBUG("Using lwtnn number network");
-    // Evaluate the number network once per cluster
-    lwt::ValueMap discriminant = lwtnn_collection->at(0)->compute(input);
-    double num0 = discriminant["output_number0"];
-    double num1 = discriminant["output_number1"];
-    double num2 = discriminant["output_number2"];
+    // Order of output matches order in JSON config in "outputs"
+    // Only 1 node here, simple compute function
+    Eigen::VectorXd discriminant = lwtnn_collection->at(0)->compute(input);
+    double num0 = discriminant[0];
+    double num1 = discriminant[1];
+    double num2 = discriminant[2];
     // Get normalized predictions
     double prob1 = num0/(num0+num1+num2);
     double prob2 = num1/(num0+num1+num2);
@@ -408,14 +406,14 @@ namespace InDet {
                                                " (3): " << number_probabilities[2]);
 
     return number_probabilities;
+
   }
 
+
   std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositions(const InDet::PixelCluster& pCluster,
                                                                              Amg::Vector3D & beamSpotPosition,
                                                                              std::vector<Amg::MatrixX> & errors,
-                                                                             int numberSubClusters,
-                                                                             int sizeX,
-                                                                             int sizeY) const
+                                                                             int numberSubClusters) const
   {
 
     ATH_MSG_VERBOSE(" Starting to estimate positions...");
@@ -424,32 +422,30 @@ namespace InDet {
 
     NNinput input( createInput(pCluster,
                                beamSpotPosition,
-                               tanl,
-                               sizeX,
-                               sizeY) );
+                               tanl) );
 
     if (!input)
     {
       return std::vector<Amg::Vector2D>();
     }
 
-    std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
-
     // If using old TTrainedNetworks, fetch correct ones for the
     // without-track situation and call them now.
     if (m_useTTrainedNetworks) {
+
+      std::vector<double> inputData=(this->*m_assembleInput)(input);
+
       SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
       if (!nn_collection.isValid()) {
 	ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
       }
       // *(ReadCondHandle<>) returns a pointer rather than a reference ...
-      return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
+      return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,numberSubClusters,errors);
     }
 
     // Otherwise, prepare lwtnn input map and use new networks.
-    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
-    return estimatePositionsLWTNN(nnInputData,input,pCluster,numberSubClusters,errors);
-
+    NnClusterizationFactory::InputVector nnInputVector = eigenInput(input);
+    return estimatePositionsLWTNN(nnInputVector,input,pCluster,numberSubClusters,errors);
   }
 
 
@@ -457,9 +453,7 @@ namespace InDet {
                                                                         const Trk::Surface& pixelSurface,
                                                                         const Trk::TrackParameters& trackParsAtSurface,
                                                                         std::vector<Amg::MatrixX> & errors,
-                                                                        int numberSubClusters,
-                                                                        int sizeX,
-                                                                        int sizeY) const
+                                                                        int numberSubClusters) const
   {
 
     ATH_MSG_VERBOSE(" Starting to estimate positions...");
@@ -470,9 +464,7 @@ namespace InDet {
 
     NNinput input( createInput(pCluster,
                                dummyBS,
-                               tanl,
-                               sizeX,
-                               sizeY) );
+                               tanl) );
 
     if (!input)
     {
@@ -481,22 +473,22 @@ namespace InDet {
 
     addTrackInfoToInput(input,pixelSurface,trackParsAtSurface,tanl);
 
-    std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
 
     // If using old TTrainedNetworks, fetch correct ones for the
     // without-track situation and call them now.
     if (m_useTTrainedNetworks) {
+      std::vector<double> inputData=(this->*m_assembleInput)(input);
       SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
       if (!nn_collection.isValid()) {
 	ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithTrack.key() );
       }
 
-      return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
+      return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,numberSubClusters,errors);
     }
 
     // Otherwise, prepare lwtnn input map and use new networks.
-    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
-    return estimatePositionsLWTNN(nnInputData,input,pCluster,numberSubClusters,errors);
+    NnClusterizationFactory::InputVector nnInputVector = eigenInput(input);
+    return estimatePositionsLWTNN(nnInputVector,input,pCluster,numberSubClusters,errors);
 
   }
 
@@ -505,8 +497,6 @@ namespace InDet {
                                                 const std::vector<double>& inputData,
                                                 const NNinput& input,
                                                 const InDet::PixelCluster& pCluster,
-                                                int sizeX,
-                                                int sizeY,
                                                 int numberSubClusters,
                                                 std::vector<Amg::MatrixX> & errors) const
   {
@@ -519,7 +509,7 @@ namespace InDet {
       // get position network id for the given cluster multiplicity then
       // dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
       std::vector<double> position1P(((*(nn_collection.at( m_NNId[kPositionNN-1].at(numberSubClusters-1)))).*m_calculateOutput)(inputData));
-      std::vector<Amg::Vector2D> myPosition1=getPositionsFromOutput(position1P,input,pCluster,sizeX,sizeY);
+      std::vector<Amg::Vector2D> myPosition1=getPositionsFromOutput(position1P,input,pCluster);
 
       assert( position1P.size() % 2 == 0);
       for (unsigned int i=0; i<position1P.size()/2 ; ++i) {
@@ -553,8 +543,9 @@ namespace InDet {
     return allPositions;
   }
 
+
   std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositionsLWTNN(
-                                                                NnClusterizationFactory::InputMap & input, 
+                                                                NnClusterizationFactory::InputVector & input, 
                                                                 NNinput& rawInput,
                                                                 const InDet::PixelCluster& pCluster,
                                                                 int numberSubClusters,
@@ -572,6 +563,7 @@ namespace InDet {
     std::vector<Amg::MatrixX> errorMatrices;
     errorMatrices.reserve(numberSubClusters);
     positionValues.reserve(numberSubClusters * 2);
+    std::size_t outputNode(0);
     for (int cluster = 1; cluster < numberSubClusters+1; cluster++) {
 
       // Check that the network is defined. 
@@ -580,23 +572,35 @@ namespace InDet {
 	ATH_MSG_ERROR( "No lwtnn network configured for this run! If you are outside the valid range for lwtnn-based configuration, plesae run with useNNTTrainedNetworks instead." << m_readKeyJSON.key() );
       }
 
-      std::string outNodeName = "merge_"+std::to_string(cluster);
-      std::map<std::string, double> position = lwtnn_collection->at(numberSubClusters)->compute(input, {},outNodeName);
+      if(numberSubClusters==1) {
+        outputNode = m_outputNodesPos1; }
+      else if(numberSubClusters==2) {
+        outputNode = m_outputNodesPos2[cluster-1]; 
+      } else if(numberSubClusters==3) {
+        outputNode = m_outputNodesPos3[cluster-1]; 
+      } else {
+        ATH_MSG_ERROR( "Cannot evaluate LWTNN networks with " << numberSubClusters << " numberSubClusters" );
+      }
+      
+      // Order of output matches order in JSON config in "outputs"
+      // "alpha", "mean_x", "mean_y", "prec_x", "prec_y"
+      // Assume here that 1 particle network is in position 1, 2 at 2, and 3 at 3.
+      Eigen::VectorXd position = lwtnn_collection->at(numberSubClusters)->compute(input, {}, outputNode);
 
       ATH_MSG_DEBUG("Testing for numberSubClusters " << numberSubClusters << " and cluster " << cluster);
-      for (const auto& item : position) {
-        ATH_MSG_DEBUG(item.first << ": " << item.second);
+      for (int i=0; i<position.rows(); i++) {
+        ATH_MSG_DEBUG(" position " << position[i]);
       }
-      positionValues.push_back(position["mean_x"]);
-      positionValues.push_back(position["mean_y"]);
+      positionValues.push_back(position[1]); //mean_x
+      positionValues.push_back(position[2]); //mean_y
 
       // Fill errors.
       // Values returned by NN are inverse of variance, and we want variances.
-      float rawRmsX = sqrt(1.0/position["prec_x"]);
-      float rawRmsY = sqrt(1.0/position["prec_y"]);
+      float rawRmsX = sqrt(1.0/position[3]); //prec_x
+      float rawRmsY = sqrt(1.0/position[4]); //prec_y
       // Now convert to real space units
       double rmsX = correctedRMSX(rawRmsX);
-      double rmsY = correctedRMSY(rawRmsY, 7., rawInput.vectorOfPitchesY);
+      double rmsY = correctedRMSY(rawRmsY, rawInput.vectorOfPitchesY);
       ATH_MSG_DEBUG(" Estimated RMS errors (1) x: " << rmsX << ", y: " << rmsY);  
 
       // Fill matrix    
@@ -614,7 +618,6 @@ namespace InDet {
     for (unsigned int index = 0; index < errorMatrices.size(); index++) errors.push_back(errorMatrices.at(index));
 
     return myPositions;
-
   }
 
   double NnClusterizationFactory::correctedRMSX(double posPixels) const
@@ -628,18 +631,17 @@ namespace InDet {
   }
 
   double NnClusterizationFactory::correctedRMSY(double posPixels,
-         double sizeY,
         std::vector<float>& pitches) const
   {
-    double p = posPixels + (sizeY - 1) / 2.0;
+    double p = posPixels + (m_sizeY - 1) / 2.0;
     double p_Y = -100;
     double p_center = -100;
     double p_actual = 0;
 
-    for (int i = 0; i < sizeY; i++) {
+    for (int i = 0; i < m_sizeY; i++) {
       if (p >= i && p <= (i + 1))
         p_Y = p_actual + (p - i + 0.5) * pitches.at(i);
-      if (i == (sizeY - 1) / 2)
+      if (i == (m_sizeY - 1) / 2)
         p_center = p_actual + 0.5 * pitches.at(i);
       p_actual += pitches.at(i);
     }
@@ -758,10 +760,8 @@ namespace InDet {
 
 
   std::vector<Amg::Vector2D> NnClusterizationFactory::getPositionsFromOutput(std::vector<double> & output,
-                                                                             const NNinput & input,
-                                                                             const InDet::PixelCluster& pCluster,
-                                                                             int /* sizeX */,
-                                                                             int /* sizeY */) const
+      const NNinput & input,
+      const InDet::PixelCluster& pCluster) const
   {
     std::vector<Amg::Vector2D> invalidResult;
     ATH_MSG_VERBOSE(" Translating output back into a position " );
@@ -904,9 +904,7 @@ namespace InDet {
 
   NNinput NnClusterizationFactory::createInput(const InDet::PixelCluster& pCluster,
                                                 Amg::Vector3D & beamSpotPosition,
-                                                double & tanl,
-                                                int sizeX,
-                                                int sizeY) const
+                                                double & tanl) const
 {
   NNinput input;
 
@@ -1071,8 +1069,8 @@ namespace InDet {
 
   ATH_MSG_VERBOSE(" weighted pos row: " << rowWeightedPosition << " col: " << columnWeightedPosition );
 
-  int centralIndexX=(sizeX-1)/2;
-  int centralIndexY=(sizeY-1)/2;
+  int centralIndexX=(m_sizeX-1)/2;
+  int centralIndexY=(m_sizeY-1)/2;
 
 
   if (abs(rowWeightedPosition-rowMin)>centralIndexX ||
@@ -1089,12 +1087,12 @@ namespace InDet {
     return input;
   }
 
-  input.matrixOfToT.reserve(sizeX);
-  for (int a=0;a<sizeX;a++)
+  input.matrixOfToT.reserve(m_sizeX);
+  for (int a=0;a<m_sizeX;a++)
   {
-    input.matrixOfToT.emplace_back(sizeY, 0.0);
+    input.matrixOfToT.emplace_back(m_sizeY, 0.0);
   }
-  input.vectorOfPitchesY.assign(sizeY, 0.4);
+  input.vectorOfPitchesY.assign(m_sizeY, 0.4);
 
   rdosBegin = rdos.begin();
   //charge = chList.size() ? chList.begin() : chListRecreated.begin();
@@ -1132,14 +1130,14 @@ namespace InDet {
     }
 
 
-    if (absrow <0 || absrow > sizeX)
+    if (absrow <0 || absrow > m_sizeX)
     {
-      ATH_MSG_WARNING(" problem with index: " << absrow << " min: " << 0 << " max: " << sizeX);
+      ATH_MSG_WARNING(" problem with index: " << absrow << " min: " << 0 << " max: " << m_sizeX);
       return input;
     }
-    if (abscol <0 || abscol > sizeY)
+    if (abscol <0 || abscol > m_sizeY)
     {
-      ATH_MSG_WARNING(" problem with index: " << abscol << " min: " << 0 << " max: " << sizeY);
+      ATH_MSG_WARNING(" problem with index: " << abscol << " min: " << 0 << " max: " << m_sizeY);
       return input;
     }
     InDetDD::SiCellId cellId = element->cellIdFromIdentifier(*rdosBegin);
diff --git a/Reconstruction/LwtnnUtils/CMakeLists.txt b/Reconstruction/LwtnnUtils/CMakeLists.txt
index a10e47c1630b2f0d1d22cde527f3896610ed0d0d..f1ccf35ee2fd8fd2a7062c258fa8b442d1269496 100644
--- a/Reconstruction/LwtnnUtils/CMakeLists.txt
+++ b/Reconstruction/LwtnnUtils/CMakeLists.txt
@@ -17,3 +17,7 @@ atlas_add_library( LwtnnUtils
   INCLUDE_DIRS ${LWTNN_INCLUDE_DIRS} ${EIGEN_INCLUDE_DIRS}
   LINK_LIBRARIES ${LWTNN_LIBRARIES} ${EIGEN_LIBRARIES} )
 
+atlas_add_executable(
+  test_lwtnn_fastgraph
+  utils/test_lwtnn_fastgraph.cxx
+  LINK_LIBRARIES LwtnnUtils )
diff --git a/Reconstruction/LwtnnUtils/LwtnnUtils/FastGraph.h b/Reconstruction/LwtnnUtils/LwtnnUtils/FastGraph.h
index 89e26d44d7fe31a2bda6c99504904ce35944874d..59ae7c0779e8d1c35a85d7fa875f36f98ec071ec 100644
--- a/Reconstruction/LwtnnUtils/LwtnnUtils/FastGraph.h
+++ b/Reconstruction/LwtnnUtils/LwtnnUtils/FastGraph.h
@@ -53,13 +53,15 @@ namespace lwt::atlas {
     // The simpler "compute" function
     Eigen::VectorXd compute(const NodeVec&, const SeqNodeVec& = {}) const;
 
+    // the other "compute" which allows you to select an arbitrary output
+    Eigen::VectorXd compute(const NodeVec&, const SeqNodeVec&, size_t) const;
+
   private:
     typedef FastInputPreprocessor IP;
     typedef FastInputVectorPreprocessor IVP;
     typedef std::vector<IP*> Preprocs;
     typedef std::vector<IVP*> VecPreprocs;
 
-    Eigen::VectorXd compute(const NodeVec&, const SeqNodeVec&, size_t) const;
     Graph* m_graph;
     Preprocs m_preprocs;
     VecPreprocs m_vec_preprocs;
diff --git a/Reconstruction/LwtnnUtils/utils/test_lwtnn_fastgraph.cxx b/Reconstruction/LwtnnUtils/utils/test_lwtnn_fastgraph.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..7371afeaf77681317209b01fed91758fd73c2741
--- /dev/null
+++ b/Reconstruction/LwtnnUtils/utils/test_lwtnn_fastgraph.cxx
@@ -0,0 +1,32 @@
+#include "LwtnnUtils/FastGraph.h"
+#include "LwtnnUtils/InputOrder.h"
+
+#include "lwtnn/lightweight_network_config.hh"
+#include "lwtnn/parse_json.hh"
+
+#include <string>
+#include <vector>
+#include <fstream>
+
+struct Args
+{
+  std::string nn_file;
+};
+
+Args getArgs(int nargs, char* argv[]) {
+  Args args;
+  if (nargs != 1) return args;
+  args.nn_file = argv[1];
+  return args;
+}
+
+int main(int nargs, char* argv[]) {
+  Args args = getArgs(nargs, argv);
+  if (args.nn_file.size() == 0) return 1;
+  auto nn_file = std::ifstream(args.nn_file);
+  auto graph_config = lwt::parse_json_graph(nn_file);
+
+  lwt::atlas::FastGraph graph(graph_config, {}, "");
+
+  return 0;
+}