Skip to content
Snippets Groups Projects
Commit 116b75be authored by Sebastien Rettie's avatar Sebastien Rettie
Browse files

First pass at lwtnn implementation of number network; only modularize old...

First pass at lwtnn implementation of number network; only modularize old function and run a small sanity check.
parent 841c4ce8
6 merge requests!58791DataQualityConfigurations: Modify L1Calo config for web display,!46784MuonCondInterface: Enable thread-safety checking.,!46776Updated LArMonitoring config file for WD to match new files produced using MT,!45405updated ART test cron job,!42417Draft: DIRE and VINCIA Base Fragments for Pythia 8.3,!36851Add pixel clustering number network lwtnn implementation in r22
......@@ -134,6 +134,15 @@ namespace InDet {
// Handling lwtnn inputs
typedef std::map<std::string, std::map<std::string, double> > InputMap;
/* Estimate number of particles for both with and w/o tracks */
/* Method 1: using older TTrainedNetworks */
std::vector<double> estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection,
std::vector<double> inputData) const;
/* Estimate number of particles for both with and w/o tracks */
/* Method 2: using lwtnn for more flexible interfacing */
//std::vector<double> estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input) const;
/* Estimate position for both with and w/o tracks */
/* Method 1: using older TTrainedNetworks */
std::vector<Amg::Vector2D> estimatePositionsTTN(
......
......@@ -302,6 +302,7 @@ namespace InDet {
double tanl=0;
// Move this to if (m_useTTrainedNetworks) clause when validated
SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithoutTrack );
if (!nn_collection.isValid()) {
std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithoutTrack.key();
......@@ -327,6 +328,20 @@ namespace InDet {
" (2): " << resultNN_NoTrack[1] <<
" (3): " << resultNN_NoTrack[2]);
std::vector<double> resultNN_NoTrack_sanity = estimateNumberOfParticlesTTN(**nn_collection, inputData);
std::cout<<"yoyoma_p1_orig,"<<resultNN_NoTrack[0]<<std::endl;
std::cout<<"yoyoma_p2_orig,"<<resultNN_NoTrack[1]<<std::endl;
std::cout<<"yoyoma_p3_orig,"<<resultNN_NoTrack[2]<<std::endl;
std::cout<<"yoyoma_p1_sanity,"<<resultNN_NoTrack_sanity[0]<<std::endl;
std::cout<<"yoyoma_p2_sanity,"<<resultNN_NoTrack_sanity[1]<<std::endl;
std::cout<<"yoyoma_p3_sanity,"<<resultNN_NoTrack_sanity[2]<<std::endl;
std::cout<<"yoyoma_p1_sanitydiff,"<<resultNN_NoTrack[0] - resultNN_NoTrack_sanity[0]<<std::endl;
std::cout<<"yoyoma_p2_sanitydiff,"<<resultNN_NoTrack[1] - resultNN_NoTrack_sanity[1]<<std::endl;
std::cout<<"yoyoma_p3_sanitydiff,"<<resultNN_NoTrack[2] - resultNN_NoTrack_sanity[2]<<std::endl;
return resultNN_NoTrack;
}
......@@ -337,6 +352,7 @@ namespace InDet {
int sizeY) const
{
// Move this to if (m_useTTrainedNetworks) clause when validated
SG::ReadCondHandle<TTrainedNetworkCollection> nn_collection( m_readKeyWithTrack );
if (!nn_collection.isValid()) {
std::stringstream msg; msg << "Failed to get trained network collection with key " << m_readKeyWithTrack.key();
......@@ -366,9 +382,48 @@ namespace InDet {
ATH_MSG_VERBOSE(" Prob of n. particles (1): " << resultNN[0] << " (2): " << resultNN[1] << " (3): " << resultNN[2]);
std::vector<double> resultNN_sanity = estimateNumberOfParticlesTTN(**nn_collection, inputData);
std::cout<<"yoyoma_p1_orig,"<<resultNN[0]<<std::endl;
std::cout<<"yoyoma_p2_orig,"<<resultNN[1]<<std::endl;
std::cout<<"yoyoma_p3_orig,"<<resultNN[2]<<std::endl;
std::cout<<"yoyoma_p1_sanity,"<<resultNN_sanity[0]<<std::endl;
std::cout<<"yoyoma_p2_sanity,"<<resultNN_sanity[1]<<std::endl;
std::cout<<"yoyoma_p3_sanity,"<<resultNN_sanity[2]<<std::endl;
std::cout<<"yoyoma_p1_sanitydiff,"<<resultNN[0] - resultNN_sanity[0]<<std::endl;
std::cout<<"yoyoma_p2_sanitydiff,"<<resultNN[1] - resultNN_sanity[1]<<std::endl;
std::cout<<"yoyoma_p3_sanitydiff,"<<resultNN[2] - resultNN_sanity[2]<<std::endl;
return resultNN;
}
std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesTTN(const TTrainedNetworkCollection &nn_collection,
std::vector<double> inputData) const
{
ATH_MSG_DEBUG("Using TTN number network");
// dereference unique_ptr<TTrainedNetwork> then call calculateOutput :
std::vector<double> resultNN_TTN( ((*(nn_collection.at(m_nParticleNNId))).*m_calculateOutput)(inputData) );
return resultNN_TTN;
}
// std::vector<double> NnClusterizationFactory::estimateNumberOfParticlesLWTNN(NnClusterizationFactory::InputMap & input)
// {
// ATH_MSG_DEBUG("Using lwtnn number network");
// // Evaluate the number network once per cluster
// lwt::ValueMap discriminant = m_lwnnNumber->compute(input);
// double num0 = discriminant["output_number0"];
// double num1 = discriminant["output_number1"];
// double num2 = discriminant["output_number2"];
// // Get normalized predictions
// double prob1 = num0/(num0+num1+num2);
// double prob2 = num1/(num0+num1+num2);
// double prob3 = num2/(num0+num1+num2);
// std::vector<double> number_probabilities{prob1, prob2, prob3};
// return number_probabilities;
// }
std::vector<Amg::Vector2D> NnClusterizationFactory::estimatePositions(const InDet::PixelCluster& pCluster,
Amg::Vector3D & beamSpotPosition,
std::vector<Amg::MatrixX> & errors,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment