From 116b75be6db21edf6d35a52187f94335225fb6a4 Mon Sep 17 00:00:00 2001
From: Sebastien Rettie <srettie@lxplus700.cern.ch>
Date: Tue, 29 Sep 2020 11:22:38 +0200
Subject: [PATCH] First pass at lwtnn implementation of number network; only
 modularize old function and run a small sanity check.

---
 .../NnClusterizationFactory.h                 |  9 +++
 .../src/NnClusterizationFactory.cxx           | 55 +++++++++++++++++++
 2 files changed, 64 insertions(+)

diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
index 01039990556a..90ee9b29a9fc 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h
@@ -134,6 +134,15 @@ namespace InDet {
     // Handling lwtnn inputs
     typedef std::map<std::string, std::map<std::string, double> > InputMap;
 
+    /* Estimate number of particles for both with and w/o tracks */
+    /* Method 1: using older TTrainedNetworks */
+    std::vector<double> estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection,
+                                                     std::vector<double> inputData) const;
+
+    /* Estimate number of particles for both with and w/o tracks */
+    /* Method 2: using lwtnn for more flexible interfacing */
+    //std::vector<double> estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const;
+
     /* Estimate position for both with and w/o tracks */
     /* Method 1: using older TTrainedNetworks */
     std::vector<Amg::Vector2D> estimatePositionsTTN(
diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
index 28feb5fa6854..5eef23f8e37d 100644
--- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
+++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx
@@ -302,6 +302,7 @@ namespace InDet {
 
     double tanl=0;
 
+    // Move this to if (m_useTTrainedNetworks) clause when validated
     SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
     if (!nn_collection.isValid()) {
       std::stringstream msg; msg  << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key();
@@ -327,6 +328,20 @@ namespace InDet {
                     " (2): " << resultNN_NoTrack[1] <<
                     " (3): " << resultNN_NoTrack[2]);
 
+    std::vector<double> resultNN_NoTrack_sanity = estimateNumberOfParticlesTTN(**nn_collection, inputData);
+
+    std::cout<<"yoyoma_p1_orig,"<<resultNN_NoTrack[0]<<std::endl;
+    std::cout<<"yoyoma_p2_orig,"<<resultNN_NoTrack[1]<<std::endl;
+    std::cout<<"yoyoma_p3_orig,"<<resultNN_NoTrack[2]<<std::endl;
+
+    std::cout<<"yoyoma_p1_sanity,"<<resultNN_NoTrack_sanity[0]<<std::endl;
+    std::cout<<"yoyoma_p2_sanity,"<<resultNN_NoTrack_sanity[1]<<std::endl;
+    std::cout<<"yoyoma_p3_sanity,"<<resultNN_NoTrack_sanity[2]<<std::endl;
+
+    std::cout<<"yoyoma_p1_sanitydiff,"<<resultNN_NoTrack[0] - resultNN_NoTrack_sanity[0]<<std::endl;
+    std::cout<<"yoyoma_p2_sanitydiff,"<<resultNN_NoTrack[1] - resultNN_NoTrack_sanity[1]<<std::endl;
+    std::cout<<"yoyoma_p3_sanitydiff,"<<resultNN_NoTrack[2] - resultNN_NoTrack_sanity[2]<<std::endl;
+
     return resultNN_NoTrack;
   }
 
@@ -337,6 +352,7 @@ namespace InDet {
                                                                          int sizeY) const
   {
 
+    // Move this to if (m_useTTrainedNetworks) clause when validated
     SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
     if (!nn_collection.isValid()) {
       std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key();
@@ -366,9 +382,48 @@ namespace InDet {
 
     ATH_MSG_VERBOSE(" Prob of n. particles (1): " << resultNN[0] << " (2): " << resultNN[1] << " (3): " << resultNN[2]);
 
+    std::vector<double> resultNN_sanity = estimateNumberOfParticlesTTN(**nn_collection, inputData);
+
+    std::cout<<"yoyoma_p1_orig,"<<resultNN[0]<<std::endl;
+    std::cout<<"yoyoma_p2_orig,"<<resultNN[1]<<std::endl;
+    std::cout<<"yoyoma_p3_orig,"<<resultNN[2]<<std::endl;
+
+    std::cout<<"yoyoma_p1_sanity,"<<resultNN_sanity[0]<<std::endl;
+    std::cout<<"yoyoma_p2_sanity,"<<resultNN_sanity[1]<<std::endl;
+    std::cout<<"yoyoma_p3_sanity,"<<resultNN_sanity[2]<<std::endl;
+
+    std::cout<<"yoyoma_p1_sanitydiff,"<<resultNN[0] - resultNN_sanity[0]<<std::endl;
+    std::cout<<"yoyoma_p2_sanitydiff,"<<resultNN[1] - resultNN_sanity[1]<<std::endl;
+    std::cout<<"yoyoma_p3_sanitydiff,"<<resultNN[2] - resultNN_sanity[2]<<std::endl;
+
     return resultNN;
   }
 
+  std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection,
+                                                                            std::vector<double> inputData) const
+  {
+    ATH_MSG_DEBUG("Using TTN number network");
+    // dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
+    std::vector<double> resultNN_TTN( ((*(nn_collection.at(m_nParticleNNId))).*m_calculateOutput)(inputData) );  
+    return resultNN_TTN;
+  }
+
+  // std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input)
+  // {
+  //   ATH_MSG_DEBUG("Using lwtnn number network");
+  //   // Evaluate the number network once per cluster
+  //   lwt::ValueMap discriminant = m_lwnnNumber->compute(input);
+  //   double num0 = discriminant["output_number0"];
+  //   double num1 = discriminant["output_number1"];
+  //   double num2 = discriminant["output_number2"];
+  //   // Get normalized predictions
+  //   double prob1 = num0/(num0+num1+num2);
+  //   double prob2 = num1/(num0+num1+num2);
+  //   double prob3 = num2/(num0+num1+num2);
+  //   std::vector<double> number_probabilities{prob1, prob2, prob3};
+  //   return number_probabilities;
+  // }
+
   std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositions(const InDet::PixelCluster& pCluster,
                                                                              Amg::Vector3D & beamSpotPosition,
                                                                              std::vector<Amg::MatrixX> & errors,
-- 
GitLab