From e87ac1ae5711048b7b98af64ab829dfed0bb4794 Mon Sep 17 00:00:00 2001
From: Katherine Pachal <katepachal@gmail.com>
Date: Fri, 18 Sep 2020 11:55:52 +0000
Subject: [PATCH] Add NN access for pixel clustering via lwtnn to r22

---
 .../InDetConfig/python/InDetConfigFlags.py    |   1 +
 .../python/InDetJobProperties.py              |   7 +
 .../InDetRecExample/python/TrackingCommon.py  |  48 ++++-
 .../share/InDetRecConditionsAccess.py         |   3 +-
 .../SiClusterizationTool/CMakeLists.txt       |  18 +-
 .../SiClusterizationTool/LWTNNCollection.h    |  30 +++
 .../NnClusterizationFactory.h                 |  47 +++-
 .../SiClusterizationTool/src/LWTNNCondAlg.cxx | 182 ++++++++++++++++
 .../SiClusterizationTool/src/LWTNNCondAlg.h   |  85 ++++++++
 .../src/NnClusterizationFactory.cxx           | 200 +++++++++++++++---
 .../SiClusterizationTool_entries.cxx          |   3 +-
 11 files changed, 577 insertions(+), 47 deletions(-)
 create mode 100644 InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h
 create mode 100644 InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx
 create mode 100644 InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h

diff --git a/InnerDetector/InDetConfig/python/InDetConfigFlags.py b/InnerDetector/InDetConfig/python/InDetConfigFlags.py
index b04a31baeb1..f5a76d284f6 100644
--- a/InnerDetector/InDetConfig/python/InDetConfigFlags.py
+++ b/InnerDetector/InDetConfig/python/InDetConfigFlags.py
@@ -173,6 +173,7 @@ def createInDetConfigFlags():
   icf.addFlag("InDet.doSLHCVeryForward", False ) # Turn running of SLHC reconstruction for Very Forward extension on and off 
   icf.addFlag("InDet.doTRTGlobalOccupancy", False) # Turn running of Event Info TRT Occupancy Filling Alg on and off (also whether it is used in TRT PID calculation) 
   icf.addFlag("InDet.doNNToTCalibration", False ) # USe ToT calibration for NN clustering rather than Charge 
+  icf.addFlag("InDet.useNNTTrainedNetworks", True ) # Use older NNs stored as TTrainedNetworks in place of default MDNs/other more recent networks. This is necessary for older configuration tags where the trainings were not available.
   icf.addFlag("InDet.keepAdditionalHitsOnTrackParticle", False) # Do not drop first/last hits on track (only for special cases - will blow up TrackParticle szie!!!) 
   icf.addFlag("InDet.doSCTModuleVeto", False) # Turn on SCT_ModuleVetoSvc, allowing it to be configured later 
   icf.addFlag("InDet.doParticleConversion", False) # In case anyone still wants to do Rec->xAOD TrackParticle Conversion 
diff --git a/InnerDetector/InDetExample/InDetRecExample/python/InDetJobProperties.py b/InnerDetector/InDetExample/InDetRecExample/python/InDetJobProperties.py
index 160c7f3fcfe..e988de6ed2e 100644
--- a/InnerDetector/InDetExample/InDetRecExample/python/InDetJobProperties.py
+++ b/InnerDetector/InDetExample/InDetRecExample/python/InDetJobProperties.py
@@ -1145,6 +1145,12 @@ class doNNToTCalibration(InDetFlagsJobProperty):
   allowedTypes = ['bool']
   StoredValue  = False
 
+class useNNTTrainedNetworks(InDetFlagsJobProperty):
+  """Use older NNs stored as TTrainedNetworks in place of default MDNs/other more recent networks. This is necessary for older configuration tags where the trainings were not available."""
+  statusOn     = True
+  allowedTypes = ['bool']
+  StoredValue  = True
+
 class keepAdditionalHitsOnTrackParticle(InDetFlagsJobProperty): 
   """Do not drop first/last hits on track (only for special cases - will blow up TrackParticle szie!!!)""" 
   statusOn     = True 
