Skip to content
Snippets Groups Projects
Commit 5a815a63 authored by Dan Guest's avatar Dan Guest
Browse files

clean up multifold NN

parent d457921f
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !71323. Comments created here will be created in the context of that merge request.
......@@ -15,10 +15,13 @@ namespace FlavorTagDiscriminants
{
public:
DeclareInterfaceID(FlavorTagDiscriminants::INNSharingSvc, 1, 0);
bool contains(const std::string& nn_name);
void insert(const std::string& nn_name,
const GNNOptions& opts,
const GNN&);
virtual bool contains(const std::string& nn_name);
virtual void insert(const std::string& nn_name,
const GNNOptions& opts,
const std::shared_ptr<const GNN>&);
virtual std::shared_ptr<const GNN> at(
const std::string& nn_name,
const GNNOptions& opts);
}
}
......
......@@ -24,6 +24,8 @@ namespace FlavorTagDiscriminants {
MultifoldGNN(const std::vector<std::string>& folds,
const std::string& fold_hash_name,
const GNNOptions& opts);
MultifoldGNN(const std::vector<std::shared_ptr<const GNN>>& folds,
const std::string& fold_hash_name);
MultifoldGNN(MultifoldGNN&&);
MultifoldGNN(const MultifoldGNN&);
~MultifoldGNN();
......@@ -36,7 +38,7 @@ namespace FlavorTagDiscriminants {
std::set<std::string> getConstituentAuxInputKeys() const;
private:
const GNN& getFold(const SG::AuxElement& element) const;
std::vector<std::shared_ptr<GNN>> m_folds;
std::vector<std::shared_ptr<const GNN>> m_folds;
SG::AuxElement::ConstAccessor<uint32_t> m_fold_hash;
SG::AuxElement::ConstAccessor<ElementLink<xAOD::JetContainer>> m_jetLink;
};
......
......@@ -20,6 +20,17 @@ namespace {
}
return first;
}
auto getNNs(
const std::vector<std::string>& nn_files,
const FlavorTagDiscriminants::GNNOptions& o)
{
using ftd = FlavorTagDiscriminants;
std::vector<std::shared_ptr<const ftd::GNN>> nns;
for (const auto& nn_file: nn_files) {
nns.emplace_back(std::make_shared<const ftd::GNN>(nn_file, o));
}
return nns;
}
}
namespace FlavorTagDiscriminants {
......@@ -28,12 +39,16 @@ namespace FlavorTagDiscriminants {
const std::vector<std::string>& nn_files,
const std::string& fold_hash_name,
const GNNOptions& o):
MultifoldGNN(getNNs(nn_files, o), fold_hash_name)
{
}
MultifoldGNN::MultifoldGNN(
const std::vector<std::shared_ptr<const GNN>>& nns,
const std::string& fold_hash_name):
m_folds(nns),
m_fold_hash(fold_hash_name),
m_jetLink(jetLinkName)
{
for (const auto& nn_file: nn_files) {
m_folds.emplace_back(std::make_shared<GNN>(nn_file, o));
}
}
MultifoldGNN::MultifoldGNN(MultifoldGNN&&) = default;
MultifoldGNN::MultifoldGNN(const MultifoldGNN&) = default;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment