diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h index 1dc95e4df2bfe756de81b78a91523e081ccaea2d..2a84f885ef5f658a3da0cc6bbf41f31a2a62c812 100644 --- a/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h +++ b/InnerDetector/InDetRecTools/SiClusterizationTool/SiClusterizationTool/NnClusterizationFactory.h @@ -134,6 +134,15 @@ namespace InDet { // Handling lwtnn inputs typedef std::map<std::string, std::map<std::string, double> > InputMap; + /* Estimate number of particles for both with and w/o tracks */ + /* Method 1: using older TTrainedNetworks */ + std::vector<double> estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection, + std::vector<double> inputData) const; + + /* Estimate number of particles for both with and w/o tracks */ + /* Method 2: using lwtnn for more flexible interfacing */ + std::vector<double> estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const; + /* Estimate position for both with and w/o tracks */ /* Method 1: using older TTrainedNetworks */ std::vector<Amg::Vector2D> estimatePositionsTTN( diff --git a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx index 529e02c1b0eb58b3dc5d5f65e0e72c0ed1adb10b..716f562270135c53d979161dfe6a907a3e56bf0c 100644 --- a/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx +++ b/InnerDetector/InDetRecTools/SiClusterizationTool/src/NnClusterizationFactory.cxx @@ -94,9 +94,9 @@ namespace InDet { else { if (m_nParticleGroup[network_i]>0) { if (m_nParticleGroup[network_i]>=match_result.size()) { - std::stringstream msg; msg << "Regex and match group of particle multiplicity do not coincide (groups=" << match_result.size() << " n particle group=" << m_nParticleGroup[network_i] - << "; type=" << network_i << ")"; - throw std::logic_error(msg.str()); + ATH_MSG_ERROR("Regex and match group of particle multiplicity do not coincide (groups=" << match_result.size() + << " n particle group=" << m_nParticleGroup[network_i] + << "; type=" << network_i << ")"); } int n_particles=atoi( match_result[m_nParticleGroup[network_i]].str().c_str()); if (n_particles<=0 || static_cast<unsigned int>(n_particles)>m_maxSubClusters) { @@ -299,11 +299,6 @@ namespace InDet { double tanl=0; - SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack ); - if (!nn_collection.isValid()) { - std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key(); - throw std::runtime_error(msg.str() ); - } NNinput input( createInput(pCluster, beamSpotPosition, tanl, @@ -317,14 +312,21 @@ namespace InDet { std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY); - // dereference unique_ptr<TTrainedNetwork> then call calculateOutput : - std::vector<double> resultNN_NoTrack( ((*(nn_collection->at(m_nParticleNNId))).*m_calculateOutput)(inputData) ); - ATH_MSG_VERBOSE(" NOTRACK Prob of n. particles (1): " << resultNN_NoTrack[0] << - " (2): " << resultNN_NoTrack[1] << - " (3): " << resultNN_NoTrack[2]); + // If using old TTrainedNetworks, fetch correct ones for the + // without-track situation and call them now. + if (m_useTTrainedNetworks) { + SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack ); + if (!nn_collection.isValid()) { + ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() ); + } + return estimateNumberOfParticlesTTN(**nn_collection, inputData); + } + + // Otherwise, prepare lwtnn input map and use new networks. + NnClusterizationFactory::InputMap nnInputData = flattenInput(input); + return estimateNumberOfParticlesLWTNN(nnInputData); - return resultNN_NoTrack; } std::vector<double> NnClusterizationFactory::estimateNumberOfParticles(const InDet::PixelCluster& pCluster, @@ -334,12 +336,6 @@ namespace InDet { int sizeY) const { - SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack ); - if (!nn_collection.isValid()) { - std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key(); - throw std::runtime_error(msg.str() ); - } - Amg::Vector3D dummyBS(0,0,0); double tanl=0; @@ -358,12 +354,60 @@ namespace InDet { addTrackInfoToInput(input,pixelSurface,trackParsAtSurface,tanl); std::vector<double> inputData=(this->*m_assembleInput)(input,sizeX,sizeY); + + // If using old TTrainedNetworks, fetch correct ones for the + // with-track situation and call them now. + if (m_useTTrainedNetworks) { + SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack ); + if (!nn_collection.isValid()) { + ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() ); + } + return estimateNumberOfParticlesTTN(**nn_collection, inputData); + } + + // Otherwise, prepare lwtnn input map and use new networks. + NnClusterizationFactory::InputMap nnInputData = flattenInput(input); + return estimateNumberOfParticlesLWTNN(nnInputData); + + } + + std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection, + std::vector<double> inputData) const + { + ATH_MSG_DEBUG("Using TTN number network"); // dereference unique_ptr<TTrainedNetwork> then call calculateOutput : - std::vector<double> resultNN( ( ( *(nn_collection->at(m_nParticleNNId))).*m_calculateOutput)(inputData) ); + std::vector<double> resultNN_TTN( ((*(nn_collection.at(m_nParticleNNId))).*m_calculateOutput)(inputData) ); - ATH_MSG_VERBOSE(" Prob of n. particles (1): " << resultNN[0] << " (2): " << resultNN[1] << " (3): " << resultNN[2]); + ATH_MSG_VERBOSE(" TTN Prob of n. particles (1): " << resultNN_TTN[0] << + " (2): " << resultNN_TTN[1] << + " (3): " << resultNN_TTN[2]); + + return resultNN_TTN; + } - return resultNN; + std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const + { + SG::ReadCondHandle<LWTNNCollection> lwtnn_collection(m_readKeyJSON) ; + if (!lwtnn_collection.isValid()) { + ATH_MSG_ERROR( "Failed to get LWTNN network collection with key " << m_readKeyJSON.key() ); + } + ATH_MSG_DEBUG("Using lwtnn number network"); + // Evaluate the number network once per cluster + lwt::ValueMap discriminant = lwtnn_collection->at(0)->compute(input); + double num0 = discriminant["output_number0"]; + double num1 = discriminant["output_number1"]; + double num2 = discriminant["output_number2"]; + // Get normalized predictions + double prob1 = num0/(num0+num1+num2); + double prob2 = num1/(num0+num1+num2); + double prob3 = num2/(num0+num1+num2); + std::vector<double> number_probabilities{prob1, prob2, prob3}; + + ATH_MSG_VERBOSE(" LWTNN Prob of n. particles (1): " << number_probabilities[0] << + " (2): " << number_probabilities[1] << + " (3): " << number_probabilities[2]); + + return number_probabilities; } std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositions(const InDet::PixelCluster& pCluster, @@ -396,8 +440,7 @@ namespace InDet { if (m_useTTrainedNetworks) { SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack ); if (!nn_collection.isValid()) { - std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key(); - throw std::runtime_error( msg.str() ); + ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key() ); } // *(ReadCondHandle<>) returns a pointer rather than a reference ... return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors); @@ -445,8 +488,7 @@ namespace InDet { if (m_useTTrainedNetworks) { SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack ); if (!nn_collection.isValid()) { - std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key(); - throw std::runtime_error( msg.str() ); + ATH_MSG_ERROR( "Failed to get trained network collection with key " << m_readKeyWithTrack.key() ); } return estimatePositionsTTN(**nn_collection, inputData,input,pCluster,sizeX,sizeY,numberSubClusters,errors); @@ -521,8 +563,7 @@ namespace InDet { SG::ReadCondHandle<LWTNNCollection> lwtnn_collection(m_readKeyJSON) ; if (!lwtnn_collection.isValid()) { - std::stringstream msg; msg << "Failed to get LWTNN network collection with key " << m_readKeyJSON.key(); - throw std::runtime_error( msg.str() ); + ATH_MSG_ERROR( "Failed to get LWTNN network collection with key " << m_readKeyJSON.key() ); } // Need to evaluate the correct network once per cluster we're interested in. @@ -536,8 +577,7 @@ namespace InDet { // Check that the network is defined. // If not, we are outside an IOV and should fail if (not lwtnn_collection->at(numberSubClusters)) { - std::stringstream msg; msg << "No lwtnn network configured for this run! If you are outside the valid range for lwtnn-based configuration, plesae run with useNNTTrainedNetworks instead." << m_readKeyJSON.key(); - throw std::runtime_error( msg.str() ); + ATH_MSG_ERROR( "No lwtnn network configured for this run! If you are outside the valid range for lwtnn-based configuration, plesae run with useNNTTrainedNetworks instead." << m_readKeyJSON.key() ); } std::string outNodeName = "merge_"+std::to_string(cluster);