@@ -2757,6 +2763,7 @@ _list_InDetJobProperties = [Enabled,
                             doSLHCVeryForward,
                             doTRTGlobalOccupancy,
                             doNNToTCalibration,
+                            useNNTTrainedNetworks,
                             keepAdditionalHitsOnTrackParticle,
                             doSCTModuleVeto,
                             doDBMstandalone,
diff --git a/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py b/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py
index 088720d6326..082a8dcb2f9 100644
--- a/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py
+++ b/InnerDetector/InDetExample/InDetRecExample/python/TrackingCommon.py
@@ -274,6 +274,23 @@ def getPixelClusterNnCondAlg(**kwargs) :
     from SiClusterizationTool.SiClusterizationToolConf import InDet__TTrainedNetworkCondAlg
     return InDet__TTrainedNetworkCondAlg(kwargs.pop("name", 'PixelClusterNnCondAlg'), **kwargs)
 
+def getLWTNNCondAlg(**kwargs) :
+
+    # Check for the folder
+    from IOVDbSvc.CondDB import conddb
+    if not conddb.folderRequested('/PIXEL/PixelClustering/PixelClusNNCalibJSON'):
+        # COOL binding
+        conddb.addFolderSplitOnline("PIXEL","/PIXEL/Onl/PixelClustering/PixelNNCalibJSON",
+                                    "/PIXEL/PixelClustering/PixelNNCalibJSON",className='CondAttrListCollection')
+
+    # What we'll store it as
+    kwargs=setDefaults(kwargs,
+                       WriteKey = 'PixelClusterNNJSON')
+
+    # Set up the algorithm
+    from SiClusterizationTool.SiClusterizationToolConf import InDet__LWTNNCondAlg
+    return InDet__LWTNNCondAlg(kwargs.pop("name", "LWTNNCondAlg"),**kwargs)
+
 def getPixelClusterNnWithTrackCondAlg(**kwargs) :
 
     kwargs = setDefaults( kwargs,
@@ -306,10 +323,33 @@ def getNnClusterizationFactory(name='NnClusterizationFactory', **kwargs) :
     if 'PixelLorentzAngleTool' not in kwargs :
         kwargs = setDefaults( kwargs, PixelLorentzAngleTool = getPixelLorentzAngleTool())
 
+    from InDetRecExample.InDetJobProperties import InDetFlags
+    useTTrainedNetworks = InDetFlags.useNNTTrainedNetworks()
     from AtlasGeoModel.CommonGMJobProperties import CommonGeometryFlags as geoFlags
     do_runI = geoFlags.Run() not in ["RUN2", "RUN3"]
-    createAndAddCondAlg( getPixelClusterNnCondAlg,         'PixelClusterNnCondAlg',          GetInputsInfo = do_runI)
-    createAndAddCondAlg( getPixelClusterNnWithTrackCondAlg,'PixelClusterNnWithTrackCondAlg', GetInputsInfo = do_runI)
+    
+    if useTTrainedNetworks :
+      log.debug("Setting up TTrainedNetworks")
+      createAndAddCondAlg( getPixelClusterNnCondAlg,         'PixelClusterNnCondAlg',          GetInputsInfo = do_runI)
+      createAndAddCondAlg( getPixelClusterNnWithTrackCondAlg,'PixelClusterNnWithTrackCondAlg', GetInputsInfo = do_runI)
+    else :
+
+      ######################################
+      # Temporary - pixel clustering setup #
+      ######################################
+      # Allow use of folder that exists but is not yet in global tag.
+      # Different names in different DB instances....
+      if not ('conddb' in dir()):
+        from IOVDbSvc.CondDB import conddb
+
+      if (conddb.dbmc == "OFLP200" or (conddb.dbdata=="OFLP200" and globalflags.DataSource=='data')) :
+        conddb.addOverride("/PIXEL/PixelClustering/PixelNNCalibJSON","PixelNNCalibJSON-SIM-RUN2-000-00")
+      if ((conddb.dbmc == "CONDBR2" and globalflags.DataSource!='data') or conddb.dbdata == "CONDBR2") :
+        conddb.addOverride("/PIXEL/PixelClustering/PixelNNCalibJSON","PixelNNCalibJSON-DATA-RUN2-000-00")
+      ## End of temporary code
+
+      log.debug("Setting up lwtnn system")
+      createAndAddCondAlg( getLWTNNCondAlg,                  'LWTNNCondAlg')
 
     from InDetRecExample.InDetJobProperties import InDetFlags
     kwargs = setDefaults( kwargs,
@@ -319,8 +359,10 @@ def getNnClusterizationFactory(name='NnClusterizationFactory', **kwargs) :
                           useRecenteringNNWithTracks         = False if do_runI else False,  # default,
                           correctLorShiftBarrelWithoutTracks = 0,
                           correctLorShiftBarrelWithTracks    = 0.030 if do_runI else 0.000,  # default,
+                          useTTrainedNetworks                = useTTrainedNetworks,
                           NnCollectionReadKey                = 'PixelClusterNN',
-                          NnCollectionWithTrackReadKey       = 'PixelClusterNNWithTrack')
+                          NnCollectionWithTrackReadKey       = 'PixelClusterNNWithTrack',
+                          NnCollectionJSONReadKey            = 'PixelClusterNNJSON')
     return InDet__NnClusterizationFactory(name=the_name, **kwargs)
 
 @makePublicTool
diff --git a/InnerDetector/InDetExample/InDetRecExample/share/InDetRecConditionsAccess.py b/InnerDetector/InDetExample/InDetRecExample/share/InDetRecConditionsAccess.py
index c6282395efc..335ae6e7d84 100644
--- a/InnerDetector/InDetExample/InDetRecExample/share/InDetRecConditionsAccess.py
+++ b/InnerDetector/InDetExample/InDetRecExample/share/InDetRecConditionsAccess.py
@@ -228,8 +228,7 @@ if DetFlags.pixel_on():
                 PixeldEdxAlg.CalibrationFile="dtpar_signed_234.txt"
             else:
                 PixeldEdxAlg.CalibrationFile="mcpar_signed_234.txt"
-
-
+ 
 #
 # --- Load SCT Conditions Services
 #
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt b/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt
index 5c1639fc50e..f9cf31ff723 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/CMakeLists.txt
@@ -41,27 +41,33 @@ atlas_depends_on_subdirs(
    Tracking/TrkEvent/VxVertex )
 
 # External dependencies:
+find_package( lwtnn )
 find_package( CLHEP )
 find_package( ROOT COMPONENTS Core MathCore Hist )
+find_package( COOL COMPONENTS CoolKernel CoolApplication )
 
 # Component(s) in the package:
 atlas_add_library( SiClusterizationToolLib
    SiClusterizationTool/*.h src/*.cxx
    PUBLIC_HEADERS SiClusterizationTool
-   INCLUDE_DIRS ${ROOT_INCLUDE_DIRS}
-   PRIVATE_INCLUDE_DIRS ${CLHEP_INCLUDE_DIRS}
+   INCLUDE_DIRS ${ROOT_INCLUDE_DIRS} ${LWTNN_INCLUDE_DIRS}
+   PRIVATE_INCLUDE_DIRS ${CLHEP_INCLUDE_DIRS} ${COOL_INCLUDE_DIRS}
    PRIVATE_DEFINITIONS ${CLHEP_DEFINITIONS}
-   LINK_LIBRARIES ${ROOT_LIBRARIES} AthenaBaseComps AthenaKernel GeoPrimitives
+   LINK_LIBRARIES ${ROOT_LIBRARIES} ${LWTNN_LIBRARIES} AthenaBaseComps AthenaKernel GeoPrimitives
    Identifier EventPrimitives GaudiKernel InDetSimData InDetIdentifier
    InDetReadoutGeometry PixelReadoutGeometry SCT_ReadoutGeometry InDetRawData InDetPrepRawData InDetRecToolInterfaces InDetConditionsSummaryService
    TrkParameters TrkNeuralNetworkUtilsLib PixelConditionsData
    PixelGeoModelLib PixelCablingLib BeamSpotConditionsData
-   PRIVATE_LINK_LIBRARIES ${CLHEP_LIBRARIES} AthenaPoolUtilities FileCatalog AtlasDetDescr
-   TrkSurfaces TrkEventPrimitives VxVertex PixelGeoModelLib PoolSvcLib DetDescrCondToolsLib )
+   PRIVATE_LINK_LIBRARIES ${CLHEP_LIBRARIES} ${COOL_LIBRARIES} AthenaPoolUtilities FileCatalog AtlasDetDescr
+   TrkSurfaces TrkEventPrimitives VxVertex PixelGeoModelLib PoolSvcLib DetDescrCondToolsLib stdc++fs)
 
 atlas_add_component( SiClusterizationTool
    src/components/*.cxx
-   LINK_LIBRARIES GaudiKernel PixelConditionsData SiClusterizationToolLib PoolSvcLib )
+   INCLUDE_DIRS ${COOL_INCLUDE_DIRS}
+   LINK_LIBRARIES ${COOL_LIBRARIES} GaudiKernel PixelConditionsData SiClusterizationToolLib PoolSvcLib )
 
 # Install files from the package:
 atlas_install_joboptions( share/*.py )
+# These files can be added by the user for testing in Grid environments,
+# in which case un-comment this line and re-cmake
+# atlas_install_data( share/*.db )
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h
new file mode 100644
index 00000000000..885b24a055a
--- /dev/null
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/LWTNNCollection.h
@@ -0,0 +1,30 @@
+/*
+  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef _LWTNNCollection_H_
+#define _LWTNNCollection_H_
+
+//#include <vector>
+#include "lwtnn/LightweightGraph.hh"
+
+class LWTNNCollection 
+  :  public std::map<int, std::unique_ptr<lwt::LightweightGraph> >
+{
+public:
+
+private:
+
+};
+
+// These values produced using clid script.
+// clid LWTNNCollection
+// 1196174442 LWTNNCollection None
+#include "AthenaKernel/CLASS_DEF.h"
+CLASS_DEF(LWTNNCollection, 1196174442, 1)
+// clid -cs LWTNNCollection
+// 1226994220
+#include "AthenaKernel/CondCont.h"
+CONDCONT_DEF(LWTNNCollection, 1226994220);
+
+#endif
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
index a84861a41e9..01039990556 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
@@ -2,13 +2,13 @@
   Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
 */
 
- #ifndef BTAGTOOL_NnClusterizationFactory_C
- #define BTAGTOOL_NnClusterizationFactory_C
+ #ifndef SICLUSTERIZATIONTOOL_NnClusterizationFactory_C
+ #define SICLUSTERIZATIONTOOL_NnClusterizationFactory_C
 
  /******************************************************
      @class NnClusterizationFactory
      @author Giacinto Piacquadio (PH-ADE-ID)
-     Package : JetTagTools
+     Package : SiClusterizationTool
      Created : January 2011
      DESCRIPTION: Load neural networks used for clustering
                   and deal with:
@@ -37,6 +37,7 @@
 #include "EventPrimitives/EventPrimitives.h"
 #include "InDetCondTools/ISiLorentzAngleTool.h"
 #include "SiClusterizationTool/TTrainedNetworkCollection.h"
+#include "SiClusterizationTool/LWTNNCollection.h"
 #include "PixelCabling/IPixelCablingSvc.h"
 #include "PixelConditionsData/PixelModuleData.h"
 #include "PixelConditionsData/PixelChargeCalibCondData.h"
@@ -46,6 +47,11 @@
  class TH1;
  class ICoolHistSvc;
 
+namespace lwt {
+  class NanReplacer;    
+  class LightweightGraph;
+}
+
 namespace Trk {
   class NeuralNetworkToHistoTool;
   class Surface;
@@ -106,6 +112,7 @@ namespace InDet {
                                                   int sizeX=7,
                                                   int sizeY=7) const;
 
+    /* Public-facing method 1: no track parameters */
     std::vector<Amg::Vector2D> estimatePositions(const InDet::PixelCluster& pCluster,
                                                       Amg::Vector3D & beamSpotPosition,
                                                       std::vector<Amg::MatrixX> & errors,
@@ -113,7 +120,7 @@ namespace InDet {
                                                       int sizeX=7,
                                                       int sizeY=7) const;
 
-
+    /* Public-facing method 1: with track parameters */
     std::vector<Amg::Vector2D> estimatePositions(const InDet::PixelCluster& pCluster,
                                                       const Trk::Surface& pixelSurface,
                                                       const Trk::TrackParameters& trackParsAtSurface,
@@ -124,8 +131,13 @@ namespace InDet {
 
    private:
 
-    /* estimate position for both with and w/o tracks */
-    std::vector<Amg::Vector2D> estimatePositions(const TTrainedNetworkCollection &nn_collection,
+    // Handling lwtnn inputs
+    typedef std::map<std::string, std::map<std::string, double> > InputMap;
+
+    /* Estimate position for both with and w/o tracks */
+    /* Method 1: using older TTrainedNetworks */
+    std::vector<Amg::Vector2D> estimatePositionsTTN(
+                                                 const TTrainedNetworkCollection &nn_collection,
                                                  std::vector<double> inputData,
                                                  const NNinput& input,
                                                  const InDet::PixelCluster& pCluster,
@@ -134,6 +146,20 @@ namespace InDet {
                                                  int numberSubClusters,
                                                  std::vector<Amg::MatrixX> & errors) const;
 
+    /* Estimate position for both with and w/o tracks */
+    /* Method 2: using lwtnn for more flexible interfacing */
+    std::vector<Amg::Vector2D> estimatePositionsLWTNN(
+                                                NnClusterizationFactory::InputMap & input, 
+                                                NNinput& rawInput,
+                                                const InDet::PixelCluster& pCluster,
+                                                int numberSubClusters,
+                                                std::vector<Amg::MatrixX> & errors) const;
+
+    // For error formatting in lwtnn cases
+    double correctedRMSX(double posPixels) const;
+
+    double correctedRMSY(double posPixels, double sizeY, std::vector<float>& pitches) const; 
+
      /* algorithmic component */
     NNinput createInput(const InDet::PixelCluster& pCluster,
                         Amg::Vector3D & beamSpotPosition,
@@ -157,7 +183,7 @@ namespace InDet {
                                       int sizeX,
                                       int sizeY) const;
 
-
+    InputMap flattenInput(NNinput & input) const;
 
     std::vector<Amg::Vector2D> getPositionsFromOutput(std::vector<double> & output,
                                                       const NNinput & input,
@@ -225,6 +251,10 @@ namespace InDet {
        {this, "NnCollectionWithTrackReadKey", "PixelClusterNNWithTrack",
         "The conditions store key for the pixel cluster NNs which needs tracks as input"};
 
+    SG::ReadCondHandleKey<LWTNNCollection> m_readKeyJSON
+       {this, "NnCollectionJSONReadKey", "PixelClusterNNJSON",
+        "The conditions key for the pixel cluster NNs configured via JSON file and accessed with lwtnn"};
+
     Gaudi::Property<unsigned int> m_maxSubClusters
        {this, "MaxSubClusters", 3, "Maximum number of sub cluster supported by the networks." };
 
@@ -243,6 +273,9 @@ namespace InDet {
     Gaudi::Property<bool> m_doRunI
        {this, "doRunI", false, "Use runI style network (outputs are not normalised; add pitches; use charge if not m_useToT)"};
 
+    Gaudi::Property<bool> m_useTTrainedNetworks
+       {this, "useTTrainedNetworks", false, "Use earlier (release-21-like) neural networks stored in ROOT files and accessed via TTrainedNetowrk."};
+
     Gaudi::Property<bool> m_useRecenteringNNWithouTracks
        {this, "useRecenteringNNWithoutTracks",false,"Recenter x position when evaluating NN without track input."};
 
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx
new file mode 100644
index 00000000000..0a733ff61e7
--- /dev/null
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.cxx
@@ -0,0 +1,182 @@
+/*
+  Copyright (C) 2002-2018 CERN for the benefit of the ATLAS collaboration
+*/
+/*
+ *   */
+
+#include "LWTNNCondAlg.h"
+
+#include "AthenaPoolUtilities/CondAttrListCollection.h"
+#include "AthenaPoolUtilities/AthenaAttributeList.h"
+#include "CoolKernel/IObject.h"
+#include "FileCatalog/IFileCatalog.h"
+
+// NN includes
+#include "lwtnn/parse_json.hh"
+#include "lwtnn/Exceptions.hh"
+#include "lwtnn/lightweight_nn_streamers.hh"
+#include "lwtnn/NanReplacer.hh"
+
+// JSON parsers
+#include <boost/property_tree/ptree.hpp>
+#include <boost/property_tree/json_parser.hpp>
+#include "boost/property_tree/exceptions.hpp"
+
+
+// for error messages
+#include <typeinfo>
+
+namespace InDet {
+
+  LWTNNCondAlg::LWTNNCondAlg (const std::string& name, ISvcLocator* pSvcLocator)
+    : ::AthAlgorithm( name, pSvcLocator )
+  {}
+
+  StatusCode LWTNNCondAlg::initialize() {
+    ATH_CHECK( m_condSvc.retrieve() );
+
+    // Condition Handles
+    ATH_CHECK( m_readKey.initialize() );
+    ATH_CHECK( m_writeKey.initialize() );
+
+    // Register write handle
+    if (m_condSvc->regHandle(this, m_writeKey).isFailure()) {
+      ATH_MSG_ERROR("Unable to register WriteCondHandle " << m_writeKey.fullKey() << " with CondSvc");
+      return StatusCode::FAILURE;
+    }
+
+    return StatusCode::SUCCESS;
+  }
+
+  StatusCode LWTNNCondAlg::finalize()
+  {
+    return StatusCode::SUCCESS;
+  }
+
+  StatusCode LWTNNCondAlg::configureLwtnn(std::unique_ptr<lwt::LightweightGraph> & thisNN, 
+                                        std::string thisJson) {
+
+    // Read DNN weights from input json config
+    lwt::GraphConfig config;
+    try {
+      std::istringstream input_cfg( thisJson );
+      config = lwt::parse_json_graph(input_cfg);
+    } catch (boost::property_tree::ptree_error& err) {
+      ATH_MSG_ERROR("NN file unreadable!");
+      return StatusCode::FAILURE;
+    }
+
+    // Build the network
+    try {
+      thisNN.reset(new lwt::LightweightGraph(config, "merge_1"));
+    } catch (lwt::NNConfigurationException& exc) {
+      ATH_MSG_ERROR("NN configuration problem: " << exc.what());
+      return StatusCode::FAILURE;
+    }
+
+    return StatusCode::SUCCESS;   
+
+  }
+
+  StatusCode LWTNNCondAlg::execute() {
+
+    SG::WriteCondHandle<LWTNNCollection> NnWriteHandle{m_writeKey};
+    if (NnWriteHandle.isValid()) {
+      ATH_MSG_DEBUG("Write CondHandle "<< NnWriteHandle.fullKey() << " is already valid");
+      return StatusCode::SUCCESS;
+    }
+
+    SG::ReadCondHandle<CondAttrListCollection> readHandle{m_readKey};
+    if(!readHandle.isValid()) {
+      ATH_MSG_ERROR("Invalid read handle " << m_readKey.key());
+      return StatusCode::FAILURE;
+    }
+    const CondAttrListCollection* atrcol{*readHandle};
+    assert( atrcol != nullptr);
+
+    // So now we have the string containing the json. Access it.
+    // Retrieve channel 0 (only channel there is)
+    const coral::AttributeList& attrList=atrcol->attributeList(0);
+
+    // Check that it is filled as expected
+    if ((attrList["NNConfigurations"]).isNull()) {
+      ATH_MSG_ERROR( "NNConfigurations is NULL !" );
+      return StatusCode::FAILURE;
+    }
+
+    // Retrieve the string
+    // This is for a single LOB when it is all a giant block
+    const std::string megajson = attrList["NNConfigurations"].data<cool::String16M>();
+
+    // Parse the large json to extract the individual configurations for the NNs
+    std::istringstream initializerStream(megajson);
+    namespace pt = boost::property_tree;    
+    pt::ptree parentTree;
+    pt::read_json(initializerStream, parentTree);
+    std::ostringstream configStream;
+
+    // This is for handling IOVs
+    EventIDRange cdo_iov;
+    if(!readHandle.range(cdo_iov)) {
+      ATH_MSG_ERROR("Failed to get valid validity range from  " << readHandle.key());
+      return StatusCode::FAILURE;
+    }
+
+    // Here I create a pointer to the object I want to write
+    // And what I want to write is the map with my lwtnn networks.
+    std::unique_ptr<LWTNNCollection> writeCdo{std::make_unique<LWTNNCollection>()};
+
+    // First, extract configuration for the number network.
+    pt::ptree subtreeNumberNetwork = parentTree.get_child("NumberNetwork");
+    writeCdo->insert(std::make_pair(0,std::unique_ptr<lwt::LightweightGraph>(nullptr)));
+    // If this json is empty, just fill a null pointer.
+    if(subtreeNumberNetwork.empty()) {
+      ATH_MSG_INFO("Not using lwtnn for number network.");
+    }
+    // Otherwise, set up lwtnn.
+    else {      
+      ATH_MSG_INFO("Setting up lwtnn for number network...");
+      pt::write_json(configStream, subtreeNumberNetwork);
+      std::string numberNetworkConfig = configStream.str();
+      if ((configureLwtnn(writeCdo->at(0), numberNetworkConfig)).isFailure())
+        return StatusCode::FAILURE;     
+    }
+
+    // Now extract configuration for each position network.
+    // For simplicity, we'll require all three configurations
+    // in order to use lwtnn for positions.
+    for (int i=1; i<4; i++) {
+      const std::string key = "PositionNetwork_N"+std::to_string(i);
+      configStream.str("");
+      pt::ptree subtreePosNetwork = parentTree.get_child(key);
+      pt::write_json(configStream, subtreePosNetwork);
+      std::string posNetworkConfig = configStream.str();
+      
+      // Put a lwt network into the map
+      writeCdo->insert(std::make_pair(i,std::unique_ptr<lwt::LightweightGraph>(nullptr)));
+
+      // Now do empty check: if any one of these is empty we won't use lwtnn
+      if(subtreePosNetwork.empty()) {
+        ATH_MSG_INFO("Not using lwtnn for position networks.");
+      } else {
+        // Otherwise, set up lwtnn
+        ATH_MSG_INFO("Setting up lwtnn for n = " << i << " position network...");
+        if ((configureLwtnn(writeCdo->at(i), posNetworkConfig)).isFailure())
+          return StatusCode::FAILURE;
+      }
+
+    }
+    
+    // Write the networks to the store
+
+    if(NnWriteHandle.record(cdo_iov,std::move(writeCdo)).isFailure()) {
+      ATH_MSG_ERROR("Failed to record Trained network collection to " 
+                    << NnWriteHandle.key()
+                    << " with IOV " << cdo_iov );
+      return StatusCode::FAILURE;
+    }
+
+    return StatusCode::SUCCESS;
+  }
+
+}
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h
new file mode 100644
index 00000000000..23a94d535fa
--- /dev/null
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/LWTNNCondAlg.h
@@ -0,0 +1,85 @@
+/*
+  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
+*/
+
+#ifndef _InDet_LWTNNCondAlg_H_
+#define _InDet_LWTNNCondAlg_H_
+
+#include "AthenaBaseComps/AthAlgorithm.h"
+#include "StoreGate/ReadCondHandleKey.h"
+#include "StoreGate/WriteCondHandleKey.h"
+
+#include "GaudiKernel/ICondSvc.h"
+#include "PoolSvc/IPoolSvc.h"
+
+//#include "TrkNeuralNetworkUtils/NeuralNetworkToHistoTool.h"
+#include "SiClusterizationTool/LWTNNCollection.h"
+#include "AthenaPoolUtilities/CondAttrListCollection.h"
+#include <string>
+
+class IPoolSvc;
+
+namespace lwt {
+  class NanReplacer;    
+  class LightweightGraph;
+}
+
+namespace InDet {
+
+  /**
+  */
+class LWTNNCondAlg : public AthAlgorithm {
+
+ public:
+
+  LWTNNCondAlg (const std::string& name, ISvcLocator* pSvcLocator);
+  ~LWTNNCondAlg() = default;
+
+  StatusCode initialize();
+  StatusCode execute();
+  StatusCode finalize();
+
+ private:
+//  TTrainedNetwork* retrieveNetwork(TFile &input_file, const std::string& folder) const;
+
+  ServiceHandle<ICondSvc> m_condSvc
+    {this, "CondSvc", "CondSvc", "The conditions service to register new conditions data."};
+/*  ServiceHandle<IPoolSvc> m_poolsvc
+    {this, "PoolSvc", "PoolSvc", "The service to retrieve files by GUID."};
+  ToolHandle<Trk::NeuralNetworkToHistoTool> m_networkToHistoTool
+    {this,"NetworkToHistoTool", "Trk::NeuralNetworkToHistoTool/NeuralNetworkToHistoTool", "Tool to create a neural network from a set of histograms." };
+*/
+  StatusCode configureLwtnn(std::unique_ptr<lwt::LightweightGraph> & thisNN, std::string thisJson);
+
+  SG::ReadCondHandleKey<CondAttrListCollection> m_readKey
+    {this, "ReadKey", "/PIXEL/PixelClustering/PixelNNCalibJSON", "Cool folder name for the cluster NN input histogram file."};
+
+  SG::WriteCondHandleKey<LWTNNCollection> m_writeKey
+    {this, "WriteKey", "PixelClusterNNJSON", "The conditions statore key for the pixel cluster NNs"};
+
+  Gaudi::Property< std::vector<std::string> > m_nnOrder
+    {this, "NetworkNames", {
+          "NumberNetwork",
+          "PositionNetwork_N1",
+          "PositionNetwork_N2",
+          "PositionNetwork_N3"},
+        "List of network names, which are indexe in map in this order"};
+/*
+  Gaudi::Property<std::string> m_layerInfoHistogram
+  {this, "LayerInfoHistogram",      "LayersInfo","Name about the layer info histogram."};
+
+  Gaudi::Property<std::string> m_layerPrefix
+  {this, "LayerPrefix",             "Layer",     "Prefix of the pre layer weight and threshold histograms."}; 
+
+  Gaudi::Property<std::string> m_weightIndicator
+  {this, "LayerWeightIndicator",    "_weights",  "Suffix of the weight histograms."};
+
+  Gaudi::Property<std::string> m_thresholdIndicator
+  {this, "LayerThresholdIndicator", "_thresholds","Suffix of the threshold histograms."};
+
+  Gaudi::Property<bool> m_getInputsInfo
+  {this, "GetInputsInfo", false,"Also read a histogram which contains information about the inputs (Run I)."};
+  */
+};
+}
+#endif
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
index e847f6f1ad5..28feb5fa685 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
@@ -156,6 +156,7 @@ namespace InDet {
 
     ATH_CHECK( m_readKeyWithoutTrack.initialize( !m_readKeyWithoutTrack.key().empty() ) );
     ATH_CHECK( m_readKeyWithTrack.initialize( !m_readKeyWithTrack.key().empty() ) );
+    ATH_CHECK( m_readKeyJSON.initialize( !m_readKeyJSON.key().empty() ) );
 
     return StatusCode::SUCCESS;
   }
@@ -255,6 +256,43 @@ namespace InDet {
     return inputData;
   }
 
+  NnClusterizationFactory::InputMap NnClusterizationFactory::flattenInput(NNinput & input) const
+  {
+
+    // Format for use with lwtnn
+    std::map<std::string, std::map<std::string, double> > flattened;
+
+    // Fill it!
+    // Variable names here need to match the ones in the configuration.    
+
+    std::map<std::string, double> simpleInputs;
+    for (unsigned int x = 0; x < input.matrixOfToT.size(); x++) {
+      for (unsigned int y = 0; y < input.matrixOfToT.at(0).size(); y++) {
+        unsigned int index = x*input.matrixOfToT.at(0).size()+y;
+        std::string varname = "NN_matrix"+std::to_string(index);
+        simpleInputs[varname] = input.matrixOfToT.at(x).at(y);
+      }
+    }
+
+    for (unsigned int p = 0; p < input.vectorOfPitchesY.size(); p++) {
+      std::string varname = "NN_pitches" + std::to_string(p);
+      simpleInputs[varname] = input.vectorOfPitchesY.at(p);
+    }
+
+    simpleInputs["NN_layer"] = input.ClusterPixLayer;
+    simpleInputs["NN_barrelEC"] = input.ClusterPixBarrelEC;
+    simpleInputs["NN_phi"] = input.phi;
+    simpleInputs["NN_theta"] = input.theta;
+
+    if (input.useTrackInfo) simpleInputs["NN_etaModule"] = input.etaModule;
+
+    // We have only one node for now, so we just store things there.
+    flattened["NNinputs"] = simpleInputs;
+
+    return flattened;
+
+
+  }
 
   std::vector<double> NnClusterizationFactory::estimateNumberOfParticles(const InDet::PixelCluster& pCluster,
                                                                          Amg::Vector3D & beamSpotPosition,
@@ -339,12 +377,6 @@ namespace InDet {
                                                                              int sizeY) const
   {
 
-    SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
-    if (!nn_collection.isValid()) {
-      std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key();
-      throw std::runtime_error( msg.str() );
-    }
-
     ATH_MSG_VERBOSE(" Starting to estimate positions...");
 
     double tanl=0;
@@ -360,14 +392,23 @@ namespace InDet {
       return std::vector<Amg::Vector2D>();
     }
 
-
     std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
 
+    // If using old TTrainedNetworks, fetch correct ones for the
+    // without-track situation and call them now.
+    if (m_useTTrainedNetworks) {
+      SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
+      if (!nn_collection.isValid()) {
+        std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key();
+        throw std::runtime_error( msg.str() );
+      }
+      // *(ReadCondHandle<>) returns a pointer rather than a reference ...
+      return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
+    }
 
-    // *(ReadCondHandle<>) returns a pointer rather than a reference ...
-    return estimatePositions(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
-
-
+    // Otherwise, prepare lwtnn input map and use new networks.
+    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
+    return estimatePositionsLWTNN(nnInputData,input,pCluster,numberSubClusters,errors);
 
   }
 
@@ -380,12 +421,8 @@ namespace InDet {
                                                                         int sizeX,
                                                                         int sizeY) const
   {
-    SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
-    if (!nn_collection.isValid()) {
-      std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key();
-      throw std::runtime_error( msg.str() );
-    }
 
+    ATH_MSG_VERBOSE(" Starting to estimate positions...");
 
     Amg::Vector3D dummyBS(0,0,0);
 
@@ -402,24 +439,39 @@ namespace InDet {
        return std::vector<Amg::Vector2D>();
     }
 
-
     addTrackInfoToInput(input,pixelSurface,trackParsAtSurface,tanl);
 
     std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
 
+    // If using old TTrainedNetworks, fetch correct ones for the
+    // without-track situation and call them now.
+    if (m_useTTrainedNetworks) {
+      SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
+      if (!nn_collection.isValid()) {
+        std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key();
+        throw std::runtime_error( msg.str() );
+      }
+
+      return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
+    }
+
+    // Otherwise, prepare lwtnn input map and use new networks.
+    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
+    return estimatePositionsLWTNN(nnInputData,input,pCluster,numberSubClusters,errors);
 
-    return estimatePositions(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
   }
 
-  std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositions(const TTrainedNetworkCollection &nn_collection,
-                                                                        std::vector<double> inputData,
-                                                                        const NNinput& input,
-                                                                        const InDet::PixelCluster& pCluster,
-                                                                        int sizeX,
-                                                                        int sizeY,
-                                                                        int numberSubClusters,
-                                                                        std::vector<Amg::MatrixX> & errors) const
+  std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositionsTTN(
+                                                const TTrainedNetworkCollection &nn_collection,
+                                                std::vector<double> inputData,
+                                                const NNinput& input,
+                                                const InDet::PixelCluster& pCluster,
+                                                int sizeX,
+                                                int sizeY,
+                                                int numberSubClusters,
+                                                std::vector<Amg::MatrixX> & errors) const
   {
+
     bool applyRecentering=(!input.useTrackInfo && m_useRecenteringNNWithouTracks)  || (input.useTrackInfo && m_useRecenteringNNWithTracks);
 
     std::vector<Amg::Vector2D> allPositions;
@@ -432,8 +484,8 @@ namespace InDet {
 
       assert( position1P.size() % 2 == 0);
       for (unsigned int i=0; i<position1P.size()/2 ; ++i) {
-        ATH_MSG_VERBOSE(" RAW Estimated positions (" << i << ") x: " << back_posX(position1P[0+i*2],applyRecentering) << " y: " << back_posY(position1P[1+i*2]));
-        ATH_MSG_VERBOSE(" Estimated myPositions ("   << i << ") x: " << myPosition1[i][Trk::locX] << " y: " << myPosition1[i][Trk::locY]);
+        ATH_MSG_DEBUG(" Original RAW Estimated positions (" << i << ") x: " << back_posX(position1P[0+i*2],applyRecentering) << " y: " << back_posY(position1P[1+i*2]));
+        ATH_MSG_DEBUG(" Original estimated myPositions ("   << i << ") x: " << myPosition1[i][Trk::locX] << " y: " << myPosition1[i][Trk::locY]);
       }
 
       std::vector<double> inputDataNew=inputData;
@@ -462,7 +514,99 @@ namespace InDet {
     return allPositions;
   }
 
+  std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositionsLWTNN(
+                                                                NnClusterizationFactory::InputMap & input, 
+                                                                NNinput& rawInput,
+                                                                const InDet::PixelCluster& pCluster,
+                                                                int numberSubClusters,
+                                                                std::vector<Amg::MatrixX> & errors) const 
+    {
+    
+    SG::ReadCondHandle<LWTNNCollection> lwtnn_collection(m_readKeyJSON) ;
+    if (!lwtnn_collection.isValid()) {
+      std::stringstream msg; msg << "Failed to get LWTNN network collection with key " << m_readKeyJSON.key();
+      throw std::runtime_error( msg.str() );      
+    }
+
+    // Need to evaluate the correct network once per cluster we're interested in.
+    // Save the output
+    std::vector<double> positionValues;
+    std::vector<Amg::MatrixX> errorMatrices;
+    for (int cluster = 1; cluster < numberSubClusters+1; cluster++) {
+
+      // Check that the network is defined. 
+      // If not, we are outside an IOV and should fail
+      if (not lwtnn_collection->at(numberSubClusters)) {
+        std::stringstream msg; msg << "No lwtnn network configured for this run! If you are outside the valid range for lwtnn-based configuration, plesae run with useNNTTrainedNetworks instead." << m_readKeyJSON.key();
+        throw std::runtime_error( msg.str() );
+      }
+
+      std::string outNodeName = "merge_"+std::to_string(cluster);
+      std::map<std::string, double> position = lwtnn_collection->at(numberSubClusters)->compute(input, {},outNodeName);
+
+      ATH_MSG_DEBUG("Testing for numberSubClusters " << numberSubClusters << " and cluster " << cluster);
+      for (auto item : position) {
+        ATH_MSG_DEBUG(item.first << ": " << item.second);
+      }
+      positionValues.push_back(position["mean_x"]);
+      positionValues.push_back(position["mean_y"]);
+
+      // Fill errors.
+      // Values returned by NN are inverse of variance, and we want variances.
+      float rawRmsX = sqrt(1.0/position["prec_x"]);
+      float rawRmsY = sqrt(1.0/position["prec_y"]);
+      // Now convert to real space units
+      double rmsX = correctedRMSX(rawRmsX);
+      double rmsY = correctedRMSY(rawRmsY, 7., rawInput.vectorOfPitchesY);
+      ATH_MSG_DEBUG(" Estimated RMS errors (1) x: " << rmsX << ", y: " << rmsY);  
+
+      // Fill matrix    
+      Amg::MatrixX erm(2,2);
+      erm.setZero();
+      erm(0,0)=rmsX*rmsX;
+      erm(1,1)=rmsY*rmsY;
+      errorMatrices.push_back(erm); 
+
+    }
+
+    std::vector<Amg::Vector2D> myPositions = getPositionsFromOutput(positionValues,rawInput,pCluster);
+    ATH_MSG_DEBUG(" Estimated myPositions (1) x: " << myPositions[0][Trk::locX] << " y: " << myPositions[0][Trk::locY]);
+    
+    for (unsigned int index = 0; index < errorMatrices.size(); index++) errors.push_back(errorMatrices.at(index));
+
+    return myPositions;
+
+  }
+
+  double NnClusterizationFactory::correctedRMSX(double posPixels) const
+  {
+
+    // This gives location in pixels
+    double pitch = 0.05;
+    double corrected = posPixels * pitch;
+
+    return corrected;
+  }
+
+  double NnClusterizationFactory::correctedRMSY(double posPixels,
+         double sizeY,
+        std::vector<float>& pitches) const
+  {
+    double p = posPixels + (sizeY - 1) / 2.0;
+    double p_Y = -100;
+    double p_center = -100;
+    double p_actual = 0;
+
+    for (int i = 0; i < sizeY; i++) {
+      if (p >= i && p <= (i + 1))
+        p_Y = p_actual + (p - i + 0.5) * pitches.at(i);
+      if (i == (sizeY - 1) / 2)
+        p_center = p_actual + 0.5 * pitches.at(i);
+      p_actual += pitches.at(i);
+    }
 
+    return abs(p_Y - p_center);
+  }  
 
   void NnClusterizationFactory::getErrorMatrixFromOutput(std::vector<double>& outputX,
                                                          std::vector<double>& outputY,
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/components/SiClusterizationTool_entries.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/components/SiClusterizationTool_entries.cxx
index 19e4588a84f..c6f4dc085f7 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/components/SiClusterizationTool_entries.cxx
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/components/SiClusterizationTool_entries.cxx
@@ -10,7 +10,7 @@
 #include "SiClusterizationTool/TruthPixelClusterSplitter.h"
 #include "SiClusterizationTool/TruthClusterizationFactory.h"
 #include "SiClusterizationTool/TruthPixelClusterSplitProbTool.h"
-
+#include "../LWTNNCondAlg.h"
 #include "../TTrainedNetworkCondAlg.h"
 
 DECLARE_COMPONENT( InDet::MergedPixelsTool )
@@ -21,6 +21,7 @@ DECLARE_COMPONENT( InDet::TotPixelClusterSplitter )
 DECLARE_COMPONENT( InDet::NnPixelClusterSplitter )
 DECLARE_COMPONENT( InDet::NnClusterizationFactory )
 DECLARE_COMPONENT( InDet::TTrainedNetworkCondAlg )
+DECLARE_COMPONENT( InDet::LWTNNCondAlg )
 DECLARE_COMPONENT( InDet::NnPixelClusterSplitProbTool )
 DECLARE_COMPONENT( InDet::TruthPixelClusterSplitter )
 DECLARE_COMPONENT( InDet::TruthClusterizationFactory )
-- 
GitLab