diff --git a/Tracking/TrkFitter/TrkGaussianSumFilter/TrkGaussianSumFilter/KLGaussianMixtureReduction.h b/Tracking/TrkFitter/TrkGaussianSumFilter/TrkGaussianSumFilter/KLGaussianMixtureReduction.h index 54d069daeac308ca604fc5ff6ad538eb3e6ada36..df077d6be321ef458ff19e85defaad8d66853d13 100644 --- a/Tracking/TrkFitter/TrkGaussianSumFilter/TrkGaussianSumFilter/KLGaussianMixtureReduction.h +++ b/Tracking/TrkFitter/TrkGaussianSumFilter/TrkGaussianSumFilter/KLGaussianMixtureReduction.h @@ -76,9 +76,9 @@ #include "CxxUtils/features.h" #include "TrkGaussianSumFilter/GsfConstants.h" #include <array> -#include <vector> #include <cstdint> #include <utility> +#include <vector> namespace GSFUtils { @@ -107,10 +107,22 @@ struct Component1DArray int32_t numComponents = 0; }; +struct MergeArray +{ + struct merge + { + int8_t To = 0; + int8_t From = 0; + }; + std::array<merge, GSFConstants::maxComponentsAfterConvolution> merges{}; + int32_t numMerges=0; +}; + /* typedef tracking which component has been merged */ -using IsMergedArray = std::array<bool,GSFConstants::maxComponentsAfterConvolution>; +using IsMergedArray = + std::array<bool, GSFConstants::maxComponentsAfterConvolution>; /** * @brief Merge the componentsIn and return @@ -125,7 +137,7 @@ using IsMergedArray = std::array<bool,GSFConstants::maxComponentsAfterConvolutio * Furthemore, the input component array is assumed to be * GSFConstants::alignment aligned. */ -std::vector<std::pair<int8_t, int8_t>> +MergeArray findMerges(Component1DArray& componentsIn, const int8_t reducedSize); /** diff --git a/Tracking/TrkFitter/TrkGaussianSumFilter/src/GsfMaterialMixtureConvolution.cxx b/Tracking/TrkFitter/TrkGaussianSumFilter/src/GsfMaterialMixtureConvolution.cxx index a9fcc72aaa9ff0d3ea606c7c4e09f899cb52b554..ea0daa7e41f0e257c71098a78cc9362681a1adc8 100644 --- a/Tracking/TrkFitter/TrkGaussianSumFilter/src/GsfMaterialMixtureConvolution.cxx +++ b/Tracking/TrkFitter/TrkGaussianSumFilter/src/GsfMaterialMixtureConvolution.cxx @@ -342,29 +342,26 @@ Trk::GsfMaterialMixtureConvolution::update( } // Gather the merges -- order is important -- RHS is smaller than LHS - std::vector<std::pair<int8_t, int8_t>> merges; + GSFUtils::MergeArray KL; if (n > m_maximumNumberOfComponents) { - merges = findMerges(componentsArray, m_maximumNumberOfComponents); + KL = findMerges(componentsArray, m_maximumNumberOfComponents); } // Merge components MultiComponentStateAssembler::Cache assemblerCache; int nMerges(0); - GSFUtils::IsMergedArray isMerged={}; - for (const auto& mergePair : merges) { - const int8_t mini = mergePair.first; - const int8_t minj = mergePair.second; + GSFUtils::IsMergedArray isMerged = {}; + int32_t returnedMerges = KL.numMerges; + + for (int32_t i = 0; i < returnedMerges; ++i) { + const int8_t mini = KL.merges[i].To; + const int8_t minj = KL.merges[i].From; if (isMerged[minj]) { ATH_MSG_WARNING("Component is already merged " << minj); - for (const auto& mergePair2 : merges) { - ATH_MSG_DEBUG("Pairs that should be merged together: " - << mergePair2.first << " " << mergePair2.second); - } continue; } // Get the first TP size_t stateIndex = indices[mini].first; size_t materialIndex = indices[mini].second; - // Copy weight and first parameters as they are needed later on // for updating the covariance AmgVector(5) firstParameters = diff --git a/Tracking/TrkFitter/TrkGaussianSumFilter/src/KLGaussianMixtureReduction.cxx b/Tracking/TrkFitter/TrkGaussianSumFilter/src/KLGaussianMixtureReduction.cxx index 51d4ebe90eaf5e8b64bf2d15359c2dadd3ec688c..ec5e3e0a3a201ef7c4900001019b7fdb4b79441c 100644 --- a/Tracking/TrkFitter/TrkGaussianSumFilter/src/KLGaussianMixtureReduction.cxx +++ b/Tracking/TrkFitter/TrkGaussianSumFilter/src/KLGaussianMixtureReduction.cxx @@ -216,37 +216,35 @@ namespace GSFUtils { * Merge the componentsIn and return * which componets got merged. */ -std::vector<std::pair<int8_t, int8_t>> +MergeArray findMerges(Component1DArray& componentsIn, const int8_t reducedSize) { - Component1D* components = static_cast<Component1D*>(__builtin_assume_aligned( - componentsIn.components.data(), GSFConstants::alignment)); const int32_t n = componentsIn.numComponents; - // Sanity check. Function throw on invalid inputs if (n < 0 || n > GSFConstants::maxComponentsAfterConvolution || reducedSize > n) { throw std::runtime_error("findMerges :Invalid InputSize or reducedSize"); } - // We need just one for the full duration of a job - // so static and const + + Component1D* components = static_cast<Component1D*>(__builtin_assume_aligned( + componentsIn.components.data(), GSFConstants::alignment)); + + // We need just one for the full duration of a job so static const const static std::vector<triangularToIJ> convert = createToIJMaxRowCols(); //Based on the inputSize n allocate enough space for the pairwise distances - const int32_t nn = n * (n - 1) / 2; // We work with a multiple of 8*floats (32 bytes). + const int32_t nn = n * (n - 1) / 2; const int32_t nn2 = (nn & 7) == 0 ? nn : nn + (8 - (nn & 7)); AlignedDynArray<float, GSFConstants::alignment> distances( nn2, std::numeric_limits<float>::max()); - - // vector to be returned - std::vector<std::pair<int8_t, int8_t>> merges; - merges.reserve(n - reducedSize); // initial distance calculation calculateAllDistances(components, distances.buffer(), n); //keep track of where we are int32_t numberOfComponentsLeft = n; IsMergedArray ismerged={}; + //Result to returned + MergeArray result{}; // merge loop while (numberOfComponentsLeft > reducedSize) { // see if we have the next already @@ -264,10 +262,11 @@ findMerges(Component1DArray& componentsIn, const int8_t reducedSize) // re-calculate distances wrt the new component at mini recalculateDistances(components, distances.buffer(), ismerged, mini, n); // keep track and decrement - merges.emplace_back(mini, minj); + result.merges[result.numMerges]={mini, minj}; + ++result.numMerges; --numberOfComponentsLeft; } // end of merge while - return merges; + return result; } /** diff --git a/Tracking/TrkFitter/TrkGaussianSumFilter/src/QuickCloseComponentsMultiStateMerger.cxx b/Tracking/TrkFitter/TrkGaussianSumFilter/src/QuickCloseComponentsMultiStateMerger.cxx index e9a980890edd02214dad8bb47ea1f354ab32d2d1..50eaa4762bf9514a2bc5026f6e9e0c8e820d2dd2 100644 --- a/Tracking/TrkFitter/TrkGaussianSumFilter/src/QuickCloseComponentsMultiStateMerger.cxx +++ b/Tracking/TrkFitter/TrkGaussianSumFilter/src/QuickCloseComponentsMultiStateMerger.cxx @@ -43,13 +43,14 @@ mergeFullDistArray(Trk::MultiComponentStateAssembler::Cache& cache, } // Gather the merges - const std::vector<std::pair<int8_t, int8_t>> merges = + const GSFUtils::MergeArray KL = findMerges(componentsArray, maximumNumberOfComponents); // Do the full 5D calculations of the merge - for (const auto& mergePair : merges) { - const int8_t mini = mergePair.first; - const int8_t minj = mergePair.second; + const int32_t numMerges = KL.numMerges; + for (int32_t i = 0; i < numMerges; ++i) { + const int8_t mini = KL.merges[i].To; + const int8_t minj = KL.merges[i].From; Trk::MultiComponentStateCombiner::combineWithWeight(statesToMerge[mini], statesToMerge[minj]); statesToMerge[minj].first.reset(); diff --git a/Tracking/TrkFitter/TrkGaussianSumFilter/test/testMergeComponents.cxx b/Tracking/TrkFitter/TrkGaussianSumFilter/test/testMergeComponents.cxx index 104730353ac4c75b77ddbed87e0ebe71a40421e1..e77e1cb99cfb76881a2d121ca5024aba6d9993fb 100644 --- a/Tracking/TrkFitter/TrkGaussianSumFilter/test/testMergeComponents.cxx +++ b/Tracking/TrkFitter/TrkGaussianSumFilter/test/testMergeComponents.cxx @@ -98,11 +98,11 @@ main() componentsArray.components[i].invCov = input[i].invCov; componentsArray.components[i].weight = input[i].weight; } - std::vector<std::pair<int8_t, int8_t>> mergeOrder = - findMerges(componentsArray, 12); - for (const auto& i : mergeOrder) { - std::cout << "[" << static_cast<int>(i.first) << ", " - << static_cast<int>(i.second) << "]" << '\n'; + const GSFUtils::MergeArray order = findMerges(componentsArray, 12); + const int32_t numMerges = order.numMerges; + for (int32_t i = 0; i < numMerges; ++i) { + std::cout << "[" << static_cast<int>(order.merges[i].To) << ", " + << static_cast<int>(order.merges[i].From) << "]" << '\n'; } return 0; }