Skip to content
Snippets Groups Projects

Improve GNN4ITk pipeline

Merged Xiangyang Ju requested to merge xju/athena:mr_walkthrough into main
1 file
+ 1
26
Compare changes
  • Side-by-side
  • Inline
@@ -116,7 +116,6 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
spacepointIDs.push_back(sp_idx++);
}
ATH_MSG_INFO("after filling features: " << eNodeFeatures.size() / numSpacepoints << " " << fNodeFeatures.size() / numSpacepoints << " " << gNodeFeatures.size() / numSpacepoints);
// ************
// Embedding
// ************
@@ -129,12 +128,7 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) );
ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) );
ATH_MSG_INFO("after embedding");
// print the embedding for the first space point.
ATH_MSG_INFO("Embedding for the first space point:");
for(size_t i = 0; i < m_embeddingDim; i++){
ATH_MSG_INFO(" " << eOutputData[i]);
}
// ************
// Building Edges
// ************
@@ -148,7 +142,6 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
eInputTensor.clear();
eOutputData.clear();
eOutputTensor.clear();
ATH_MSG_INFO("Building edges" << m_rVal << " " << m_knnVal << " " << m_embeddingDim << " " << numEdges);
// sort the edge list and remove duplicate edges.
std::vector<std::pair<int64_t, int64_t>> edgePairs;
@@ -178,8 +171,6 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
}
edgePairs.clear();
ATH_MSG_INFO("after sorting and shuffling.");
ATH_MSG_INFO("numEdges: " << numEdges << " numSpacepoints: " << numSpacepoints);
// ************
// Filtering
@@ -193,27 +184,20 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) );
ATH_MSG_INFO("Filtering: after adding inputs");
std::vector<float> fOutputData;
std::vector<Ort::Value> fOutputTensor;
ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) );
ATH_MSG_INFO("Filtering: after adding outputs");
ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) );
ATH_MSG_INFO("Filtering: after inference");
// apply sigmoid to the filtering output data
// and remove edges with score < filterCut
// and sort the edge list so that sender idx < receiver.
std::vector<int64_t> rowIndices;
std::vector<int64_t> colIndices;
ATH_MSG_INFO("Filtering scores for the first 10 edges:");
for (int64_t i = 0; i < numEdges; i++){
float v = 1.f / (1.f + std::exp(-fOutputData[i])); // sigmoid, float type
if (i < 10) {
ATH_MSG_INFO("\t " << senders[i] << " -> " << receivers[i] << " " << v);
}
if (v >= m_filterCut){
auto src = edgeList[i];
auto dst = edgeList[numEdges + i];
@@ -225,7 +209,6 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
};
};
int64_t numEdgesAfterF = rowIndices.size();
ATH_MSG_INFO("after filtering: " << numEdgesAfterF);
// clean up filtering data.
fNodeFeatures.clear();
@@ -258,11 +241,9 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) );
// calculate the edge features.
ATH_MSG_INFO("before calculate edge features");
std::vector<float> gnnEdgeFeatures;
ExaTrkXUtils::calculateEdgeFeatures(gNodeFeatures, numSpacepoints, rowIndices, colIndices, gnnEdgeFeatures);
ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, gnnEdgeFeatures, 2, numEdgesAfterF) );
ATH_MSG_INFO("after calculate edge features");
// gnn outputs
std::vector<float> gOutputData;
@@ -279,12 +260,6 @@ StatusCode InDet::SiGNNTrackFinderTool::getTracks(
gNodeFeatures.clear();
gInputTensor.clear();
edgesAfterFiltering.clear();
ATH_MSG_INFO("after GNN");
// print the gnn output for the first 10 edges.
ATH_MSG_INFO("GNN output for the first 10 edges:");
for(size_t i = 0; i < std::min((size_t)10, gOutputData.size()); i++){
ATH_MSG_INFO("\t" << rowIndices[i] << " -> " << colIndices[i] << " " << gOutputData[i]);
}
// ************
// Track Labeling with cugraph::connected_components
Loading