Skip to content
Snippets Groups Projects
Commit 197e1d43 authored by Alexander Froch's avatar Alexander Froch
Browse files

Adding more flexibility to the MultiFoldTagger python config

parent 610f2994
No related branches found
No related tags found
No related merge requests found
Pipeline #9967132 passed
...@@ -15,6 +15,9 @@ class MultifoldTagger(BaseBlock): ...@@ -15,6 +15,9 @@ class MultifoldTagger(BaseBlock):
---------- ----------
nn_paths : list[str] nn_paths : list[str]
List of paths to the neural network files. List of paths to the neural network files.
alg_name : str
Name of the algorithm. If None, the name will be
extracted from the nn_paths. By default None
target : str target : str
Whether to tag the BTagging object or the Jet. Whether to tag the BTagging object or the Jet.
jet_collection : str | None jet_collection : str | None
...@@ -26,13 +29,17 @@ class MultifoldTagger(BaseBlock): ...@@ -26,13 +29,17 @@ class MultifoldTagger(BaseBlock):
Name of the fold hash variable. Name of the fold hash variable.
constituents : str constituents : str
Name of the constituent container. Name of the constituent container.
default_zero_tracks : bool
Bool to decide how to work with zero tracks. By default False.
""" """
nn_paths: list[str] nn_paths: list[str]
alg_name: str = None
target: str = "BTagging" target: str = "BTagging"
jet_collection: str = None jet_collection: str = None
remap: dict = None remap: dict = None
fold_hash_name: str = "jetFoldHash" fold_hash_name: str = "jetFoldHash"
constituents: str = "InDetTrackParticles" constituents: str = "InDetTrackParticles"
default_zero_tracks: bool = False
def __post_init__(self): def __post_init__(self):
...@@ -52,25 +59,27 @@ class MultifoldTagger(BaseBlock): ...@@ -52,25 +59,27 @@ class MultifoldTagger(BaseBlock):
self.track_link_type = "IPARTICLE" self.track_link_type = "IPARTICLE"
else: else:
raise ValueError(f"Unknown target {self.target}") raise ValueError(f"Unknown target {self.target}")
self.name = '_'.join(Path(self.nn_paths[0]).parts[2:-1]) if self.alg_name is None:
self.alg_name = '_'.join(Path(self.nn_paths[0]).parts[2:-1])
if "BTagTrackToJetAssociator" in self.remap: if "BTagTrackToJetAssociator" in self.remap:
self.name = f'{self.remap["BTagTrackToJetAssociator"]}_{self.name}' self.alg_name = f'{self.remap["BTagTrackToJetAssociator"]}_{self.alg_name}'
def to_ca(self): def to_ca(self):
ca = ComponentAccumulator() ca = ComponentAccumulator()
ca.addEventAlgo( ca.addEventAlgo(
self.deco_alg( self.deco_alg(
name=f'{self.name}_Alg', name=f'{self.alg_name}_Alg',
container=self.container, container=self.container,
constituentContainer=self.constituents, constituentContainer=self.constituents,
decorator=CompFactory.FlavorTagDiscriminants.MultifoldGNNTool( decorator=CompFactory.FlavorTagDiscriminants.MultifoldGNNTool(
name=f'{self.name}_Tool', name=f'{self.alg_name}_Tool',
foldHashName=self.fold_hash_name, foldHashName=self.fold_hash_name,
nnFiles=self.nn_paths, nnFiles=self.nn_paths,
variableRemapping=self.remap, variableRemapping=self.remap,
trackLinkType=self.track_link_type trackLinkType=self.track_link_type,
defaultZeroTracks=self.default_zero_tracks,
) )
) )
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment