diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
index 1dc95e4df2bfe756de81b78a91523e081ccaea2d..2a84f885ef5f658a3da0cc6bbf41f31a2a62c812 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
@@ -134,6 +134,15 @@ namespace InDet {
     // Handling lwtnn inputs
     typedef std::map<std::string, std::map<std::string, double> > InputMap;
 
+    /* Estimate number of particles for both with and w/o tracks */
+    /* Method 1: using older TTrainedNetworks */
+    std::vector<double> estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection,
+                                                     std::vector<double> inputData) const;
+
+    /* Estimate number of particles for both with and w/o tracks */
+    /* Method 2: using lwtnn for more flexible interfacing */
+    std::vector<double> estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const;
+
     /* Estimate position for both with and w/o tracks */
     /* Method 1: using older TTrainedNetworks */
     std::vector<Amg::Vector2D> estimatePositionsTTN(
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
index 529e02c1b0eb58b3dc5d5f65e0e72c0ed1adb10b..716f562270135c53d979161dfe6a907a3e56bf0c 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
@@ -94,9 +94,9 @@ namespace InDet {
           else {
             if (m_nParticleGroup[network_i]>0) {
 	      if (m_nParticleGroup[network_i]>=match_result.size()) {
-		      std::stringstream msg; msg << "Regex and match group of particle multiplicity do not coincide (groups=" << match_result.size() << " n particle group=" << m_nParticleGroup[network_i]
-			                        << "; type=" << network_i << ")";
-		      throw std::logic_error(msg.str());
+		ATH_MSG_ERROR("Regex and match group of particle multiplicity do not coincide (groups=" << match_result.size()
+			      << " n particle group=" << m_nParticleGroup[network_i]
+			      << "; type=" << network_i << ")");
 	      }
               int n_particles=atoi( match_result[m_nParticleGroup[network_i]].str().c_str());
               if (n_particles<=0 || static_cast<unsigned int>(n_particles)>m_maxSubClusters) {
@@ -299,11 +299,6 @@ namespace InDet {
 
     double tanl=0;
 
-    SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
-    if (!nn_collection.isValid()) {
-      std::stringstream msg; msg  << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key();
-      throw std::runtime_error(msg.str() );
-    }
     NNinput input( createInput(pCluster,
                                beamSpotPosition,
                                tanl,
@@ -317,14 +312,21 @@ namespace InDet {
 
 
     std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
-    // dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
-    std::vector<double> resultNN_NoTrack( ((*(nn_collection->at(m_nParticleNNId))).*m_calculateOutput)(inputData) );  
 
-    ATH_MSG_VERBOSE(" NOTRACK Prob of n. particles (1): " << resultNN_NoTrack[0] <<
-                    " (2): " << resultNN_NoTrack[1] <<
-                    " (3): " << resultNN_NoTrack[2]);
+    // If using old TTrainedNetworks, fetch correct ones for the
+    // without-track situation and call them now.
+    if (m_useTTrainedNetworks) {
+      SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
+      if (!nn_collection.isValid()) {
+	ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
+      }
+      return estimateNumberOfParticlesTTN(**nn_collection, inputData);
+    }
+
+    // Otherwise, prepare lwtnn input map and use new networks.
+    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
+    return estimateNumberOfParticlesLWTNN(nnInputData);
 
-    return resultNN_NoTrack;
   }
 
   std::vector<double> NnClusterizationFactory::estimateNumberOfParticles(const InDet::PixelCluster& pCluster,
@@ -334,12 +336,6 @@ namespace InDet {
                                                                          int sizeY) const
   {
 
-    SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
-    if (!nn_collection.isValid()) {
-      std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key();
-      throw std::runtime_error(msg.str() );
-    }
-
     Amg::Vector3D dummyBS(0,0,0);
 
     double tanl=0;
@@ -358,12 +354,60 @@ namespace InDet {
     addTrackInfoToInput(input,pixelSurface,trackParsAtSurface,tanl);
 
     std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY);
+
+    // If using old TTrainedNetworks, fetch correct ones for the
+    // with-track situation and call them now.
+    if (m_useTTrainedNetworks) {
+      SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
+      if (!nn_collection.isValid()) {
+	ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
+      }
+      return estimateNumberOfParticlesTTN(**nn_collection, inputData);
+    }
+
+    // Otherwise, prepare lwtnn input map and use new networks.
+    NnClusterizationFactory::InputMap nnInputData = flattenInput(input);
+    return estimateNumberOfParticlesLWTNN(nnInputData);
+
+  }
+
+  std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection,
+                                                                            std::vector<double> inputData) const
+  {
+    ATH_MSG_DEBUG("Using TTN number network");
     // dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
-    std::vector<double> resultNN( ( ( *(nn_collection->at(m_nParticleNNId))).*m_calculateOutput)(inputData) );
+    std::vector<double> resultNN_TTN( ((*(nn_collection.at(m_nParticleNNId))).*m_calculateOutput)(inputData) );
 
-    ATH_MSG_VERBOSE(" Prob of n. particles (1): " << resultNN[0] << " (2): " << resultNN[1] << " (3): " << resultNN[2]);
+    ATH_MSG_VERBOSE(" TTN Prob of n. particles (1): " << resultNN_TTN[0] <<
+                                             " (2): " << resultNN_TTN[1] <<
+                                             " (3): " << resultNN_TTN[2]);
+
+    return resultNN_TTN;
+  }
 
-    return resultNN;
+  std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const
+  {
+    SG::ReadCondHandle<LWTNNCollection> lwtnn_collection(m_readKeyJSON) ;
+    if (!lwtnn_collection.isValid()) {
+      ATH_MSG_ERROR( "Failed to get LWTNN network collection with key " << m_readKeyJSON.key() );
+    }
+    ATH_MSG_DEBUG("Using lwtnn number network");
+    // Evaluate the number network once per cluster
+    lwt::ValueMap discriminant = lwtnn_collection->at(0)->compute(input);
+    double num0 = discriminant["output_number0"];
+    double num1 = discriminant["output_number1"];
+    double num2 = discriminant["output_number2"];
+    // Get normalized predictions
+    double prob1 = num0/(num0+num1+num2);
+    double prob2 = num1/(num0+num1+num2);
+    double prob3 = num2/(num0+num1+num2);
+    std::vector<double> number_probabilities{prob1, prob2, prob3};
+
+    ATH_MSG_VERBOSE(" LWTNN Prob of n. particles (1): " << number_probabilities[0] <<
+                                               " (2): " << number_probabilities[1] <<
+                                               " (3): " << number_probabilities[2]);
+
+    return number_probabilities;
   }
 
   std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositions(const InDet::PixelCluster& pCluster,
@@ -396,8 +440,7 @@ namespace InDet {
     if (m_useTTrainedNetworks) {
       SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
       if (!nn_collection.isValid()) {
-        std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key();
-        throw std::runtime_error( msg.str() );
+	ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() );
       }
       // *(ReadCondHandle<>) returns a pointer rather than a reference ...
       return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
@@ -445,8 +488,7 @@ namespace InDet {
     if (m_useTTrainedNetworks) {
       SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
       if (!nn_collection.isValid()) {
-        std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key();
-        throw std::runtime_error( msg.str() );
+	ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithTrack.key() );
       }
 
       return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors);
@@ -521,8 +563,7 @@ namespace InDet {
     
     SG::ReadCondHandle<LWTNNCollection> lwtnn_collection(m_readKeyJSON) ;
     if (!lwtnn_collection.isValid()) {
-      std::stringstream msg; msg << "Failed to get LWTNN network collection with key " << m_readKeyJSON.key();
-      throw std::runtime_error( msg.str() );      
+      ATH_MSG_ERROR(  "Failed to get LWTNN network collection with key " << m_readKeyJSON.key() );
     }
 
     // Need to evaluate the correct network once per cluster we're interested in.
@@ -536,8 +577,7 @@ namespace InDet {
       // Check that the network is defined. 
       // If not, we are outside an IOV and should fail
       if (not lwtnn_collection->at(numberSubClusters)) {
-        std::stringstream msg; msg << "No lwtnn network configured for this run! If you are outside the valid range for lwtnn-based configuration, plesae run with useNNTTrainedNetworks instead." << m_readKeyJSON.key();
-        throw std::runtime_error( msg.str() );
+	ATH_MSG_ERROR( "No lwtnn network configured for this run! If you are outside the valid range for lwtnn-based configuration, plesae run with useNNTTrainedNetworks instead." << m_readKeyJSON.key() );
       }
 
       std::string outNodeName = "merge_"+std::to_string(cluster);