diff --git a/.gitignore b/.gitignore index 2f319e19c6e6ff37c1e0e8ab49729af9aa040cd3..ce12c01cb5b536c72191410ae6c37b7db342b0d7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ __pycache__ # Data scratch/ artifacts/ +artifacts_old/ lightning_logs/ LHCb_Pipeline/output/ LHCb_Pipeline/analysis/ diff --git a/.vscode/launch.json b/.vscode/launch.json index f206ccc6b8cd378c9d36c1b4be1dabd00565950d..b3da7cf2d463bbebc5807c97b2ff9f9b2d34d52f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,6 +12,14 @@ "console": "integratedTerminal", "justMyCode": true, "cwd": "/home/fgias/etx4velo/LHCb_Pipeline/" - } - ] + }, + { + "name": "anthonyc", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "cwd": "${fileDirname}" + }, + ], } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 2364b7bbfea1dabaebb76063505827875cf1c732..a56d0d3c4948a14ce812af5a8c9eff03c865274d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,5 +4,8 @@ "./montetracko" ], "python.linting.flake8Enabled": true, - "python.linting.enabled": true + "python.linting.enabled": true, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, } \ No newline at end of file diff --git a/LHCb_Pipeline/Embedding/build_embedding.py b/LHCb_Pipeline/Embedding/build_embedding.py index d1bc0793fd3762654ed0c10c727589766b303e97..ac2958573daa14924c704d5f209b88014406a54c 100644 --- a/LHCb_Pipeline/Embedding/build_embedding.py +++ b/LHCb_Pipeline/Embedding/build_embedding.py @@ -1,9 +1,14 @@ +from __future__ import annotations +from types import ModuleType import torch from torch_geometric.data import Data from Embedding.embedding_base import EmbeddingBase from utils.modelutils.build import ModelBuilderBase from utils.commonutils.config import load_config +from utils.graphutils.edgeutils import sort_edge_nodes + +from . import building_custom device = "cuda" if torch.cuda.is_available() else "cpu" @@ -35,44 +40,117 @@ class EmbeddingInferenceBuilder(ModelBuilderBase): model: EmbeddingBase, knn_max: int = 1000, radius: float = 0.1, + bidir: bool | None = None, ): super(EmbeddingInferenceBuilder, self).__init__(model=model) self.knn_max = knn_max self.radius = radius + self._bidir = bidir - def construct_downstream(self, batch: Data): - batch = self.select_data(batch) + @property + def bidir(self) -> bool: + """Whether to use a bi-directional graph""" + if self._bidir is None: + return self.model.hparams.get("bidir", True) + else: + return self.bidir + def construct_downstream(self, batch: Data): y_cluster, e_spatial, e_bidir = self.get_performance( batch=batch, r_max=self.radius, k_max=self.knn_max ) - module_mask = batch.plane[e_spatial[0]] != batch.plane[e_spatial[1]] - y_cluster = y_cluster[module_mask] - e_spatial = e_spatial[:, module_mask] - # Arbitrary ordering to remove half of the duplicate edges - # TODO: if one wants to really do that, why not ordering the indices instead? - R_dist = torch.sqrt(batch.x[:, 0] ** 2 + batch.x[:, 2] ** 2) - e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])] + # Remove Edges within a same plane + plane_mask = batch.plane[e_spatial[0]] != batch.plane[e_spatial[1]] + y_cluster = y_cluster[plane_mask] + e_spatial = e_spatial[:, plane_mask] - e_spatial, y_cluster = self.model.get_truth(batch, e_spatial, e_bidir) + if self.bidir: + # Arbitrary ordering to remove half of the duplicate edges + # TODO: if one wants to really do that, why not ordering the indices instead? + R_dist = torch.sqrt(batch.x[:, 0] ** 2 + batch.x[:, 2] ** 2) + e_spatial = e_spatial[:, (R_dist[e_spatial[0]] <= R_dist[e_spatial[1]])] - # Re-introduce random direction, to avoid training bias - random_flip = torch.randint(2, (e_spatial.shape[1],)).bool() - e_spatial[0, random_flip], e_spatial[1, random_flip] = ( - e_spatial[1, random_flip], - e_spatial[0, random_flip], - ) + e_spatial, y_cluster = self.model.get_truth(batch, e_spatial, e_bidir) - batch.edge_index = e_spatial - batch.y = y_cluster + if self.bidir: + # Re-introduce random direction, to avoid training bias + random_flip = torch.randint(2, (e_spatial.shape[1],)).bool() + e_spatial[0, random_flip], e_spatial[1, random_flip] = ( + e_spatial[1, random_flip], + e_spatial[0, random_flip], + ) + else: + # Do the opposite: enforce a direction + sort_edge_nodes(e_spatial, batch.un_z) + # Remove duplicates + e_spatial, unique_inverse = torch.unique( + e_spatial, dim=1, return_inverse=True + ) + unique_indices = torch.unique(unique_inverse) + y_cluster = y_cluster[unique_indices] + + assert e_spatial.shape[1] == torch.unique(e_spatial, dim=1).shape[1] + + batch["edge_index"] = e_spatial + batch["y"] = y_cluster + return batch + + def _get_building_custom_module(self) -> ModuleType: + return building_custom + + # def filter_batch(self, batch: Data) -> Data: + # # at least 3 hits to be classified as valid edges + # # not_reconstructible_mask = batch.n_unique_planes < 3 + # not_reconstructible_mask = batch.nhits_velo < 3 + + # # Apply this transformation to `y` as well + # not_reconstructible_edge_mask = ( + # not_reconstructible_mask[batch.edge_index].min(dim=0).values + # ) + + # batch.y[not_reconstructible_edge_mask] = False + + # # Classify the hits as fake (so that the edges are also classified like so) + # batch.particle_id[not_reconstructible_mask] = 0 + + # # Remove edges in same `plane` and `z` + # edge_index_plane = batch.plane[batch.edge_index] + # no_self_edge_mask = edge_index_plane[0] != edge_index_plane[1] + # batch.edge_index = batch.edge_index[:, no_self_edge_mask] + # batch.y = batch.y[no_self_edge_mask] + + # return batch + + # def build_features(self, batch: Data) -> Data: + # # norm_x = batch.un_x / 14.5 + # # norm_y = batch.un_y / 14.5 + # # xe = batch.un_x[batch.edge_index] + # # ye = batch.un_y[batch.edge_index] + # # ze = batch.un_z[batch.edge_index] + + # # slopes_yz = (ye[1] - ye[0]) / (ze[1] - ze[0]) / 0.17 + # # slopes_xz = (xe[1] - xe[0]) / (ze[1] - ze[0]) / 0.17 + + # # assert not torch.isnan(slopes_yz).any() + # # assert not torch.isnan(slopes_xz).any() + # # # Modify batch definition + # # batch.x = torch.stack((norm_x, norm_y, batch["x"][:, 2]), dim=1).float() + # # batch.edge_features = torch.stack((slopes_xz, slopes_yz), dim=1).float() + # return batch + + # def build_weights(self, batch: Data) -> Data: + # node_weights = 7.0 / batch.n_unique_planes + # node_weights = torch.nan_to_num(node_weights, nan=1.0) + # edge_weights = torch.mean(node_weights[batch.edge_index], dim=0) + # batch.edge_weights = ( + # edge_weights * batch.edge_index.shape[1] / edge_weights.sum() + # ) + # assert not torch.isnan(batch.edge_weights).any(), str(batch.edge_weights) + # return batch def get_performance(self, batch: Data, r_max: float, k_max: int): with torch.no_grad(): results = self.model.shared_evaluation(batch, 0, r_max, k_max) return results["truth"], results["preds"], results["truth_graph"] - - def select_data(self, event: Data) -> Data: - event.signal_true_edges = event.modulewise_true_edges - return event diff --git a/LHCb_Pipeline/Embedding/building_custom.py b/LHCb_Pipeline/Embedding/building_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0fb9b5905496327468ea61f72a8028d34ba4ec --- /dev/null +++ b/LHCb_Pipeline/Embedding/building_custom.py @@ -0,0 +1,57 @@ +"""Custom functions for filtering and alterning an event. +""" +import torch +from torch_geometric.data import Data + + +def edges_at_least_3_hits(batch: Data) -> Data: + # at least 3 hits to be classified as valid edges + # not_reconstructible_mask = batch.n_unique_planes < 3 + not_reconstructible_mask = batch.nhits_velo < 3 + + # Apply this transformation to `y` as well + not_reconstructible_edge_mask = ( + not_reconstructible_mask[batch.edge_index].min(dim=0).values + ) + + batch.y[not_reconstructible_edge_mask] = False + + # Classify the hits as fake (so that the edges are also classified like so) + batch.particle_id[not_reconstructible_mask] = 0 + + # Remove edges in same `plane` and `z` + edge_index_plane = batch.plane[batch.edge_index] + no_self_edge_mask = edge_index_plane[0] != edge_index_plane[1] + batch["edge_index"] = batch.edge_index[:, no_self_edge_mask] + batch["y"] = batch.y[no_self_edge_mask] + return batch + + +def edge_features_as_slope(batch: Data) -> Data: + """Build edge features that correspond to the slope.""" + norm_x = batch.un_x / 14.5 + norm_y = batch.un_y / 14.5 + xe = batch.un_x[batch.edge_index] + ye = batch.un_y[batch.edge_index] + ze = batch.un_z[batch.edge_index] + + slopes_yz = (ye[1] - ye[0]) / (ze[1] - ze[0]) / 0.17 + slopes_xz = (xe[1] - xe[0]) / (ze[1] - ze[0]) / 0.17 + + assert not torch.isnan(slopes_yz).any() + assert not torch.isnan(slopes_xz).any() + # Modify batch definition + batch["x"] = torch.stack((norm_x, norm_y, batch["x"][:, 2]), dim=1).float() + batch["edge_features"] = torch.stack((slopes_xz, slopes_yz), dim=1).float() + + return batch + + +def weights_inversely_proportional_to_nhits(batch: Data) -> Data: + """Define edge weights that are inversely proportional to the number of hits.""" + node_weights = 7.0 / batch.n_unique_planes + node_weights = torch.nan_to_num(node_weights, nan=1.0) + edge_weights = torch.mean(node_weights[batch.edge_index], dim=0) + batch.edge_weights = edge_weights * batch.edge_index.shape[1] / edge_weights.sum() + assert not torch.isnan(batch.edge_weights).any(), str(batch.edge_weights) + return batch diff --git a/LHCb_Pipeline/Embedding/embedding_base.py b/LHCb_Pipeline/Embedding/embedding_base.py index 89bcfb46df906e6fe1ab58efe3654e593000b5a7..27d4ecdce353aed095585dfde47ee7b9ed6272fc 100644 --- a/LHCb_Pipeline/Embedding/embedding_base.py +++ b/LHCb_Pipeline/Embedding/embedding_base.py @@ -28,7 +28,7 @@ from torch_geometric.data import Data from utils.modelutils.basemodel import ModelBase from utils.commonutils.config import load_config from .graphutils import graph_intersection, build_edges - +from utils.graphutils.edgeutils import sort_edge_nodes device = "cuda" if torch.cuda.is_available() else "cpu" @@ -36,21 +36,10 @@ device = "cuda" if torch.cuda.is_available() else "cpu" class EmbeddingBase(ModelBase): """A class that implements the metric learning model.""" - def load_dataset(self, input_path: str) -> Data: - """Load and process one PyTorch DataSet. - - Args: - input_path: path to the PyTorch dataset - - Returns: - PyTorch DataSet - """ - loaded_event = super(EmbeddingBase, self).load_dataset(input_path=input_path) - # Define which column corresponds to the true edges - loaded_event["signal_true_edges"] = loaded_event[ - self.hparams["true_edges_column"] - ] - return loaded_event + @property + def bidir(self) -> bool: + """Whether the graph to build is bidirectional.""" + return self.hparams.get("bidir", True) def get_query_points(self, batch, spatial): if "query_all_points" in self.hparams["regime"]: @@ -161,7 +150,6 @@ class EmbeddingBase(ModelBase): Returns: ``torch.tensor`` The loss function as a tensor """ - # Instantiate empty prediction edge list e_spatial = torch.empty([2, 0], dtype=torch.int64, device=self.device) @@ -182,9 +170,12 @@ class EmbeddingBase(ModelBase): e_spatial = self.append_random_pairs(e_spatial, query_indices, spatial) # Instantiate bidirectional truth (since KNN prediction will be bidirectional) - e_bidir = torch.cat( - (batch.signal_true_edges, batch.signal_true_edges.flip(0)), dim=-1 - ) + if self.bidir: + e_bidir = torch.cat( + (batch.signal_true_edges, batch.signal_true_edges.flip(0)), dim=-1 + ) + else: + e_bidir = batch.signal_true_edges # Calculate truth from intersection between Prediction graph and Truth graph e_spatial, y_cluster = self.get_truth(batch, e_spatial, e_bidir) @@ -194,13 +185,16 @@ class EmbeddingBase(ModelBase): e_spatial, y_cluster, new_weights = self.get_true_pairs( e_spatial, y_cluster, new_weights, e_bidir ) + if not self.bidir: # is it really what I want? + sort_edge_nodes(e_spatial, batch.z) included_hits = e_spatial.unique() spatial[included_hits] = self(input_data[included_hits]) hinge, d = self.get_hinge_distance(spatial, e_spatial, y_cluster) - # Give negative examples a weight of 1 (note that there may still be TRUE examples that are weightless) + # Give negative examples a weight of 1 + # (note that there may still be TRUE examples that are weightless) new_weights[hinge == -1] = 1 negative_loss = torch.nn.functional.hinge_embedding_loss( @@ -223,7 +217,6 @@ class EmbeddingBase(ModelBase): "train_loss", loss, on_epoch=True, - on_step=False, batch_size=e_spatial.shape[1], prog_bar=True, ) @@ -236,16 +229,23 @@ class EmbeddingBase(ModelBase): input_data = self.get_input_data(batch) spatial = self(input_data) - e_bidir = torch.cat( - (batch.signal_true_edges, batch.signal_true_edges.flip(0)), dim=-1 - ) + if self.bidir: + e_bidir = torch.cat( + (batch.signal_true_edges, batch.signal_true_edges.flip(0)), dim=-1 + ) + else: + e_bidir = batch.signal_true_edges # Build whole KNN graph e_spatial = build_edges( spatial, spatial, indices=None, r_max=knn_radius, k_max=knn_num ) + if not self.bidir: + sort_edge_nodes(e_spatial, batch.un_z) e_spatial, y_cluster = self.get_truth(batch, e_spatial, e_bidir) + if not self.bidir: + sort_edge_nodes(e_spatial, batch.un_z) hinge, d = self.get_hinge_distance( spatial, e_spatial.to(self.device), y_cluster @@ -259,7 +259,6 @@ class EmbeddingBase(ModelBase): cluster_true_positive = y_cluster.sum() cluster_positive = len(e_spatial[0]) - # what is this? eff = cluster_true_positive / cluster_true pur = cluster_true_positive / cluster_positive @@ -331,8 +330,8 @@ def get_example_data( metric_learning_configs = configs["metric_learning"] model = EmbeddingBase(metric_learning_configs) - model.setup(stage="fit") - training_example = model.trainset[idx] + # model.setup(stage="fit") + training_example = model.valset[idx] example_hit_inputs = model.get_input_data(training_example) example_hit_df = pd.DataFrame(example_hit_inputs.numpy()) diff --git a/LHCb_Pipeline/Embedding/embedding_plots.py b/LHCb_Pipeline/Embedding/embedding_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..51e914b83efea1ed59a15527ecf4e3fa66971cd9 --- /dev/null +++ b/LHCb_Pipeline/Embedding/embedding_plots.py @@ -0,0 +1,136 @@ +"""A module that handles the validation plots for the embedding phase specifically. +""" +import typing +import os.path as op + +import numpy as np +from uncertainties import unumpy as unp +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from matplotlib.axes import Axes + +from Embedding.embedding_validation import ( + evaluate_embedding_performances_given_radius_knn_max, +) +from utils.plotutils.plotconfig import partition_to_color, partition_to_label +from utils.plotutils.plotools import save_fig +from utils.commonutils.cpaths import get_performance_directory +from utils.commonutils.config import load_config +from .embedding_base import EmbeddingBase +from .embedding_validation import EmbeddingRadiusExplorer + + +def plot_embedding_performance_given_radius_knn_max( + model: EmbeddingBase, + path_or_config: str | dict, + partitions: typing.List[str] = ["train", "val"], + n_events: int = 10, + radius: np.ndarray | float | None = None, + knn_max: np.ndarray | int | None = None, + show_err: bool = True, +) -> typing.Tuple[ + typing.Dict[str, typing.Tuple[Figure, Axes]], + typing.Tuple[Figure, Axes], + typing.Tuple[ + np.ndarray, typing.Dict[str, unp.matrix], typing.Dict[str, unp.matrix] + ], +]: + """Plot edge efficiency, purity and graph size as a function of the maximal + radius and maximal number of neighbours in the k-nearest neighbour algorithm. + + Args: + model: Embedding model + path_or_config: YAML configuration + partitions: List of partitions to plot + n_events: Maximal number of events to use for each partition for performance + evaluationn + radius: Maximal distance in the embedding space + knn_max: Maximal number of neighbours + show_err: whether to show the error bars + + Returns: + Tuples of 2 dictionary. The first dictionary associates a metric name with + the tuple of matplotlib Figure and Axes. + The second dictionary associates a metric name with another dictionary + that associates a partition with the list of metric values, for the different + ``radius`` or ``knn_max`` given as input. + """ + knn_is_array = isinstance(knn_max, np.ndarray) + radius_is_array = isinstance(radius, np.ndarray) + if radius_is_array and knn_is_array: + raise ValueError( + "Error: Cannot vary `radius` and `knn_max` at the same time but they were " + "both provided as a list." + ) + elif radius_is_array: + list_hyperparam_values = radius + hyperparam_name = "radius" + elif knn_is_array: + list_hyperparam_values = knn_max + hyperparam_name = "knn" + else: + raise ValueError("Either `radius` or `knn_max` should be a numpy array.") + + dict_metrics_partitions = evaluate_embedding_performances_given_radius_knn_max( + model=model, + partitions=partitions, + radius=radius, + knn_max=knn_max, + n_events=n_events, + ) + + dict_figs_axs = {} + + for metric_name, dict_partitions in dict_metrics_partitions.items(): + fig, ax = plt.subplots(figsize=(8, 6)) + for partition in partitions: + ax.errorbar( + x=list_hyperparam_values, + y=unp.nominal_values(dict_partitions[partition]), + yerr=unp.std_devs(dict_partitions[partition]) if show_err else None, # type: ignore + color=partition_to_color.get(partition), + label=partition_to_label.get(partition, partition), + marker=".", + ) + + if hyperparam_name == "radius": + ax.set_xlabel("Radius") + elif hyperparam_name == "knn": + ax.set_xlabel("Maximal number of neighbours") + else: + raise Exception() + ax.grid(color="grey", alpha=0.5) + ax.legend() + ax.set_ylabel(metric_name.replace("_", "").title()) + + performance_dir = get_performance_directory(path_or_config=path_or_config) + save_fig(fig=fig, path=op.join(performance_dir, metric_name)) + + dict_figs_axs[metric_name] = (fig, ax) + + return (dict_figs_axs, dict_metrics_partitions) + + + + +def plot_best_performances_radius( + model: EmbeddingBase, + path_or_config: str | dict, + partition: str, + list_radius: typing.Sequence[float], + knn_max: int | None = None, + n_events: int | None = None, + seed: int | None = None, +) -> typing.Tuple[Figure, Axes, typing.Dict[str, typing.Dict[str, float]]]: + + embeddingRadiusExplorer = EmbeddingRadiusExplorer(model=model) + config = load_config(path_or_config=path_or_config) + return embeddingRadiusExplorer.plot( + path_or_config=config, + partition=partition, + values=list_radius, + n_events=n_events, + seed=seed, + knn_max=knn_max, + building=config["metric_learning"].get("building"), + ) diff --git a/LHCb_Pipeline/Embedding/embedding_validation.py b/LHCb_Pipeline/Embedding/embedding_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..fc338ba74a853d905c5274e69efba2039ebb73a3 --- /dev/null +++ b/LHCb_Pipeline/Embedding/embedding_validation.py @@ -0,0 +1,199 @@ +"""A module that defines tools to perform the validation step of the embedding step. +""" +from __future__ import annotations +import typing +import logging + +from uncertainties import ufloat +from uncertainties.core import Variable +from tqdm.auto import tqdm +from uncertainties import unumpy as unp +import numpy as np +import pandas as pd +import torch +from torch_geometric.data import Data + +from GNN.perfect_gnn import PerfectInferenceBuilder +from Scripts.Step_5_Build_Track_Candidates import TrackBuilder +from Scripts.Step_6_Evaluate_Reconstruction_MonteTracko import get_tracks_from_batch +from utils.modelutils.batches import get_batches, select_subset +from utils.modelutils.evaluation import ParamExplorer + +from .build_embedding import EmbeddingInferenceBuilder +from .embedding_base import EmbeddingBase + + +def get_default_radius(model: EmbeddingBase, radius: float | None = None) -> float: + if radius is None: + r_infer = model.hparams.get("r_infer") + if r_infer is None: + r_infer = model.hparams["r"] + return r_infer + else: + return radius + + +def evaluate_embedding_performance( + model: EmbeddingBase, + batches: typing.List[Data], + radius: float | None = None, + knn_max: int | None = None, +) -> typing.Tuple[Variable, Variable, Variable]: + """Compute the edge efficiency and edge purity of a given model, on a subset + of the train, val or test dataset. + + Args: + model: PyTorch model inheriting from + :py:class:`utils.modelutils.basemodel.ModelBase` + partition: ``train``, ``val``, ``test`` (for the current already loaded + test sample) or the name of a test dataset + radius: Maximal radius for the KNN. If not given, taken from the hyperparameter + in the model. + knn_max: Maximal number of neighbours for the KNN. If not given, + taken from the hyperparameter in the model. + n_events: Number of events to compute the performance metrics on + seed: Seed used to randomly select the ``n_events`` + + Returns: + A tuple of 3 ufloat numbers corresponding to the event-based average of the + edge efficiency and edge purity, and the graph size + """ + # Handle default values for `knn_max` and `radius` + radius = get_default_radius(model=model, radius=radius) + knn_max = model.hparams["knn"] if knn_max is None else knn_max + + n_batches = len(batches) + # Compute performance for each batch + with torch.no_grad(): + efficiencies = np.full(shape=n_batches, fill_value=np.nan) + purities = np.full(shape=n_batches, fill_value=np.nan) + graph_sizes = np.full(shape=n_batches, fill_value=np.nan) + for batch_idx, batch in enumerate(batches): + results = model.shared_evaluation( + batch=batch, batch_idx=batch_idx, knn_radius=radius, knn_num=knn_max + ) + efficiencies[batch_idx] = results["eff"] + purities[batch_idx] = results["pur"] + graph_sizes[batch_idx] = results["preds"].shape[1] + + return ( + ufloat(efficiencies.mean(), efficiencies.std()), + ufloat(purities.mean(), purities.std()), + ufloat(graph_sizes.mean(), graph_sizes.std()), + ) + + +def evaluate_embedding_performances_given_radius_knn_max( + model: EmbeddingBase, + partitions: typing.List[str] = ["train", "val"], + n_events: int = 10, + radius: np.ndarray | float | None = None, + knn_max: np.ndarray | int | None = None, + seed: int | None = None, +) -> typing.Dict[str, typing.Dict[str, unp.matrix]]: + """ """ + if isinstance(knn_max, np.ndarray): + list_hyperparam_values = knn_max + hyperparam_name = "knn_max" + elif isinstance(radius, np.ndarray): + list_hyperparam_values = radius + hyperparam_name = "radius" + else: + raise ValueError( + "Both `knn_max` and `radius` are provided as array but only one " + "is supported." + ) + + dict_metrics_partitions = { + "edge_efficiency": {}, + "edge_purity": {}, + "graph_size": {}, + } + for partition in partitions: + logging.info(f"Compute edge performance metrics for {partition}") + batches = model.fetch_partition( + partition=partition, + n_events=n_events, + shuffle=True, + seed=seed, + map_location=model.device, + ) + + # Move batches to save device as model + # batches = [batch.to(model.device) for batch in batches] # type: ignore + + efficiencies = [] + purities = [] + graph_sizes = [] + for hyperparam_value in (pbar := tqdm(list_hyperparam_values)): + pbar.set_description( + f"Loop over {hyperparam_name} (current value: {hyperparam_value})" + ) + ( + efficiency, + purity, + graph_size, + ) = evaluate_embedding_performance( + model=model, + batches=batches, + radius=hyperparam_value if hyperparam_name == "radius" else radius, # type: ignore + knn_max=int(hyperparam_value) if hyperparam_name == "knn_max" else knn_max, # type: ignore + ) + efficiencies.append(efficiency) + purities.append(purity) + graph_sizes.append(graph_size) + + dict_metrics_partitions["edge_efficiency"][partition] = np.array(efficiencies) + dict_metrics_partitions["edge_purity"][partition] = np.array(purities) + dict_metrics_partitions["graph_size"][partition] = np.array(graph_sizes) + + return dict_metrics_partitions + + +class EmbeddingRadiusExplorer(ParamExplorer): + """A class that allows to vary the maximal radius and compare the best metric + performances of track finding, in the case where all the fake edges are filtered + out. + """ + + def __init__(self, model: EmbeddingBase) -> None: + super().__init__(model, varname="radius", varlabel=r"$r_{\max}$") + + def get_tracks( + self, + value: float, + batches: typing.List[Data], + knn_max: int | None = None, + building: str | None = None, + ) -> pd.DataFrame: + # Run embedding inference + embeddingInferenceBuilder = EmbeddingInferenceBuilder( + model=self.model, + knn_max=self.model.hparams["knn"] if knn_max is None else knn_max, + radius=value, + ) + batches = [ + embeddingInferenceBuilder.process_one_step( + batch=batch.clone(), + building=building, + ) + for batch in tqdm(batches, desc="Graph Building") + ] + + # Run perfect GNN inference + perfectInferenceBuilder = PerfectInferenceBuilder() + batches = [ + perfectInferenceBuilder.construct_downstream(batch=batch) + for batch in batches + ] + + # Run track reconstruction + trackBuilder = TrackBuilder(score_cut=0.5) + batches = [ + trackBuilder.construct_downstream(batch=batch.cpu()) for batch in batches + ] + + # Define dataframe of tracks + return pd.concat( + tuple(get_tracks_from_batch(batch=batch) for batch in batches) + ).drop_duplicates() diff --git a/LHCb_Pipeline/Embedding/models/layerless_embedding.py b/LHCb_Pipeline/Embedding/models/layerless_embedding.py index 7bd08e14092efd44af790008b852ad9b033fb55a..8ea10afbd5d2e64c0ef3199594bdeb66586fe607 100644 --- a/LHCb_Pipeline/Embedding/models/layerless_embedding.py +++ b/LHCb_Pipeline/Embedding/models/layerless_embedding.py @@ -3,7 +3,7 @@ from ..embedding_base import EmbeddingBase import torch.nn.functional as F # Local imports -from utils.modelutils.mpl import make_mlp +from utils.modelutils.mlp import make_mlp from utils.commonutils.cfeatures import get_number_input_features diff --git a/LHCb_Pipeline/GNN/build_gnn.py b/LHCb_Pipeline/GNN/build_gnn.py index d2c40418e836f7706367449d7e28232e0889156e..1a9ad0c3b25e20872c5ac40a7d10588370c5435c 100644 --- a/LHCb_Pipeline/GNN/build_gnn.py +++ b/LHCb_Pipeline/GNN/build_gnn.py @@ -19,4 +19,8 @@ class GNNInferenceBuilder(ModelBuilderBase): particle_ids=batch.particle_id, ) output = self.model.shared_evaluation(batch, 0, log=False) - batch.scores = output["score"][: int(len(output["score"]) / 2)] + if self.model.hparams.get("bidir", True): + batch.scores = output["score"][: int(len(output["score"]) / 2)] + else: + batch.scores = output["score"] + return batch diff --git a/LHCb_Pipeline/GNN/gnn_base.py b/LHCb_Pipeline/GNN/gnn_base.py index 8648c11a665bd4c9d8285c782fe488bb6e650d4d..6b167b123405bad2f078fa06388249be5cf76302 100644 --- a/LHCb_Pipeline/GNN/gnn_base.py +++ b/LHCb_Pipeline/GNN/gnn_base.py @@ -1,3 +1,4 @@ +import typing import numpy as np from sklearn.metrics import roc_auc_score import torch.nn.functional as F @@ -36,7 +37,11 @@ def compute_edge_labels( class GNNBase(ModelBase): - def load_dataset(self, input_path: str) -> Data: + @property + def bidir(self) -> bool: + return self.hparams.get("bidir", True) + + def fetch_dataset(self, input_path: str, **kwargs) -> Data: """Load and process one PyTorch DataSet. Args: @@ -45,7 +50,9 @@ class GNNBase(ModelBase): Returns: PyTorch DataSet """ - loaded_event = super(GNNBase, self).load_dataset(input_path=input_path) + loaded_event = super(GNNBase, self).fetch_dataset( + input_path=input_path, **kwargs + ) # Add `y_pid` column if not already there if "y_pid" not in loaded_event: @@ -53,55 +60,139 @@ class GNNBase(ModelBase): edge_indices=loaded_event.edge_index, particle_ids=loaded_event.particle_id, ) - return loaded_event - def handle_directed(self, batch, edge_sample, truth_sample): - edge_sample = torch.cat([edge_sample, edge_sample.flip(0)], dim=-1) - truth_sample = truth_sample.repeat(2) + if self.hparams.get("shuffle_edge_direction", False): + assert self.bidir, ( + "It was required to shuffle the edge directions, even though " + "the graph is not bidirectional. This is odd." + ) + # Randomly shuffle direction of edges + random_flip = torch.randint(2, (loaded_event.edge_index.shape[1],)).bool() + ( + loaded_event.edge_index[0, random_flip], + loaded_event.edge_index[1, random_flip], + ) = ( + loaded_event.edge_index[1, random_flip], + loaded_event.edge_index[0, random_flip], + ) - if ("directed" in self.hparams.keys()) and self.hparams["directed"]: - direction_mask = batch.x[edge_sample[0], 0] < batch.x[edge_sample[1], 0] - edge_sample = edge_sample[:, direction_mask] - truth_sample = truth_sample[direction_mask] + return loaded_event - return edge_sample, truth_sample + def handle_bidirectional(self, edge_sample, truth_sample): + if self.bidir: + edge_sample = torch.cat([edge_sample, edge_sample.flip(0)], dim=-1) + truth_sample = truth_sample.repeat(2) - def training_step(self, batch, batch_idx): - weight = ( - torch.tensor(self.hparams["weight"]) - if ("weight" in self.hparams) - else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()) - ) + return edge_sample, truth_sample - truth = ( - batch.y_pid.bool() if "pid" in self.hparams["regime"] else batch.y.bool() - ) + def compute_loss(self, output: torch.Tensor, truth: torch.Tensor, batch: Data): + """Compute the loss. - edge_sample, truth_sample = self.handle_directed(batch, batch.edge_index, truth) - input_data = self.get_input_data(batch) - output = self(input_data, edge_sample).squeeze() + Args: + output: network output + truth: true labels (to compare to ``output``) + batch: PyTorch geometric data objects, used to compute the weights + and having access to other columns + Returns: + loss function + """ + # Compute weights if "weighting" in self.hparams["regime"]: - manual_weights = batch.weights + manual_weights = batch.edge_weights else: manual_weights = None - loss = F.binary_cross_entropy_with_logits( - output, truth_sample.float(), weight=manual_weights, pos_weight=weight + # Compute weights on positive samples + if self.hparams.get("focal_loss", False): + weight = ( + torch.tensor(self.hparams["weight"]) + if ("weight" in self.hparams) + else (~truth).sum() / truth.shape[0] + ) + else: + weight = ( + torch.tensor(self.hparams["weight"]) + if ("weight" in self.hparams) + else (~truth).sum() / truth.sum() + ) + + # Compute weighted loss + if self.hparams.get("focal_loss", False): + from torchvision.ops import sigmoid_focal_loss + + loss = sigmoid_focal_loss( + inputs=output, + targets=truth.float(), + alpha=weight, + reduction="mean", + ) + else: + loss = F.binary_cross_entropy_with_logits( + output, + truth.float(), + weight=manual_weights, + pos_weight=weight, + ) + if "triplet" in self.hparams["regime"]: + assert not self.bidir, ( + "Loss for triplets with penalty term not supported " + "for bidirectional graph" + ) + total_angle_diff_norm = torch.abs(batch.diff_angle_xz_norm) + torch.abs( + batch.diff_angle_yz_norm + ) + # Penality term if score is large + pos_penality = (torch.sigmoid(output) * total_angle_diff_norm).mean() + # Penality term if score is small + neg_penality = ( + (1 - torch.sigmoid(output)) * (2 - total_angle_diff_norm) + ).mean() + + loss += ( + self.hparams["pos_penality"] * pos_penality.mean() + + self.hparams["neg_penality"] * neg_penality.mean() + ) + return loss + + def common_training_validation_step( + self, batch: Data + ) -> typing.Tuple[torch.Tensor, torch.Tensor, float]: + """Perform the inference and loss computation step that is common + to the training and validation step. + + Returns: + Network output, true labels and loss function. + """ + # Get true labels + truth = ( + batch.y_pid.bool() if "pid" in self.hparams["regime"] else batch.y.bool() ) + # Handle bidirectional graphs + edge_sample, truth_sample = self.handle_bidirectional(batch.edge_index, truth) + input_data = self.get_input_data(batch) + # Run GNN inference + output = self(input_data, edge_sample).squeeze() + loss = self.compute_loss(output=output, truth=truth_sample, batch=batch) + return output, truth_sample, loss + + def training_step(self, batch, batch_idx): + output, _, loss = self.common_training_validation_step(batch=batch) self.log( "train_loss", loss, on_epoch=True, on_step=False, - batch_size=edge_sample.shape[1], + batch_size=output.shape[0], prog_bar=True, ) return loss - def log_metrics(self, score, preds, truth, batch, loss): + def log_metrics( + self, score: float, preds: torch.Tensor, truth: torch.Tensor, loss: float + ): edge_positive = preds.sum().float() edge_true = truth.sum().float() edge_true_positive = (truth.bool() & preds).sum().float() @@ -109,7 +200,8 @@ class GNNBase(ModelBase): eff = edge_true_positive.clone().detach() / max(1, edge_true) pur = edge_true_positive.clone().detach() / max(1, edge_positive) - # Fix error: "ValueError: Only one class present in y_true. ROC AUC score is not defined in that case" + # Fix error: "ValueError: Only one class present in y_true. + # ROC AUC score is not defined in that case" try: auc = roc_auc_score(truth.bool().cpu().detach(), score.cpu().detach()) except ValueError: @@ -129,42 +221,22 @@ class GNNBase(ModelBase): batch_size=preds.shape[0], ) - def shared_evaluation(self, batch, batch_idx, log=False): - weight = ( - torch.tensor(self.hparams["weight"]) - if ("weight" in self.hparams) - else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()) - ) - - truth = ( - batch.y_pid.bool() if "pid" in self.hparams["regime"] else batch.y.bool() - ) - - edge_sample, truth_sample = self.handle_directed(batch, batch.edge_index, truth) - input_data = self.get_input_data(batch) - output = self(input_data, edge_sample).squeeze() - - if "weighting" in self.hparams["regime"]: - manual_weights = batch.weights - else: - manual_weights = None - - loss = F.binary_cross_entropy_with_logits( - output, truth_sample.float(), weight=manual_weights, pos_weight=weight - ) - + def shared_evaluation( + self, batch: Data, batch_idx: int, log: bool = False + ) -> typing.Dict[str, float]: + output, truth, loss = self.common_training_validation_step(batch=batch) # Edge filter performance score = torch.sigmoid(output) preds = score > self.hparams["edge_cut"] if log: - self.log_metrics(score, preds, truth_sample, batch, loss) + self.log_metrics(score, preds, truth, loss) return { "loss": loss, "score": score, "preds": preds, - "truth": truth_sample, + "truth": truth, } def validation_step(self, batch, batch_idx): diff --git a/LHCb_Pipeline/GNN/gnn_plots.py b/LHCb_Pipeline/GNN/gnn_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..f376217575c595cfa0b7ca346f2d0be442b2de8f --- /dev/null +++ b/LHCb_Pipeline/GNN/gnn_plots.py @@ -0,0 +1,32 @@ +import typing + +import numpy.typing as npt +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from matplotlib.axes import Axes + +from .gnn_base import GNNBase +from .gnn_validation import GNNScoreCutExplorer + + +def plot_best_performances_score_cut( + model: GNNBase, + path_or_config: str | dict, + partition: str, + score_cuts: typing.Sequence[float], + n_events: int | None = None, + seed: int | None = None, + identifier: str | None = None, +) -> typing.Tuple[Figure, npt.NDArray, typing.Dict[str, typing.Dict[str, float]]]: + if identifier is None: + identifier = "" + + gnnScoreCutExplorer = GNNScoreCutExplorer(model=model) + return gnnScoreCutExplorer.plot( + path_or_config=path_or_config, + partition=partition, + values=score_cuts, + n_events=n_events, + seed=seed, + identifier=identifier, + ) diff --git a/LHCb_Pipeline/GNN/gnn_validation.py b/LHCb_Pipeline/GNN/gnn_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..0f20e4ee34d78e00d872398b9a9b229f6ecb6c98 --- /dev/null +++ b/LHCb_Pipeline/GNN/gnn_validation.py @@ -0,0 +1,94 @@ +from __future__ import annotations +import typing +import os.path as op + +from tqdm.auto import tqdm +import pandas as pd +import torch +from torch_geometric.data import Data + +from .build_gnn import GNNInferenceBuilder +from .gnn_base import GNNBase +from Scripts.Step_5_Build_Track_Candidates import TrackBuilder +from Scripts.Step_6_Evaluate_Reconstruction_MonteTracko import ( + get_tracks_from_batch, + load_parquet_files, + perform_matching, +) +from utils.modelutils.evaluation import ParamExplorer + + +class GNNScoreCutExplorer(ParamExplorer): + """A class that allows to vary the score cut after the GNN, and compare the metric + performances of track finding. + """ + def __init__(self, model: GNNBase) -> None: + super().__init__(model, varname="score_cut", varlabel="Score cut") + + def get_tracks( + self, + value: float, + batches: typing.List[Data], + ): + inferencerBuilder = GNNInferenceBuilder(model=self.model) + with torch.no_grad(): + batches = [ + inferencerBuilder.construct_downstream(batch=batch.clone()) + for batch in tqdm(batches, desc="GNN inference") + ] + + # Run track reconstruction + trackBuilder = TrackBuilder(score_cut=value) + batches = [trackBuilder.construct_downstream(batch=batch.cpu()) for batch in batches] + + # Define dataframe of tracks + return pd.concat( + tuple(get_tracks_from_batch(batch=batch) for batch in batches) + ).drop_duplicates() + + + +def compute_best_tracks( + model: GNNBase, + input_dir: str, + output_dir: str, + file_names: typing.List[str], + score_cut: float, +) -> pd.DataFrame: + """Compute the best achievable tracking efficiency given a choice of + hyperameter ``score_cut``. + + Args: + model: PyTorch model inheriting from + :py:class:`utils.modelutils.basemodel.ModelBase` + partition: ``train``, ``val``, ``test`` (for the current already loaded + test sample) or the name of a test dataset + score_cut: Minimal GNN edge score + n_events: Number of events to compute the performance metrics on + seed: Seed used to randomly select the ``n_events`` + + Returns: + Dataframe of tracks and average graph size + """ + inferencerBuilder = GNNInferenceBuilder(model=model) + inferencerBuilder.infer( + input_dir=input_dir, + output_dir=output_dir, + reproduce=True, + file_names=file_names, + ) + batches = [ + torch.load(op.join(output_dir, filename), map_location="cpu") + for filename in file_names + ] + + # Run track reconstruction + trackBuilder = TrackBuilder(score_cut=score_cut) + batches = [trackBuilder.construct_downstream(batch=batch) for batch in batches] + + # Define dataframe of tracks + df_tracks = pd.concat( + tuple(get_tracks_from_batch(batch=batch) for batch in batches) + ).drop_duplicates() + + return df_tracks diff --git a/LHCb_Pipeline/GNN/models/interaction_gnn.py b/LHCb_Pipeline/GNN/models/interaction_gnn.py index f369106eaef54a0aaec5e0aaadd62b526b171b9c..af6a3d7d9cdc7ccd46cbb5c844149816c0075647 100644 --- a/LHCb_Pipeline/GNN/models/interaction_gnn.py +++ b/LHCb_Pipeline/GNN/models/interaction_gnn.py @@ -3,14 +3,14 @@ import torch from torch_scatter import scatter_add, scatter_max from torch.utils.checkpoint import checkpoint -from ..gnn_base import GNNBase -from utils.modelutils.mpl import make_mlp +from utils.modelutils.mlp import make_mlp from utils.commonutils.cfeatures import get_number_input_features +from ..gnn_base import GNNBase class InteractionGNN(GNNBase): - """An interaction network class - """ + """An interaction network class""" + def __init__(self, hparams): super().__init__(hparams) """ @@ -21,11 +21,17 @@ class InteractionGNN(GNNBase): concatenation_factor = ( 3 if (self.hparams["aggregation"] in ["sum_max", "mean_max"]) else 2 ) + if not self.bidir: + concatenation_factor = (concatenation_factor - 1) * 2 + 1 + + nb_edge_layers: int = hparams["nb_edge_layers"] + nb_node_layers: int = hparams["nb_node_layers"] + nb_hidden: int = hparams["hidden"] # Setup input network self.node_encoder = make_mlp( get_number_input_features(hparams["feature_indices"]), - [hparams["hidden"]] * hparams["nb_node_layer"], + [nb_hidden] * hparams.get("nb_node_encoder_layers", nb_node_layers), output_activation=None, hidden_activation=hparams["hidden_activation"], layer_norm=hparams["layernorm"], @@ -33,8 +39,8 @@ class InteractionGNN(GNNBase): # The edge network computes new edge features from connected nodes self.edge_encoder = make_mlp( - 2 * (hparams["hidden"]), - [hparams["hidden"]] * hparams["nb_edge_layer"], + 2 * (nb_hidden), + [nb_hidden] * hparams.get("nb_edge_encoder_layers", nb_edge_layers), layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], @@ -42,8 +48,8 @@ class InteractionGNN(GNNBase): # The edge network computes new edge features from connected nodes self.edge_network = make_mlp( - 3 * hparams["hidden"], - [hparams["hidden"]] * hparams["nb_edge_layer"], + 3 * nb_hidden, + [nb_hidden] * nb_edge_layers, layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], @@ -51,8 +57,8 @@ class InteractionGNN(GNNBase): # The node network computes new node features self.node_network = make_mlp( - concatenation_factor * hparams["hidden"], - [hparams["hidden"]] * hparams["nb_node_layer"], + concatenation_factor * nb_hidden, + [nb_hidden] * nb_node_layers, layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], @@ -60,8 +66,10 @@ class InteractionGNN(GNNBase): # Final edge output classification network self.output_edge_classifier = make_mlp( - 3 * hparams["hidden"], - [hparams["hidden"]] * hparams["nb_edge_layer"] + [1], + 3 * nb_hidden, + [nb_hidden] + * hparams.get("nb_edge_classifier_layers", nb_edge_layers) + + [1], layer_norm=hparams["layernorm"], output_activation=None, hidden_activation=hparams["hidden_activation"], @@ -76,19 +84,32 @@ class InteractionGNN(GNNBase): def message_step(self, x, start, end, e): # Compute new node features if self.hparams["aggregation"] == "sum": + assert not self.bidir edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0]) elif self.hparams["aggregation"] == "max": + assert not self.bidir edge_messages = scatter_max(e, end, dim=0, dim_size=x.shape[0])[0] elif self.hparams["aggregation"] == "sum_max": - edge_messages = torch.cat( - [ - scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], - scatter_add(e, end, dim=0, dim_size=x.shape[0]), - ], - dim=-1, - ) + if not self.bidir: + edge_messages = torch.cat( + [ + scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], + scatter_add(e, end, dim=0, dim_size=x.shape[0]), + scatter_max(e, start, dim=0, dim_size=x.shape[0])[0], + scatter_add(e, start, dim=0, dim_size=x.shape[0]), + ], + dim=-1, + ) + else: + edge_messages = torch.cat( + [ + scatter_max(e, end, dim=0, dim_size=x.shape[0])[0], + scatter_add(e, end, dim=0, dim_size=x.shape[0]), + ], + dim=-1, + ) node_inputs = torch.cat([x, edge_messages], dim=-1) x_out = self.node_network(node_inputs) @@ -119,7 +140,6 @@ class InteractionGNN(GNNBase): # edge_outputs = [] # Loop over iterations of edge and node networks for i in range(self.hparams["n_graph_iters"]): - x, e = checkpoint(self.message_step, x, start, end, e) # Compute final edge scores; use original edge directions only diff --git a/LHCb_Pipeline/GNN/perfect_gnn.py b/LHCb_Pipeline/GNN/perfect_gnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5714c6458ee8e4b1842d8660c4e0582a4a7d3db7 --- /dev/null +++ b/LHCb_Pipeline/GNN/perfect_gnn.py @@ -0,0 +1,22 @@ +"""Replace the GNN by a perfect inference in order to understand +what is the best result that can be obtained with the current pipeline. +""" +from torch_geometric.data import Data +from utils.modelutils.build import BuilderBase +from GNN.gnn_base import compute_edge_labels + + +class PerfectInferenceBuilder(BuilderBase): + """Generate perfect inference, that is, the edge score is equal to the truth. + """ + def construct_downstream(self, batch: Data, pid: bool = False): + if pid: + if "y_pid" not in batch: + batch["y_pid"] = compute_edge_labels( + edge_indices=batch.edge_index, + particle_ids=batch.particle_id, + ) + batch.scores = batch["y_pid"] + else: + batch.scores = batch.y + return batch diff --git a/LHCb_Pipeline/Preprocessing/particle_line_fitting.py b/LHCb_Pipeline/Preprocessing/particle_line_fitting.py index f0969ab97332426dcd7ec8cebfb35a48320ba827..24b34f72630ac9528568f6fe3e526985276b73e6 100644 --- a/LHCb_Pipeline/Preprocessing/particle_line_fitting.py +++ b/LHCb_Pipeline/Preprocessing/particle_line_fitting.py @@ -132,7 +132,7 @@ def compute_particle_metric( @nb.jit(nopython=True, cache=True) -def compute_particle_distances_to_lines_events_impl( +def compute_particle_line_metrics_events_impl( array_metric_values: np.ndarray, coords_events_particles: np.ndarray, event_ids: np.ndarray, @@ -178,9 +178,10 @@ def compute_particle_distances_to_lines_events_impl( event_idx += n_particles -def compute_particle_distances_to_lines_dataframe( +def compute_particle_line_metrics_dataframe( hits: pd.DataFrame, metric_names: typing.List[str], + event_id_column: str = "event", ) -> pd.DataFrame: """Compute the pandas Series of the distance from particle hits to straight lines fitted to these lines. The "distance" actually corresponds to the square-root @@ -197,6 +198,7 @@ def compute_particle_distances_to_lines_dataframe( * ``xz_angle`` * ``yz_angle`` + event_id_column: name of the event ID column Returns: A pandas Series with index ``event`` and ``particle_id``, and for every @@ -204,12 +206,15 @@ def compute_particle_distances_to_lines_dataframe( fitted to the points. The distance is the square-root of the average of the squared distances from the hits to the straight line. """ - hits = hits.sort_values(by=["event", "particle_id"]) - events_particles_group = hits.groupby(["event", "particle_id"], sort=False).size() + + hits = hits.sort_values(by=[event_id_column, "particle_id"]) + events_particles_group = hits.groupby( + [event_id_column, "particle_id"], sort=False + ).size() n_particles = events_particles_group.shape[0] array_metric_values = np.zeros(shape=(n_particles, len(metric_names))) - compute_particle_distances_to_lines_events_impl( - event_ids=hits["event"].to_numpy(), + compute_particle_line_metrics_events_impl( + event_ids=hits[event_id_column].to_numpy(), particle_ids=hits["particle_id"].to_numpy(), coords_events_particles=hits[["x", "y", "z"]].to_numpy(), array_metric_values=array_metric_values, diff --git a/LHCb_Pipeline/Preprocessing/particle_line_metrics.py b/LHCb_Pipeline/Preprocessing/particle_line_metrics.py index dca906a9c14feac071341faae98178c2ea50b5f1..b29f3f90e5f51dfdb1b9492fc75c37c7c5ac1e27 100644 --- a/LHCb_Pipeline/Preprocessing/particle_line_metrics.py +++ b/LHCb_Pipeline/Preprocessing/particle_line_metrics.py @@ -1,6 +1,11 @@ +import warnings import numpy as np import numba as nb +warnings.filterwarnings( + "ignore", ".*type 'reflected list'.*", category=DeprecationWarning +) + @nb.jit(nopython=True, cache=True) def compute_distance_to_line( diff --git a/LHCb_Pipeline/Preprocessing/preprocessing.py b/LHCb_Pipeline/Preprocessing/preprocessing.py index 6241961c1f5902ebf9043c31a71bb7836cad3935..e20f6b8a4be37a9a39ffac247520d21e949034e0 100644 --- a/LHCb_Pipeline/Preprocessing/preprocessing.py +++ b/LHCb_Pipeline/Preprocessing/preprocessing.py @@ -3,6 +3,7 @@ import typing import os import logging from tqdm.auto import tqdm +import numpy as np import pandas as pd from . import selecting @@ -48,37 +49,43 @@ def load_dataframes( **kwargs: other keyword arguments passed to the function that load the files Returns: - A 2-tuple containing the dataframe of hits and the dataframes of particles + A 2-tuple containing the dataframe of hits-particles and the dataframes + of particles Notes: The function also defines the column ``particle_id = mcid + 1`` in both dataframes. """ - particles = pd.read_parquet( - f"{indir}/mc_particles.parquet.lz4", - columns=None - if particles_columns is None - else ["event", "mcid"] + particles_columns, + path=os.path.join(indir, "mc_particles.parquet.lz4"), + columns=( + None if particles_columns is None else ["event", "mcid"] + particles_columns + ), **kwargs, ) - cast_boolean_columns(particles) hits_particles = pd.read_parquet( - f"{indir}/hits_velo.parquet.lz4", - columns=None - if hits_particles_columns is None - else ["event", "mcid", "lhcbid"] + hits_particles_columns, + path=os.path.join(indir, "hits_velo.parquet.lz4"), + columns=( + None + if hits_particles_columns is None + else ["event", "mcid", "lhcbid"] + hits_particles_columns + ), **kwargs, ) + cast_boolean_columns(particles) + # Define `particle_id = mcid + 1` directly in the original dataframes particles["particle_id"] = particles["mcid"] + 1 hits_particles["particle_id"] = hits_particles["mcid"] + 1 + particles.drop("mcid", axis=1, inplace=True) + hits_particles.drop("mcid", axis=1, inplace=True) + # Rename `lhcbid` to `hit_id` hits_particles.rename(columns={"lhcbid": "hit_id"}, inplace=True) - return particles, hits_particles + return hits_particles, particles def enough_true_hits( @@ -133,113 +140,242 @@ def enough_true_hits( return True -def preprocess( - input_dir: str, - output_dir: str, - n_events: int, +def load_and_filter_dataframes( + indir: str, + particles_columns: typing.List[str] | None = None, + hits_particles_columns: typing.List[str] | None = None, selection: str | None = None, - num_true_hits_threshold: int | None = None, -): - """Preprocess the first `n_events` events in the input files, - into the form of the TrackML dataset. - Remove any events that contain only fake hits. - """ - pd.set_option("chained_assignment", None) # disable chaine assignment warning - os.makedirs(output_dir, exist_ok=True) - logging.info(f"Preprocessing: output will be written in {output_dir}") + **kwargs, +) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: + """Load and filter the dataframes of hits-particles and particles. - #: Columns to load from the `hits_particles` dataframe - hits_particles_columns = [ - # Features - "x", - "y", - "z", - # Plane-wise edges - "plane", - ] - - #: Columns to load from the `particles` dataframe - # particles_columns = [ - # # Module-wise true edges - # "vx", - # "vy", - # "vz", - # # For evaluation and selection - # "has_velo", - # "has_scifi", - # "charge", - # "pid", - # "nhits_velo", - # "pt", - # "p", - # ] - particles_columns = None + Args: + indir: directory where the dataframes are saved + particles_columns: columns to load for the dataframe of particles + hits_particles_columns: columns to load for the dataframe of hits + and the hits-particles association information + selection: function to use to filter the candidates. + The latter is defined in the :py:mod:`.selecting` module. + **kwargs: other keyword arguments passed to the function that load the files + + Returns: + Filtered dataframes of hits-particles and of particles + """ # Load dataframes - logging.info("Load dataframe") - particles, hits_particles = load_dataframes( - indir=input_dir, - # particles_columns=particles_columns, + logging.info(f"Load dataframes in {indir}") + hits_particles, particles = load_dataframes( + indir=indir, + particles_columns=particles_columns, hits_particles_columns=hits_particles_columns, + **kwargs, ) # Add truth particle information to the dataframe of hits if selection: - logging.info("Apply selection") + logging.info(f"Apply selection `{selection}`") selection_function: selecting.SelectionFunction = getattr(selecting, selection) hits_particles, particles = selection_function( hits_particles=hits_particles, particles=particles, ) - event_list = particles["event"].unique() # The order is not mixed + # Define `n_unique_planes` + # We'll train with all the hits for the training + # And cut before the GNN + n_unique_planes = ( + hits_particles.groupby(["event", "particle_id"])["plane"] + .nunique() + .rename("n_unique_planes") + ) + particles = particles.merge( + n_unique_planes, how="left", on=["event", "particle_id"] + ).fillna(0) + + return hits_particles, particles + + +def get_indirs( + input_dir: str | None = None, + subdirs: int | str | typing.List[str] | typing.Dict[str, int] | None = None, +): + """Get the input directories that can be used as input of the preprocessing. + + Args: + input_dir: A single input directory if ``subdirs`` is ``None``, + or the main directory where sub-directories are + subdirs: + + * If ``subdirs`` is None, there is a single input directory, ``input_dir`` + * If ``subdirs`` is a string or a list of strings, they specify \ + the sub-directories with respect to ``input_dir``. If ``input_dir`` \ + is ``None``, then they are the (list of) input directories directly, which \ + can be useful if the input directories are not at the same location \ + (even though it is discouraged) + * If ``subdirs`` is an integer, it corresponds to the the name of the last \ + sub-directory to consider (i.e., from 0 to ``subdirs``). If ``subdirs`` \ + is ``-1``, all the sub-directories are considered as input. + * If ``subdirs`` is a dictionary, the keys ``start`` and ``stop`` specify \ + the first and last sub-directories to consider as input. + + Returns: + List of input directories that can be considered. + """ + if input_dir is None: + if isinstance(subdirs, str): + return [subdirs] + elif isinstance(subdirs, list): + return [str(subdir) for subdir in subdirs] + else: + raise TypeError( + "`input_dir` is `None` but `subdirs` is neither a string nor " + "a list of strings, so the input directories of the preprocessing " + "cannot be determined." + ) + else: + # Get the list of all the sub-directories inside ``input_dir`` + + # Filter this list according to ``subdirs`` + if subdirs is None: + return [input_dir] + elif isinstance(subdirs, (int, dict)): + available_subdirs = sorted( + [ + int(file_or_dir.name) + for file_or_dir in os.scandir(input_dir) + if file_or_dir.is_dir() + ] + ) + if subdirs == -1: + final_subdirs = available_subdirs + else: + if isinstance(subdirs, int): + start = 0 + stop = subdirs + else: # dict + start = subdirs.get("start", 0) + stop = subdirs["stop"] + + assert ( + stop >= start + ), f"`start` ({start}) is strictly higher than `stop ({stop})" + final_subdirs = [ + subdir + for subdir in available_subdirs + if subdir >= start and subdir <= stop + ] + elif isinstance(subdirs, str): + final_subdirs = [subdirs] + elif isinstance(subdirs, list): + final_subdirs = subdirs + else: + raise ValueError( + f"`input_dir` is not `None` and `subdirs` is `{subdirs}`, which are " + "not valid inputs." + ) + + return [os.path.join(input_dir, str(subdir)) for subdir in final_subdirs] + + +def preprocess( + input_dir: str, + output_dir: str, + subdirs: int | str | typing.List[str] | None = None, + n_events: int = -1, + selection: str | None = None, + num_true_hits_threshold: int | None = None, + hits_particles_columns: typing.List[str] | None = None, + particles_columns: typing.List[str] | None = None, +): + """Preprocess the first `n_events` events in the input files, + into the form of the TrackML dataset. + Remove any events that contain only fake hits. + """ + pd.set_option("chained_assignment", None) # disable chaine assignment warning + os.makedirs(output_dir, exist_ok=True) + logging.info(f"Preprocessing: output will be written in {output_dir}") + + indirs = get_indirs(input_dir=input_dir, subdirs=subdirs) + if len(indirs) == 0: + raise ValueError("No input directories.") + logging.info("Input directories:") + for indir in indirs: + logging.info(f"- {indir}") + n_output_saved = 0 # Count the number of events outputted event_idx = 0 - with tqdm(total=n_events) as pbar: - while n_output_saved < n_events and event_idx < len(event_list): - current_event_id = event_list[event_idx] - event_hits_particles = hits_particles[ - hits_particles["event"] == current_event_id - ] - event_particles = particles[particles["event"] == current_event_id] - - #: String representation of the event ID - event_id_str = str(current_event_id).zfill(9) - - if (num_true_hits_threshold is None) or enough_true_hits( - event_hits_particles=event_hits_particles, - num_true_hits_threshold=num_true_hits_threshold, - event_id_str=event_id_str, - num_events=n_output_saved, - required_num_events=n_events, - ): - # Save subset of columns - if hits_particles_columns is None: - hits_particles_csv = event_hits_particles - else: - hits_particles_csv = event_hits_particles[ - ["particle_id", "hit_id"] + hits_particles_columns - ] - if particles_columns is None: - particles_csv = event_particles - else: - particles_csv = event_particles[["particle_id"] + particles_columns] - - # Save - hits_particles_csv.to_parquet( - f"{output_dir}/event{event_id_str}-hits_particles.parquet", - ) - particles_csv.to_parquet( - f"{output_dir}/event{event_id_str}-particles.parquet", - ) - n_output_saved += 1 - pbar.update() - event_idx += 1 + n_required_events = np.inf if n_events == -1 else n_events + left_indirs = indirs + logging.info(f"Number of events to produce: {n_required_events}") + with tqdm(total=n_required_events) as pbar: + while n_output_saved < n_required_events and left_indirs: + hits_particles, particles = load_and_filter_dataframes( + indir=left_indirs[0], + hits_particles_columns=hits_particles_columns, + particles_columns=particles_columns, + selection=selection, + ) + left_indirs = left_indirs[1:] + + event_ids_in_df_hits_particles = hits_particles["event"].unique() + event_ids_df_particles = particles["event"].unique() + event_ids = np.intersect1d( + event_ids_in_df_hits_particles, event_ids_df_particles + ) + + # Loop over the events in the dataframe of hits-particles + grouped_df_hits_particles = hits_particles.groupby("event") + grouped_df_particles = particles.groupby("event") + for event_id in event_ids: + if n_output_saved >= n_required_events: + break + event_hits_particles = grouped_df_hits_particles.get_group(event_id) + event_particles = grouped_df_particles.get_group(event_id) + + #: String representation of the event ID + event_id_str = str(event_id).zfill(9) + + no_hits = event_hits_particles.shape[0] == 0 + + if not no_hits and ( + (num_true_hits_threshold is None) + or enough_true_hits( + event_hits_particles=event_hits_particles, + num_true_hits_threshold=num_true_hits_threshold, + event_id_str=event_id_str, + num_events=n_output_saved, + required_num_events=n_events, + ) + ): + # Save subset of columns + if hits_particles_columns is None: + hits_particles_csv = event_hits_particles + else: + hits_particles_csv = event_hits_particles[ + ["particle_id", "hit_id"] + hits_particles_columns + ] + if particles_columns is None: + particles_csv = event_particles + else: + particles_csv = event_particles[ + ["particle_id"] + particles_columns + ] + + # Save + hits_particles_csv.to_parquet( + f"{output_dir}/event{event_id_str}-hits_particles.parquet", + ) + particles_csv.to_parquet( + f"{output_dir}/event{event_id_str}-particles.parquet", + ) + n_output_saved += 1 + pbar.update() + event_idx += 1 pbar.close() pd.set_option("chained_assignment", "warn") # re-enable chained-assignment warning - if n_output_saved < n_events: + if n_output_saved < n_required_events: raise Exception( "Not enough events found with more than " f"{num_true_hits_threshold} true hits" diff --git a/LHCb_Pipeline/Preprocessing/run_preprocessing.py b/LHCb_Pipeline/Preprocessing/run_preprocessing.py index fcbecaa550b047678ad3ccfac08afe8ac3d81733..1a1db4a8e0df5537f0c476ee3cade71d544ae7b4 100644 --- a/LHCb_Pipeline/Preprocessing/run_preprocessing.py +++ b/LHCb_Pipeline/Preprocessing/run_preprocessing.py @@ -18,7 +18,7 @@ def run_preprocessing(path_or_config: str | dict, reproduce: bool = True): """ config = load_config(path_or_config) output_dir = config["preprocessing"]["output_dir"] - + if config["preprocessing"]["n_events"] is None: config["preprocessing"]["n_events"] = ( config["processing"]["n_train_events"] diff --git a/LHCb_Pipeline/Preprocessing/selecting.py b/LHCb_Pipeline/Preprocessing/selecting.py index d9105cf383a92d3c34f37b2d6203adac3359003c..9a0000c0d6a4528b55facbfa9ca0f752a3bac5e6 100644 --- a/LHCb_Pipeline/Preprocessing/selecting.py +++ b/LHCb_Pipeline/Preprocessing/selecting.py @@ -1,5 +1,8 @@ import typing +import logging +import numpy as np import pandas as pd +from .particle_line_fitting import compute_particle_line_metrics_dataframe class SelectionFunction(typing.Protocol): @@ -29,12 +32,15 @@ def apply_mask( """ # About 3 seconds in a dataframe with 5000 events hits_particles_mask = ( - hits_particles[["event", "mcid"]] + hits_particles[["event", "particle_id"]] + .reset_index() .merge( - right=particles[["event", "mcid"]].assign(mask_=particles_mask), - on=["event", "mcid"], + right=particles[["event", "particle_id"]].assign(mask_=particles_mask), + on=["event", "particle_id"], how="left", - )["mask_"] + sort=False, + ) + .set_index("index")["mask_"] .fillna(True) ) # fillna to keep fake hits @@ -80,8 +86,8 @@ def everything_but_long_electrons( long electrons are left. """ # 1. Create a mask of the particles to keep: - mask_particles_to_keep = ( - particles["has_velo"] & particles["has_scifi"] & (particles["pid"].abs() != 11) + mask_particles_to_keep = particles["has_velo"] & ~( + particles["has_scifi"] & (particles["pid"].abs() == 11) ) # 2. Propagate the mask to the dataframe of `particles` and `hits_particles` @@ -91,6 +97,43 @@ def everything_but_long_electrons( return hits_particles, particles +def default_old_training_for_rta_presentation( + hits_particles: pd.DataFrame, particles: pd.DataFrame +) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: + """Selection that was used in the training presented in the RTA meeting.""" + + # Drop duplicates hits + hits_particles = hits_particles.drop_duplicates( + subset=["event", "particle_id", "plane"], keep="first" + ) + # Remove fake hits (there shouldn't be any already) + hits_particles = hits_particles[hits_particles["particle_id"] != 0] + + # Compute distance to line and add it to the dataframe of particles + logging.info("Compute distance to line (that might take some time)") + new_distances = compute_particle_line_metrics_dataframe( + hits=hits_particles, + metric_names=["distance_to_line"], + ) + particles = particles.merge(new_distances, how="left", on=["event", "particle_id"]) + + # Only keep reconstructible particles that are straight enough + logging.info("Apply particle selection mask") + mask_particles_to_keep = ( + (particles["has_velo"] == 1) + & (particles["nhits_velo"] >= 3) + & (particles["distance_to_line"] < np.sqrt(0.6)) + ) + particles, hits_particles = apply_mask( + mask_particles_to_keep, particles, hits_particles + ) + # assert that there is not any nan values at this point + assert not particles.isna().any().any() + assert not hits_particles.isna().any().any() + + return hits_particles, particles + + def everything_but_electrons( hits_particles: pd.DataFrame, particles: pd.DataFrame ) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: @@ -105,12 +148,64 @@ def everything_but_electrons( long electrons are left. """ # 1. Create a mask of the particles to keep: - mask_particles_to_keep = ( - particles["has_velo"] & (particles["pid"].abs() != 11) - ) + mask_particles_to_keep = particles["has_velo"] & (particles["pid"].abs() != 11) # 2. Propagate the mask to the dataframe of `particles` and `hits_particles` particles, hits_particles = apply_mask( mask_particles_to_keep, particles, hits_particles ) return hits_particles, particles + + +def track_weighting_selection( + hits_particles: pd.DataFrame, particles: pd.DataFrame +) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: + """The selection performed in the ``track-weighting`` experiment.""" + + # Only keep reconstructible particles that are straight enough + # Also remove the hits to avoid splitted tracks + # (we will counter-balance by requiring enough clusters / event) + logging.info("Compute distance to line (that might take some time)") + new_distances = compute_particle_line_metrics_dataframe( + hits=hits_particles, metric_names=["distance_to_line"] + ) + particles = particles.merge(new_distances, how="left", on=["event", "particle_id"]) + mask_particles_to_keep = particles["distance_to_line"] < 0.8 + particles, hits_particles = apply_mask( + mask_particles_to_keep, particles, hits_particles + ) + + # assert that there is not any nan values at this point + assert not particles.isna().any().any() + assert not hits_particles.isna().any().any() + + return hits_particles, particles + + +def triplets_first_selection( + hits_particles: pd.DataFrame, particles: pd.DataFrame +) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: + """The selection performed in the ``triplets-edge`` experiment.""" + # Only keep one particles-hits association (drop duplicates) + hits_particles = hits_particles.drop_duplicates( + subset=["event", "hit_id"], keep="first" + ) + + # Only keep reconstructible particles that are straight enough + # Also remove the hits to avoid splitted tracks + # (we will counter-balance by requiring enough clusters / event) + logging.info("Compute distance to line (that might take some time)") + new_distances = compute_particle_line_metrics_dataframe( + hits=hits_particles, metric_names=["distance_to_line"] + ) + particles = particles.merge(new_distances, how="left", on=["event", "particle_id"]) + mask_particles_to_keep = particles["distance_to_line"] < 0.8 + particles, hits_particles = apply_mask( + mask_particles_to_keep, particles, hits_particles + ) + + # assert that there is not any nan values at this point + assert not particles.isna().any().any() + assert not hits_particles.isna().any().any() + + return hits_particles, particles diff --git a/LHCb_Pipeline/Processing/planewise_edges.py b/LHCb_Pipeline/Processing/planewise_edges.py index a0757960b8387448bb34cb26aff9e360edcc0867..dcf4107e4b17938fcbb7310f3c1877bae1a74e84 100644 --- a/LHCb_Pipeline/Processing/planewise_edges.py +++ b/LHCb_Pipeline/Processing/planewise_edges.py @@ -4,6 +4,7 @@ This way, we define the edge orientation using a left to right convention. However, if a plane a multiple hits for the same particle, the edges can be not well defined. """ +import typing import numpy as np import pandas as pd import numba as nb @@ -13,12 +14,12 @@ from utils.tools.tgroupby import get_group_indices_from_group_lengths @nb.jit(nopython=True, cache=True) def get_edges_from_sorted_impl( - edges: np.ndarray, hit_ids: np.ndarray, - particle_group_indices: np.ndarray, -) -> None: + particle_ids: np.ndarray, + plane_ids: np.ndarray, +) -> typing.List[np.ndarray]: """Fill the array of plane-wise edges by grouping by particle ID already sorted - by plane, and forming edge by linking "adjacent" hit IDs. + by plane, and forming edge by linking "adjacent" planes. Args: edges: Pre-allocated empty array of edges to fill @@ -26,46 +27,69 @@ def get_edges_from_sorted_impl( particle_group_indices: Start and end indices in ``hit_ids`` that delimits hits that have same particle ID. """ - edge_idx = 0 - for start_idx, end_idx in zip( + n_hits_per_particles = group_lengths(particle_ids)[0] + particle_group_indices = get_group_indices_from_group_lengths(n_hits_per_particles) + + list_edges = [np.zeros(dtype=hit_ids.dtype, shape=(2, 1)) for _ in range(0)] + + for particle_start_idx, particle_end_idx in zip( particle_group_indices[:-1], particle_group_indices[1:] ): - n_edges = end_idx - start_idx - 1 - next_edge_idx = edge_idx + n_edges - edges[0, edge_idx:next_edge_idx] = hit_ids[start_idx : end_idx - 1] - edges[1, edge_idx:next_edge_idx] = hit_ids[start_idx + 1 : end_idx] - edge_idx = next_edge_idx + particle_hit_ids = hit_ids[particle_start_idx: particle_end_idx] + n_planes_per_hits = group_lengths( + plane_ids[particle_start_idx:particle_end_idx] + )[0] + plane_group_indices = get_group_indices_from_group_lengths(n_planes_per_hits) + + n_edges = np.sum(np.multiply(n_planes_per_hits[:-1], n_planes_per_hits[1:])) + edges = np.full(shape=(2, n_edges), dtype=hit_ids.dtype, fill_value=-1) + edge_idx = 0 + for plane_group in range( + len(plane_group_indices) - 2 + ): # up second to last plane + plane_start_idx = plane_group_indices[plane_group] + plane_end_idx = plane_group_indices[plane_group + 1] + next_plane_end_idx = plane_group_indices[plane_group + 2] + for hit_idx in range(plane_start_idx, plane_end_idx): + n_edges_to_add = next_plane_end_idx - plane_end_idx + edges[0, edge_idx : edge_idx + n_edges_to_add] = particle_hit_ids[ + hit_idx + ] + edges[1, edge_idx : edge_idx + n_edges_to_add] = particle_hit_ids[ + plane_end_idx:next_plane_end_idx + ] + edge_idx += n_edges_to_add - # Sanity check - assert edge_idx == edges.shape[1] + # Sanity check + assert edge_idx == n_edges + list_edges.append(edges) + + return list_edges def get_planewise_edges_impl( hit_ids: np.ndarray, particle_ids: np.ndarray, + plane_ids: np.ndarray, ) -> np.ndarray: """Get the plane-wise edges - + Args: hit_ids: array of hit IDs, sorted by particle IDs particle_ids: Sorted array of particle IDs for every hit - + plane_ids: Sorted array of plane IDs for every hit + Returns: Two-dimensional array where every column represent an edge. In this array, for every edge, a hit is referred to by its index in the dataframe of hits. """ - n_hits_per_particles = group_lengths(particle_ids)[0] - particle_group_indices = get_group_indices_from_group_lengths(n_hits_per_particles) - # Create, fill and return array of edges - n_edges = (n_hits_per_particles - 1).sum() - edges = np.zeros(shape=(2, n_edges), dtype=int) - get_edges_from_sorted_impl( - edges=edges, + list_edges = get_edges_from_sorted_impl( hit_ids=hit_ids, - particle_group_indices=particle_group_indices, + particle_ids=particle_ids, + plane_ids=plane_ids, ) - return edges + return np.hstack(list_edges) def get_planewise_edges( @@ -99,4 +123,5 @@ def get_planewise_edges( return get_planewise_edges_impl( hit_ids=signal_hits["index"].to_numpy(), particle_ids=signal_hits["particle_id"].to_numpy(), + plane_ids=signal_hits["plane"].to_numpy(), ) diff --git a/LHCb_Pipeline/Processing/processing.py b/LHCb_Pipeline/Processing/processing.py index 84b82c36da8c48e587bdd6a60560cc5eebcc5062..4e00f5bfbf3cef116ccb5cbddec06a33a816b88e 100644 --- a/LHCb_Pipeline/Processing/processing.py +++ b/LHCb_Pipeline/Processing/processing.py @@ -12,6 +12,8 @@ import torch from torch_geometric.data import Data from .modulewise_edges import get_modulewise_edges +from .planewise_edges import get_planewise_edges +from .sortedwise_edges import get_sortedwise_edges from .compute import compute_columns @@ -47,14 +49,31 @@ def get_normalised_features( return (array_features - feature_means) / feature_scales +def _get_source_target_columns( + columns: typing.List[str | typing.Dict[str, str]] +) -> typing.Tuple[typing.List[str], typing.List[str]]: + columns_source = [] + columns_target = [] + for column in columns: + if isinstance(column, dict): + first_key = next(iter(column.keys())) + columns_source.append(column[first_key]) + columns_target.append(first_key) + else: + columns_source.append(column) + columns_target.append(column) + return columns_source, columns_target + + def build_event( truncated_path: str, event_str: str, features: typing.List[str], feature_means: typing.List[float], feature_scales: typing.List[float], - kept_hits_columns: typing.List[str], - kept_particles_columns: typing.List[str], + kept_hits_columns: typing.List[str | typing.Dict[str, str]], + kept_particles_columns: typing.List[str | typing.Dict[str, str]], + true_edges_column: str, ) -> Data: """Load the event, compute the necessary columns. @@ -76,11 +95,19 @@ def build_event( particles = pd.read_parquet(truncated_path + "-particles.parquet") hits_particles = pd.read_parquet(truncated_path + "-hits_particles.parquet") + ( + kept_particles_columns_source, + kept_particles_columns_target, + ) = _get_source_target_columns(kept_particles_columns) + kept_hits_columns_source, kept_hits_columns_target = _get_source_target_columns( + kept_hits_columns + ) + merged_particles_columns = list( set( ["particle_id"] # index + ["vx", "vy", "vz"] # module-wise true edges - + kept_particles_columns # other columns to keep in the PyTorch data object + + kept_particles_columns_source # other columns to keep in the PyTorch data object ) ) @@ -94,16 +121,20 @@ def build_event( # Compute columns that are not already defined compute_columns( hits=hits_particles, - columns=kept_particles_columns + features, + columns=kept_particles_columns_source + features, ) # Find the true edges - true_edges = get_modulewise_edges(hits_particles) - logging.debug( - "Modulewise truth graph built for {} with size {}".format( - truncated_path, true_edges.shape + if true_edges_column == "modulewise": + true_edges = get_modulewise_edges(hits_particles) + elif true_edges_column == "planewise": + true_edges = get_planewise_edges(hits_particles) + elif true_edges_column == "sortedwise": + true_edges = get_sortedwise_edges(hits_particles) + else: + raise ValueError( + f"`true_edges_column` is `{true_edges_column}`, which is not recognised." ) - ) normalised_features = get_normalised_features( hits_particles, @@ -112,15 +143,20 @@ def build_event( feature_scales=feature_scales, ) - kept_columns = set( + kept_columns_source = ( # required columns [ "particle_id", # Plots, "hit_id", # matching ] # Other columns - + kept_hits_columns - + kept_particles_columns + + kept_hits_columns_source + + kept_particles_columns_source + ) + kept_columns_target = ( + ["particle_id", "hit_id"] + + kept_hits_columns_target + + kept_particles_columns_target ) torch_data = Data( @@ -128,10 +164,12 @@ def build_event( truncated_path=truncated_path, # To know for sure where the data come from event_str=event_str, # for the file names **{ - column: torch.from_numpy(hits_particles[column].to_numpy()) - for column in kept_columns + column_target: torch.from_numpy(hits_particles[column_source].to_numpy()) + for column_source, column_target in zip( + kept_columns_source, kept_columns_target + ) }, - modulewise_true_edges=torch.from_numpy(true_edges), + signal_true_edges=torch.from_numpy(true_edges), ) return torch_data diff --git a/LHCb_Pipeline/Processing/run_processing.py b/LHCb_Pipeline/Processing/run_processing.py index 0ad2e97e452d72328c1c12756fe9ab22fb1554a6..de7975ba7220d822948aa61330c33bef7ffbb638 100644 --- a/LHCb_Pipeline/Processing/run_processing.py +++ b/LHCb_Pipeline/Processing/run_processing.py @@ -36,16 +36,19 @@ def run_processing_in_parallel( """ if reproduce: delete_directory(output_dir) - elif is_directory_not_empty(output_dir): - logging.warn(f"Output directory is not empty: {output_dir}") os.makedirs(os.path.join(output_dir), exist_ok=True) - logging.info("Writing outputs to " + output_dir) - # Process input files with a worker pool and progress bar - process_func = partial( - prepare_event, output_dir=output_dir, **processing_config - ) - process_map(process_func, truncated_paths, max_workers=max_workers, chunksize=1) + if is_directory_not_empty(output_dir): + logging.info( + f"Output folder is not empty so processing was not run: {output_dir}" + ) + else: + logging.info("Writing outputs to " + output_dir) + # Process input files with a worker pool and progress bar + process_func = partial( + prepare_event, output_dir=output_dir, **processing_config + ) + process_map(process_func, truncated_paths, max_workers=max_workers, chunksize=1) def run_processing_test_dataset( diff --git a/LHCb_Pipeline/Processing/sortedwise_edges.py b/LHCb_Pipeline/Processing/sortedwise_edges.py new file mode 100644 index 0000000000000000000000000000000000000000..53969f3ac85201cfb580ccbb950458646f65519f --- /dev/null +++ b/LHCb_Pipeline/Processing/sortedwise_edges.py @@ -0,0 +1,101 @@ +"""A module that defines a way of defines the edges by sorting the hits by z-abscissa +(instead of by distance from the origin vertex). + +This way, we define the edge orientation using a left to right convention. +""" +import numpy as np +import pandas as pd +import numba as nb +from montetracko.array_utils.groupby import group_lengths +from utils.tools.tgroupby import get_group_indices_from_group_lengths + + +@nb.jit(nopython=True, cache=True) +def get_edges_from_sorted_impl( + edges: np.ndarray, + hit_ids: np.ndarray, + particle_group_indices: np.ndarray, +) -> None: + """Fill the array of sorted-wise edges by grouping by hits belonging to the + same particle, already sorted by z, and forming edge by linking "adjacent" hit IDs. + + Args: + edges: Pre-allocated empty array of edges to fill + hit_ids: List of hit IDs, sorted by particle IDs and z-coordinates. + particle_group_indices: Start and end indices in ``hit_ids`` + that delimits hits that have same particle ID. + """ + edge_idx = 0 + for start_idx, end_idx in zip( + particle_group_indices[:-1], particle_group_indices[1:] + ): + n_edges = end_idx - start_idx - 1 + next_edge_idx = edge_idx + n_edges + edges[0, edge_idx:next_edge_idx] = hit_ids[start_idx : end_idx - 1] + edges[1, edge_idx:next_edge_idx] = hit_ids[start_idx + 1 : end_idx] + edge_idx = next_edge_idx + + # Sanity check + assert edge_idx == edges.shape[1] + + +def get_sortedwise_edges_impl( + hit_ids: np.ndarray, + particle_ids: np.ndarray, +) -> np.ndarray: + """Get the sorted-wise edges + + Args: + hit_ids: array of hit IDs, sorted by particle IDs + particle_ids: z-sorted array of particle IDs for every hit + + Returns: + Two-dimensional array where every column represent an edge. In this array, + for every edge, a hit is referred to by its index in the dataframe of hits. + """ + n_hits_per_particles = group_lengths(particle_ids)[0] + particle_group_indices = get_group_indices_from_group_lengths(n_hits_per_particles) + + # Create, fill and return array of edges + n_edges = (n_hits_per_particles - 1).sum() + edges = np.zeros(shape=(2, n_edges), dtype=int) + get_edges_from_sorted_impl( + edges=edges, + hit_ids=hit_ids, + particle_group_indices=particle_group_indices, + ) + return edges + + +def get_sortedwise_edges( + hits: pd.DataFrame, drop_duplicates: bool = False +) -> np.ndarray: + """Get edges by sorting the hits by ``z`` for every particle in the event, + and linking the adjacent hits by edges. + + Args: + hits: dataframe of hits, with columns ``particle_id`` and ``z`` + drop_duplicates: whether to drop hits of a particle that belong to the same + z + + Returns: + Two-dimensional array where every column represent an edge. In this array, + for every edge, a hit is referred to by its index in the dataframe of hits. + """ + # Exclude noise + signal_hits = hits[hits.particle_id != 0] + + # Remove hits on the same z belonging to the same particle + if drop_duplicates: + signal_hits = signal_hits.drop_duplicates(subset=["particle_id", "z"]) + + # Sort by particle ID and z in order to group by particle ID and z in Numba + signal_hits = signal_hits.sort_values(["particle_id", "z"]).reset_index( + drop=False + ) # produce `index`, the indices before sorting + + # Get edges + return get_sortedwise_edges_impl( + hit_ids=signal_hits["index"].to_numpy(), + particle_ids=signal_hits["particle_id"].to_numpy(), + ) diff --git a/LHCb_Pipeline/Scripts/Build_Triplets.py b/LHCb_Pipeline/Scripts/Build_Triplets.py new file mode 100644 index 0000000000000000000000000000000000000000..85025101e9e5ac8ca9406cf4d0cfd3f0fa27fe15 --- /dev/null +++ b/LHCb_Pipeline/Scripts/Build_Triplets.py @@ -0,0 +1,216 @@ +""" +This script runs step 5 of the TrackML Quickstart example: Labelling spacepoints based on the scored graph. +""" +from __future__ import annotations +import typing +import logging + +import numpy as np +import scipy.sparse as sps +import torch +from torch_geometric.data import Data +from utils.commonutils.config import load_config +from utils.commonutils.ctests import get_required_test_dataset_names +from utils.commonutils.crun import run_for_different_partitions +from utils.modelutils.build import BuilderBase +from utils.scriptutils import parse_args, configure_logger, headline +from GNN.gnn_base import compute_edge_labels + + +configure_logger() + + +def filter_graph_edges( + edge_mask: torch.Tensor, + edge_indices: torch.Tensor, +) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """Filter the hits that are not connected to any edge after filtering the graph + edges. + + Args: + edge_mask: Mask to filter the edges + edge_indices: current edge indices, before filtering + features: node features + + Returns: + Tuple of two arrays. Array of filtered edge indices, reindexed to take into + account the filtering of the nodes. + The second array is the array of unique node indices that remain in the graph, + and that can be used to filter all the node-related attributes. + """ + # Only keep kemaining nodes + unique_node_indices = torch.unique(edge_indices, sorted=True) + + # Reindex the node indices in `filtered_edge_indices` + n_filtered_nodes = unique_node_indices.shape[0] + mapping_new_indices = torch.full( + (edge_indices.max() + 1,), -1, dtype=torch.long # type: ignore + ) + mapping_new_indices[unique_node_indices] = torch.arange(n_filtered_nodes) + reindexed_edge_indices = mapping_new_indices[edge_indices] + filtered_reindexed_edge_indices = reindexed_edge_indices[:, edge_mask] + + return filtered_reindexed_edge_indices, unique_node_indices + + +def edge_to_triplet_scipy(edge_indices: np.ndarray) -> np.ndarray: + """Build the array of edge indices corresponding to overlapping doublets. + + Args: + edge_indices: Array of shape :math:`\\left(2, n_{\\text{edges}})`. One edge + is considered as a doublet + + Returns: + Array of edges, of shape :math:`\\left(2, n_{\\text{triplets}})`, that links + doublets that share a hit. + + Notes: + This function was taken from + https://github.com/exatrkx/exatrkx-ctd2020/blob/master/GraphLearning/build_triplets.py + """ + n_edges = edge_indices.shape[1] + n_hits = edge_indices.max() + 1 + + e_coo = sps.coo_matrix((np.ones(n_edges), (edge_indices[0], edge_indices[1]))) + + # Array (hit, edge) + # Element (i, j) = 1 if edge `j` is incoming from from node `i` + e_in_coo = sps.coo_matrix( + (np.ones(n_edges), (e_coo.row, np.arange(n_edges))), + shape=(n_hits, n_edges), + ) + e_in_csr = e_in_coo.tocsr() + + # Element (i, j) = 1 if edge `j` is outgoing from node `i` + e_out_coo = sps.coo_matrix( + (np.ones(n_edges), (e_coo.col, np.arange(n_edges))), + shape=(n_hits, n_edges), + ) + e_out_csr = e_out_coo.tocsr() + + # 1 if edges are common, 0 otherwise + e_total = e_out_csr.T * e_in_csr + + # extract indices of non-zero elements + e_total_coo = e_total.tocoo() + return np.vstack([e_total_coo.row, e_total_coo.col]) + + +class TripletBuilder(BuilderBase): + def construct_downstream(self, batch: Data): + # Compute edge slopes + un_xe = batch.un_x[batch.edge_index] + un_ye = batch.un_y[batch.edge_index] + un_ze = batch.un_z[batch.edge_index] + + batch.angle_yz = torch.atan((un_ye[1] - un_ye[0]) / (un_ze[1] - un_ze[0])) + batch.angle_xz = torch.atan((un_xe[1] - un_xe[0]) / (un_ze[1] - un_ze[0])) + batch.zdiff = un_ze[1] - un_ze[0] + + # Compute doublet features + doublet_features = torch.stack( + ( + un_xe[0] / 14.5, + un_ye[0] / 14.5, + batch.x[:, 2][batch.edge_index[0]], + batch.angle_yz / 0.16, + batch.angle_xz / 0.16, + (batch.zdiff - 75.0) / 65.0, + ), + dim=1, + ).float() + + # Compute triplet edges + triplet_indices = torch.from_numpy( + edge_to_triplet_scipy(batch.edge_index.numpy()) + ).long() + + diff_angle_xz = ( + batch.angle_xz[triplet_indices[1]] + - batch.angle_xz[triplet_indices[0]] + ) + + diff_angle_yz = ( + batch.angle_yz[triplet_indices[1]] + - batch.angle_yz[triplet_indices[0]] + ) + + # rough_mask = ( + # (torch.abs(diff_angle_xz) < 0.01) & (torch.abs(diff_angle_yz) < 0.01) + # ) + # triplet_indices = triplet_indices[:, rough_mask] + # diff_angle_xz = diff_angle_xz[rough_mask] + # diff_angle_yz = diff_angle_yz[rough_mask] + diff_angle_xz_norm = diff_angle_xz / 0.0076 + diff_angle_yz_norm = diff_angle_yz / 0.0076 + + # Compute true labels: both edges need to be `true` for the triplet + # to be true as well + y_triplet = batch.y[triplet_indices].min(dim=0).values + + if "y_pid" not in batch: + y_pid = compute_edge_labels( + edge_indices=batch.edge_index, + particle_ids=batch.particle_id, + ) + else: + y_pid = batch.y_pid + + y_pid_triplet = y_pid[triplet_indices].min(dim=0).values + + + return Data( + x=doublet_features, + edge_index=triplet_indices, + doublet_edge_index=batch.edge_index, + y=y_triplet, + y_pid=y_pid_triplet, + diff_angle_xz_norm=diff_angle_xz_norm, + diff_angle_yz_norm=diff_angle_yz_norm, + diff_angle_yz=diff_angle_yz, + diff_angle_xz=diff_angle_xz, + hit_id=batch.hit_id, + event_str=batch.event_str, + truncated_path=batch.truncated_path, + angle_xz=batch.angle_xz, + angle_yz=batch.angle_yz, + scores=batch.scores, + # **{ + # column_name: column + # for column_name, column in batch.items() + # if column_name + # not in [ + # "x", + # "edge_index", + # "y_pid", + # "y", + # ] + # }, + ) + + +def train( + path_or_config: str | dict, + partitions: typing.List[str] = ["train", "val", "test"], + reproduce: bool = True, + parallel: bool = False, +): + all_configs = load_config(path_or_config) + triplet_building_configs = all_configs["triplet_building"] + logging.info(headline(" Step 5: Building triplets.")) + + tripletBuilder = TripletBuilder() + run_for_different_partitions( + tripletBuilder.infer, + input_dir=triplet_building_configs["input_dir"], + output_dir=triplet_building_configs["output_dir"], + reproduce=reproduce, + partitions=partitions, + test_dataset_names=get_required_test_dataset_names(all_configs), + parallel=parallel + ) + + +if __name__ == "__main__": + config_file = parse_args() + train(config_file) diff --git a/LHCb_Pipeline/Scripts/Filter_Edges.py b/LHCb_Pipeline/Scripts/Filter_Edges.py new file mode 100644 index 0000000000000000000000000000000000000000..a092f286df87ac9a5251efa2bc34b6fd87d0c5eb --- /dev/null +++ b/LHCb_Pipeline/Scripts/Filter_Edges.py @@ -0,0 +1,63 @@ +""" +This script runs step 5 of the TrackML Quickstart example: Labelling spacepoints based on the scored graph. +""" +from __future__ import annotations +import typing +import logging + +import numpy as np +import scipy.sparse as sps +import torch +from torch_geometric.data import Data +from utils.commonutils.config import load_config +from utils.commonutils.ctests import get_required_test_dataset_names +from utils.commonutils.crun import run_for_different_partitions +from utils.modelutils.build import BuilderBase +from utils.scriptutils import parse_args, configure_logger, headline + +configure_logger() + + +class EdgeFilter(BuilderBase): + def __init__(self, score_cut: float) -> None: + super(EdgeFilter, self).__init__() + self.score_cut = float(score_cut) + + def construct_downstream(self, batch: Data): + edge_mask = batch.scores > self.score_cut + batch.edge_index = batch.edge_index[:, edge_mask] + batch.y = batch.y[edge_mask] + batch.scores = batch.scores[edge_mask] + if "y_pid" in batch: + batch.y_pid = batch.y_pid[edge_mask] + return batch + + +def train( + path_or_config: str | dict, + partitions: typing.List[str] = ["train", "val", "test"], + score_cut: float | None = None, + reproduce: bool = True, +): + all_configs = load_config(path_or_config) + edge_filtering_configs = all_configs["edge_filtering"] + logging.info(headline(" Step 5: Building track candidates from the scored graph ")) + + trackBuilder = EdgeFilter( + score_cut=( + edge_filtering_configs["score_cut"] if score_cut is None else score_cut + ) + ) + run_for_different_partitions( + trackBuilder.infer, + input_dir=edge_filtering_configs["input_dir"], + output_dir=edge_filtering_configs["output_dir"], + reproduce=reproduce, + partitions=partitions, + test_dataset_names=get_required_test_dataset_names(all_configs), + ) + + +if __name__ == "__main__": + config_file = parse_args() + train(config_file) diff --git a/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py b/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py index dfc9e35f7e5bf20f76c3da2aba9ed2be7f5bf966..67f8d9a30f3bb05fe151a4e118fee8a27613661c 100644 --- a/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py +++ b/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py @@ -21,8 +21,32 @@ configure_logger() def train( path_or_config: str | dict, partitions: typing.List[str] = ["train", "val", "test"], - checkpoint: str | None = None, + checkpoint: LayerlessEmbedding | str | None = None, + reproduce: bool = True, + override_hparams: bool = False, + **kwargs, ): + """Run the inference of the metric learning stage. + + Args: + path_or_config: configuration dictionary, or path to the YAML file that contains + the configuration + partitions: Partitions to run the inference on: + + * ``train``: train dataset + * ``val``: validation dataset + * ``test``: all the test datasets + * A specific test dataset name + + checkpoint: Model already loaded, or path to its checkpoint. If ``None``, + try to find it automatically in the artifact folder given + the configuration. + reproduce: whether to delete an existing folder + override_hparams: whether to override the hyparameters of the model + that is loaded, with the ones in the YAML configuration + **kwargs: Other keyword arguments passed to the + :py:func:`PyTorch.LightingModel.load_from_checkpoint` class method + """ all_configs = load_config(path_or_config) logging.info(headline("Step 2: Constructing graphs from metric learning model")) @@ -32,24 +56,21 @@ def train( logging.info(headline("a) Loading trained model")) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if checkpoint is None: # Default loading mode from last artifact - checkpoint = os.path.join( + + if override_hparams: + kwargs = {"hparams": metric_learning_configs, **kwargs} + logging.info(str(kwargs)) + + model = LayerlessEmbedding.get_model_from_checkpoint( + checkpoint=checkpoint, + default_checkpoint=os.path.join( common_configs["artifact_directory"], "metric_learning", common_configs["experiment_name"] + ".ckpt", - ) - model = LayerlessEmbedding.load_from_checkpoint( - checkpoint, map_location=device - ) - else: - model = LayerlessEmbedding.load_from_checkpoint( - checkpoint_path=checkpoint, - map_location=device, - **metric_learning_configs, # Override the hyperparameters - ) - - # Load checkpoint from specified path - logging.info(f"Load model from {checkpoint}.") + ), + map_location=device, + **kwargs, + ) logging.info(headline("b) Running inferencing")) @@ -58,6 +79,10 @@ def train( else: radius = metric_learning_configs["r_test"] + building = metric_learning_configs.pop("building", None) + filtering = metric_learning_configs.pop("filtering", None) + + logging.info(f"Use radius {radius}") graph_builder = EmbeddingInferenceBuilder( model, knn_max=metric_learning_configs["knn"], @@ -70,10 +95,15 @@ def train( output_dir=metric_learning_configs["output_dir"], partitions=partitions, test_dataset_names=get_required_test_dataset_names(all_configs), - reproduce=True, + reproduce=reproduce, + list_kwargs=[ + dict(building=building, filtering=filtering) + if partition in ["train", "val"] + else dict(building=building) + for partition in partitions + ], ) - return graph_builder diff --git a/LHCb_Pipeline/Scripts/Step_3_Train_GNN.py b/LHCb_Pipeline/Scripts/Step_3_Train_GNN.py index d7e062aa17c7f62a0259da79d7a3a0f3a5f19f79..a8980a52d5df23f61944df4f5142adcc7f3467c5 100644 --- a/LHCb_Pipeline/Scripts/Step_3_Train_GNN.py +++ b/LHCb_Pipeline/Scripts/Step_3_Train_GNN.py @@ -18,12 +18,14 @@ from utils.scriptutils import parse_args, configure_logger, headline configure_logger() -def train(path_or_config: str | dict): +def train(path_or_config: str | dict, identifier: str | None = None): all_configs = load_config(path_or_config) + if identifier is None: + identifier = "" logging.info(headline(" Step 3: Running GNN training ")) common_configs = all_configs["common"] - gnn_configs = all_configs["gnn"] + gnn_configs = all_configs["gnn" + identifier] logging.info(headline("a) Initialising model")) @@ -32,7 +34,7 @@ def train(path_or_config: str | dict): logging.info(headline("b) Running training")) save_directory = os.path.abspath( - os.path.join(common_configs["artifact_directory"], "gnn") + os.path.join(common_configs["artifact_directory"], "gnn" + identifier) ) logger = CSVLogger(save_directory, name=common_configs["experiment_name"]) @@ -43,6 +45,7 @@ def train(path_or_config: str | dict): devices=common_configs["gpus"], max_epochs=gnn_configs["max_epochs"], logger=logger, + gradient_clip_val=gnn_configs.get("gradient_clip_val"), # callbacks=[EarlyStopping(monitor="val_loss", mode="min")] ) diff --git a/LHCb_Pipeline/Scripts/Step_4_Run_GNN.py b/LHCb_Pipeline/Scripts/Step_4_Run_GNN.py index 03bd765bde74c0f5a49d1e3d4f27a8c8bf63e298..4d842a24e9de10962048b9bda28b4d02251ec35e 100644 --- a/LHCb_Pipeline/Scripts/Step_4_Run_GNN.py +++ b/LHCb_Pipeline/Scripts/Step_4_Run_GNN.py @@ -22,32 +22,37 @@ configure_logger() def train( path_or_config: str | dict, partitions: typing.List[str] = ["train", "val", "test"], - checkpoint: str | None = None, + checkpoint: InteractionGNN | str | None = None, + reproduce: bool = True, + override_hparams: bool = False, + identifier: str | None = None, + **kwargs, ): all_configs = load_config(path_or_config) + if identifier is None: + identifier = "" logging.info(headline("Step 4: Scoring graph edges using GNN ")) common_configs = all_configs["common"] - gnn_configs = all_configs["gnn"] + gnn_configs = all_configs["gnn" + identifier] logging.info(headline("a) Loading trained model")) + if override_hparams: + kwargs = {"hparams": gnn_configs, **kwargs} + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if checkpoint is None: # Default loading mode from last artifact - checkpoint = os.path.join( + model = InteractionGNN.get_model_from_checkpoint( + checkpoint=checkpoint, + default_checkpoint=os.path.join( common_configs["artifact_directory"], - "gnn", + "gnn" + identifier, common_configs["experiment_name"] + ".ckpt", - ) - - model = InteractionGNN.load_from_checkpoint(checkpoint, map_location=device) - else: - model = InteractionGNN.load_from_checkpoint( - checkpoint_path=checkpoint, - map_location=device, - **gnn_configs, # Override the hyperparameters - ) + ), + map_location=device, + **kwargs, + ) logging.info(f"Load model from {checkpoint}.") @@ -59,7 +64,7 @@ def train( output_dir=gnn_configs["output_dir"], partitions=partitions, test_dataset_names=get_required_test_dataset_names(all_configs), - reproduce=True, + reproduce=reproduce, ) diff --git a/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py b/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py index a593d52ca0a9387bac2aff4f1d0b46f01db689fd..7874c7a2a7fd45265a81331c791332b306c4b4a4 100644 --- a/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py +++ b/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py @@ -1,6 +1,7 @@ """ This script runs step 5 of the TrackML Quickstart example: Labelling spacepoints based on the scored graph. """ +from __future__ import annotations import typing import logging @@ -17,35 +18,61 @@ from utils.scriptutils import parse_args, configure_logger, headline configure_logger() +def get_track_ids(edge_indices: torch.Tensor, n_hits: int) -> np.ndarray: + row, col = edge_indices + edge_attr = np.ones(row.size(0)) + + sparse_edges = sps.coo_matrix( + (edge_attr, (row.numpy(), col.numpy())), (n_hits, n_hits) + ) + + _, candidate_labels = sps.csgraph.connected_components( + sparse_edges, directed=False, return_labels=True + ) + return candidate_labels + + class TrackBuilder(BuilderBase): def __init__(self, score_cut: float) -> None: super(TrackBuilder, self).__init__() self.score_cut = float(score_cut) def construct_downstream(self, batch: Data): - edge_mask = batch.scores > self.score_cut + edge_indices = batch.edge_index[:, batch.scores > self.score_cut] - row, col = batch.edge_index[:, edge_mask] - edge_attr = np.ones(row.size(0)) + if "doublet_edge_index" in batch.keys: + # Come back to doublets + double_edge_indices = batch.doublet_edge_index[ + :, torch.unique(edge_indices) + ] + else: + double_edge_indices = edge_indices - N = batch.x.size(0) - sparse_edges = sps.coo_matrix((edge_attr, (row.numpy(), col.numpy())), (N, N)) + labels = torch.from_numpy( + get_track_ids( + edge_indices=double_edge_indices, + n_hits=batch.x.shape[0], + ) + ).long() + batch.labels = labels - _, candidate_labels = sps.csgraph.connected_components( - sparse_edges, directed=False, return_labels=True - ) - batch.labels = torch.from_numpy(candidate_labels).long() + return batch def train( path_or_config: str | dict, partitions: typing.List[str] = ["train", "val", "test"], + score_cut: float | None = None, ): all_configs = load_config(path_or_config) track_building_configs = all_configs["track_building"] logging.info(headline(" Step 5: Building track candidates from the scored graph ")) - trackBuilder = TrackBuilder(score_cut=track_building_configs["score_cut"]) + score_cut = track_building_configs["score_cut"] if score_cut is None else score_cut + trackBuilder = TrackBuilder( + score_cut=score_cut + ) + logging.info("Score cut: " + str(score_cut)) run_for_different_partitions( trackBuilder.infer, input_dir=track_building_configs["input_dir"], diff --git a/LHCb_Pipeline/Scripts/Step_6_Evaluate_Reconstruction_MonteTracko.py b/LHCb_Pipeline/Scripts/Step_6_Evaluate_Reconstruction_MonteTracko.py index b6feb71677c02d96ae5dee981c70d75a9eaa7755..1992589961383f3add6c0ddfb08afed0f3ac8fd7 100644 --- a/LHCb_Pipeline/Scripts/Step_6_Evaluate_Reconstruction_MonteTracko.py +++ b/LHCb_Pipeline/Scripts/Step_6_Evaluate_Reconstruction_MonteTracko.py @@ -11,20 +11,34 @@ from argparse import ArgumentParser from tqdm.auto import tqdm import pandas as pd import torch +from torch_geometric.data import Data import montetracko as mt import montetracko.lhcb as mtb from Preprocessing.preprocessing_paths import get_truncated_paths_for_partition +from Preprocessing.particle_line_fitting import compute_particle_line_metrics_dataframe from utils.plotutils import plotconfig from utils.commonutils.config import load_config from utils.commonutils.ctests import get_required_test_dataset_names +from utils.commonutils.cpaths import get_performance_directory from utils.scriptutils import configure_logger, headline + configure_logger() +def get_tracks_from_batch(batch: Data) -> pd.DataFrame: + return pd.DataFrame( + { + "event_id": int(batch.event_str), + "hit_id": batch.hit_id, + "track_id": batch.labels, + } + ) + + def load_tracks_event(input_path: str) -> pd.DataFrame: """Load the dataframe of tracks out of track building. @@ -36,13 +50,7 @@ def load_tracks_event(input_path: str) -> pd.DataFrame: Dataframe with columns ``event_id``, ``hit_id``, ``track_id`` """ graph = torch.load(input_path, map_location="cpu") - df_tracks = pd.DataFrame( - { - "event_id": int(graph.event_str), - "hit_id": graph.hit_id, - "track_id": graph.labels, - } - ) + df_tracks = get_tracks_from_batch(graph) return df_tracks @@ -144,7 +152,7 @@ def load_dataframes_given_partition( df_hits_particles = load_parquet_files( truncated_paths=truncated_paths, ending="-hits_particles", - columns=["particle_id", "hit_id"], + columns=["particle_id", "hit_id", "plane", "x", "y", "z"], ) df_particles = load_parquet_files( truncated_paths=truncated_paths, ending="-particles" @@ -172,6 +180,7 @@ def perform_evaluation( allen_report: bool = True, table_report: bool = True, plot_categories: typing.Iterable[mt.requirement.Category] | None = None, + plotted_groups: typing.List[str] | None = ["basic"], output_dir: str | None = None, suffix: str | None = None, ): @@ -186,7 +195,13 @@ def perform_evaluation( histograms are plotted for the reconstructible tracks in the velo, and the long electrons. In order not to plot, you may set this variable to an empty list. + plotted_groups: Pre-configured metrics and columns to plot. + Each group corresponds to one plot that shows the the distributions of + various metrics as a function of various truth variables, + as hard-coded in :py:func:`plot`. + There are 3 groups: ``basic``, ``geometry`` and ``challenging``. output_dir: Output directory where to save the report and the plots + suffix: string to append to the file name of the reports and figures produced. """ timestr = time.strftime("%Y.%m.%d-%H.%M.%S") @@ -194,15 +209,15 @@ def perform_evaluation( if allen_report or table_report: list_reports = [] if allen_report: - allen_report = trackEvaluator.report( + allen_report_str = trackEvaluator.report( reporter=mt.AllenReporter(), categories=mtb.category.allen_categories, ) - list_reports.append(allen_report) + list_reports.append(allen_report_str) if table_report: - table_report_categories = trackEvaluator.report( + table_report_str = trackEvaluator.report( reporter=mt.TabReporter( [ "efficiency", @@ -216,7 +231,7 @@ def perform_evaluation( ), categories=mtb.category.velo_categories, ) - list_reports.append(table_report_categories) + list_reports.append(table_report_str) table_report_global = trackEvaluator.report( reporter=mt.TabReporter( metric_names=["n_ghosts", "n_tracks", "ghost_rate"], @@ -239,12 +254,13 @@ def perform_evaluation( report_file.write(total_report) logging.info(f"Report was saved in {output_path}") - if plot_categories is not None: + if plot_categories is not None and plotted_groups is not None and plotted_groups: for plot_category in plot_categories: plot( trackEvaluator=trackEvaluator, category=plot_category, output_dir=output_dir, + plotted_groups=plotted_groups, suffix=suffix, ) @@ -252,6 +268,7 @@ def perform_evaluation( def plot( trackEvaluator: mt.TrackEvaluator, category: mt.requirement.Category, + plotted_groups: typing.List[str] = ["basic"], output_dir: str | None = None, suffix: str | None = None, ): @@ -260,29 +277,107 @@ def plot( Args: trackEvaluator: A ``TrackEvaluator`` instance containing the results - of the track matching. + of the track matching + category: Truth category for the plot + plotted_groups: Pre-configured metrics and columns to plot. + Each group corresponds to one plot that shows the the distributions of + various metrics as a function of various truth variables, + as hard-coded in this function. + There are 3 groups: ``basic``, ``geometry`` and ``challenging``. """ plotconfig.configure_matplotlib() - fig, _, _ = trackEvaluator.plot_histograms( - columns=["pt", "p", "eta", "vz"], - metric_names=[ - "efficiency", - "clone_rate", - "hit_purity_per_candidate", - "hit_efficiency_per_candidate", - ], - column_labels=plotconfig.column_labels, - column_ranges=plotconfig.column_ranges, - category=category, - bins=20, + + group_configurations = { + "basic": dict( + columns=["pt", "p", "eta", "vz"], + metric_names=[ + "efficiency", + "clone_rate", + # "hit_purity_per_candidate", + "hit_efficiency_per_candidate", + ], + ), + "challenging": dict( + columns=["vz", "nhits_velo"], + metric_names=["efficiency"], + ), + "geometry": dict( + columns=["distance_to_line", "distance_to_z_axis", "xz_angle", "yz_angle"], + metric_names=["efficiency"], + ), + } + + for group_name in plotted_groups: + if group_name not in group_configurations: + raise ValueError( + f"Group `{group_name}` is unknown. " + "Valid groups are: " + ", ".join(group_configurations.keys()) + ) + + group_config = group_configurations[group_name] + + fig, _, _ = trackEvaluator.plot_histograms( + **group_config, + column_labels=plotconfig.column_labels, + column_ranges=plotconfig.column_ranges, + category=category, + bins=plotconfig.column_bins, + ) + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + if suffix is None: + suffix = "" + plot_path = op.join( + output_dir, f"hist1d_{group_name}_{category.name}{suffix}.pdf" + ) + fig.savefig(plot_path, dpi=200, bbox_inches="tight") + logging.info( + f"Plot {group_name} for category {category.name} saved in {plot_path}" + ) + + +def compute_plane_stats( + df_hits_particles: pd.DataFrame, df_particles: pd.DataFrame +) -> pd.DataFrame: + """Compute variables related to the numbers of hits w.r.t. the planes. + + Args: + df_hits_particles: Dataframe of hits-particles association. Must have + the columns ``event_id``, ``particle_id`` and ``plane``. + df_particles: Dataframe of particles. Must have the columns ``event_id`` + and ``particle_id``. + + Returns: + Dataframe of particles with the new columns. + """ + min_planes = ( + df_hits_particles.groupby(["event_id", "particle_id"])["plane"] + .min() + .rename("min_plane") + ) + max_planes = ( + df_hits_particles.groupby(["event_id", "particle_id"])["plane"] + .max() + .rename("max_plane") + ) + n_unique_planes = ( + df_hits_particles.groupby(["event_id", "particle_id"])["plane"] + .nunique() + .rename("n_planes") + ) + n_hits = ( + df_hits_particles.groupby(["event_id", "particle_id"]).size().rename("n_hits") + ) + n_repeated_planes = (n_hits - n_unique_planes).rename("n_repeated_planes") + n_skipped_planes = (max_planes - min_planes + 1 - n_unique_planes).rename( + "n_skipped_planes" + ) + + return df_particles.merge( + pd.concat((n_repeated_planes, n_skipped_planes), axis=1).reset_index(), + how="left", + on=["event_id", "particle_id"], ) - if output_dir is not None: - os.makedirs(output_dir, exist_ok=True) - if suffix is None: - suffix = "" - plot_path = op.join(output_dir, f"hist1d_{category.name}{suffix}.pdf") - fig.savefig(plot_path, dpi=200, bbox_inches="tight") - logging.info(f"Plot for category {category.name} saved in {plot_path}") def evaluate( @@ -291,6 +386,7 @@ def evaluate( allen_report: bool = True, table_report: bool = True, plot_categories: typing.Iterable[mt.requirement.Category] | None = None, + plotted_groups: typing.List[str] | None = ["basic"], min_track_length: int = 3, ) -> mt.TrackEvaluator: """Runs truth-based tracking evaluation. @@ -304,6 +400,11 @@ def evaluate( histograms are plotted for the reconstructible tracks in the velo, and the long electrons. In order not to plot, you may set this variable to an empty list. + plotted_groups: Pre-configured metrics and columns to plot. + Each group corresponds to one plot that shows the the distributions of + various metrics as a function of various truth variables, + as hard-coded in :py:func:`plot`. + There are 3 groups: ``basic``, ``geometry`` and ``challenging``. Returns: object containing the evaluation. @@ -316,6 +417,28 @@ def evaluate( df_tracks, df_hits_particles, df_particles = load_dataframes_given_partition( path_or_config=all_configs, partition=partition ) + logging.info("Compute plat stats") + df_particles = compute_plane_stats( + df_hits_particles=df_hits_particles, + df_particles=df_particles, + ) + + if plotted_groups is not None and "geometry" in plotted_groups: + logging.info("Compute particle line metrics") + new_distances = compute_particle_line_metrics_dataframe( + hits=df_hits_particles, + metric_names=[ + "distance_to_line", + "distance_to_z_axis", + "xz_angle", + "yz_angle", + ], + event_id_column="event_id", + ) + + df_particles = df_particles.merge( + new_distances, how="left", on=["event_id", "particle_id"] + ) logging.info("2) Matching") trackEvaluator = perform_matching( @@ -337,8 +460,9 @@ def evaluate( allen_report=allen_report, table_report=table_report, plot_categories=plot_categories, - output_dir=all_configs["common"]["performance_directory"], - suffix=f"-{partition}" + plotted_groups=plotted_groups, + output_dir=get_performance_directory(all_configs), + suffix=f"-{partition}", ) return trackEvaluator diff --git a/LHCb_Pipeline/analyses/README.md b/LHCb_Pipeline/analyses/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ebe04f5e3afed485b9d281d7b1ee08152c4e8948 --- /dev/null +++ b/LHCb_Pipeline/analyses/README.md @@ -0,0 +1,8 @@ +# Analyses + +This folder contains a collection of scripts and notebooks to answer +typical questions about the data distribution. + +How many hits are there in average / event? In average, how many repeated planes +do a track have? This folder contains scripts and notebooks to answer these kind of +questions in a reproducible way. diff --git a/LHCb_Pipeline/analyses/anaconfig.py b/LHCb_Pipeline/analyses/anaconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb1aef1baab8f1e839984e5137c40a42cd2b67e --- /dev/null +++ b/LHCb_Pipeline/analyses/anaconfig.py @@ -0,0 +1,14 @@ +"""Common configurations for the folder analyses. +""" +import sys +import os.path as op + +# Add montetracko and LHCb_Pipeline to PYTHONPATH +sys.path.append(op.abspath(op.join(op.dirname(__file__), "../../montetracko"))) +sys.path.append(op.abspath(op.join(op.dirname(__file__), ".."))) + +#: Directory where to save the plots +PLOTDIR = op.abspath(op.join("..", "output", "analyses")) + +#: Directory where the dataframes that are used for the analysis are located +DATAFRAME_DIR = "/scratch/acorreia/minbias-sim10b-xdigi-nospillover/92" diff --git a/LHCb_Pipeline/analyses/angles.ipynb b/LHCb_Pipeline/analyses/angles.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..bc29161f8d0dff2bd17830410fe5df05b392de11 --- /dev/null +++ b/LHCb_Pipeline/analyses/angles.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import anaconfig\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "import os\n", + "from tqdm.auto import tqdm\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "\n", + "configure_matplotlib()" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'tqdm' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 12\u001b[0m\n\u001b[1;32m 9\u001b[0m sizes_doublets \u001b[39m=\u001b[39m []\n\u001b[1;32m 10\u001b[0m sizes_triplets \u001b[39m=\u001b[39m []\n\u001b[0;32m---> 12\u001b[0m \u001b[39mfor\u001b[39;00m filename \u001b[39min\u001b[39;00m tqdm(os\u001b[39m.\u001b[39mlistdir(train_dir)[:\u001b[39m100\u001b[39m]):\n\u001b[1;32m 13\u001b[0m path \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(train_dir, filename)\n\u001b[1;32m 14\u001b[0m batch \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mload(path, map_location\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mcpu\u001b[39m\u001b[39m\"\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'tqdm' is not defined" + ] + } + ], + "source": [ + "train_dir = \"/scratch/acorreia/data//triplet_building/train\"\n", + "\n", + "diff_angles_xz = []\n", + "diff_angles_yz = []\n", + "angles_xz = []\n", + "angles_yz = []\n", + "y = []\n", + "y_pid = []\n", + "sizes_doublets = []\n", + "sizes_triplets = []\n", + "\n", + "for filename in tqdm(os.listdir(train_dir)[:100]):\n", + " path = os.path.join(train_dir, filename)\n", + " batch = torch.load(path, map_location=\"cpu\")\n", + " y.append(batch.y.numpy())\n", + " angle_xz = batch.angle_xz\n", + " angle_yz = batch.angle_yz\n", + " sizes_triplets.append(batch.edge_index.shape[1])\n", + " sizes_doublets.append(batch.doublet_edge_index.shape[1])\n", + " \n", + "\n", + " diff_angles_xz.append(batch.diff_angle_xz.numpy())\n", + " diff_angles_yz.append(batch.diff_angle_yz.numpy())\n", + " angles_xz.append(batch.angle_xz.numpy())\n", + " angles_yz.append(batch.angle_xz.numpy())\n", + " y_pid.append(batch.y_pid.numpy())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "array_diff_angles_xz = np.concatenate(diff_angles_xz, axis=0)\n", + "array_diff_angles_yz = np.concatenate(diff_angles_yz, axis=0)\n", + "array_angle_xz = np.concatenate(angles_xz, axis=0)\n", + "array_angle_yz = np.concatenate(angles_yz, axis=0)\n", + "array_y_pid = np.concatenate(y_pid, axis=0)\n", + "array_y = np.concatenate(y, axis=0)\n", + "sizes_doublets = np.array(sizes_doublets)\n", + "sizes_triplets = np.array(sizes_triplets)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9.110746876470676" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(sizes_triplets / sizes_doublets).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist2d(\n", + " x=array_diff_angles_xz[array_y_pid == False],\n", + " y=array_diff_angles_yz[array_y_pid == False],\n", + " bins=300,\n", + " range=((-.05, .05), (-.05, 0.05))\n", + ");\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist2d(\n", + " x=array_diff_angles_xz[array_y_pid == True],\n", + " y=array_diff_angles_yz[array_y_pid == True],\n", + " bins=300,\n", + " range=((-.05, .05), (-.05, 0.05))\n", + ");\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7f3899956050>" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_angles_yz[array_y_pid == False],\n", + " label=\"Fake triplets\",\n", + " color=\"red\",\n", + " bins=500,\n", + " range=(-2.0, 2.0)\n", + ")\n", + "ax.hist(\n", + " array_diff_angles_xz[array_y_pid == True],\n", + " label=\"Genuine triplets\",\n", + " color=\"green\",\n", + " bins=edges,\n", + ")\n", + "ax.legend()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "truth = np.abs(array_diff_angles_xz)[array_y_pid == True]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9944294536612657" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(truth < 0.01).sum() / truth.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6206851640087514" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(array_diff_angles_xz < 0.01).sum() / array_diff_angles_xz.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6.519369284875865" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(array_diff_angles_xz < 0.01).sum() / sizes_doublets.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean: -0.0006086131776879624\n", + "std: 0.21332288342745845\n", + "mean: 0.0003985171692728271\n", + "std: 0.22402053515876744\n" + ] + } + ], + "source": [ + "print(\"mean:\", array_diff_angles_xz.mean())\n", + "print(\"std:\", array_diff_angles_xz.std())\n", + "print(\"mean:\", array_diff_angles_yz.mean())\n", + "print(\"std:\", array_diff_angles_yz.std())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean: 0.002268124843566985\n", + "std: 0.15293637532876916\n", + "mean: 0.002268124843566985\n", + "std: 0.15293637532876916\n" + ] + } + ], + "source": [ + "print(\"mean:\", array_angle_xz.mean())\n", + "print(\"std:\", array_angle_xz.std())\n", + "print(\"mean:\", array_angle_yz.mean())\n", + "print(\"std:\", array_angle_yz.std())\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/edge_scores.ipynb b/LHCb_Pipeline/analyses/edge_scores.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..53e53e7ba31362b90e76925037742a030ddeadc0 --- /dev/null +++ b/LHCb_Pipeline/analyses/edge_scores.ipynb @@ -0,0 +1,531 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import anaconfig\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "import os\n", + "from tqdm.auto import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "\n", + "configure_matplotlib()\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Analyse the edge scores, depending on the track size" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch_scatter import scatter_add" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "test_dir = \"/scratch/acorreia/data/focal-loss-pid/gnn_processed/test/velo-sim10b-nospillover/\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "93a11ac913964b5596a9ed22f4d1cf61", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "list_n_hits = []\n", + "list_scores = []\n", + "list_y_pid = []\n", + "\n", + "dict_list_values = {\n", + " \"n_hits\": [],\n", + " \"scores\": [],\n", + " \"y_pid\": [],\n", + "}\n", + "\n", + "for filename in tqdm(os.listdir(test_dir)[:100]):\n", + " path = os.path.join(test_dir, filename)\n", + " batch = torch.load(path, map_location=\"cpu\")\n", + " df_particles = pd.read_parquet(\n", + " batch.truncated_path + \"-particles.parquet\",\n", + " columns=[\"particle_id\", \"nhits_velo\"],\n", + " )\n", + " # batch.edge_index, unique_indices = torch.unique(\n", + " # batch.edge_index, dim=1, return_inverse=True\n", + " # )\n", + " # unique_indices = torch.unique(unique_indices)\n", + " # batch.y_pid = batch.y_pid[unique_indices]\n", + " # batch.y = batch.y[unique_indices]\n", + " # batch.scores = batch.scores[unique_indices]\n", + "\n", + " nhits_velo = (\n", + " pd.DataFrame({\"particle_id\": batch.particle_id.numpy()})\n", + " .merge(\n", + " df_particles[[\"particle_id\", \"nhits_velo\"]],\n", + " how=\"left\",\n", + " on=[\"particle_id\"],\n", + " )[\"nhits_velo\"]\n", + " .to_numpy()\n", + " )\n", + "\n", + " dict_list_values[\"y_pid\"].append(batch.y_pid.numpy())\n", + " dict_list_values[\"scores\"].append(batch.scores.numpy())\n", + " dict_list_values[\"n_hits\"].append(nhits_velo[batch.edge_index.numpy()])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "82136" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.edge_index.shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "41102" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.unique(batch.edge_index, dim=1).shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "particle_id = 5409\n", + "min_plane = batch.plane[batch.particle_id == particle_id].min()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.5567, 0.6978, 0.5567, 0.6978])" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.scores[(batch.plane[batch.edge_index[0]] == min_plane)\n", + "& (batch.particle_id[batch.edge_index[0]] == particle_id)\n", + "& (batch.y_pid)]" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "first_scores = []\n", + "other_scores = []\n", + "\n", + "for particle_id in torch.unique(batch.particle_id):\n", + " min_plane = batch.plane[batch.particle_id == particle_id].min()\n", + "\n", + " scores = batch.scores[\n", + " (batch.plane[batch.edge_index[0]] == min_plane)\n", + " & (batch.particle_id[batch.edge_index[0]] == particle_id)\n", + " & (batch.y_pid)\n", + " ]\n", + " first_scores += scores.numpy().tolist()\n", + " \n", + " scores = batch.scores[\n", + " (batch.plane[batch.edge_index[0]] != min_plane)\n", + " & (batch.particle_id[batch.edge_index[0]] == particle_id)\n", + " & (batch.y_pid)\n", + " ]\n", + " other_scores += scores.numpy().tolist()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7ff1f015fcd0>" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist(first_scores, label=\"First edges\", alpha=0.5, density=True)\n", + "ax.hist(other_scores, label=\"Other edges\", alpha=0.5, density=True)\n", + "ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 196, + "metadata": {}, + "outputs": [], + "source": [ + "dict_arrays = {\n", + " name: np.concatenate(list_arrays, axis=-1)\n", + " for name, list_arrays in dict_list_values.items()\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "dict_arrays_true = {\n", + " name: array[..., dict_arrays[\"y_pid\"]] \n", + " for name, array in dict_arrays.items()\n", + " if name != \"y_pid\"\n", + "}\n", + "\n", + "dict_arrays_fake = {\n", + " name: array[..., ~dict_arrays[\"y_pid\"]] \n", + " for name, array in dict_arrays.items()\n", + " if name != \"y_pid\"\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 601088)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dict_arrays_true[\"n_hits\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for nhits in range(3, 10):\n", + " fig, ax = plt.subplots(figsize=(8, 6))\n", + " _, edges, _ = ax.hist(\n", + " x=dict_arrays_fake[\"scores\"][\n", + " (dict_arrays_fake[\"n_hits\"][0] == nhits)\n", + " | (dict_arrays_fake[\"n_hits\"][1] == nhits)\n", + " ],\n", + " bins=20,\n", + " range=(0.0, 1.0),\n", + " density=False,\n", + " alpha=0.5,\n", + " label=\"Fake\",\n", + " color=\"red\",\n", + " )\n", + " _, edges, _ = ax.hist(\n", + " x=dict_arrays_true[\"scores\"][dict_arrays_true[\"n_hits\"][0] == nhits],\n", + " bins=edges,\n", + " density=False,\n", + " alpha=0.5,\n", + " label=\"True\",\n", + " color=\"green\",\n", + " )\n", + " # ax.set_yscale(\"log\")\n", + " ax.legend()\n", + " ax.set_title(f\"{nhits} hits\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "dict_arrays[\"scores\"]\n", + "keep = dict_arrays[\"scores\"] > 0.2\n", + "mask_3_hits = (dict_arrays[\"n_hits\"] == 3).min(axis=0)\n", + "keep[mask_3_hits] = dict_arrays[\"scores\"][mask_3_hits] > 0.2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.34167739628040056" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "keep.sum() / keep.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.9903418255310052" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores_3 = dict_arrays_true[\"scores\"][dict_arrays_true[\"n_hits\"][0] == 10]\n", + "\n", + "(scores_3 > 0.4).sum() / len(scores_3)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "particle_ids = df_particles[df_particles[\"nhits_velo\"] == 4][\"particle_id\"].to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2146, 2146, 2146, ..., 5409, 6307, 8481],\n", + " [4975, 2064, 2129, ..., 6307, 6307, 6307]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "particle_id = particle_ids[0]\n", + "batch.particle_id[batch.edge_index]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4827])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.particle_id.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/normalisations.ipynb b/LHCb_Pipeline/analyses/normalisations.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..db887861f1da04c881e6dbbd0a85ccf321325397 --- /dev/null +++ b/LHCb_Pipeline/analyses/normalisations.ipynb @@ -0,0 +1,289 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import anaconfig\n", + "\n", + "from Preprocessing.preprocessing import load_dataframes\n", + "\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "from utils.commonutils.cfeatures import get_unnormalised_features\n", + "\n", + "from tqdm.auto import tqdm\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "\n", + "configure_matplotlib()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles, particles = load_dataframes(\n", + " indir=\"/scratch/acorreia/minbias-sim10b-xdigi-nospillover/92\",\n", + ")\n", + "hits = hits_particles.drop_duplicates([\"event\", \"hit_id\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7f1e84902620>" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist(hits_particles[\"x\"], alpha=0.5, color=\"b\", label=\"x\")\n", + "ax.hist(hits_particles[\"y\"], alpha=0.5, color=\"darkorange\", label=\"y\")\n", + "ax.legend()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean: -0.08074610279141962\n", + "std: 14.51331843142442\n", + "mean: 0.02629636994005264\n", + "std: 14.760759458144337\n" + ] + } + ], + "source": [ + "print(\"mean:\", hits_particles[\"x\"].mean())\n", + "print(\"std:\", hits_particles[\"x\"].std())\n", + "print(\"mean:\", hits_particles[\"y\"].mean())\n", + "print(\"std:\", hits_particles[\"y\"].std())" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cdab97d708ec4e958cdabf240b8c2162", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "slopes_xz = []\n", + "slopes_yz = []\n", + "zdiffs = []\n", + "\n", + "train_dir = \"/scratch/acorreia/data/track-edges/metric_learning_processed/train\"\n", + "for filename in tqdm(os.listdir(train_dir)[:1000]):\n", + " path = os.path.join(train_dir, filename)\n", + " batch = torch.load(path, map_location=\"cpu\")\n", + " start, end = batch.edge_index\n", + " r, phi, z = get_unnormalised_features(\n", + " batch=batch,\n", + " path_or_config=\"../pipeline_configs/track-edges.yaml\",\n", + " feature_names=[\"r\", \"phi\", \"z\"],\n", + " )\n", + " x = r * np.cos(phi)\n", + " y = r * np.sin(phi)\n", + " z_edge_index = batch.z[batch.edge_index]\n", + " batch.edge_index = batch.edge_index[:, z_edge_index[0] != z_edge_index[1]]\n", + "\n", + " xe = x[batch.edge_index]\n", + " ye = y[batch.edge_index]\n", + " ze = z[batch.edge_index]\n", + " \n", + " slopes_xz.append(\n", + " ((ye[1] - ye[0]) / (ze[1] - ze[0])).numpy()\n", + " )\n", + " slopes_yz.append(\n", + " ((xe[1] - xe[0]) / (ze[1] - ze[0])).numpy()\n", + " )\n", + " zdiffs.append(ze[1] - ze[0])\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "array_slopes_xz = np.concatenate(slopes_xz, axis=0)\n", + "array_slopes_yz = np.concatenate(slopes_yz, axis=0)\n", + "zdiffs = np.concatenate(zdiffs, axis=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "array_angles_xz = np.arctan(array_slopes_xz)\n", + "array_angles_yz = np.arctan(array_slopes_yz)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7f1e867c7880>" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "_, edges, _ = ax.hist(array_angles_xz, alpha=0.5, color=\"b\", label=\"x-z\")\n", + "ax.hist(array_slopes_yz, alpha=0.5, color=\"darkorange\", label=\"y-z\", bins=edges)\n", + "ax.legend()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean: -0.0065150103\n", + "std: 0.18471645\n", + "mean: 0.002169753\n", + "std: 0.16960813\n" + ] + } + ], + "source": [ + "print(\"mean:\", array_slopes_xz.mean())\n", + "print(\"std:\", array_slopes_xz.std())\n", + "print(\"mean:\", array_slopes_yz.mean())\n", + "print(\"std:\", array_slopes_yz.std())" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean: -0.0049208854\n", + "std: 0.16074418\n", + "mean: 0.0019381851\n", + "std: 0.15437067\n" + ] + } + ], + "source": [ + "print(\"mean:\", array_angles_xz.mean())\n", + "print(\"std:\", array_angles_xz.std())\n", + "print(\"mean:\", array_angles_yz.mean())\n", + "print(\"std:\", array_angles_yz.std())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean: 74.665184\n", + "std: 64.93781\n" + ] + } + ], + "source": [ + "print(\"mean:\", zdiffs.mean())\n", + "print(\"std:\", zdiffs.std())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/particle_columns.ipynb b/LHCb_Pipeline/analyses/particle_columns.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..fb5764aeeeaaf9d824378dc19244510436b0d9ce --- /dev/null +++ b/LHCb_Pipeline/analyses/particle_columns.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import typing\n", + "import os.path as op\n", + "\n", + "import anaconfig\n", + "\n", + "import numpy as np\n", + "import numpy.typing as npt\n", + "import pandas as pd\n", + "from montetracko.requirement import Category\n", + "import montetracko.lhcb.category as mtbc\n", + "\n", + "from Preprocessing.preprocessing import load_dataframes\n", + "from Preprocessing.particle_line_fitting import compute_particle_line_metrics_dataframe\n", + "\n", + "from utils.plotutils.plotools import save_fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "import matplotlib.pyplot as plt\n", + "\n", + "configure_matplotlib()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#: List of the columns to plot\n", + "column_names = [\n", + " \"distance_to_line\",\n", + " \"n_unique_planes\",\n", + "]\n", + "\n", + "column_labels = {\n", + " \"distance_to_line\": \"Distance to line\",\n", + " \"n_unique_planes\": \"# unique planes\",\n", + "}\n", + "\n", + "column_bins = {\n", + " \"distance_to_line\": np.linspace(0.0, 0.4, 20),\n", + " \"n_unique_planes\": np.arange(3, 21) - 0.5,\n", + "}\n", + "\n", + "#: List of the categories to plot\n", + "categories = [\n", + " mtbc.category_velo,\n", + " mtbc.category_velo_no_electrons,\n", + " mtbc.category_long_only_electrons,\n", + "]\n", + "category_colors = [\"blue\", \"purple\", \"darkorange\"]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load dataframes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles, particles = load_dataframes(\n", + " indir=anaconfig.DATAFRAME_DIR,\n", + ")\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compute columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "particle_line_metrics = [\n", + " \"distance_to_line\",\n", + " \"distance_to_z_axis\",\n", + " \"xz_angle\",\n", + " \"yz_angle\",\n", + "]\n", + "\n", + "\n", + "def compute_n_unique_planes(hits_particles: pd.DataFrame, particles: pd.DataFrame):\n", + " n_unique_planes = (\n", + " hits_particles.groupby([\"event\", \"particle_id\"])[\"plane\"]\n", + " .nunique()\n", + " .rename(\"n_unique_planes\")\n", + " )\n", + " particles: pd.DataFrame = particles.merge(\n", + " n_unique_planes, how=\"left\", on=[\"event\", \"particle_id\"]\n", + " ).fillna(0)\n", + " return particles\n", + "\n", + "column_name_to_fct = {\n", + " \"n_unique_planes\": compute_n_unique_planes\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute particle line metrics\n", + "particle_line_metrics_to_compute = [\n", + " column_name for column_name in column_names if column_name in particle_line_metrics\n", + "]\n", + "if particle_line_metrics_to_compute:\n", + " print(f\"Compute {particle_line_metrics_to_compute}\")\n", + " new_distances = compute_particle_line_metrics_dataframe(\n", + " hits=hits_particles,\n", + " metric_names=particle_line_metrics_to_compute,\n", + " )\n", + " particles = particles.merge(new_distances, how=\"left\", on=[\"event\", \"particle_id\"])\n", + "\n", + "# Compute other columns\n", + "other_columns_to_compute = [\n", + " column_name\n", + " for column_name in column_names\n", + " if column_name not in particle_line_metrics\n", + "]\n", + "for column_name in other_columns_to_compute:\n", + " computing_fct = column_name_to_fct.get(column_name)\n", + " if computing_fct is None:\n", + " raise ValueError(f\"No function to compute {column_name}\")\n", + " print(f\"Compute {column_name}\")\n", + " particles = computing_fct(\n", + " hits_particles=hits_particles, particles=particles\n", + " )\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_in_different_categories(\n", + " particles: pd.DataFrame,\n", + " column_name: str,\n", + " categories: typing.List[Category],\n", + " colors: typing.List[str] | None,\n", + " column_label: str | None = None,\n", + " range: typing.Tuple[float, float] | None = None,\n", + " bins: npt.ArrayLike | None = None,\n", + " alpha: float = 0.5,\n", + " **kwargs,\n", + "):\n", + " fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + " for cat_idx, category in enumerate(categories):\n", + " _, edges, _ = ax.hist(\n", + " category.filter(particles)[column_name],\n", + " color=colors[cat_idx] if colors is not None else None,\n", + " label=category.label,\n", + " range=range,\n", + " bins=bins,\n", + " density=True,\n", + " alpha=alpha,\n", + " **kwargs,\n", + " )\n", + " bins = edges\n", + " range = None\n", + "\n", + " if column_label is None:\n", + " column_label = column_name.replace(\"_\", \" \").title()\n", + " ax.set_xlabel(column_label)\n", + " ax.set_ylabel(\"Proportion\")\n", + " ax.legend()\n", + "\n", + " save_fig(\n", + " fig,\n", + " path=op.join(\n", + " anaconfig.PLOTDIR,\n", + " f\"{column_name}_\" + \"_vs_\".join([category.name for category in categories]),\n", + " ),\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for column_name in column_names:\n", + " plot_in_different_categories(\n", + " particles=particles,\n", + " column_name=column_name,\n", + " categories=categories,\n", + " colors=category_colors,\n", + " column_label=column_labels.get(column_name),\n", + " bins=column_bins.get(column_name),\n", + " )\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/relative_edge_scores.ipynb b/LHCb_Pipeline/analyses/relative_edge_scores.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c881f03529e6bb35791542fbc670c848d82ef45d --- /dev/null +++ b/LHCb_Pipeline/analyses/relative_edge_scores.ipynb @@ -0,0 +1,364 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import anaconfig\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "import os\n", + "from tqdm.auto import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "\n", + "configure_matplotlib()\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Analyse the edge scores, depending on the track size" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "test_dir = \"/scratch/acorreia/data/focal-loss-pid/gnn_processed/test/velo-sim10b-nospillover/\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c231f0031a0a4696bdaa08d887fc1c2e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "list_df_edges = []\n", + "\n", + "for filename in tqdm(os.listdir(test_dir)[:100]):\n", + " path = os.path.join(test_dir, filename)\n", + " batch = torch.load(path, map_location=\"cpu\")\n", + " df_particles = pd.read_parquet(\n", + " batch.truncated_path + \"-particles.parquet\",\n", + " columns=[\"particle_id\", \"nhits_velo\"],\n", + " )\n", + "\n", + " edge_particle_ids = batch.particle_id[batch.edge_index]\n", + " edge_genuine_mask = (edge_particle_ids[0] == edge_particle_ids[1]) & (\n", + " edge_particle_ids[0] != 0\n", + " )\n", + "\n", + " #: Dataframe of edges\n", + " df_edges_genuine = pd.DataFrame(\n", + " {\n", + " \"event_id\": int(batch.event_str),\n", + " \"particle_id\": edge_particle_ids[0, edge_genuine_mask],\n", + " \"score\": batch.scores[edge_genuine_mask].numpy(),\n", + " \"genuine\": True,\n", + " }\n", + " )\n", + " df_edges_fake_left = pd.DataFrame(\n", + " {\n", + " \"event_id\": int(batch.event_str),\n", + " \"particle_id\": edge_particle_ids[0, ~edge_genuine_mask],\n", + " \"score\": batch.scores[~edge_genuine_mask].numpy(),\n", + " \"genuine\": False,\n", + " }\n", + " )\n", + " df_edges_fake_right = pd.DataFrame(\n", + " {\n", + " \"event_id\": int(batch.event_str),\n", + " \"particle_id\": edge_particle_ids[1, ~edge_genuine_mask],\n", + " \"score\": batch.scores[~edge_genuine_mask].numpy(),\n", + " \"genuine\": False,\n", + " }\n", + " )\n", + " df_edges = pd.concat(\n", + " (df_edges_genuine, df_edges_fake_left, df_edges_fake_right), axis=0\n", + " )\n", + "\n", + " df_edges = df_edges.merge(\n", + " df_particles[[\"particle_id\", \"nhits_velo\"]],\n", + " how=\"left\",\n", + " on=[\"particle_id\"],\n", + " )\n", + "\n", + " list_df_edges.append(df_edges)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "df_edges_tot = pd.concat(list_df_edges, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "genuine_scores = (\n", + " df_edges_tot[df_edges_tot[\"genuine\"]]\n", + " .groupby([\"event_id\", \"particle_id\"])[\"score\"]\n", + " .mean()\n", + ")\n", + "fake_scores = (\n", + " df_edges_tot[~df_edges_tot[\"genuine\"]]\n", + " .groupby([\"event_id\", \"particle_id\"])[\"score\"]\n", + " .max()\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "event_id particle_id\n", + "476545 514 0.838615\n", + " 568 0.904482\n", + " 576 0.867661\n", + " 584 0.887780\n", + " 588 0.903285\n", + " ... \n", + "1200239 13982 0.926008\n", + " 14005 0.830914\n", + " 14006 0.933353\n", + " 14007 0.924459\n", + " 14013 0.914115\n", + "Name: score, Length: 25408, dtype: float32" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "genuine_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "df_diff_score = (genuine_scores - fake_scores).fillna(1.5).rename(\"diff_score\").reset_index().merge(\n", + " df_edges.drop_duplicates([\"event_id\", \"particle_id\"])[[\"event_id\", \"particle_id\", \"nhits_velo\"]],\n", + " on=[\"event_id\", \"particle_id\"],\n", + " how=\"left\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for nhits in range(3, 10):\n", + " fig, ax = plt.subplots(figsize=(8, 6))\n", + " diff_score = df_diff_score.query(f\"nhits_velo == {nhits}\")[\"diff_score\"]\n", + " prop_higher_50 = (diff_score > 0.05).sum() / diff_score.shape[0]\n", + " ax.hist(\n", + " x=diff_score,\n", + " bins=50,\n", + " # range=(0.0, 1.0),\n", + " density=False,\n", + " label=f\"{prop_higher_50:.2%}\"\n", + " )\n", + "\n", + " ax.set_yscale(\"log\")\n", + " ax.legend()\n", + " ax.set_title(f\"{nhits} hits\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0, 0, 0, ..., 4641, 4642, 4653],\n", + " [ 229, 231, 232, ..., 4826, 4826, 4826]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.edge_index" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "keep = torch.zeros_like(batch.y)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0, 0, 0, ..., 25, 25, 25])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.plane" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for plane in range(26):\n", + " # get hits\n", + " for hit_id in batch.hit_id[batch.plane == plane]:\n", + " # get next edges\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/slopes.ipynb b/LHCb_Pipeline/analyses/slopes.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..70668f9203ccd02fed301017df8d2f27502f4b35 --- /dev/null +++ b/LHCb_Pipeline/analyses/slopes.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import anaconfig\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "import os\n", + "from tqdm.auto import tqdm\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "\n", + "configure_matplotlib()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "089fe97d81f147bd9b1c656607c35f10", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "train_dir = \"/scratch/acorreia/data/triplets-first/triplet_processed/train\"\n", + "\n", + "diff_slopes_xz = []\n", + "diff_slopes_yz = []\n", + "y = []\n", + "y_pid = []\n", + "\n", + "for filename in tqdm(os.listdir(train_dir)[:100]):\n", + " path = os.path.join(train_dir, filename)\n", + " batch = torch.load(path, map_location=\"cpu\")\n", + " y.append(batch.y.numpy())\n", + " slopes_xz = batch.slopes_xz\n", + " slopes_yz = batch.slopes_yz\n", + "\n", + " diff_slopes_xz.append(\n", + " (slopes_xz[batch.edge_index[1]] - slopes_xz[batch.edge_index[0]]).numpy()\n", + " )\n", + " diff_slopes_yz.append(\n", + " (slopes_yz[batch.edge_index[1]] - slopes_yz[batch.edge_index[0]]).numpy()\n", + " )\n", + " y_pid.append(batch.y_pid.numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "array_diff_slopes_xz = np.concatenate(diff_slopes_xz, axis=0)\n", + "array_diff_slopes_yz = np.concatenate(diff_slopes_yz, axis=0)\n", + "array_y_pid = np.concatenate(y_pid, axis=0)\n", + "array_y = np.concatenate(y, axis=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist2d(\n", + " x=array_diff_slopes_xz[array_y_pid == False],\n", + " y=array_diff_slopes_yz[array_y_pid == False],\n", + " bins=300,\n", + " range=((-.5, .5), (-.5, .5))\n", + ");\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist2d(\n", + " x=array_diff_slopes_xz[array_y_pid == True],\n", + " y=array_diff_slopes_yz[array_y_pid == True],\n", + " bins=300,\n", + " range=((-.5, .5), (-.5, .5))\n", + ");\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7f0822003a00>" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_slopes_yz[array_y_pid == False],\n", + " label=\"Fake triplets\",\n", + " color=\"red\",\n", + " bins=500,\n", + " range=(-5, 5)\n", + ")\n", + "ax.hist(\n", + " array_diff_slopes_xz[array_y_pid == True],\n", + " label=\"Genuine triplets\",\n", + " color=\"green\",\n", + " bins=edges,\n", + ")\n", + "ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.011936023423773843" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "array_diff_slopes_xz[array_y_pid == True].std()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7f08224f1cc0>" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_slopes_yz[array_y == False],\n", + " label=\"False\",\n", + " color=\"red\",\n", + " bins=500,\n", + " range=(-5, 5)\n", + ")\n", + "ax.hist(\n", + " array_diff_slopes_xz[array_y == True],\n", + " label=\"True\",\n", + " color=\"green\",\n", + " bins=edges,\n", + ")\n", + "ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/slopes_after_gnn.ipynb b/LHCb_Pipeline/analyses/slopes_after_gnn.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5f002d654a8abdd94c5e8bbf7165aa3747543230 --- /dev/null +++ b/LHCb_Pipeline/analyses/slopes_after_gnn.ipynb @@ -0,0 +1,397 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "import anaconfig\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "import os\n", + "from tqdm.auto import tqdm\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from utils.plotutils.plotools import save_fig\n", + "\n", + "configure_matplotlib()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3f2fc36a6170419d824525bd707eaca5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "test_dir = \"/scratch/acorreia/data/focal-loss-pid/triplet_building/test/velo-sim10b-nospillover/\"\n", + "\n", + "diff_angles_xz = []\n", + "diff_angles_yz = []\n", + "y = []\n", + "y_pid = []\n", + "min_scores = []\n", + "\n", + "for filename in tqdm(os.listdir(test_dir[:100])):\n", + " path = os.path.join(test_dir, filename)\n", + " batch = torch.load(path, map_location=\"cpu\")\n", + " y.append(batch.y.numpy())\n", + " angles_xz = batch.angle_xz\n", + " angles_yz = batch.angle_yz\n", + " scores = batch.scores[batch.edge_index].min(dim=0).values\n", + "\n", + " diff_angles_xz.append(\n", + " (angles_xz[batch.edge_index[1]] - angles_xz[batch.edge_index[0]]).numpy()\n", + " )\n", + " diff_angles_yz.append(\n", + " (angles_yz[batch.edge_index[1]] - angles_yz[batch.edge_index[0]]).numpy()\n", + " )\n", + " y_pid.append(batch.y_pid.numpy())\n", + " min_scores.append(scores)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.8058, 0.9442, 0.9453, ..., 0.1108, 0.9200, 0.9075])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.scores" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "array_diff_angles_xz = np.concatenate(diff_angles_xz, axis=0)\n", + "array_diff_angles_yz = np.concatenate(diff_angles_yz, axis=0)\n", + "array_y_pid = np.concatenate(y_pid, axis=0)\n", + "array_y = np.concatenate(y, axis=0)\n", + "array_scores = np.concatenate(min_scores, axis=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-2.95289799, -2.95099742, -2.94795341, ..., 0.51473529,\n", + " 0.48531203, 0.5169795 ])" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "array_diff_angles_xz[array_y_pid == False]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Figure was saved in output/analyses/triplet_angles.pdf\n", + "Figure was saved in output/analyses/triplet_angles.png\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_angles_yz[array_y_pid == False],\n", + " label=\"Fake triplets\",\n", + " color=\"red\",\n", + " bins=500,\n", + " range=(-1.0, 1.0)\n", + ")\n", + "ax.hist(\n", + " array_diff_angles_xz[array_y_pid == True],\n", + " label=\"Genuine triplets\",\n", + " color=\"green\",\n", + " bins=edges,\n", + ")\n", + "ax.legend()\n", + "ax.set_xlabel(\"$y$-$z$ angle\")\n", + "save_fig(fig, os.path.join(anaconfig.PLOTDIR, \"triplet_angles\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Figure was saved in output/analyses/true_triplet_angles.pdf\n", + "Figure was saved in output/analyses/true_triplet_angles.png\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_angles_xz[(array_y_pid == True)],\n", + " label=\"Genuine\",\n", + " color=\"green\",\n", + " bins=200,\n", + " range=(-0.01, 0.01)\n", + " \n", + ")\n", + "ax.hist(\n", + " array_diff_angles_yz[(array_scores > 0.8)],\n", + " label=\"Score > 0.8\",\n", + " color=\"purple\",\n", + " bins=edges,\n", + ")\n", + "ax.set_xlabel(\"$y$-$z$ angle\")\n", + "ax.legend()\n", + "\n", + "save_fig(fig, os.path.join(anaconfig.PLOTDIR, \"true_triplet_angles\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Figure was saved in output/analyses/true_triplet_angles_0p2.pdf\n", + "Figure was saved in output/analyses/true_triplet_angles_0p2.png\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_angles_yz[(array_scores > 0.2)],\n", + " label=\"Score > 0.2\",\n", + " color=\"purple\",\n", + " bins=200,\n", + " range=(-0.01, 0.01),\n", + " alpha=0.5,\n", + " \n", + ")\n", + "\n", + "ax.hist(\n", + " array_diff_angles_xz[(array_y_pid == True)],\n", + " label=\"Genuine\",\n", + " color=\"green\",\n", + " bins=edges,\n", + " alpha=0.5,\n", + ")\n", + "ax.set_xlabel(\"$y$-$z$ angle\")\n", + "ax.legend()\n", + "\n", + "save_fig(fig, os.path.join(anaconfig.PLOTDIR, \"true_triplet_angles_0p2\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Figure was saved in output/analyses/true_triplet_angles.pdf\n", + "Figure was saved in output/analyses/true_triplet_angles.png\n" + ] + }, + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7fc9ff0ee8f0>" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "save_fig(fig, os.path.join(anaconfig.PLOTDIR, \"true_triplet_angles\"))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_angles_xz[(array_y_pid == True) & (array_scores < 0.8)],\n", + " label=\"Genuine triplets as Fake\",\n", + " color=\"blue\",\n", + " bins=500,\n", + " range=(-0.05, 0.05)\n", + " \n", + ")\n", + "ax.hist(\n", + " array_diff_angles_yz[(array_y_pid == False) & (array_scores > 0.8)],\n", + " label=\"Fake triplets as True\",\n", + " color=\"purple\",\n", + " bins=edges,\n", + ")\n", + "ax.set_xlabel(\"$y$-$z$ angle\")\n", + "ax.legend()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7fe3489d4580>" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "_, edges, _ = ax.hist(\n", + " array_diff_angles_yz[array_y == False],\n", + " label=\"False\",\n", + " color=\"red\",\n", + " bins=500,\n", + " range=(-5, 5)\n", + ")\n", + "ax.hist(\n", + " array_diff_angles_xz[array_y == True],\n", + " label=\"True\",\n", + " color=\"green\",\n", + " bins=edges,\n", + ")\n", + "ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/track_length.ipynb b/LHCb_Pipeline/analyses/track_length.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2731a06ff07fdd44c1dc36a9fb66170561aed455 --- /dev/null +++ b/LHCb_Pipeline/analyses/track_length.ipynb @@ -0,0 +1,196 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import os.path as op\n", + "\n", + "import anaconfig\n", + "\n", + "import montetracko.lhcb as mtb\n", + "import montetracko.lhcb.category as mtbc\n", + "\n", + "from Preprocessing.preprocessing import load_dataframes\n", + "from Preprocessing.particle_line_fitting import compute_particle_line_metrics_dataframe\n", + "\n", + "from utils.plotutils.plotools import save_fig\n", + "\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "configure_matplotlib()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles, particles = load_dataframes(\n", + " indir=\"/scratch/acorreia/minbias-sim10b-xdigi-nospillover/92\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'Abundance')" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist(particles[\"nhits_velo\"], bins=np.arange(3, 27) - 0.5)\n", + "ax.set_xlabel(\"# hits\")\n", + "ax.set_ylabel(\"Abundance\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "n_particles_velo=1,561,846\n" + ] + } + ], + "source": [ + "n_particles_velo = (particles[\"nhits_velo\"] >= 3).sum()\n", + "print(f\"{n_particles_velo=:,}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$f \\times \\sum_{n = 3}^{26} \\frac{1}{n} = n_{\\text{hits}}$" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "nhits_max = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "factor = 1 / np.arange(3, nhits_max)\n", + "norm = n_particles_velo / factor.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<matplotlib.legend.Legend at 0x7f20e7086140>" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist(\n", + " particles[\"nhits_velo\"],\n", + " bins=np.arange(3, 27) - 0.5,\n", + " label=\"Actual\",\n", + " alpha=0.5,\n", + ")\n", + "\n", + "ax.bar(\n", + " x=np.arange(3, nhits_max),\n", + " height=norm * factor,\n", + " width=1.0,\n", + " color=\"orange\",\n", + " label=\"Normalised\",\n", + " alpha=0.5,\n", + ")\n", + "\n", + "ax.set_xlabel(\"# hits\")\n", + "ax.set_ylabel(\"Abundance\")\n", + "ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analyses/visualisation.ipynb b/LHCb_Pipeline/analyses/visualisation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d379cac1a87fd7ced74ddf198baccd03f31e60be --- /dev/null +++ b/LHCb_Pipeline/analyses/visualisation.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A notebook to visualise the tracks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import anaconfig\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os.path as op\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", + "from Preprocessing.preprocessing import load_dataframes\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "from utils.plotutils.plotools import save_fig\n", + "configure_matplotlib()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles, particles = load_dataframes(\n", + " indir=\"/scratch/acorreia/minbias-sim10b-xdigi-nospillover/92\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "event_ids = hits_particles[\"event\"].unique()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "event_idx = 15\n", + "\n", + "event_hits_particles = hits_particles[hits_particles[\"event\"] == event_ids[event_idx]]\n", + "particles = particles[particles[\"event\"] == event_ids[event_idx]]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_event(hits_particles: pd.DataFrame, show_tracks: bool = False):\n", + " fig, ax = plt.subplots(figsize=(30, 6))\n", + "\n", + " df_hits = hits_particles.drop_duplicates([\"hit_id\"])\n", + " \n", + " ax.axhline(y=0, color='k')\n", + " ax.scatter(\n", + " x=df_hits[\"z\"],\n", + " y=df_hits[\"x\"],\n", + " color=\"grey\",\n", + " s=2,\n", + " )\n", + " ax.set_xlabel(\"$z$ (cm)\")\n", + " ax.set_ylabel(\"$x$ (cm)\")\n", + " ax.set_ylim(-50.0, 50.0)\n", + " ax.set_xlim(-290.0, 760.0)\n", + " ax.grid(color=\"grey\", alpha=0.5)\n", + " ax.set_aspect(1)\n", + "\n", + " if show_tracks:\n", + " for (_, hits_particle) in hits_particles.groupby(by=[\"event\", \"particle_id\"]):\n", + " hit_coordinates = hits_particle.sort_values(by=\"z\")\n", + " ax.plot(\n", + " hit_coordinates[\"z\"],\n", + " hit_coordinates[\"x\"],\n", + " linestyle=\"-\",\n", + " linewidth=1.0,\n", + " marker=None,\n", + " )\n", + " return fig, ax" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plot_event(hits_particles=event_hits_particles)\n", + "save_fig(fig, op.join(anaconfig.PLOTDIR, \"hits_xz\"), dpi=300, exts=[\".svg\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plot_event(hits_particles=event_hits_particles, show_tracks=True)\n", + "save_fig(fig, op.join(anaconfig.PLOTDIR, \"hits_tracks_xz\"), dpi=300, exts=[\".svg\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles[\"z\"].min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def plot_xy_graph(\n", + " df_hits_particles: pd.DataFrame,\n", + " n_tracks: int,\n", + " n_events: int = 10,\n", + " seed: int | None = None,\n", + "):\n", + " fig1, ax1 = plt.subplots(figsize=(12, 12))\n", + " fig2, ax2 = plt.subplots(figsize=(12, 5))\n", + " axes = [ax1, ax2]\n", + "\n", + " for ax in axes:\n", + " ax.axhline(y=0.0, color=\"k\", linewidth=0.5)\n", + " ax.axvline(x=0.0, color=\"k\", linewidth=0.5)\n", + "\n", + " event_ids = df_hits_particles[\"event\"].unique()\n", + "\n", + " rng = np.random.default_rng(seed=seed)\n", + " rng.shuffle(event_ids)\n", + "\n", + " for idx, (_, hits_particle) in enumerate(\n", + " df_hits_particles[\n", + " df_hits_particles[\"event\"].isin(event_ids[:n_events])\n", + " ].groupby(\n", + " by=[\"event\", \"particle_id\"]\n", + " )\n", + " ):\n", + " hit_coordinates = hits_particle.sort_values(by=\"plane\")\n", + " ax1.plot(\n", + " hit_coordinates[\"x\"],\n", + " hit_coordinates[\"y\"],\n", + " linestyle=\"-\",\n", + " markersize=5.0,\n", + " marker=\"o\",\n", + " )\n", + " ax2.plot(\n", + " hit_coordinates[\"z\"],\n", + " hit_coordinates[\"x\"],\n", + " linestyle=\"-\",\n", + " markersize=5.0,\n", + " marker=\"o\",\n", + " )\n", + " if idx > n_tracks:\n", + " break\n", + "\n", + " for ax in axes:\n", + " ax.set_ylim(-50.0, 50.0)\n", + " ax.grid(color=\"grey\", alpha=0.5)\n", + " \n", + " ax1.set_xlim(-50.0, 50.0)\n", + " ax2.set_xlim(-570, 470)\n", + " return fig1, ax1, fig2, ax2\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_xy_graph(\n", + " df_hits_particles=hits_particles,\n", + " n_tracks=10,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LHCb_Pipeline/analysis.ipynb b/LHCb_Pipeline/analysis.ipynb index cc0fe9e76fc910f37ad37da63ebfd4e51cd8d451..5adeae21cf3708adf740f6f1d831dc560e09c19f 100644 --- a/LHCb_Pipeline/analysis.ipynb +++ b/LHCb_Pipeline/analysis.ipynb @@ -40,6 +40,32 @@ ")\n" ] }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.03448266360631707" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(\n", + " (df_particles[\"has_velo\"]) & (df_particles[\"has_scifi\"]) & (df_particles[\"eta\"] > 2)\n", + " & (df_particles[\"eta\"] < 5.0) & (df_particles[\"pid\"].abs() == 11) \n", + ").sum() / (\n", + " (df_particles[\"has_velo\"]) & (df_particles[\"eta\"] > 2)\n", + " & (df_particles[\"eta\"] < 5.0)\n", + ").sum()\n" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -50,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -65,7 +91,7 @@ " **kwargs,\n", " )\n", " ax.set_xlabel(\"# hits in the velo\")\n", - " ax.set_ylabel(\"Origin vertex coordinate z-position\")\n", + " ax.set_ylabel(\"$v_z$\")\n", " fig.colorbar(im, ax=ax)\n", "\n", " return fig, ax\n" @@ -73,12 +99,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "<Figure size 800x600 with 2 Axes>" ] @@ -88,6 +114,9 @@ } ], "source": [ + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "configure_matplotlib()\n", + "\n", "fig, _ = compare_nhits_vz(\n", " df_particles=df_particles[\n", " (df_particles[\"has_velo\"] == 1)\n", @@ -95,8 +124,8 @@ " & (df_particles[\"pid\"].abs() == 11)\n", " ]\n", ")\n", - "\n", - "fig.savefig(\"nhits_velo_vs_vz_long_electrons.pdf\")\n" + "fig.tight_layout()\n", + "fig.savefig(\"nhits_velo_vs_vz_long_electrons.png\", dpi=300, transparent=True)\n" ] }, { @@ -161,6 +190,7 @@ } ], "source": [ + "\n", "fig, _ = compare_nhits_vz(\n", " df_particles=df_particles[\n", " (df_particles[\"has_velo\"] == 1)\n", diff --git a/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed.ipynb b/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b10c97baf089dfc18fc58a1033e24d1c678e67bf --- /dev/null +++ b/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed.ipynb @@ -0,0 +1,1714 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 0. Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div class=\"bk-root\">\n", + " <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n", + " <span id=\"1002\">Loading BokehJS ...</span>\n", + " </div>\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "(function(root) {\n", + " function now() {\n", + " return new Date();\n", + " }\n", + "\n", + " const force = true;\n", + "\n", + " if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n", + " root._bokeh_onload_callbacks = [];\n", + " root._bokeh_is_loading = undefined;\n", + " }\n", + "\n", + "const JS_MIME_TYPE = 'application/javascript';\n", + " const HTML_MIME_TYPE = 'text/html';\n", + " const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n", + " const CLASS_NAME = 'output_bokeh rendered_html';\n", + "\n", + " /**\n", + " * Render data to the DOM node\n", + " */\n", + " function render(props, node) {\n", + " const script = document.createElement(\"script\");\n", + " node.appendChild(script);\n", + " }\n", + "\n", + " /**\n", + " * Handle when an output is cleared or removed\n", + " */\n", + " function handleClearOutput(event, handle) {\n", + " const cell = handle.cell;\n", + "\n", + " const id = cell.output_area._bokeh_element_id;\n", + " const server_id = cell.output_area._bokeh_server_id;\n", + " // Clean up Bokeh references\n", + " if (id != null && id in Bokeh.index) {\n", + " Bokeh.index[id].model.document.clear();\n", + " delete Bokeh.index[id];\n", + " }\n", + "\n", + " if (server_id !== undefined) {\n", + " // Clean up Bokeh references\n", + " const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n", + " cell.notebook.kernel.execute(cmd_clean, {\n", + " iopub: {\n", + " output: function(msg) {\n", + " const id = msg.content.text.trim();\n", + " if (id in Bokeh.index) {\n", + " Bokeh.index[id].model.document.clear();\n", + " delete Bokeh.index[id];\n", + " }\n", + " }\n", + " }\n", + " });\n", + " // Destroy server and session\n", + " const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n", + " cell.notebook.kernel.execute(cmd_destroy);\n", + " }\n", + " }\n", + "\n", + " /**\n", + " * Handle when a new output is added\n", + " */\n", + " function handleAddOutput(event, handle) {\n", + " const output_area = handle.output_area;\n", + " const output = handle.output;\n", + "\n", + " // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n", + " if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n", + " return\n", + " }\n", + "\n", + " const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n", + "\n", + " if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n", + " toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n", + " // store reference to embed id on output_area\n", + " output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n", + " }\n", + " if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n", + " const bk_div = document.createElement(\"div\");\n", + " bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n", + " const script_attrs = bk_div.children[0].attributes;\n", + " for (let i = 0; i < script_attrs.length; i++) {\n", + " toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n", + " toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n", + " }\n", + " // store reference to server id on output_area\n", + " output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n", + " }\n", + " }\n", + "\n", + " function register_renderer(events, OutputArea) {\n", + "\n", + " function append_mime(data, metadata, element) {\n", + " // create a DOM node to render to\n", + " const toinsert = this.create_output_subarea(\n", + " metadata,\n", + " CLASS_NAME,\n", + " EXEC_MIME_TYPE\n", + " );\n", + " this.keyboard_manager.register_events(toinsert);\n", + " // Render to node\n", + " const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n", + " render(props, toinsert[toinsert.length - 1]);\n", + " element.append(toinsert);\n", + " return toinsert\n", + " }\n", + "\n", + " /* Handle when an output is cleared or removed */\n", + " events.on('clear_output.CodeCell', handleClearOutput);\n", + " events.on('delete.Cell', handleClearOutput);\n", + "\n", + " /* Handle when a new output is added */\n", + " events.on('output_added.OutputArea', handleAddOutput);\n", + "\n", + " /**\n", + " * Register the mime type and append_mime function with output_area\n", + " */\n", + " OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n", + " /* Is output safe? */\n", + " safe: true,\n", + " /* Index of renderer in `output_area.display_order` */\n", + " index: 0\n", + " });\n", + " }\n", + "\n", + " // register the mime type if in Jupyter Notebook environment and previously unregistered\n", + " if (root.Jupyter !== undefined) {\n", + " const events = require('base/js/events');\n", + " const OutputArea = require('notebook/js/outputarea').OutputArea;\n", + "\n", + " if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n", + " register_renderer(events, OutputArea);\n", + " }\n", + " }\n", + " if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n", + " root._bokeh_timeout = Date.now() + 5000;\n", + " root._bokeh_failed_load = false;\n", + " }\n", + "\n", + " const NB_LOAD_WARNING = {'data': {'text/html':\n", + " \"<div style='background-color: #fdd'>\\n\"+\n", + " \"<p>\\n\"+\n", + " \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n", + " \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n", + " \"</p>\\n\"+\n", + " \"<ul>\\n\"+\n", + " \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n", + " \"<li>use INLINE resources instead, as so:</li>\\n\"+\n", + " \"</ul>\\n\"+\n", + " \"<code>\\n\"+\n", + " \"from bokeh.resources import INLINE\\n\"+\n", + " \"output_notebook(resources=INLINE)\\n\"+\n", + " \"</code>\\n\"+\n", + " \"</div>\"}};\n", + "\n", + " function display_loaded() {\n", + " const el = document.getElementById(\"1002\");\n", + " if (el != null) {\n", + " el.textContent = \"BokehJS is loading...\";\n", + " }\n", + " if (root.Bokeh !== undefined) {\n", + " if (el != null) {\n", + " el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n", + " }\n", + " } else if (Date.now() < root._bokeh_timeout) {\n", + " setTimeout(display_loaded, 100)\n", + " }\n", + " }\n", + "\n", + " function run_callbacks() {\n", + " try {\n", + " root._bokeh_onload_callbacks.forEach(function(callback) {\n", + " if (callback != null)\n", + " callback();\n", + " });\n", + " } finally {\n", + " delete root._bokeh_onload_callbacks\n", + " }\n", + " console.debug(\"Bokeh: all callbacks have finished\");\n", + " }\n", + "\n", + " function load_libs(css_urls, js_urls, callback) {\n", + " if (css_urls == null) css_urls = [];\n", + " if (js_urls == null) js_urls = [];\n", + "\n", + " root._bokeh_onload_callbacks.push(callback);\n", + " if (root._bokeh_is_loading > 0) {\n", + " console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n", + " return null;\n", + " }\n", + " if (js_urls == null || js_urls.length === 0) {\n", + " run_callbacks();\n", + " return null;\n", + " }\n", + " console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n", + " root._bokeh_is_loading = css_urls.length + js_urls.length;\n", + "\n", + " function on_load() {\n", + " root._bokeh_is_loading--;\n", + " if (root._bokeh_is_loading === 0) {\n", + " console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n", + " run_callbacks()\n", + " }\n", + " }\n", + "\n", + " function on_error(url) {\n", + " console.error(\"failed to load \" + url);\n", + " }\n", + "\n", + " for (let i = 0; i < css_urls.length; i++) {\n", + " const url = css_urls[i];\n", + " const element = document.createElement(\"link\");\n", + " element.onload = on_load;\n", + " element.onerror = on_error.bind(null, url);\n", + " element.rel = \"stylesheet\";\n", + " element.type = \"text/css\";\n", + " element.href = url;\n", + " console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n", + " document.body.appendChild(element);\n", + " }\n", + "\n", + " for (let i = 0; i < js_urls.length; i++) {\n", + " const url = js_urls[i];\n", + " const element = document.createElement('script');\n", + " element.onload = on_load;\n", + " element.onerror = on_error.bind(null, url);\n", + " element.async = false;\n", + " element.src = url;\n", + " console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n", + " document.head.appendChild(element);\n", + " }\n", + " };\n", + "\n", + " function inject_raw_css(css) {\n", + " const element = document.createElement(\"style\");\n", + " element.appendChild(document.createTextNode(css));\n", + " document.body.appendChild(element);\n", + " }\n", + "\n", + " const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n", + " const css_urls = [];\n", + "\n", + " const inline_js = [ function(Bokeh) {\n", + " Bokeh.set_log_level(\"info\");\n", + " },\n", + "function(Bokeh) {\n", + " }\n", + " ];\n", + "\n", + " function run_inline_js() {\n", + " if (root.Bokeh !== undefined || force === true) {\n", + " for (let i = 0; i < inline_js.length; i++) {\n", + " inline_js[i].call(root, root.Bokeh);\n", + " }\n", + "if (force === true) {\n", + " display_loaded();\n", + " }} else if (Date.now() < root._bokeh_timeout) {\n", + " setTimeout(run_inline_js, 100);\n", + " } else if (!root._bokeh_failed_load) {\n", + " console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n", + " root._bokeh_failed_load = true;\n", + " } else if (force !== true) {\n", + " const cell = $(document.getElementById(\"1002\")).parents('.cell').data().cell;\n", + " cell.output_area.append_execute_result(NB_LOAD_WARNING)\n", + " }\n", + " }\n", + "\n", + " if (root._bokeh_is_loading === 0) {\n", + " console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n", + " run_inline_js();\n", + " } else {\n", + " load_libs(css_urls, js_urls, function() {\n", + " console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n", + " run_inline_js();\n", + " });\n", + " }\n", + "}(window));" + ], + "application/vnd.bokehjs_load.v0+json": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\n\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"<div style='background-color: #fdd'>\\n\"+\n \"<p>\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"</p>\\n\"+\n \"<ul>\\n\"+\n \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n \"<li>use INLINE resources instead, as so:</li>\\n\"+\n \"</ul>\\n\"+\n \"<code>\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"</code>\\n\"+\n \"</div>\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"1002\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"1002\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "<div class=\"bk-root\">\n", + " <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n", + " <span id=\"1003\">Loading BokehJS ...</span>\n", + " </div>\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "(function(root) {\n", + " function now() {\n", + " return new Date();\n", + " }\n", + "\n", + " const force = true;\n", + "\n", + " if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n", + " root._bokeh_onload_callbacks = [];\n", + " root._bokeh_is_loading = undefined;\n", + " }\n", + "\n", + "const JS_MIME_TYPE = 'application/javascript';\n", + " const HTML_MIME_TYPE = 'text/html';\n", + " const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n", + " const CLASS_NAME = 'output_bokeh rendered_html';\n", + "\n", + " /**\n", + " * Render data to the DOM node\n", + " */\n", + " function render(props, node) {\n", + " const script = document.createElement(\"script\");\n", + " node.appendChild(script);\n", + " }\n", + "\n", + " /**\n", + " * Handle when an output is cleared or removed\n", + " */\n", + " function handleClearOutput(event, handle) {\n", + " const cell = handle.cell;\n", + "\n", + " const id = cell.output_area._bokeh_element_id;\n", + " const server_id = cell.output_area._bokeh_server_id;\n", + " // Clean up Bokeh references\n", + " if (id != null && id in Bokeh.index) {\n", + " Bokeh.index[id].model.document.clear();\n", + " delete Bokeh.index[id];\n", + " }\n", + "\n", + " if (server_id !== undefined) {\n", + " // Clean up Bokeh references\n", + " const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n", + " cell.notebook.kernel.execute(cmd_clean, {\n", + " iopub: {\n", + " output: function(msg) {\n", + " const id = msg.content.text.trim();\n", + " if (id in Bokeh.index) {\n", + " Bokeh.index[id].model.document.clear();\n", + " delete Bokeh.index[id];\n", + " }\n", + " }\n", + " }\n", + " });\n", + " // Destroy server and session\n", + " const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n", + " cell.notebook.kernel.execute(cmd_destroy);\n", + " }\n", + " }\n", + "\n", + " /**\n", + " * Handle when a new output is added\n", + " */\n", + " function handleAddOutput(event, handle) {\n", + " const output_area = handle.output_area;\n", + " const output = handle.output;\n", + "\n", + " // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n", + " if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n", + " return\n", + " }\n", + "\n", + " const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n", + "\n", + " if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n", + " toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n", + " // store reference to embed id on output_area\n", + " output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n", + " }\n", + " if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n", + " const bk_div = document.createElement(\"div\");\n", + " bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n", + " const script_attrs = bk_div.children[0].attributes;\n", + " for (let i = 0; i < script_attrs.length; i++) {\n", + " toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n", + " toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n", + " }\n", + " // store reference to server id on output_area\n", + " output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n", + " }\n", + " }\n", + "\n", + " function register_renderer(events, OutputArea) {\n", + "\n", + " function append_mime(data, metadata, element) {\n", + " // create a DOM node to render to\n", + " const toinsert = this.create_output_subarea(\n", + " metadata,\n", + " CLASS_NAME,\n", + " EXEC_MIME_TYPE\n", + " );\n", + " this.keyboard_manager.register_events(toinsert);\n", + " // Render to node\n", + " const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n", + " render(props, toinsert[toinsert.length - 1]);\n", + " element.append(toinsert);\n", + " return toinsert\n", + " }\n", + "\n", + " /* Handle when an output is cleared or removed */\n", + " events.on('clear_output.CodeCell', handleClearOutput);\n", + " events.on('delete.Cell', handleClearOutput);\n", + "\n", + " /* Handle when a new output is added */\n", + " events.on('output_added.OutputArea', handleAddOutput);\n", + "\n", + " /**\n", + " * Register the mime type and append_mime function with output_area\n", + " */\n", + " OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n", + " /* Is output safe? */\n", + " safe: true,\n", + " /* Index of renderer in `output_area.display_order` */\n", + " index: 0\n", + " });\n", + " }\n", + "\n", + " // register the mime type if in Jupyter Notebook environment and previously unregistered\n", + " if (root.Jupyter !== undefined) {\n", + " const events = require('base/js/events');\n", + " const OutputArea = require('notebook/js/outputarea').OutputArea;\n", + "\n", + " if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n", + " register_renderer(events, OutputArea);\n", + " }\n", + " }\n", + " if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n", + " root._bokeh_timeout = Date.now() + 5000;\n", + " root._bokeh_failed_load = false;\n", + " }\n", + "\n", + " const NB_LOAD_WARNING = {'data': {'text/html':\n", + " \"<div style='background-color: #fdd'>\\n\"+\n", + " \"<p>\\n\"+\n", + " \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n", + " \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n", + " \"</p>\\n\"+\n", + " \"<ul>\\n\"+\n", + " \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n", + " \"<li>use INLINE resources instead, as so:</li>\\n\"+\n", + " \"</ul>\\n\"+\n", + " \"<code>\\n\"+\n", + " \"from bokeh.resources import INLINE\\n\"+\n", + " \"output_notebook(resources=INLINE)\\n\"+\n", + " \"</code>\\n\"+\n", + " \"</div>\"}};\n", + "\n", + " function display_loaded() {\n", + " const el = document.getElementById(\"1003\");\n", + " if (el != null) {\n", + " el.textContent = \"BokehJS is loading...\";\n", + " }\n", + " if (root.Bokeh !== undefined) {\n", + " if (el != null) {\n", + " el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n", + " }\n", + " } else if (Date.now() < root._bokeh_timeout) {\n", + " setTimeout(display_loaded, 100)\n", + " }\n", + " }\n", + "\n", + " function run_callbacks() {\n", + " try {\n", + " root._bokeh_onload_callbacks.forEach(function(callback) {\n", + " if (callback != null)\n", + " callback();\n", + " });\n", + " } finally {\n", + " delete root._bokeh_onload_callbacks\n", + " }\n", + " console.debug(\"Bokeh: all callbacks have finished\");\n", + " }\n", + "\n", + " function load_libs(css_urls, js_urls, callback) {\n", + " if (css_urls == null) css_urls = [];\n", + " if (js_urls == null) js_urls = [];\n", + "\n", + " root._bokeh_onload_callbacks.push(callback);\n", + " if (root._bokeh_is_loading > 0) {\n", + " console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n", + " return null;\n", + " }\n", + " if (js_urls == null || js_urls.length === 0) {\n", + " run_callbacks();\n", + " return null;\n", + " }\n", + " console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n", + " root._bokeh_is_loading = css_urls.length + js_urls.length;\n", + "\n", + " function on_load() {\n", + " root._bokeh_is_loading--;\n", + " if (root._bokeh_is_loading === 0) {\n", + " console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n", + " run_callbacks()\n", + " }\n", + " }\n", + "\n", + " function on_error(url) {\n", + " console.error(\"failed to load \" + url);\n", + " }\n", + "\n", + " for (let i = 0; i < css_urls.length; i++) {\n", + " const url = css_urls[i];\n", + " const element = document.createElement(\"link\");\n", + " element.onload = on_load;\n", + " element.onerror = on_error.bind(null, url);\n", + " element.rel = \"stylesheet\";\n", + " element.type = \"text/css\";\n", + " element.href = url;\n", + " console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n", + " document.body.appendChild(element);\n", + " }\n", + "\n", + " for (let i = 0; i < js_urls.length; i++) {\n", + " const url = js_urls[i];\n", + " const element = document.createElement('script');\n", + " element.onload = on_load;\n", + " element.onerror = on_error.bind(null, url);\n", + " element.async = false;\n", + " element.src = url;\n", + " console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n", + " document.head.appendChild(element);\n", + " }\n", + " };\n", + "\n", + " function inject_raw_css(css) {\n", + " const element = document.createElement(\"style\");\n", + " element.appendChild(document.createTextNode(css));\n", + " document.body.appendChild(element);\n", + " }\n", + "\n", + " const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n", + " const css_urls = [];\n", + "\n", + " const inline_js = [ function(Bokeh) {\n", + " Bokeh.set_log_level(\"info\");\n", + " },\n", + "function(Bokeh) {\n", + " }\n", + " ];\n", + "\n", + " function run_inline_js() {\n", + " if (root.Bokeh !== undefined || force === true) {\n", + " for (let i = 0; i < inline_js.length; i++) {\n", + " inline_js[i].call(root, root.Bokeh);\n", + " }\n", + "if (force === true) {\n", + " display_loaded();\n", + " }} else if (Date.now() < root._bokeh_timeout) {\n", + " setTimeout(run_inline_js, 100);\n", + " } else if (!root._bokeh_failed_load) {\n", + " console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n", + " root._bokeh_failed_load = true;\n", + " } else if (force !== true) {\n", + " const cell = $(document.getElementById(\"1003\")).parents('.cell').data().cell;\n", + " cell.output_area.append_execute_result(NB_LOAD_WARNING)\n", + " }\n", + " }\n", + "\n", + " if (root._bokeh_is_loading === 0) {\n", + " console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n", + " run_inline_js();\n", + " } else {\n", + " load_libs(css_urls, js_urls, function() {\n", + " console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n", + " run_inline_js();\n", + " });\n", + " }\n", + "}(window));" + ], + "application/vnd.bokehjs_load.v0+json": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\n\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"<div style='background-color: #fdd'>\\n\"+\n \"<p>\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"</p>\\n\"+\n \"<ul>\\n\"+\n \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n \"<li>use INLINE resources instead, as so:</li>\\n\"+\n \"</ul>\\n\"+\n \"<code>\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"</code>\\n\"+\n \"</div>\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"1003\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"1003\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from __future__ import annotations\n", + "import typing\n", + "import logging\n", + "import os\n", + "import sys\n", + "import warnings\n", + "\n", + "sys.path.append('../montetracko')\n", + "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n", + "\n", + "import yaml\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from Preprocessing.run_preprocessing import run_preprocessing_test_dataset\n", + "from Preprocessing.run_preprocessing import run_preprocessing\n", + "from Processing.run_processing import run_processing_from_config\n", + "\n", + "from Scripts.Step_1_Train_Metric_Learning import train as train_metric_learning\n", + "from Scripts.Step_2_Run_Metric_Learning import train as run_metric_learning_inference\n", + "from Scripts.Step_3_Train_GNN import train as train_gnn\n", + "from Scripts.Step_4_Run_GNN import train as run_gnn_inference\n", + "from Scripts.Step_5_Build_Track_Candidates import train as build_track_candidates\n", + "from Scripts.Step_6_Evaluate_Reconstruction_MonteTracko import (\n", + " evaluate as evaluate_candidates_montetracko\n", + ")\n", + "\n", + "from utils.plotutils import graph as graphplot\n", + "from utils.plotutils import performance as perfplot\n", + "from utils.plotutils import performance_mpl as perfplot_mpl\n", + "from utils.commonutils.ctests import get_required_test_dataset_names\n", + "from utils.commonutils.config import load_config\n", + "from utils.modelutils import checkpoint_utils\n", + "from utils.scriptutils.loghandler import headline\n", + "\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "\n", + "configure_matplotlib()\n", + "\n", + "warnings.filterwarnings(\n", + " \"ignore\", message=(\n", + " \"TypedStorage is deprecated. It will be removed in the future and \"\n", + " \"UntypedStorage will be the only storage class. This should only matter to you \"\n", + " \"if you are using storages directly.\"\n", + " )\n", + ")\n", + "\n", + "CONFIG = 'pipeline_configs/focal-loss-pid-fixed.yaml'\n", + "\n", + "run_training: bool = True\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pipeline configurations\n", + "\n", + "The configurations for the entire pipeline are defined under pipeline_config.yml." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download data\n", + "Uncomment if you do not already have the data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# path = 'data/input/0'\n", + "# os.makedirs(path, exist_ok=True)\n", + "# ! xrdcp -r root://eoslhcb.cern.ch//eos/lhcb/user/a/anthonyc/tracking/data/csv/v2/minbias-sim10b-xdigi/0 data/input --parallel 4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Telegram notification bot" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import json\n", + "\n", + "# from datetime import datetime\n", + "\n", + "# def send_telegram_message(message: str,\n", + "# chat_id: str,\n", + "# api_key: str,\n", + "# ):\n", + "# responses = {}\n", + "\n", + "# url = f'https://api.telegram.org/bot{api_key}/sendMessage?chat_id={chat_id}&text={message}'\n", + " \n", + "# response = requests.post(url)\n", + " \n", + "# return response\n", + "\n", + "def send_telegram_message(*args, **kwargs):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "chat_id = \"5027012918\"\n", + "api_key = \"6268687426:AAE1P7WQofCBuQPiYZlYaKU-p1GNn6OvAxM\"\n", + "\n", + "send_telegram_message(\"======================\", chat_id, api_key)\n", + "\n", + "send_telegram_message(\"Starting.\", chat_id, api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preprocess the test samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for required_test_dataset_name in get_required_test_dataset_names(CONFIG):\n", + " run_preprocessing_test_dataset(\n", + " test_dataset_name=required_test_dataset_name,\n", + " path_or_config=CONFIG,\n", + " path_or_config_test=\"test_samples.yaml\",\n", + " reproduce=False,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preprocessing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_preprocessing(CONFIG, reproduce=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_processing_from_config(CONFIG, reproduce=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for required_test_dataset_name in get_required_test_dataset_names(CONFIG):\n", + " run_processing_from_config(\n", + " test_dataset_name=required_test_dataset_name,\n", + " path_or_config=CONFIG,\n", + " reproduce=False,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Train Metric Learning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## What it does\n", + "Broadly speaking, the first stage of our pipeline is embedding the space points on to graphs, in a way that is efficient, i.e. we miss as few points on a graph as possible. We train a MLP to transform the input feature vector of each space point $\\mathbf{u}_i$ into an N-dimensional latent space $\\mathbf{v}_i$. The graph is then constructed by connecting the space points whose Euclidean distance between the latent space points $$d_{ij} = \\left| \\mathbf{v}_i - \\mathbf{v}_j \\right| < r_{embedding}$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training data\n", + "Let us take a look at the data before training. In this example pipeline, we have preprocessed the TrackML data into a more convenient form. We calculated directional information and summary statistics from the charge deposited in each spacepoints, and append them to its cyclidrical coordinates. Let us load an example data file and inspect the content." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from Embedding.embedding_base import get_example_data\n", + "example_data_df, example_data_pyg = get_example_data(CONFIG)\n", + "example_data_df\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graphplot.plot_true_graph(example_data_pyg, CONFIG, num_tracks=50)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train metric learning model\n", + "\n", + "Finally we come to model training. By default, we train the MLP for 30 epochs, which takes approximately 15 minutes on an NVidia V100. Feel free to adjust the epoch number in pipeline_config.yml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! nvcc --version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "! nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if run_training:\n", + " send_telegram_message('Started metric learning training.', chat_id, api_key)\n", + "\n", + " metric_learning_trainer, metric_learning_model = train_metric_learning(CONFIG)\n", + "\n", + " send_telegram_message('Finished metric learning training.', chat_id, api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the case you want to continue the training of a certain network, you may\n", + "use the code below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### From checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "embedding_metric_path='artifacts/metric_learning/focal-loss-pid-fixed/version_0/metrics.csv'\n", + "embedding_artifact_path='artifacts/metric_learning/focal-loss-pid-fixed/version_0/checkpoints/epoch=16-step=170000.ckpt'\n" + ] + } + ], + "source": [ + "from Embedding.layerless_embedding import LayerlessEmbedding\n", + "\n", + "embedding_version_dir = checkpoint_utils.get_last_version_dir_from_config(\n", + " step=\"metric_learning\", path_or_config=CONFIG\n", + ")\n", + "embedding_metric_path = os.path.join(embedding_version_dir, \"metrics.csv\")\n", + "embedding_artifact_path = checkpoint_utils.get_last_artifact(\n", + " version_dir=embedding_version_dir\n", + ")\n", + "print(f\"{embedding_metric_path=}\")\n", + "print(f\"{embedding_artifact_path=}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To continue the training of the network, you may use the code below" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_lightning.loggers import CSVLogger\n", + "from pytorch_lightning import Trainer\n", + "\n", + "\n", + "def continue_embedding_training(\n", + " path_or_config: str | dict,\n", + ") -> typing.Tuple[Trainer, LayerlessEmbedding]:\n", + " config = load_config(path_or_config=path_or_config)\n", + "\n", + " metric_learning_model = LayerlessEmbedding.load_from_checkpoint(\n", + " embedding_artifact_path\n", + " ) # you may change `metric_learning_model`\n", + "\n", + " save_directory = os.path.abspath(\n", + " os.path.join(config[\"common\"][\"artifact_directory\"], \"metric_learning\")\n", + " )\n", + "\n", + " logger = CSVLogger(save_directory, name=config[\"common\"][\"experiment_name\"])\n", + "\n", + " metric_learning_trainer = Trainer(\n", + " accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n", + " devices=1,\n", + " max_epochs=40, # you may increase the number of epochs\n", + " logger=logger,\n", + " # callbacks=[EarlyStopping(monitor=\"train_loss\", mode=\"min\")]\n", + " )\n", + "\n", + " metric_learning_trainer.fit(metric_learning_model)\n", + " return metric_learning_trainer, metric_learning_model\n", + "\n", + "\n", + "# metric_learning_trainer, metric_learning_model = continue_embedding_training(CONFIG)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_learning_model = LayerlessEmbedding.load_from_checkpoint(\n", + " embedding_artifact_path,\n", + " # map_location=\"cpu\",\n", + " # If importing model from another experiment\n", + " hparams=load_config(CONFIG)[\"metric_learning\"],\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot training metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can examine how the training went. This is stored in a simple dataframe:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# embedding_metrics = checkpoint_utils.get_training_metrics(metric_learning_trainer) \n", + "\n", + "embedding_metrics = checkpoint_utils.get_training_metrics(embedding_metric_path)\n", + "\n", + "embedding_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "perfplot_mpl.plot_loss(embedding_metrics, CONFIG, \"metric_learning\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "perfplot_mpl.plot_metric_epochs(\n", + " \"eff\",\n", + " embedding_metrics,\n", + " CONFIG,\n", + " \"metric_learning\",\n", + " \"Edge Efficiency\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n", + "perfplot_mpl.plot_metric_epochs(\n", + " \"pur\",\n", + " embedding_metrics,\n", + " CONFIG,\n", + " \"metric_learning\",\n", + " \"Edge Purity\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate model performance on sample test data\n", + "\n", + "Here we evaluate the model performance on one sample test data. We look at how the efficiency and purity change with the embedding radius." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from Embedding import embedding_plots\n", + "\n", + "embedding_plots.plot_embedding_performance_given_radius_knn_max(\n", + " model=metric_learning_model,\n", + " path_or_config=CONFIG,\n", + " radius=np.linspace(0.01, 0.05, 10),\n", + " n_events=20,\n", + " partitions=[\"train\", \"val\", \"velo-sim10b-nospillover\"],\n", + ");\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot example truth and predicted graphs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "metric_learning_model.load_partition(\"velo-sim10b-nospillover\")\n", + "graphplot.plot_predicted_graph(metric_learning_model, CONFIG)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from Embedding.embedding_plots import plot_best_performances_radius\n", + "plot_best_performances_radius(\n", + " model=metric_learning_model,\n", + " path_or_config=CONFIG,\n", + " partition=\"velo-sim10b-nospillover\",\n", + " # list_radius=[0.015, 0.020],\n", + " list_radius=[0.015, 0.020, 0.025, 0.030, 0.035, 0.040],\n", + " n_events=200,\n", + " seed=0,\n", + ");\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Track lengths" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_learning_model.load_partition(\"velo-sim10b-nospillover\")\n", + "perfplot.plot_track_lengths(metric_learning_model, CONFIG)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_learning_model.setup(stage=\"fit\") # load train and val datasets\n", + "\n", + "perfplot.plot_graph_sizes(metric_learning_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Construct graphs from metric learning inference\n", + "\n", + "This step performs model inference on the entire input datasets (train, validation and test), to obtain input graphs to the graph neural network. Optionally, we also clear the directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graph_builder = run_metric_learning_inference(\n", + " CONFIG,\n", + " checkpoint=embedding_artifact_path,\n", + " # checkpoint=metric_learning_model, # here directly use the model\n", + " reproduce=False,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Train graph neural networks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We have a set of graphs constructed. We now train a GNN to classify edges as either \"true\" (belonging to the same track) or \"false\" (not belonging to the same track). We train for 30 epochs, which should take around 10 minutes on a V100 GPU. Your mileage may vary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:------------------------- Step 3: Running GNN training -------------------------\n", + "INFO:----------------------------- a) Initialising model -----------------------------\n", + "INFO:------------------------------ b) Running training ------------------------------\n", + "Missing logger folder: /home/acorreia/Documents/tracking/etx4velo/LHCb_Pipeline/artifacts/gnn/focal-loss-pid-fixed\n", + "INFO:Save hyperparameters, metrics and artifacts in /home/acorreia/Documents/tracking/etx4velo/LHCb_Pipeline/artifacts/gnn/focal-loss-pid-fixed/version_0\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "INFO:Load 10000 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/train\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bbbe40dfdf7143e085e7a813d5787d6d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/10000 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:Load 500 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/val\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9edd7c25454547c6a782b3302caa67ad", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/500 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "------------------------------------------------------\n", + "0 | node_encoder | Sequential | 332 K \n", + "1 | edge_encoder | Sequential | 462 K \n", + "2 | edge_network | Sequential | 793 K \n", + "3 | node_network | Sequential | 659 K \n", + "4 | output_edge_classifier | Sequential | 529 K \n", + "------------------------------------------------------\n", + "2.8 M Trainable params\n", + "0 Non-trainable params\n", + "2.8 M Total params\n", + "11.111 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5ee6357167424736846d37949bf55de0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "if run_training:\n", + " send_telegram_message('Started GNN training.', chat_id, api_key)\n", + " with warnings.catch_warnings():\n", + " warnings.filterwarnings(\n", + " \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n", + " )\n", + " gnn_trainer, gnn_model = train_gnn(CONFIG)\n", + "\n", + " send_telegram_message('Finished GNN training.', chat_id, api_key)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### From checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from utils.modelutils.checkpoint_utils import (\n", + " get_last_version_dir_from_config,\n", + " get_last_artifact,\n", + ")\n", + "from GNN.interaction_gnn import InteractionGNN\n", + "\n", + "gnn_version_dir = get_last_version_dir_from_config(step=\"gnn\", path_or_config=CONFIG)\n", + "gnn_metric_path = os.path.join(gnn_version_dir, \"metrics.csv\")\n", + "gnn_artifact_path = get_last_artifact(version_dir=gnn_version_dir)\n", + "print(f\"{gnn_metric_path=}\")\n", + "print(f\"{gnn_artifact_path=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_lightning import Trainer\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "\n", + "\n", + "def continue_gnn_training(\n", + " path_or_config: str | dict,\n", + ") -> typing.Tuple[Trainer, InteractionGNN]:\n", + " config = load_config(path_or_config=path_or_config)\n", + "\n", + " gnn_model = InteractionGNN.load_from_checkpoint(\n", + " gnn_artifact_path, **config[\"gnn\"]\n", + " ) # you may change `gnn_model`\n", + "\n", + " save_directory = os.path.abspath(\n", + " os.path.join(config[\"common\"][\"artifact_directory\"], \"gnn\")\n", + " )\n", + "\n", + " logger = CSVLogger(save_directory, name=config[\"common\"][\"experiment_name\"])\n", + "\n", + " gnn_trainer = Trainer(\n", + " accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n", + " devices=1,\n", + " max_epochs=50, # you may increase the number of epochs\n", + " logger=logger,\n", + " # callbacks=[EarlyStopping(monitor=\"val_loss\", mode=\"min\")]\n", + " )\n", + "\n", + " with warnings.catch_warnings():\n", + " warnings.filterwarnings(\n", + " \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n", + " )\n", + " gnn_trainer.fit(gnn_model, ckpt_path=gnn_artifact_path)\n", + " return gnn_trainer, gnn_model\n", + "\n", + "# gnn_trainer, gnn_model = continue_gnn_training(CONFIG)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gnn_model = InteractionGNN.load_from_checkpoint(\n", + " gnn_artifact_path,\n", + " # map_location=\"cpu\",\n", + " # hparams=load_config(CONFIG)[\"gnn\"],\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot training metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# gnn_metrics = checkpoint_utils.get_training_metrics(gnn_trainer)\n", + "\n", + "gnn_metrics = checkpoint_utils.get_training_metrics(gnn_metric_path)\n", + "\n", + "gnn_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "perfplot_mpl.plot_loss(gnn_metrics, CONFIG, \"gnn\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "perfplot_mpl.plot_metric_epochs(\n", + " \"eff\",\n", + " gnn_metrics,\n", + " CONFIG,\n", + " \"gnn\",\n", + " \"Edge Efficiency\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n", + "perfplot_mpl.plot_metric_epochs(\n", + " \"pur\",\n", + " gnn_metrics,\n", + " CONFIG,\n", + " \"gnn\",\n", + " \"Edge Purity\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate model performance on sample test data\n", + "\n", + "Here we evaluate the model performace on one sample test data. We look at how the efficiency and purity change with the embedding radius." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gnn_model.load_partition(\"velo-sim10b-nospillover\")\n", + "perfplot.plot_edge_performance(gnn_model, CONFIG)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "performances_for_various_score_cuts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from GNN.gnn_plots import plot_best_performances_score_cut\n", + "_, _, performances_for_various_score_cuts = plot_best_performances_score_cut(\n", + " model=gnn_model,\n", + " path_or_config=CONFIG,\n", + " partition=\"velo-sim10b-nospillover\",\n", + " score_cuts=[0.4, 0.5, 0.6, 0.65, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.85, 0.9, 0.95],\n", + " n_events=200,\n", + " seed=0,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. GNN inference " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_gnn_inference(CONFIG, checkpoint=gnn_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Build track candidates from GNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "build_track_candidates(CONFIG)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. Evaluate track candidates on the same data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "evaluate_candidates_montetracko(\n", + " CONFIG,\n", + " partition=\"train\",\n", + " allen_report=True,\n", + " table_report=True,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "evaluate_candidates_montetracko(\n", + " CONFIG,\n", + " partition=\"val\",\n", + " allen_report=True,\n", + " table_report=True,\n", + " plot_categories=[],\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Evaluate track candidates on unseen data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for test_dataset_name in get_required_test_dataset_names(CONFIG):\n", + " logging.info(headline(test_dataset_name))\n", + " evaluate_candidates_montetracko(\n", + " CONFIG,\n", + " partition=test_dataset_name,\n", + " allen_report=True,\n", + " table_report=True,\n", + " plot_categories=[],\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trackEvaluator = evaluate_candidates_montetracko(\n", + " CONFIG,\n", + " partition=\"velo-sim10b-nospillover\",\n", + " allen_report=True,\n", + " table_report=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/LHCb_Pipeline/full_pipeline.ipynb b/LHCb_Pipeline/full_pipeline.ipynb index a6a2278b78a7a54166efd602757a81cf770a349c..994a0a5fcbb85070bd3327c13954ea2bb27f2b8a 100644 --- a/LHCb_Pipeline/full_pipeline.ipynb +++ b/LHCb_Pipeline/full_pipeline.ipynb @@ -18,64 +18,31 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "<div class=\"bk-root\">\n", - " <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n", - " <span id=\"1002\">Loading BokehJS ...</span>\n", - " </div>\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\nconst JS_MIME_TYPE = 'application/javascript';\n const HTML_MIME_TYPE = 'text/html';\n const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n const CLASS_NAME = 'output_bokeh rendered_html';\n\n /**\n * Render data to the DOM node\n */\n function render(props, node) {\n const script = document.createElement(\"script\");\n node.appendChild(script);\n }\n\n /**\n * Handle when an output is cleared or removed\n */\n function handleClearOutput(event, handle) {\n const cell = handle.cell;\n\n const id = cell.output_area._bokeh_element_id;\n const server_id = cell.output_area._bokeh_server_id;\n // Clean up Bokeh references\n if (id != null && id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n\n if (server_id !== undefined) {\n // Clean up Bokeh references\n const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n cell.notebook.kernel.execute(cmd_clean, {\n iopub: {\n output: function(msg) {\n const id = msg.content.text.trim();\n if (id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n }\n }\n });\n // Destroy server and session\n const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n cell.notebook.kernel.execute(cmd_destroy);\n }\n }\n\n /**\n * Handle when a new output is added\n */\n function handleAddOutput(event, handle) {\n const output_area = handle.output_area;\n const output = handle.output;\n\n // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n return\n }\n\n const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n\n if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n // store reference to embed id on output_area\n output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n }\n if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n const bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n const script_attrs = bk_div.children[0].attributes;\n for (let i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n }\n\n function register_renderer(events, OutputArea) {\n\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n const toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[toinsert.length - 1]);\n element.append(toinsert);\n return toinsert\n }\n\n /* Handle when an output is cleared or removed */\n events.on('clear_output.CodeCell', handleClearOutput);\n events.on('delete.Cell', handleClearOutput);\n\n /* Handle when a new output is added */\n events.on('output_added.OutputArea', handleAddOutput);\n\n /**\n * Register the mime type and append_mime function with output_area\n */\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n /* Is output safe? */\n safe: true,\n /* Index of renderer in `output_area.display_order` */\n index: 0\n });\n }\n\n // register the mime type if in Jupyter Notebook environment and previously unregistered\n if (root.Jupyter !== undefined) {\n const events = require('base/js/events');\n const OutputArea = require('notebook/js/outputarea').OutputArea;\n\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n }\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"<div style='background-color: #fdd'>\\n\"+\n \"<p>\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"</p>\\n\"+\n \"<ul>\\n\"+\n \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n \"<li>use INLINE resources instead, as so:</li>\\n\"+\n \"</ul>\\n\"+\n \"<code>\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"</code>\\n\"+\n \"</div>\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"1002\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"1002\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));", - "application/vnd.bokehjs_load.v0+json": "" - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "<div class=\"bk-root\">\n", - " <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n", - " <span id=\"1003\">Loading BokehJS ...</span>\n", - " </div>\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\nconst JS_MIME_TYPE = 'application/javascript';\n const HTML_MIME_TYPE = 'text/html';\n const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n const CLASS_NAME = 'output_bokeh rendered_html';\n\n /**\n * Render data to the DOM node\n */\n function render(props, node) {\n const script = document.createElement(\"script\");\n node.appendChild(script);\n }\n\n /**\n * Handle when an output is cleared or removed\n */\n function handleClearOutput(event, handle) {\n const cell = handle.cell;\n\n const id = cell.output_area._bokeh_element_id;\n const server_id = cell.output_area._bokeh_server_id;\n // Clean up Bokeh references\n if (id != null && id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n\n if (server_id !== undefined) {\n // Clean up Bokeh references\n const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n cell.notebook.kernel.execute(cmd_clean, {\n iopub: {\n output: function(msg) {\n const id = msg.content.text.trim();\n if (id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n }\n }\n });\n // Destroy server and session\n const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n cell.notebook.kernel.execute(cmd_destroy);\n }\n }\n\n /**\n * Handle when a new output is added\n */\n function handleAddOutput(event, handle) {\n const output_area = handle.output_area;\n const output = handle.output;\n\n // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n return\n }\n\n const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n\n if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n // store reference to embed id on output_area\n output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n }\n if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n const bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n const script_attrs = bk_div.children[0].attributes;\n for (let i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n }\n\n function register_renderer(events, OutputArea) {\n\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n const toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[toinsert.length - 1]);\n element.append(toinsert);\n return toinsert\n }\n\n /* Handle when an output is cleared or removed */\n events.on('clear_output.CodeCell', handleClearOutput);\n events.on('delete.Cell', handleClearOutput);\n\n /* Handle when a new output is added */\n events.on('output_added.OutputArea', handleAddOutput);\n\n /**\n * Register the mime type and append_mime function with output_area\n */\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n /* Is output safe? */\n safe: true,\n /* Index of renderer in `output_area.display_order` */\n index: 0\n });\n }\n\n // register the mime type if in Jupyter Notebook environment and previously unregistered\n if (root.Jupyter !== undefined) {\n const events = require('base/js/events');\n const OutputArea = require('notebook/js/outputarea').OutputArea;\n\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n }\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"<div style='background-color: #fdd'>\\n\"+\n \"<p>\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"</p>\\n\"+\n \"<ul>\\n\"+\n \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n \"<li>use INLINE resources instead, as so:</li>\\n\"+\n \"</ul>\\n\"+\n \"<code>\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"</code>\\n\"+\n \"</div>\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"1003\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"1003\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));", - "application/vnd.bokehjs_load.v0+json": "" - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from __future__ import annotations\n", "import typing\n", + "import logging\n", "import os\n", "import sys\n", "import warnings\n", "\n", "sys.path.append('../montetracko')\n", + "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n", "\n", + "import yaml\n", + "import numpy as np\n", "import torch\n", "\n", + "from Preprocessing.run_preprocessing import run_preprocessing_test_dataset\n", + "from Preprocessing.run_preprocessing import run_preprocessing\n", + "from Processing.run_processing import run_processing_from_config\n", + "\n", "from Scripts.Step_1_Train_Metric_Learning import train as train_metric_learning\n", "from Scripts.Step_2_Run_Metric_Learning import train as run_metric_learning_inference\n", "from Scripts.Step_3_Train_GNN import train as train_gnn\n", @@ -87,9 +54,15 @@ "\n", "from utils.plotutils import graph as graphplot\n", "from utils.plotutils import performance as perfplot\n", + "from utils.plotutils import performance_mpl as perfplot_mpl\n", "from utils.commonutils.ctests import get_required_test_dataset_names\n", "from utils.commonutils.config import load_config\n", "from utils.modelutils import checkpoint_utils\n", + "from utils.scriptutils.loghandler import headline\n", + "\n", + "from utils.plotutils.plotconfig import configure_matplotlib\n", + "\n", + "configure_matplotlib()\n", "\n", "warnings.filterwarnings(\n", " \"ignore\", message=(\n", @@ -99,7 +72,9 @@ " )\n", ")\n", "\n", - "CONFIG = 'pipeline_config.yaml'\n" + "CONFIG = 'pipeline_config_default.yaml'\n", + "\n", + "run_training: bool = True\n" ] }, { @@ -123,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -142,12 +117,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import requests\n", "import json\n", + "\n", "# from datetime import datetime\n", "\n", "# def send_telegram_message(message: str,\n", @@ -168,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -190,80 +166,10 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Preprocessing: output will be written in scratch/__test__/velo-minbias-sim10b-xdigi-nospillover\n", - "INFO:Load dataframe\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5b70e7143df143c4a6383dac2607528f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Preprocessing: output will be written in scratch/__test__/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons\n", - "INFO:Load dataframe\n", - "INFO:Apply selection\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0227464fa17c4891862e586ba3136027", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Preprocessing: output will be written in scratch/__test__/velo-bu2kstee-sim10aU1-xdigi\n", - "INFO:Load dataframe\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "88f6e3994b964f248154137797a26f78", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "from Preprocessing.run_preprocessing import run_preprocessing_test_dataset\n", - "\n", "for required_test_dataset_name in get_required_test_dataset_names(CONFIG):\n", " run_preprocessing_test_dataset(\n", " test_dataset_name=required_test_dataset_name,\n", @@ -283,43 +189,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from Preprocessing.run_preprocessing import run_preprocessing\n", - "from Processing.run_processing import run_processing_from_config\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Preprocessing: output will be written in scratch/velo-minbias-sim10b-xdigi-nospillover/preprocessed\n", - "INFO:Load dataframe\n", - "INFO:Apply selection\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6c548297cecc41478eacd64449cbea66", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/100 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ "run_preprocessing(CONFIG, reproduce=False)" ] @@ -334,136 +206,18 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Input directory: scratch/velo-minbias-sim10b-xdigi-nospillover/preprocessed\n", - "INFO:Writing outputs to scratch/velo-minbias-sim10b-xdigi-nospillover/processed/train\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f668e8d341214a418d0f524abec68216", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Writing outputs to scratch/velo-minbias-sim10b-xdigi-nospillover/processed/val\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "affb051c8867453ab938d065c560be13", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/10 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Splitting was saved in scratch/velo-minbias-sim10b-xdigi-nospillover/processed/splitting.json.\n" - ] - } - ], + "outputs": [], "source": [ "run_processing_from_config(CONFIG, reproduce=False)\n" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Input directory: scratch/__test__/velo-minbias-sim10b-xdigi-nospillover\n", - "INFO:Writing outputs to scratch/velo-minbias-sim10b-xdigi-nospillover/processed/test/velo-minbias-sim10b-xdigi-nospillover\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d04d349bf0374da6aa1fe890f4dcb353", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Input directory: scratch/__test__/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons\n", - "INFO:Writing outputs to scratch/velo-minbias-sim10b-xdigi-nospillover/processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "17224ac601724b7181a511ac485e902d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Input directory: scratch/__test__/velo-bu2kstee-sim10aU1-xdigi\n", - "INFO:Writing outputs to scratch/velo-minbias-sim10b-xdigi-nospillover/processed/test/velo-bu2kstee-sim10aU1-xdigi\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6a84abe37d8b4a2d8600f5083c3819a0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "for required_test_dataset_name in get_required_test_dataset_names(CONFIG):\n", " run_processing_from_config(\n", @@ -501,141 +255,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "<div>\n", - "<style scoped>\n", - " .dataframe tbody tr th:only-of-type {\n", - " vertical-align: middle;\n", - " }\n", - "\n", - " .dataframe tbody tr th {\n", - " vertical-align: top;\n", - " }\n", - "\n", - " .dataframe thead th {\n", - " text-align: right;\n", - " }\n", - "</style>\n", - "<table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: right;\">\n", - " <th></th>\n", - " <th>0</th>\n", - " <th>1</th>\n", - " <th>2</th>\n", - " <th>3</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <th>0</th>\n", - " <td>0.379526</td>\n", - " <td>-0.562705</td>\n", - " <td>-1.440705</td>\n", - " <td>0.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>1</th>\n", - " <td>0.468974</td>\n", - " <td>-0.762697</td>\n", - " <td>-1.434295</td>\n", - " <td>0.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2</th>\n", - " <td>0.809876</td>\n", - " <td>-0.545168</td>\n", - " <td>-1.434295</td>\n", - " <td>0.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>3</th>\n", - " <td>0.338967</td>\n", - " <td>-0.681292</td>\n", - " <td>-1.440705</td>\n", - " <td>0.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>4</th>\n", - " <td>0.436733</td>\n", - " <td>-0.800723</td>\n", - " <td>-1.434295</td>\n", - " <td>0.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>...</th>\n", - " <td>...</td>\n", - " <td>...</td>\n", - " <td>...</td>\n", - " <td>...</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2501</th>\n", - " <td>0.635943</td>\n", - " <td>0.085912</td>\n", - " <td>3.753205</td>\n", - " <td>0.961538</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2502</th>\n", - " <td>0.646183</td>\n", - " <td>-0.155981</td>\n", - " <td>3.746795</td>\n", - " <td>0.961538</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2503</th>\n", - " <td>0.622268</td>\n", - " <td>-0.197432</td>\n", - " <td>3.746795</td>\n", - " <td>0.961538</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2504</th>\n", - " <td>0.479986</td>\n", - " <td>-0.175648</td>\n", - " <td>3.746795</td>\n", - " <td>0.961538</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2505</th>\n", - " <td>0.273867</td>\n", - " <td>-0.018543</td>\n", - " <td>3.753205</td>\n", - " <td>0.961538</td>\n", - " </tr>\n", - " </tbody>\n", - "</table>\n", - "<p>2506 rows × 4 columns</p>\n", - "</div>" - ], - "text/plain": [ - " 0 1 2 3\n", - "0 0.379526 -0.562705 -1.440705 0.000000\n", - "1 0.468974 -0.762697 -1.434295 0.000000\n", - "2 0.809876 -0.545168 -1.434295 0.000000\n", - "3 0.338967 -0.681292 -1.440705 0.000000\n", - "4 0.436733 -0.800723 -1.434295 0.000000\n", - "... ... ... ... ...\n", - "2501 0.635943 0.085912 3.753205 0.961538\n", - "2502 0.646183 -0.155981 3.746795 0.961538\n", - "2503 0.622268 -0.197432 3.746795 0.961538\n", - "2504 0.479986 -0.175648 3.746795 0.961538\n", - "2505 0.273867 -0.018543 3.753205 0.961538\n", - "\n", - "[2506 rows x 4 columns]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from Embedding.embedding_base import get_example_data\n", "example_data_df, example_data_pyg = get_example_data(CONFIG)\n", @@ -644,30 +266,11 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Applicable driver not found; attempting to install with Selenium Manager (Beta)\n", - "WARNING:Unable to obtain driver using Selenium Manager: /home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/common/linux/selenium-manager is missing. Please open an issue on https://github.com/SeleniumHQ/selenium/issues\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<IPython.core.display.Image object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "graphplot.plot_true_graph(example_data_pyg, CONFIG, num_tracks=25)" + "graphplot.plot_true_graph(example_data_pyg, CONFIG, num_tracks=50)" ] }, { @@ -682,428 +285,38 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nvcc: NVIDIA (R) Cuda compiler driver\n", - "Copyright (c) 2005-2023 NVIDIA Corporation\n", - "Built on Mon_Apr__3_17:16:06_PDT_2023\n", - "Cuda compilation tools, release 12.1, V12.1.105\n", - "Build cuda_12.1.r12.1/compiler.32688072_0\n" - ] - } - ], + "outputs": [], "source": [ "! nvcc --version" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Wed May 24 17:35:26 2023 \n", - "+-----------------------------------------------------------------------------+\n", - "| NVIDIA-SMI 525.60.13 Driver Version: 525.60.13 CUDA Version: 12.0 |\n", - "|-------------------------------+----------------------+----------------------+\n", - "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", - "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", - "| | | MIG M. |\n", - "|===============================+======================+======================|\n", - "| 0 NVIDIA TITAN RTX On | 00000000:5E:00.0 Off | N/A |\n", - "| 40% 28C P8 10W / 280W | 701MiB / 24576MiB | 0% Default |\n", - "| | | N/A |\n", - "+-------------------------------+----------------------+----------------------+\n", - " \n", - "+-----------------------------------------------------------------------------+\n", - "| Processes: |\n", - "| GPU GI CI PID Type Process name GPU Memory |\n", - "| ID ID Usage |\n", - "|=============================================================================|\n", - "+-----------------------------------------------------------------------------+\n" - ] - } - ], + "outputs": [], "source": [ "! nvidia-smi" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:-------------------- Step 1: Running metric learning training --------------------\n", - "INFO:----------------------------- a) Initialising model -----------------------------\n", - "INFO:------------------------------ b) Running training ------------------------------\n", - "INFO:Save hyperparameters, metrics and artifacts in /home/fgias/etx4velo-3/LHCb_Pipeline/artifacts/metric_learning/velo-minbias-sim10b-xdigi-nospillover/version_2\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/fgias/.conda/envs/etx4velo/lib/python3.10/site ...\n", - " rank_zero_warn(\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/fgias/.conda/envs/etx4velo/lib/python3.10/site ...\n", - " rank_zero_warn(\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "\n", - " | Name | Type | Params\n", - "---------------------------------------\n", - "0 | network | Sequential | 201 K \n", - "---------------------------------------\n", - "201 K Trainable params\n", - "0 Non-trainable params\n", - "201 K Total params\n", - "0.807 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "bd89a1d626f14d5ebbfab5c96e546b49", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`Trainer.fit` stopped: `max_epochs=20` reached.\n", - "INFO:-------------------------------- c) Saving model --------------------------------\n" - ] - } - ], + "outputs": [], "source": [ - "send_telegram_message('Started metric learning training.', chat_id, api_key)\n", + "if run_training:\n", + " send_telegram_message('Started metric learning training.', chat_id, api_key)\n", "\n", - "metric_learning_trainer, metric_learning_model = train_metric_learning(CONFIG)\n", + " metric_learning_trainer, metric_learning_model = train_metric_learning(CONFIG)\n", "\n", - "send_telegram_message('Finished metric learning training.', chat_id, api_key)" + " send_telegram_message('Finished metric learning training.', chat_id, api_key)" ] }, { @@ -1115,40 +328,6 @@ "use the code below." ] }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "# from Embedding.models.layerless_embedding import LayerlessEmbedding\n", - "# from pytorch_lightning import Trainer\n", - "# from pytorch_lightning.loggers import CSVLogger\n", - "\n", - "# version_number = 6\n", - "\n", - "# HPARAMS_PATH = f'/home/fgias/velo-gnn/LHCb_Pipeline/artifacts/metric_learning/velo-minbias-sim10b-xdigi/version_1/hparams.yaml'\n", - "# CKPT_PATH = f'/home/fgias/velo-gnn/LHCb_Pipeline/artifacts/metric_learning/velo-minbias-sim10b-xdigi/version_1/checkpoints/epoch=19-step=1600.ckpt'\n", - "\n", - "# load_configs = {}\n", - "# with open(HPARAMS_PATH, 'r') as f:\n", - "# load_configs = yaml.load(f, Loader=yaml.FullLoader)\n", - "\n", - "# metric_learning_model = LayerlessEmbedding(load_configs)\n", - "\n", - "# logger = CSVLogger('artifacts', name='metric_learning/velo_data')\n", - "\n", - "# metric_learning_trainer = Trainer(\n", - "# accelerator='gpu' if torch.cuda.is_available() else 'cpu',\n", - "# devices=1,\n", - "# max_epochs=40,\n", - "# logger=logger,\n", - "# # callbacks=[EarlyStopping(monitor=\"val_loss\", mode=\"min\")]\n", - "# )\n", - "\n", - "# metric_learning_trainer.fit(metric_learning_model, ckpt_path=CKPT_PATH)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -1159,28 +338,21 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "embedding_metric_path='artifacts/metric_learning/velo-minbias-sim10b-xdigi-nospillover/version_2/metrics.csv'\n", - "embedding_artifact_path='artifacts/metric_learning/velo-minbias-sim10b-xdigi-nospillover/version_2/checkpoints/epoch=19-step=1800.ckpt'\n" - ] - } - ], + "outputs": [], "source": [ "from Embedding.models.layerless_embedding import LayerlessEmbedding\n", "\n", - "version_dir = checkpoint_utils.get_last_version_dir_from_config(\n", + "embedding_version_dir = checkpoint_utils.get_last_version_dir_from_config(\n", " step=\"metric_learning\", path_or_config=CONFIG\n", ")\n", - "embedding_metric_path = os.path.join(version_dir, \"metrics.csv\")\n", - "embedding_artifact_path = checkpoint_utils.get_last_artifact(version_dir=version_dir)\n", + "embedding_metric_path = os.path.join(embedding_version_dir, \"metrics.csv\")\n", + "embedding_artifact_path = checkpoint_utils.get_last_artifact(\n", + " version_dir=embedding_version_dir\n", + ")\n", "print(f\"{embedding_metric_path=}\")\n", - "print(f\"{embedding_artifact_path=}\")" + "print(f\"{embedding_artifact_path=}\")\n" ] }, { @@ -1193,7 +365,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1221,10 +393,12 @@ " devices=1,\n", " max_epochs=40, # you may increase the number of epochs\n", " logger=logger,\n", - " # callbacks=[EarlyStopping(monitor=\"val_loss\", mode=\"min\")]\n", + " # callbacks=[EarlyStopping(monitor=\"train_loss\", mode=\"min\")]\n", " )\n", "\n", - " metric_learning_trainer.fit(metric_learning_model)\n", + " metric_learning_trainer.fit(\n", + " metric_learning_model, ckpt_path=embedding_artifact_path\n", + " )\n", " return metric_learning_trainer, metric_learning_model\n", "\n", "\n", @@ -1233,11 +407,16 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "metric_learning_model = LayerlessEmbedding.load_from_checkpoint(embedding_artifact_path)\n" + "metric_learning_model = LayerlessEmbedding.load_from_checkpoint(\n", + " embedding_artifact_path,\n", + " # map_location=\"cpu\",\n", + " # If importing model from another experiment\n", + " # hparams=load_config(CONFIG)[\"metric_learning\"],\n", + ")\n" ] }, { @@ -1258,242 +437,9 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "<div>\n", - "<style scoped>\n", - " .dataframe tbody tr th:only-of-type {\n", - " vertical-align: middle;\n", - " }\n", - "\n", - " .dataframe tbody tr th {\n", - " vertical-align: top;\n", - " }\n", - "\n", - " .dataframe thead th {\n", - " text-align: right;\n", - " }\n", - "</style>\n", - "<table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: right;\">\n", - " <th></th>\n", - " <th>epoch</th>\n", - " <th>train_loss</th>\n", - " <th>val_loss</th>\n", - " <th>eff</th>\n", - " <th>pur</th>\n", - " <th>current_lr</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <th>0</th>\n", - " <td>0</td>\n", - " <td>0.009331</td>\n", - " <td>0.008854</td>\n", - " <td>0.981461</td>\n", - " <td>0.052457</td>\n", - " <td>0.000125</td>\n", - " </tr>\n", - " <tr>\n", - " <th>1</th>\n", - " <td>1</td>\n", - " <td>0.008846</td>\n", - " <td>0.008726</td>\n", - " <td>0.977990</td>\n", - " <td>0.065469</td>\n", - " <td>0.000250</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2</th>\n", - " <td>2</td>\n", - " <td>0.008646</td>\n", - " <td>0.008260</td>\n", - " <td>0.975066</td>\n", - " <td>0.115246</td>\n", - " <td>0.000375</td>\n", - " </tr>\n", - " <tr>\n", - " <th>3</th>\n", - " <td>3</td>\n", - " <td>0.008594</td>\n", - " <td>0.008224</td>\n", - " <td>0.977431</td>\n", - " <td>0.119500</td>\n", - " <td>0.000350</td>\n", - " </tr>\n", - " <tr>\n", - " <th>4</th>\n", - " <td>4</td>\n", - " <td>0.008593</td>\n", - " <td>0.008143</td>\n", - " <td>0.975767</td>\n", - " <td>0.128147</td>\n", - " <td>0.000625</td>\n", - " </tr>\n", - " <tr>\n", - " <th>5</th>\n", - " <td>5</td>\n", - " <td>0.008575</td>\n", - " <td>0.008300</td>\n", - " <td>0.981175</td>\n", - " <td>0.111195</td>\n", - " <td>0.000750</td>\n", - " </tr>\n", - " <tr>\n", - " <th>6</th>\n", - " <td>6</td>\n", - " <td>0.008562</td>\n", - " <td>0.008318</td>\n", - " <td>0.981980</td>\n", - " <td>0.109136</td>\n", - " <td>0.000875</td>\n", - " </tr>\n", - " <tr>\n", - " <th>7</th>\n", - " <td>7</td>\n", - " <td>0.008502</td>\n", - " <td>0.008320</td>\n", - " <td>0.982643</td>\n", - " <td>0.109381</td>\n", - " <td>0.000700</td>\n", - " </tr>\n", - " <tr>\n", - " <th>8</th>\n", - " <td>8</td>\n", - " <td>0.008484</td>\n", - " <td>0.008126</td>\n", - " <td>0.979342</td>\n", - " <td>0.129930</td>\n", - " <td>0.000700</td>\n", - " </tr>\n", - " <tr>\n", - " <th>9</th>\n", - " <td>9</td>\n", - " <td>0.008474</td>\n", - " <td>0.008101</td>\n", - " <td>0.977593</td>\n", - " <td>0.132546</td>\n", - " <td>0.000700</td>\n", - " </tr>\n", - " <tr>\n", - " <th>10</th>\n", - " <td>10</td>\n", - " <td>0.008469</td>\n", - " <td>0.008078</td>\n", - " <td>0.977144</td>\n", - " <td>0.135049</td>\n", - " <td>0.000700</td>\n", - " </tr>\n", - " <tr>\n", - " <th>11</th>\n", - " <td>11</td>\n", - " <td>0.008436</td>\n", - " <td>0.008046</td>\n", - " <td>0.976337</td>\n", - " <td>0.138554</td>\n", - " <td>0.000490</td>\n", - " </tr>\n", - " <tr>\n", - " <th>12</th>\n", - " <td>12</td>\n", - " <td>0.008428</td>\n", - " <td>0.007944</td>\n", - " <td>0.974389</td>\n", - " <td>0.149568</td>\n", - " <td>0.000490</td>\n", - " </tr>\n", - " <tr>\n", - " <th>13</th>\n", - " <td>13</td>\n", - " <td>0.008424</td>\n", - " <td>0.007950</td>\n", - " <td>0.974593</td>\n", - " <td>0.148910</td>\n", - " <td>0.000490</td>\n", - " </tr>\n", - " <tr>\n", - " <th>14</th>\n", - " <td>14</td>\n", - " <td>0.008416</td>\n", - " <td>0.007949</td>\n", - " <td>0.974523</td>\n", - " <td>0.149105</td>\n", - " <td>0.000490</td>\n", - " </tr>\n", - " <tr>\n", - " <th>15</th>\n", - " <td>15</td>\n", - " <td>0.008389</td>\n", - " <td>0.007967</td>\n", - " <td>0.974943</td>\n", - " <td>0.147168</td>\n", - " <td>0.000343</td>\n", - " </tr>\n", - " <tr>\n", - " <th>16</th>\n", - " <td>16</td>\n", - " <td>0.008380</td>\n", - " <td>0.007942</td>\n", - " <td>0.974692</td>\n", - " <td>0.149762</td>\n", - " <td>0.000343</td>\n", - " </tr>\n", - " <tr>\n", - " <th>17</th>\n", - " <td>17</td>\n", - " <td>0.008376</td>\n", - " <td>0.007953</td>\n", - " <td>0.974729</td>\n", - " <td>0.148579</td>\n", - " <td>0.000343</td>\n", - " </tr>\n", - " <tr>\n", - " <th>18</th>\n", - " <td>18</td>\n", - " <td>0.008373</td>\n", - " <td>0.007932</td>\n", - " <td>0.974497</td>\n", - " <td>0.150832</td>\n", - " <td>0.000343</td>\n", - " </tr>\n", - " </tbody>\n", - "</table>\n", - "</div>" - ], - "text/plain": [ - " epoch train_loss val_loss eff pur current_lr\n", - "0 0 0.009331 0.008854 0.981461 0.052457 0.000125\n", - "1 1 0.008846 0.008726 0.977990 0.065469 0.000250\n", - "2 2 0.008646 0.008260 0.975066 0.115246 0.000375\n", - "3 3 0.008594 0.008224 0.977431 0.119500 0.000350\n", - "4 4 0.008593 0.008143 0.975767 0.128147 0.000625\n", - "5 5 0.008575 0.008300 0.981175 0.111195 0.000750\n", - "6 6 0.008562 0.008318 0.981980 0.109136 0.000875\n", - "7 7 0.008502 0.008320 0.982643 0.109381 0.000700\n", - "8 8 0.008484 0.008126 0.979342 0.129930 0.000700\n", - "9 9 0.008474 0.008101 0.977593 0.132546 0.000700\n", - "10 10 0.008469 0.008078 0.977144 0.135049 0.000700\n", - "11 11 0.008436 0.008046 0.976337 0.138554 0.000490\n", - "12 12 0.008428 0.007944 0.974389 0.149568 0.000490\n", - "13 13 0.008424 0.007950 0.974593 0.148910 0.000490\n", - "14 14 0.008416 0.007949 0.974523 0.149105 0.000490\n", - "15 15 0.008389 0.007967 0.974943 0.147168 0.000343\n", - "16 16 0.008380 0.007942 0.974692 0.149762 0.000343\n", - "17 17 0.008376 0.007953 0.974729 0.148579 0.000343\n", - "18 18 0.008373 0.007932 0.974497 0.150832 0.000343" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# embedding_metrics = checkpoint_utils.get_training_metrics(metric_learning_trainer) \n", "\n", @@ -1504,65 +450,66 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "<IPython.core.display.Image object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "perfplot.plot_training_metrics(embedding_metrics, CONFIG, \"metric_learning\")" + "perfplot_mpl.plot_loss(embedding_metrics, CONFIG, \"metric_learning\")\n" ] }, { - "attachments": {}, - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## Evaluate model performance on sample test data\n", - "\n", - "Here we evaluate the model performance on one sample test data. We look at how the efficiency and purity change with the embedding radius." + "perfplot_mpl.plot_metric_epochs(\n", + " \"eff\",\n", + " embedding_metrics,\n", + " CONFIG,\n", + " \"metric_learning\",\n", + " \"Edge Efficiency\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n", + "perfplot_mpl.plot_metric_epochs(\n", + " \"pur\",\n", + " embedding_metrics,\n", + " CONFIG,\n", + " \"metric_learning\",\n", + " \"Edge Purity\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n" ] }, { - "cell_type": "code", - "execution_count": 22, + "attachments": {}, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "metric_learning_model.load_testset(\"velo-minbias-sim10b-xdigi-nospillover\")\n" + "## Evaluate model performance on sample test data\n", + "\n", + "Here we evaluate the model performance on one sample test data. We look at how the efficiency and purity change with the embedding radius." ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "<IPython.core.display.Image object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "perfplot.plot_neighbor_performance(metric_learning_model, CONFIG)" + "from Embedding import embedding_plots\n", + "\n", + "embedding_plots.plot_embedding_performance_given_radius_knn_max(\n", + " model=metric_learning_model,\n", + " path_or_config=CONFIG,\n", + " radius=np.linspace(0.01, 0.05, 10),\n", + " n_events=20,\n", + " partitions=[\"train\", \"val\", \"velo-sim10b-nospillover\"],\n", + ");\n" ] }, { @@ -1575,26 +522,34 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "<IPython.core.display.Image object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ + "metric_learning_model.load_partition(\"velo-sim10b-nospillover\")\n", "graphplot.plot_predicted_graph(metric_learning_model, CONFIG)\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from Embedding.embedding_plots import plot_best_performances_radius\n", + "plot_best_performances_radius(\n", + " model=metric_learning_model,\n", + " path_or_config=CONFIG,\n", + " partition=\"velo-sim10b-nospillover\",\n", + " # list_radius=[0.015, 0.020],\n", + " list_radius=[0.015, 0.020, 0.025, 0.030, 0.035, 0.040],\n", + " n_events=20, # for a real training, use more events to have enough stats.\n", + " seed=0,\n", + ");\n" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -1605,56 +560,22 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "<IPython.core.display.Image object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ + "metric_learning_model.load_partition(\"velo-sim10b-nospillover\")\n", "perfplot.plot_track_lengths(metric_learning_model, CONFIG)" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cba0eb252c7a45019f1dbf261dd43c79", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 1000x500 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "metric_learning_model.setup(stage=\"fit\")\n", + "metric_learning_model.setup(stage=\"fit\") # load train and val datasets\n", + "\n", "perfplot.plot_graph_sizes(metric_learning_model)" ] }, @@ -1670,126 +591,16 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:------------- Step 2: Constructing graphs from metric learning model -------------\n", - "INFO:---------------------------- a) Loading trained model ----------------------------\n", - "INFO:Load model from artifacts/metric_learning/velo-minbias-sim10b-xdigi-nospillover/version_2/checkpoints/epoch=19-step=1800.ckpt.\n", - "INFO:----------------------------- b) Running inferencing -----------------------------\n", - "INFO:Remove directory `scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/train`.\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/processed/train to scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/train\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f3c42a35a4c9483db8f45981264e77dc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Remove directory `scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/val`.\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/processed/val to scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/val\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f3dfa17371d84d53ae2a515d99e537cb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/10 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Remove directory `scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-minbias-sim10b-xdigi-nospillover`.\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/processed/test/velo-minbias-sim10b-xdigi-nospillover to scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-minbias-sim10b-xdigi-nospillover\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e39bfa75bb8a41f5b715d139b1a01513", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Remove directory `scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons`.\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons to scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "10a125cd19d6455d803490e7f61c78a4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Remove directory `scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-bu2kstee-sim10aU1-xdigi`.\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/processed/test/velo-bu2kstee-sim10aU1-xdigi to scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-bu2kstee-sim10aU1-xdigi\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9a9447471bbb4f0aa45bb179a99f57c9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "graph_builder = run_metric_learning_inference(CONFIG, checkpoint=embedding_artifact_path)\n" + "graph_builder = run_metric_learning_inference(\n", + " CONFIG,\n", + " checkpoint=embedding_artifact_path,\n", + " # checkpoint=metric_learning_model, # here directly use the model\n", + " reproduce=False,\n", + ")\n" ] }, { @@ -1810,976 +621,19 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "import warnings\n" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:------------------------- Step 3: Running GNN training -------------------------\n", - "INFO:----------------------------- a) Initialising model -----------------------------\n", - "INFO:------------------------------ b) Running training ------------------------------\n", - "INFO:Save hyperparameters, metrics and artifacts in /home/fgias/etx4velo-3/LHCb_Pipeline/artifacts/gnn/velo-minbias-sim10b-xdigi-nospillover/version_1\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/fgias/.conda/envs/etx4velo/lib/python3.10/site ...\n", - " rank_zero_warn(\n", - "GPU available: True (cuda), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "\n", - " | Name | Type | Params\n", - "------------------------------------------------------\n", - "0 | node_encoder | Sequential | 529 K \n", - "1 | edge_encoder | Sequential | 1.8 M \n", - "2 | edge_network | Sequential | 2.1 M \n", - "3 | node_network | Sequential | 1.3 M \n", - "4 | output_edge_classifier | Sequential | 2.1 M \n", - "------------------------------------------------------\n", - "7.9 M Trainable params\n", - "0 Non-trainable params\n", - "7.9 M Total params\n", - "31.599 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c4d734ebcf1c43d58e18534c79a2fc71", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2a5af4451be94821a2149b2e4d300835", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fe0dc0d1dc784314bfa50eef598c1b73", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3bb63a6a19b9438085104dba1748f7ce", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "02162a7fb5394ad78b45ebbbf57f96b2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "591536b8afb74244887ff8d1f6b83286", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "eb84f26e12fe4e50841e6ac537767c24", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cf125bceb74a40c3916b8cf8bb27540e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "258c67fb11ff45f2acf7989a29b001d0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b890510270fc43269fb691d89eb98a49", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "97550b9b8cec43209d4725c853955138", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7a77d2f297034dcaa44a5d8225ff667b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c3cf2e99087046d8823f2a7691d5e29a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "17a86e6140db45aabbf45954c5903c1d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ac6939837fbb47df9f6e3e85276b37cc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "15a058c04c924b2488b900041948374f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0012e38be2274677ad54f6754e56f5da", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/data/dataloader.py:560: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 14, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a5dafb90431f45d29ad7ab58257f3eca", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`Trainer.fit` stopped: `max_epochs=40` reached.\n", - "INFO:-------------------------------- c) Saving model --------------------------------\n" - ] - } - ], - "source": [ - "send_telegram_message('Started GNN training.', chat_id, api_key)\n", - "with warnings.catch_warnings():\n", - " warnings.filterwarnings(\n", - " \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n", - " )\n", - " gnn_trainer, gnn_model = train_gnn(CONFIG)\n", + "if run_training:\n", + " send_telegram_message('Started GNN training.', chat_id, api_key)\n", + " with warnings.catch_warnings():\n", + " warnings.filterwarnings(\n", + " \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n", + " )\n", + " gnn_trainer, gnn_model = train_gnn(CONFIG)\n", "\n", - "send_telegram_message('Finished GNN training.', chat_id, api_key)" + " send_telegram_message('Finished GNN training.', chat_id, api_key)" ] }, { @@ -2792,18 +646,9 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "gnn_metric_path='artifacts/gnn/velo-minbias-sim10b-xdigi-nospillover/version_1/metrics.csv'\n", - "gnn_artifact_path='artifacts/gnn/velo-minbias-sim10b-xdigi-nospillover/version_1/checkpoints/epoch=39-step=3600.ckpt'\n" - ] - } - ], + "outputs": [], "source": [ "from utils.modelutils.checkpoint_utils import (\n", " get_last_version_dir_from_config,\n", @@ -2820,7 +665,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2834,8 +679,8 @@ " config = load_config(path_or_config=path_or_config)\n", "\n", " gnn_model = InteractionGNN.load_from_checkpoint(\n", - " embedding_artifact_path\n", - " ) # you may change `metric_learning_model`\n", + " gnn_artifact_path, **config[\"gnn\"]\n", + " ) # you may change `gnn_model`\n", "\n", " save_directory = os.path.abspath(\n", " os.path.join(config[\"common\"][\"artifact_directory\"], \"gnn\")\n", @@ -2846,25 +691,32 @@ " gnn_trainer = Trainer(\n", " accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n", " devices=1,\n", - " max_epochs=40, # you may increase the number of epochs\n", + " max_epochs=50, # you may increase the number of epochs\n", " logger=logger,\n", " # callbacks=[EarlyStopping(monitor=\"val_loss\", mode=\"min\")]\n", " )\n", "\n", - " metric_learning_trainer.fit(metric_learning_model)\n", + " with warnings.catch_warnings():\n", + " warnings.filterwarnings(\n", + " \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n", + " )\n", + " gnn_trainer.fit(gnn_model, ckpt_path=gnn_artifact_path)\n", " return gnn_trainer, gnn_model\n", "\n", - "\n", "# gnn_trainer, gnn_model = continue_gnn_training(CONFIG)\n" ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "gnn_model = InteractionGNN.load_from_checkpoint(gnn_artifact_path)\n" + "gnn_model = InteractionGNN.load_from_checkpoint(\n", + " gnn_artifact_path,\n", + " # map_location=\"cpu\",\n", + " # hparams=load_config(CONFIG)[\"gnn\"],\n", + ")\n" ] }, { @@ -2877,442 +729,9 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "<div>\n", - "<style scoped>\n", - " .dataframe tbody tr th:only-of-type {\n", - " vertical-align: middle;\n", - " }\n", - "\n", - " .dataframe tbody tr th {\n", - " vertical-align: top;\n", - " }\n", - "\n", - " .dataframe thead th {\n", - " text-align: right;\n", - " }\n", - "</style>\n", - "<table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: right;\">\n", - " <th></th>\n", - " <th>epoch</th>\n", - " <th>train_loss</th>\n", - " <th>val_loss</th>\n", - " <th>eff</th>\n", - " <th>pur</th>\n", - " <th>current_lr</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <th>0</th>\n", - " <td>0</td>\n", - " <td>0.847782</td>\n", - " <td>0.835718</td>\n", - " <td>0.606771</td>\n", - " <td>0.561668</td>\n", - " <td>0.000200</td>\n", - " </tr>\n", - " <tr>\n", - " <th>1</th>\n", - " <td>1</td>\n", - " <td>0.840699</td>\n", - " <td>0.828294</td>\n", - " <td>0.650663</td>\n", - " <td>0.551916</td>\n", - " <td>0.000400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>2</th>\n", - " <td>2</td>\n", - " <td>0.814449</td>\n", - " <td>0.819903</td>\n", - " <td>0.697715</td>\n", - " <td>0.546762</td>\n", - " <td>0.000600</td>\n", - " </tr>\n", - " <tr>\n", - " <th>3</th>\n", - " <td>3</td>\n", - " <td>0.808856</td>\n", - " <td>0.783907</td>\n", - " <td>0.828007</td>\n", - " <td>0.512144</td>\n", - " <td>0.000800</td>\n", - " </tr>\n", - " <tr>\n", - " <th>4</th>\n", - " <td>4</td>\n", - " <td>0.757131</td>\n", - " <td>0.795749</td>\n", - " <td>0.616819</td>\n", - " <td>0.609769</td>\n", - " <td>0.001000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>5</th>\n", - " <td>5</td>\n", - " <td>0.841196</td>\n", - " <td>0.719712</td>\n", - " <td>0.762973</td>\n", - " <td>0.614555</td>\n", - " <td>0.001200</td>\n", - " </tr>\n", - " <tr>\n", - " <th>6</th>\n", - " <td>6</td>\n", - " <td>0.739529</td>\n", - " <td>0.779560</td>\n", - " <td>0.702113</td>\n", - " <td>0.575488</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>7</th>\n", - " <td>7</td>\n", - " <td>0.648295</td>\n", - " <td>0.672503</td>\n", - " <td>0.782566</td>\n", - " <td>0.646786</td>\n", - " <td>0.001120</td>\n", - " </tr>\n", - " <tr>\n", - " <th>8</th>\n", - " <td>8</td>\n", - " <td>0.540547</td>\n", - " <td>0.599143</td>\n", - " <td>0.860576</td>\n", - " <td>0.645966</td>\n", - " <td>0.001800</td>\n", - " </tr>\n", - " <tr>\n", - " <th>9</th>\n", - " <td>9</td>\n", - " <td>0.462206</td>\n", - " <td>0.495675</td>\n", - " <td>0.819526</td>\n", - " <td>0.772188</td>\n", - " <td>0.002000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>10</th>\n", - " <td>10</td>\n", - " <td>0.390897</td>\n", - " <td>0.376769</td>\n", - " <td>0.884214</td>\n", - " <td>0.800133</td>\n", - " <td>0.002000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>11</th>\n", - " <td>11</td>\n", - " <td>0.354644</td>\n", - " <td>0.322055</td>\n", - " <td>0.895604</td>\n", - " <td>0.844770</td>\n", - " <td>0.002000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>12</th>\n", - " <td>12</td>\n", - " <td>0.303603</td>\n", - " <td>0.290172</td>\n", - " <td>0.932382</td>\n", - " <td>0.824411</td>\n", - " <td>0.002000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>13</th>\n", - " <td>13</td>\n", - " <td>0.270316</td>\n", - " <td>0.334157</td>\n", - " <td>0.889905</td>\n", - " <td>0.826773</td>\n", - " <td>0.002000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>14</th>\n", - " <td>14</td>\n", - " <td>0.268074</td>\n", - " <td>0.229994</td>\n", - " <td>0.939393</td>\n", - " <td>0.867535</td>\n", - " <td>0.002000</td>\n", - " </tr>\n", - " <tr>\n", - " <th>15</th>\n", - " <td>15</td>\n", - " <td>0.208173</td>\n", - " <td>0.228011</td>\n", - " <td>0.928939</td>\n", - " <td>0.887602</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>16</th>\n", - " <td>16</td>\n", - " <td>0.198963</td>\n", - " <td>0.169362</td>\n", - " <td>0.953345</td>\n", - " <td>0.908930</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>17</th>\n", - " <td>17</td>\n", - " <td>0.179942</td>\n", - " <td>0.162437</td>\n", - " <td>0.965979</td>\n", - " <td>0.893533</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>18</th>\n", - " <td>18</td>\n", - " <td>0.168318</td>\n", - " <td>0.167323</td>\n", - " <td>0.960488</td>\n", - " <td>0.896721</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>19</th>\n", - " <td>19</td>\n", - " <td>0.163581</td>\n", - " <td>0.147069</td>\n", - " <td>0.966788</td>\n", - " <td>0.906295</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>20</th>\n", - " <td>20</td>\n", - " <td>0.160301</td>\n", - " <td>0.139857</td>\n", - " <td>0.964095</td>\n", - " <td>0.916123</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>21</th>\n", - " <td>21</td>\n", - " <td>0.148940</td>\n", - " <td>0.138042</td>\n", - " <td>0.968433</td>\n", - " <td>0.911560</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>22</th>\n", - " <td>22</td>\n", - " <td>0.152121</td>\n", - " <td>0.137196</td>\n", - " <td>0.975356</td>\n", - " <td>0.901749</td>\n", - " <td>0.001400</td>\n", - " </tr>\n", - " <tr>\n", - " <th>23</th>\n", - " <td>23</td>\n", - " <td>0.132580</td>\n", - " <td>0.145505</td>\n", - " <td>0.961802</td>\n", - " <td>0.914907</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>24</th>\n", - " <td>24</td>\n", - " <td>0.119087</td>\n", - " <td>0.115247</td>\n", - " <td>0.971732</td>\n", - " <td>0.929652</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>25</th>\n", - " <td>25</td>\n", - " <td>0.115420</td>\n", - " <td>0.114022</td>\n", - " <td>0.970458</td>\n", - " <td>0.934912</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>26</th>\n", - " <td>26</td>\n", - " <td>0.109508</td>\n", - " <td>0.113390</td>\n", - " <td>0.971275</td>\n", - " <td>0.934489</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>27</th>\n", - " <td>27</td>\n", - " <td>0.109359</td>\n", - " <td>0.112889</td>\n", - " <td>0.973337</td>\n", - " <td>0.935166</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>28</th>\n", - " <td>28</td>\n", - " <td>0.105726</td>\n", - " <td>0.119013</td>\n", - " <td>0.971684</td>\n", - " <td>0.935247</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>29</th>\n", - " <td>29</td>\n", - " <td>0.108074</td>\n", - " <td>0.109331</td>\n", - " <td>0.974781</td>\n", - " <td>0.930926</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>30</th>\n", - " <td>30</td>\n", - " <td>0.107911</td>\n", - " <td>0.112318</td>\n", - " <td>0.970788</td>\n", - " <td>0.934250</td>\n", - " <td>0.000980</td>\n", - " </tr>\n", - " <tr>\n", - " <th>31</th>\n", - " <td>31</td>\n", - " <td>0.092714</td>\n", - " <td>0.107155</td>\n", - " <td>0.973383</td>\n", - " <td>0.936992</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " <tr>\n", - " <th>32</th>\n", - " <td>32</td>\n", - " <td>0.086805</td>\n", - " <td>0.097335</td>\n", - " <td>0.977432</td>\n", - " <td>0.938593</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " <tr>\n", - " <th>33</th>\n", - " <td>33</td>\n", - " <td>0.088135</td>\n", - " <td>0.088208</td>\n", - " <td>0.978923</td>\n", - " <td>0.946520</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " <tr>\n", - " <th>34</th>\n", - " <td>34</td>\n", - " <td>0.086349</td>\n", - " <td>0.092444</td>\n", - " <td>0.976368</td>\n", - " <td>0.947216</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " <tr>\n", - " <th>35</th>\n", - " <td>35</td>\n", - " <td>0.084271</td>\n", - " <td>0.092907</td>\n", - " <td>0.976333</td>\n", - " <td>0.947210</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " <tr>\n", - " <th>36</th>\n", - " <td>36</td>\n", - " <td>0.087265</td>\n", - " <td>0.093572</td>\n", - " <td>0.976155</td>\n", - " <td>0.945150</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " <tr>\n", - " <th>37</th>\n", - " <td>37</td>\n", - " <td>0.084715</td>\n", - " <td>0.098950</td>\n", - " <td>0.973790</td>\n", - " <td>0.948907</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " <tr>\n", - " <th>38</th>\n", - " <td>38</td>\n", - " <td>0.078550</td>\n", - " <td>0.094507</td>\n", - " <td>0.974446</td>\n", - " <td>0.949921</td>\n", - " <td>0.000686</td>\n", - " </tr>\n", - " </tbody>\n", - "</table>\n", - "</div>" - ], - "text/plain": [ - " epoch train_loss val_loss eff pur current_lr\n", - "0 0 0.847782 0.835718 0.606771 0.561668 0.000200\n", - "1 1 0.840699 0.828294 0.650663 0.551916 0.000400\n", - "2 2 0.814449 0.819903 0.697715 0.546762 0.000600\n", - "3 3 0.808856 0.783907 0.828007 0.512144 0.000800\n", - "4 4 0.757131 0.795749 0.616819 0.609769 0.001000\n", - "5 5 0.841196 0.719712 0.762973 0.614555 0.001200\n", - "6 6 0.739529 0.779560 0.702113 0.575488 0.001400\n", - "7 7 0.648295 0.672503 0.782566 0.646786 0.001120\n", - "8 8 0.540547 0.599143 0.860576 0.645966 0.001800\n", - "9 9 0.462206 0.495675 0.819526 0.772188 0.002000\n", - "10 10 0.390897 0.376769 0.884214 0.800133 0.002000\n", - "11 11 0.354644 0.322055 0.895604 0.844770 0.002000\n", - "12 12 0.303603 0.290172 0.932382 0.824411 0.002000\n", - "13 13 0.270316 0.334157 0.889905 0.826773 0.002000\n", - "14 14 0.268074 0.229994 0.939393 0.867535 0.002000\n", - "15 15 0.208173 0.228011 0.928939 0.887602 0.001400\n", - "16 16 0.198963 0.169362 0.953345 0.908930 0.001400\n", - "17 17 0.179942 0.162437 0.965979 0.893533 0.001400\n", - "18 18 0.168318 0.167323 0.960488 0.896721 0.001400\n", - "19 19 0.163581 0.147069 0.966788 0.906295 0.001400\n", - "20 20 0.160301 0.139857 0.964095 0.916123 0.001400\n", - "21 21 0.148940 0.138042 0.968433 0.911560 0.001400\n", - "22 22 0.152121 0.137196 0.975356 0.901749 0.001400\n", - "23 23 0.132580 0.145505 0.961802 0.914907 0.000980\n", - "24 24 0.119087 0.115247 0.971732 0.929652 0.000980\n", - "25 25 0.115420 0.114022 0.970458 0.934912 0.000980\n", - "26 26 0.109508 0.113390 0.971275 0.934489 0.000980\n", - "27 27 0.109359 0.112889 0.973337 0.935166 0.000980\n", - "28 28 0.105726 0.119013 0.971684 0.935247 0.000980\n", - "29 29 0.108074 0.109331 0.974781 0.930926 0.000980\n", - "30 30 0.107911 0.112318 0.970788 0.934250 0.000980\n", - "31 31 0.092714 0.107155 0.973383 0.936992 0.000686\n", - "32 32 0.086805 0.097335 0.977432 0.938593 0.000686\n", - "33 33 0.088135 0.088208 0.978923 0.946520 0.000686\n", - "34 34 0.086349 0.092444 0.976368 0.947216 0.000686\n", - "35 35 0.084271 0.092907 0.976333 0.947210 0.000686\n", - "36 36 0.087265 0.093572 0.976155 0.945150 0.000686\n", - "37 37 0.084715 0.098950 0.973790 0.948907 0.000686\n", - "38 38 0.078550 0.094507 0.974446 0.949921 0.000686" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# gnn_metrics = checkpoint_utils.get_training_metrics(gnn_trainer)\n", "\n", @@ -3323,66 +742,35 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "perfplot_mpl.plot_loss(gnn_metrics, CONFIG, \"gnn\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:Retrying (Retry(total=2, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f183340de70>: Failed to establish a new connection: [Errno 111] Connection refused')': /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize\n", - "WARNING:Retrying (Retry(total=1, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f1833358190>: Failed to establish a new connection: [Errno 111] Connection refused')': /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize\n", - "WARNING:Retrying (Retry(total=0, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f182b316740>: Failed to establish a new connection: [Errno 111] Connection refused')': /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize\n" - ] - }, - { - "ename": "MaxRetryError", - "evalue": "HTTPConnectionPool(host='localhost', port=44379): Max retries exceeded with url: /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f1822df0370>: Failed to establish a new connection: [Errno 111] Connection refused'))", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mConnectionRefusedError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:174\u001b[0m, in \u001b[0;36mHTTPConnection._new_conn\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 174\u001b[0m conn \u001b[38;5;241m=\u001b[39m \u001b[43mconnection\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_connection\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dns_host\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mport\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mextra_kw\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m SocketTimeout:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/util/connection.py:95\u001b[0m, in \u001b[0;36mcreate_connection\u001b[0;34m(address, timeout, source_address, socket_options)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m err \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 95\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m err\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m socket\u001b[38;5;241m.\u001b[39merror(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgetaddrinfo returns an empty list\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/util/connection.py:85\u001b[0m, in \u001b[0;36mcreate_connection\u001b[0;34m(address, timeout, source_address, socket_options)\u001b[0m\n\u001b[1;32m 84\u001b[0m sock\u001b[38;5;241m.\u001b[39mbind(source_address)\n\u001b[0;32m---> 85\u001b[0m \u001b[43msock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43msa\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sock\n", - "\u001b[0;31mConnectionRefusedError\u001b[0m: [Errno 111] Connection refused", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mNewConnectionError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:703\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 702\u001b[0m \u001b[38;5;66;03m# Make the request on the httplib connection object.\u001b[39;00m\n\u001b[0;32m--> 703\u001b[0m httplib_response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 704\u001b[0m \u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 706\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 707\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout_obj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 708\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 709\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 710\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 711\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 713\u001b[0m \u001b[38;5;66;03m# If we're going to release the connection in ``finally:``, then\u001b[39;00m\n\u001b[1;32m 714\u001b[0m \u001b[38;5;66;03m# the response doesn't need to know about the connection. Otherwise\u001b[39;00m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;66;03m# it will also try to release it and we'll have a double-release\u001b[39;00m\n\u001b[1;32m 716\u001b[0m \u001b[38;5;66;03m# mess.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:398\u001b[0m, in \u001b[0;36mHTTPConnectionPool._make_request\u001b[0;34m(self, conn, method, url, timeout, chunked, **httplib_request_kw)\u001b[0m\n\u001b[1;32m 397\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 398\u001b[0m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhttplib_request_kw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 400\u001b[0m \u001b[38;5;66;03m# We are swallowing BrokenPipeError (errno.EPIPE) since the server is\u001b[39;00m\n\u001b[1;32m 401\u001b[0m \u001b[38;5;66;03m# legitimately able to close the connection after sending a valid response.\u001b[39;00m\n\u001b[1;32m 402\u001b[0m \u001b[38;5;66;03m# With this behaviour, the received response is still readable.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:244\u001b[0m, in \u001b[0;36mHTTPConnection.request\u001b[0;34m(self, method, url, body, headers)\u001b[0m\n\u001b[1;32m 243\u001b[0m headers[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUser-Agent\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m _get_default_user_agent()\n\u001b[0;32m--> 244\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mHTTPConnection\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1283\u001b[0m, in \u001b[0;36mHTTPConnection.request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Send a complete request to the server.\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1283\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1329\u001b[0m, in \u001b[0;36mHTTPConnection._send_request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1328\u001b[0m body \u001b[38;5;241m=\u001b[39m _encode(body, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbody\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m-> 1329\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendheaders\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1278\u001b[0m, in \u001b[0;36mHTTPConnection.endheaders\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1277\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CannotSendHeader()\n\u001b[0;32m-> 1278\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_output\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmessage_body\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1038\u001b[0m, in \u001b[0;36mHTTPConnection._send_output\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer[:]\n\u001b[0;32m-> 1038\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1040\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m message_body \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1041\u001b[0m \n\u001b[1;32m 1042\u001b[0m \u001b[38;5;66;03m# create a consistent interface to message_body\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:976\u001b[0m, in \u001b[0;36mHTTPConnection.send\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 975\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_open:\n\u001b[0;32m--> 976\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:205\u001b[0m, in \u001b[0;36mHTTPConnection.connect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconnect\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 205\u001b[0m conn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_new_conn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_conn(conn)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:186\u001b[0m, in \u001b[0;36mHTTPConnection._new_conn\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m SocketError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 186\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m NewConnectionError(\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to establish a new connection: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m e\n\u001b[1;32m 188\u001b[0m )\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m conn\n", - "\u001b[0;31mNewConnectionError\u001b[0m: <urllib3.connection.HTTPConnection object at 0x7f1822df0370>: Failed to establish a new connection: [Errno 111] Connection refused", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mMaxRetryError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[35], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mperfplot\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_training_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgnn_metrics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCONFIG\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgnn\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/etx4velo-3/LHCb_Pipeline/utils/plotutils/performance.py:65\u001b[0m, in \u001b[0;36mplot_training_metrics\u001b[0;34m(metrics, path_or_config, name)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;66;03m# show(row([p1,p2, p3]))\u001b[39;00m\n\u001b[1;32m 62\u001b[0m filename \u001b[38;5;241m=\u001b[39m op\u001b[38;5;241m.\u001b[39mjoin(\n\u001b[1;32m 63\u001b[0m get_performance_directory(path_or_config), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_metrics_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.png\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 64\u001b[0m )\n\u001b[0;32m---> 65\u001b[0m \u001b[43mexport_png\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mp1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp3\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 66\u001b[0m display(Image(filename\u001b[38;5;241m=\u001b[39mfilename))\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/bokeh/io/export.py:111\u001b[0m, in \u001b[0;36mexport_png\u001b[0;34m(obj, filename, width, height, webdriver, timeout)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mexport_png\u001b[39m(obj: LayoutDOM \u001b[38;5;241m|\u001b[39m Document, \u001b[38;5;241m*\u001b[39m, filename: PathLike \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, width: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 74\u001b[0m height: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, webdriver: WebDriver \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, timeout: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m5\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 75\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m''' Export the ``LayoutDOM`` object or document as a PNG.\u001b[39;00m\n\u001b[1;32m 76\u001b[0m \n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m If the filename is not given, it is derived from the script name (e.g.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 109\u001b[0m \n\u001b[1;32m 110\u001b[0m \u001b[38;5;124;03m '''\u001b[39;00m\n\u001b[0;32m--> 111\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[43mget_screenshot_as_png\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwidth\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdriver\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwebdriver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m filename \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m filename \u001b[38;5;241m=\u001b[39m default_filename(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpng\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/bokeh/io/export.py:237\u001b[0m, in \u001b[0;36mget_screenshot_as_png\u001b[0;34m(obj, driver, timeout, resources, width, height)\u001b[0m\n\u001b[1;32m 234\u001b[0m file\u001b[38;5;241m.\u001b[39mwrite(html)\n\u001b[1;32m 236\u001b[0m web_driver \u001b[38;5;241m=\u001b[39m driver \u001b[38;5;28;01mif\u001b[39;00m driver \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m webdriver_control\u001b[38;5;241m.\u001b[39mget()\n\u001b[0;32m--> 237\u001b[0m \u001b[43mweb_driver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmaximize_window\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 238\u001b[0m web_driver\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfile://\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtmp\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 239\u001b[0m wait_until_render_complete(web_driver, timeout)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/webdriver.py:592\u001b[0m, in \u001b[0;36mWebDriver.maximize_window\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 590\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Maximizes the current window that webdriver is using.\"\"\"\u001b[39;00m\n\u001b[1;32m 591\u001b[0m command \u001b[38;5;241m=\u001b[39m Command\u001b[38;5;241m.\u001b[39mW3C_MAXIMIZE_WINDOW\n\u001b[0;32m--> 592\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/webdriver.py:438\u001b[0m, in \u001b[0;36mWebDriver.execute\u001b[0;34m(self, driver_command, params)\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msessionId\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m params:\n\u001b[1;32m 436\u001b[0m params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msessionId\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msession_id\n\u001b[0;32m--> 438\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcommand_executor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdriver_command\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m response:\n\u001b[1;32m 440\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39merror_handler\u001b[38;5;241m.\u001b[39mcheck_response(response)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/remote_connection.py:290\u001b[0m, in \u001b[0;36mRemoteConnection.execute\u001b[0;34m(self, command, params)\u001b[0m\n\u001b[1;32m 288\u001b[0m data \u001b[38;5;241m=\u001b[39m utils\u001b[38;5;241m.\u001b[39mdump_json(params)\n\u001b[1;32m 289\u001b[0m url \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_url\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 290\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand_info\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/remote_connection.py:311\u001b[0m, in \u001b[0;36mRemoteConnection._request\u001b[0;34m(self, method, url, body)\u001b[0m\n\u001b[1;32m 308\u001b[0m body \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkeep_alive:\n\u001b[0;32m--> 311\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 312\u001b[0m statuscode \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mstatus\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/request.py:78\u001b[0m, in \u001b[0;36mRequestMethods.request\u001b[0;34m(self, method, url, fields, headers, **urlopen_kw)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequest_encode_url(\n\u001b[1;32m 75\u001b[0m method, url, fields\u001b[38;5;241m=\u001b[39mfields, headers\u001b[38;5;241m=\u001b[39mheaders, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39murlopen_kw\n\u001b[1;32m 76\u001b[0m )\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest_encode_body\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfields\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfields\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43murlopen_kw\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/request.py:170\u001b[0m, in \u001b[0;36mRequestMethods.request_encode_body\u001b[0;34m(self, method, url, fields, headers, encode_multipart, multipart_boundary, **urlopen_kw)\u001b[0m\n\u001b[1;32m 167\u001b[0m extra_kw[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mheaders\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mupdate(headers)\n\u001b[1;32m 168\u001b[0m extra_kw\u001b[38;5;241m.\u001b[39mupdate(urlopen_kw)\n\u001b[0;32m--> 170\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mextra_kw\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/poolmanager.py:376\u001b[0m, in \u001b[0;36mPoolManager.urlopen\u001b[0;34m(self, method, url, redirect, **kw)\u001b[0m\n\u001b[1;32m 374\u001b[0m response \u001b[38;5;241m=\u001b[39m conn\u001b[38;5;241m.\u001b[39murlopen(method, url, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n\u001b[1;32m 375\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 376\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mu\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest_uri\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 378\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n\u001b[1;32m 379\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m redirect_location:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conn:\n\u001b[1;32m 811\u001b[0m \u001b[38;5;66;03m# Try again\u001b[39;00m\n\u001b[1;32m 812\u001b[0m log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 813\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[1;32m 814\u001b[0m )\n\u001b[0;32m--> 815\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 816\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 817\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 818\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 821\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 822\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 823\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 824\u001b[0m \u001b[43m \u001b[49m\u001b[43mpool_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpool_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 825\u001b[0m \u001b[43m \u001b[49m\u001b[43mrelease_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrelease_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody_pos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody_pos\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 828\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\n\u001b[1;32m 829\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n\u001b[1;32m 832\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conn:\n\u001b[1;32m 811\u001b[0m \u001b[38;5;66;03m# Try again\u001b[39;00m\n\u001b[1;32m 812\u001b[0m log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 813\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[1;32m 814\u001b[0m )\n\u001b[0;32m--> 815\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 816\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 817\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 818\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 821\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 822\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 823\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 824\u001b[0m \u001b[43m \u001b[49m\u001b[43mpool_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpool_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 825\u001b[0m \u001b[43m \u001b[49m\u001b[43mrelease_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrelease_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody_pos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody_pos\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 828\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\n\u001b[1;32m 829\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n\u001b[1;32m 832\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conn:\n\u001b[1;32m 811\u001b[0m \u001b[38;5;66;03m# Try again\u001b[39;00m\n\u001b[1;32m 812\u001b[0m log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 813\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[1;32m 814\u001b[0m )\n\u001b[0;32m--> 815\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 816\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 817\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 818\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 821\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 822\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 823\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 824\u001b[0m \u001b[43m \u001b[49m\u001b[43mpool_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpool_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 825\u001b[0m \u001b[43m \u001b[49m\u001b[43mrelease_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrelease_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody_pos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody_pos\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 828\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\n\u001b[1;32m 829\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n\u001b[1;32m 832\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:787\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 784\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, (SocketError, HTTPException)):\n\u001b[1;32m 785\u001b[0m e \u001b[38;5;241m=\u001b[39m ProtocolError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConnection aborted.\u001b[39m\u001b[38;5;124m\"\u001b[39m, e)\n\u001b[0;32m--> 787\u001b[0m retries \u001b[38;5;241m=\u001b[39m \u001b[43mretries\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mincrement\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 788\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_pool\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_stacktrace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msys\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexc_info\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 789\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 790\u001b[0m retries\u001b[38;5;241m.\u001b[39msleep()\n\u001b[1;32m 792\u001b[0m \u001b[38;5;66;03m# Keep track of the error for the retry warning.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/util/retry.py:592\u001b[0m, in \u001b[0;36mRetry.increment\u001b[0;34m(self, method, url, response, error, _pool, _stacktrace)\u001b[0m\n\u001b[1;32m 581\u001b[0m new_retry \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnew(\n\u001b[1;32m 582\u001b[0m total\u001b[38;5;241m=\u001b[39mtotal,\n\u001b[1;32m 583\u001b[0m connect\u001b[38;5;241m=\u001b[39mconnect,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 588\u001b[0m history\u001b[38;5;241m=\u001b[39mhistory,\n\u001b[1;32m 589\u001b[0m )\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_retry\u001b[38;5;241m.\u001b[39mis_exhausted():\n\u001b[0;32m--> 592\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MaxRetryError(_pool, url, error \u001b[38;5;129;01mor\u001b[39;00m ResponseError(cause))\n\u001b[1;32m 594\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIncremented Retry for (url=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m): \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, url, new_retry)\n\u001b[1;32m 596\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m new_retry\n", - "\u001b[0;31mMaxRetryError\u001b[0m: HTTPConnectionPool(host='localhost', port=44379): Max retries exceeded with url: /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f1822df0370>: Failed to establish a new connection: [Errno 111] Connection refused'))" - ] - } - ], + "outputs": [], "source": [ - "perfplot.plot_training_metrics(gnn_metrics, CONFIG, \"gnn\")" + "perfplot_mpl.plot_metric_epochs(\n", + " \"eff\",\n", + " gnn_metrics,\n", + " CONFIG,\n", + " \"gnn\",\n", + " \"Edge Efficiency\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n", + "perfplot_mpl.plot_metric_epochs(\n", + " \"pur\",\n", + " gnn_metrics,\n", + " CONFIG,\n", + " \"gnn\",\n", + " \"Edge Purity\",\n", + " color=perfplot_mpl.partition_to_color[\"val\"],\n", + ")\n" ] }, { @@ -3397,77 +785,30 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "gnn_model.load_testset(\"velo-minbias-sim10b-xdigi-nospillover\")\n" + "gnn_model.load_partition(\"velo-sim10b-nospillover\")\n", + "perfplot.plot_edge_performance(gnn_model, CONFIG)\n" ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", - " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n", - "WARNING:Retrying (Retry(total=2, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f1822dda380>: Failed to establish a new connection: [Errno 111] Connection refused')': /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize\n", - "WARNING:Retrying (Retry(total=1, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f177073a200>: Failed to establish a new connection: [Errno 111] Connection refused')': /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize\n", - "WARNING:Retrying (Retry(total=0, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f17707cc370>: Failed to establish a new connection: [Errno 111] Connection refused')': /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize\n" - ] - }, - { - "ename": "MaxRetryError", - "evalue": "HTTPConnectionPool(host='localhost', port=44379): Max retries exceeded with url: /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f17707cc610>: Failed to establish a new connection: [Errno 111] Connection refused'))", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mConnectionRefusedError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:174\u001b[0m, in \u001b[0;36mHTTPConnection._new_conn\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 174\u001b[0m conn \u001b[38;5;241m=\u001b[39m \u001b[43mconnection\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_connection\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dns_host\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mport\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mextra_kw\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m SocketTimeout:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/util/connection.py:95\u001b[0m, in \u001b[0;36mcreate_connection\u001b[0;34m(address, timeout, source_address, socket_options)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m err \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 95\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m err\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m socket\u001b[38;5;241m.\u001b[39merror(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgetaddrinfo returns an empty list\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/util/connection.py:85\u001b[0m, in \u001b[0;36mcreate_connection\u001b[0;34m(address, timeout, source_address, socket_options)\u001b[0m\n\u001b[1;32m 84\u001b[0m sock\u001b[38;5;241m.\u001b[39mbind(source_address)\n\u001b[0;32m---> 85\u001b[0m \u001b[43msock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43msa\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sock\n", - "\u001b[0;31mConnectionRefusedError\u001b[0m: [Errno 111] Connection refused", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mNewConnectionError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:703\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 702\u001b[0m \u001b[38;5;66;03m# Make the request on the httplib connection object.\u001b[39;00m\n\u001b[0;32m--> 703\u001b[0m httplib_response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 704\u001b[0m \u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 706\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 707\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout_obj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 708\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 709\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 710\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 711\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 713\u001b[0m \u001b[38;5;66;03m# If we're going to release the connection in ``finally:``, then\u001b[39;00m\n\u001b[1;32m 714\u001b[0m \u001b[38;5;66;03m# the response doesn't need to know about the connection. Otherwise\u001b[39;00m\n\u001b[1;32m 715\u001b[0m \u001b[38;5;66;03m# it will also try to release it and we'll have a double-release\u001b[39;00m\n\u001b[1;32m 716\u001b[0m \u001b[38;5;66;03m# mess.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:398\u001b[0m, in \u001b[0;36mHTTPConnectionPool._make_request\u001b[0;34m(self, conn, method, url, timeout, chunked, **httplib_request_kw)\u001b[0m\n\u001b[1;32m 397\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 398\u001b[0m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhttplib_request_kw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 400\u001b[0m \u001b[38;5;66;03m# We are swallowing BrokenPipeError (errno.EPIPE) since the server is\u001b[39;00m\n\u001b[1;32m 401\u001b[0m \u001b[38;5;66;03m# legitimately able to close the connection after sending a valid response.\u001b[39;00m\n\u001b[1;32m 402\u001b[0m \u001b[38;5;66;03m# With this behaviour, the received response is still readable.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:244\u001b[0m, in \u001b[0;36mHTTPConnection.request\u001b[0;34m(self, method, url, body, headers)\u001b[0m\n\u001b[1;32m 243\u001b[0m headers[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUser-Agent\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m _get_default_user_agent()\n\u001b[0;32m--> 244\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mHTTPConnection\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1283\u001b[0m, in \u001b[0;36mHTTPConnection.request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Send a complete request to the server.\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1283\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1329\u001b[0m, in \u001b[0;36mHTTPConnection._send_request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1328\u001b[0m body \u001b[38;5;241m=\u001b[39m _encode(body, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbody\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m-> 1329\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendheaders\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1278\u001b[0m, in \u001b[0;36mHTTPConnection.endheaders\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1277\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CannotSendHeader()\n\u001b[0;32m-> 1278\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_output\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmessage_body\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:1038\u001b[0m, in \u001b[0;36mHTTPConnection._send_output\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer[:]\n\u001b[0;32m-> 1038\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1040\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m message_body \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1041\u001b[0m \n\u001b[1;32m 1042\u001b[0m \u001b[38;5;66;03m# create a consistent interface to message_body\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/http/client.py:976\u001b[0m, in \u001b[0;36mHTTPConnection.send\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 975\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_open:\n\u001b[0;32m--> 976\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:205\u001b[0m, in \u001b[0;36mHTTPConnection.connect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconnect\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 205\u001b[0m conn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_new_conn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_conn(conn)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connection.py:186\u001b[0m, in \u001b[0;36mHTTPConnection._new_conn\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m SocketError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 186\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m NewConnectionError(\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed to establish a new connection: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m e\n\u001b[1;32m 188\u001b[0m )\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m conn\n", - "\u001b[0;31mNewConnectionError\u001b[0m: <urllib3.connection.HTTPConnection object at 0x7f17707cc610>: Failed to establish a new connection: [Errno 111] Connection refused", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mMaxRetryError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[37], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mperfplot\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_edge_performance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgnn_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCONFIG\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/etx4velo-3/LHCb_Pipeline/utils/plotutils/performance.py:281\u001b[0m, in \u001b[0;36mplot_edge_performance\u001b[0;34m(model, path_or_config)\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;66;03m# show(row(figures))\u001b[39;00m\n\u001b[1;32m 278\u001b[0m filename \u001b[38;5;241m=\u001b[39m op\u001b[38;5;241m.\u001b[39mjoin(\n\u001b[1;32m 279\u001b[0m get_performance_directory(path_or_config), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medge_performance.png\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 280\u001b[0m )\n\u001b[0;32m--> 281\u001b[0m \u001b[43mexport_png\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfigures\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 282\u001b[0m display(Image(filename\u001b[38;5;241m=\u001b[39mfilename))\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/bokeh/io/export.py:111\u001b[0m, in \u001b[0;36mexport_png\u001b[0;34m(obj, filename, width, height, webdriver, timeout)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mexport_png\u001b[39m(obj: LayoutDOM \u001b[38;5;241m|\u001b[39m Document, \u001b[38;5;241m*\u001b[39m, filename: PathLike \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, width: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 74\u001b[0m height: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, webdriver: WebDriver \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, timeout: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m5\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 75\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m''' Export the ``LayoutDOM`` object or document as a PNG.\u001b[39;00m\n\u001b[1;32m 76\u001b[0m \n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m If the filename is not given, it is derived from the script name (e.g.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 109\u001b[0m \n\u001b[1;32m 110\u001b[0m \u001b[38;5;124;03m '''\u001b[39;00m\n\u001b[0;32m--> 111\u001b[0m image \u001b[38;5;241m=\u001b[39m \u001b[43mget_screenshot_as_png\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwidth\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwidth\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdriver\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwebdriver\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m filename \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 114\u001b[0m filename \u001b[38;5;241m=\u001b[39m default_filename(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpng\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/bokeh/io/export.py:237\u001b[0m, in \u001b[0;36mget_screenshot_as_png\u001b[0;34m(obj, driver, timeout, resources, width, height)\u001b[0m\n\u001b[1;32m 234\u001b[0m file\u001b[38;5;241m.\u001b[39mwrite(html)\n\u001b[1;32m 236\u001b[0m web_driver \u001b[38;5;241m=\u001b[39m driver \u001b[38;5;28;01mif\u001b[39;00m driver \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m webdriver_control\u001b[38;5;241m.\u001b[39mget()\n\u001b[0;32m--> 237\u001b[0m \u001b[43mweb_driver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmaximize_window\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 238\u001b[0m web_driver\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfile://\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtmp\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 239\u001b[0m wait_until_render_complete(web_driver, timeout)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/webdriver.py:592\u001b[0m, in \u001b[0;36mWebDriver.maximize_window\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 590\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Maximizes the current window that webdriver is using.\"\"\"\u001b[39;00m\n\u001b[1;32m 591\u001b[0m command \u001b[38;5;241m=\u001b[39m Command\u001b[38;5;241m.\u001b[39mW3C_MAXIMIZE_WINDOW\n\u001b[0;32m--> 592\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/webdriver.py:438\u001b[0m, in \u001b[0;36mWebDriver.execute\u001b[0;34m(self, driver_command, params)\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msessionId\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m params:\n\u001b[1;32m 436\u001b[0m params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msessionId\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msession_id\n\u001b[0;32m--> 438\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcommand_executor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdriver_command\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 439\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m response:\n\u001b[1;32m 440\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39merror_handler\u001b[38;5;241m.\u001b[39mcheck_response(response)\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/remote_connection.py:290\u001b[0m, in \u001b[0;36mRemoteConnection.execute\u001b[0;34m(self, command, params)\u001b[0m\n\u001b[1;32m 288\u001b[0m data \u001b[38;5;241m=\u001b[39m utils\u001b[38;5;241m.\u001b[39mdump_json(params)\n\u001b[1;32m 289\u001b[0m url \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_url\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 290\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand_info\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/selenium/webdriver/remote/remote_connection.py:311\u001b[0m, in \u001b[0;36mRemoteConnection._request\u001b[0;34m(self, method, url, body)\u001b[0m\n\u001b[1;32m 308\u001b[0m body \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkeep_alive:\n\u001b[0;32m--> 311\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 312\u001b[0m statuscode \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mstatus\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/request.py:78\u001b[0m, in \u001b[0;36mRequestMethods.request\u001b[0;34m(self, method, url, fields, headers, **urlopen_kw)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrequest_encode_url(\n\u001b[1;32m 75\u001b[0m method, url, fields\u001b[38;5;241m=\u001b[39mfields, headers\u001b[38;5;241m=\u001b[39mheaders, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39murlopen_kw\n\u001b[1;32m 76\u001b[0m )\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 78\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest_encode_body\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfields\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfields\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43murlopen_kw\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/request.py:170\u001b[0m, in \u001b[0;36mRequestMethods.request_encode_body\u001b[0;34m(self, method, url, fields, headers, encode_multipart, multipart_boundary, **urlopen_kw)\u001b[0m\n\u001b[1;32m 167\u001b[0m extra_kw[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mheaders\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mupdate(headers)\n\u001b[1;32m 168\u001b[0m extra_kw\u001b[38;5;241m.\u001b[39mupdate(urlopen_kw)\n\u001b[0;32m--> 170\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mextra_kw\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/poolmanager.py:376\u001b[0m, in \u001b[0;36mPoolManager.urlopen\u001b[0;34m(self, method, url, redirect, **kw)\u001b[0m\n\u001b[1;32m 374\u001b[0m response \u001b[38;5;241m=\u001b[39m conn\u001b[38;5;241m.\u001b[39murlopen(method, url, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw)\n\u001b[1;32m 375\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 376\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mu\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest_uri\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 378\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n\u001b[1;32m 379\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m redirect_location:\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conn:\n\u001b[1;32m 811\u001b[0m \u001b[38;5;66;03m# Try again\u001b[39;00m\n\u001b[1;32m 812\u001b[0m log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 813\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[1;32m 814\u001b[0m )\n\u001b[0;32m--> 815\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 816\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 817\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 818\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 821\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 822\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 823\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 824\u001b[0m \u001b[43m \u001b[49m\u001b[43mpool_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpool_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 825\u001b[0m \u001b[43m \u001b[49m\u001b[43mrelease_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrelease_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody_pos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody_pos\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 828\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\n\u001b[1;32m 829\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n\u001b[1;32m 832\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conn:\n\u001b[1;32m 811\u001b[0m \u001b[38;5;66;03m# Try again\u001b[39;00m\n\u001b[1;32m 812\u001b[0m log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 813\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[1;32m 814\u001b[0m )\n\u001b[0;32m--> 815\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 816\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 817\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 818\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 821\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 822\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 823\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 824\u001b[0m \u001b[43m \u001b[49m\u001b[43mpool_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpool_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 825\u001b[0m \u001b[43m \u001b[49m\u001b[43mrelease_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrelease_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody_pos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody_pos\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 828\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\n\u001b[1;32m 829\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n\u001b[1;32m 832\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:815\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conn:\n\u001b[1;32m 811\u001b[0m \u001b[38;5;66;03m# Try again\u001b[39;00m\n\u001b[1;32m 812\u001b[0m log\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 813\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRetrying (\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) after connection broken by \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, retries, err, url\n\u001b[1;32m 814\u001b[0m )\n\u001b[0;32m--> 815\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 816\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 817\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 818\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 821\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 822\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 823\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 824\u001b[0m \u001b[43m \u001b[49m\u001b[43mpool_timeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpool_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 825\u001b[0m \u001b[43m \u001b[49m\u001b[43mrelease_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrelease_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 826\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 827\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody_pos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody_pos\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 828\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\n\u001b[1;32m 829\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 831\u001b[0m \u001b[38;5;66;03m# Handle redirect?\u001b[39;00m\n\u001b[1;32m 832\u001b[0m redirect_location \u001b[38;5;241m=\u001b[39m redirect \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mget_redirect_location()\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/connectionpool.py:787\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 784\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, (SocketError, HTTPException)):\n\u001b[1;32m 785\u001b[0m e \u001b[38;5;241m=\u001b[39m ProtocolError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConnection aborted.\u001b[39m\u001b[38;5;124m\"\u001b[39m, e)\n\u001b[0;32m--> 787\u001b[0m retries \u001b[38;5;241m=\u001b[39m \u001b[43mretries\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mincrement\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 788\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_pool\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_stacktrace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msys\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexc_info\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 789\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 790\u001b[0m retries\u001b[38;5;241m.\u001b[39msleep()\n\u001b[1;32m 792\u001b[0m \u001b[38;5;66;03m# Keep track of the error for the retry warning.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/etx4velo/lib/python3.10/site-packages/urllib3/util/retry.py:592\u001b[0m, in \u001b[0;36mRetry.increment\u001b[0;34m(self, method, url, response, error, _pool, _stacktrace)\u001b[0m\n\u001b[1;32m 581\u001b[0m new_retry \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnew(\n\u001b[1;32m 582\u001b[0m total\u001b[38;5;241m=\u001b[39mtotal,\n\u001b[1;32m 583\u001b[0m connect\u001b[38;5;241m=\u001b[39mconnect,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 588\u001b[0m history\u001b[38;5;241m=\u001b[39mhistory,\n\u001b[1;32m 589\u001b[0m )\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_retry\u001b[38;5;241m.\u001b[39mis_exhausted():\n\u001b[0;32m--> 592\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MaxRetryError(_pool, url, error \u001b[38;5;129;01mor\u001b[39;00m ResponseError(cause))\n\u001b[1;32m 594\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIncremented Retry for (url=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m): \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, url, new_retry)\n\u001b[1;32m 596\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m new_retry\n", - "\u001b[0;31mMaxRetryError\u001b[0m: HTTPConnectionPool(host='localhost', port=44379): Max retries exceeded with url: /session/216aa782-d1d1-47dc-a868-fc064e41882a/window/maximize (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f17707cc610>: Failed to establish a new connection: [Errno 111] Connection refused'))" - ] - } - ], + "outputs": [], "source": [ - "perfplot.plot_edge_performance(gnn_model, CONFIG)" + "from GNN.gnn_plots import plot_best_performances_score_cut\n", + "\n", + "_, _, performances_for_various_score_cuts = plot_best_performances_score_cut(\n", + " model=gnn_model,\n", + " path_or_config=CONFIG,\n", + " partition=\"velo-sim10b-nospillover\",\n", + " score_cuts=[0.4, 0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95],\n", + " n_events=50,\n", + " seed=0,\n", + ")" ] }, { @@ -3480,121 +821,9 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:--------------------- Step 4: Scoring graph edges using GNN ---------------------\n", - "INFO:---------------------------- a) Loading trained model ----------------------------\n", - "INFO:Load model from artifacts/gnn/velo-minbias-sim10b-xdigi-nospillover/version_1/checkpoints/epoch=39-step=3600.ckpt.\n", - "INFO:----------------------------- b) Running inferencing -----------------------------\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/train to scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/train\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "449ae081348b45a59a036f585ca77680", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", - " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/val to scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/val\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6e567db527404b2a83b5b499fc689a4e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/10 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-minbias-sim10b-xdigi-nospillover to scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/test/velo-minbias-sim10b-xdigi-nospillover\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ff680ea179ef4a1b8b8b072845a5bd7d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons to scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e7ee06f3f6184534b5aae4042ecd0871", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/metric_learning_processed/test/velo-bu2kstee-sim10aU1-xdigi to scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/test/velo-bu2kstee-sim10aU1-xdigi\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ba8bcf65fb34438d87e387a86a102f00", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "run_gnn_inference(CONFIG, checkpoint=gnn_artifact_path)" ] @@ -3609,116 +838,9 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:----------- Step 5: Building track candidates from the scored graph -----------\n", - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/train to scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/train\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f3df351acebe4264bf81a772a4c7b574", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/val to scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/val\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0135fe4bb33f45a9b7e7f362d666ca50", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/10 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/test/velo-minbias-sim10b-xdigi-nospillover to scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/test/velo-minbias-sim10b-xdigi-nospillover\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "df38b26086f845898e414e3d4d0f8e88", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons to scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/test/velo-minbias-sim10b-xdigi-nospillover-only-long-electrons\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c24b1b46a4cb46238d3fa9abc99c5b1f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Inference from scratch/velo-minbias-sim10b-xdigi-nospillover/gnn_processed/test/velo-bu2kstee-sim10aU1-xdigi to scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/test/velo-bu2kstee-sim10aU1-xdigi\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "30e83bcf22014b19adc6645e161e1f70", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "build_track_candidates(CONFIG)" ] @@ -3733,179 +855,11 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:------------------------------ Evaluation for train ------------------------------\n", - "INFO:1) Load dataframe of tracks, hits-particles and particles\n", - "INFO:Load tracks in scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/train.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "14aa74c4b1da48fd9bf10285f32623dd", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Load truncated paths for train in scratch/velo-minbias-sim10b-xdigi-nospillover/processed/splitting.json\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d5dbfb0591f145c0b15c901a6abbf014", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "257aa0e2d3fd43cbb4cc0536bfbfe626", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/90 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:2) Matching\n", - "INFO:3) Evaluation\n", - "INFO:Report was saved in scratch/output/report-2023.05.25-11.31.53-train.txt\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TrackChecker output : 565/ 36219 1.56% ghosts\n", - "01_velo : 17906/ 19310 92.73% ( 92.84%), 172 ( 0.95%) clones, pur 99.48%, hit eff 94.77%\n", - "02_long : 10614/ 11198 94.78% ( 94.87%), 102 ( 0.95%) clones, pur 99.53%, hit eff 95.59%\n", - "03_long_P>5GeV : 6973/ 7280 95.78% ( 95.85%), 74 ( 1.05%) clones, pur 99.58%, hit eff 96.15%\n", - "04_long_strange : 506/ 587 86.20% ( 87.39%), 5 ( 0.98%) clones, pur 99.18%, hit eff 92.27%\n", - "05_long_strange_P>5GeV : 252/ 288 87.50% ( 90.52%), 4 ( 1.56%) clones, pur 99.12%, hit eff 94.42%\n", - "06_long_fromB : 7/ 7 100.00% (100.00%), 0 ( 0.00%) clones, pur 98.90%, hit eff 97.07%\n", - "07_long_fromB_P>5GeV : 7/ 7 100.00% (100.00%), 0 ( 0.00%) clones, pur 98.90%, hit eff 97.07%\n", - "08_long_electrons : 0/ 0 nan% ( nan%), 0 ( nan%) clones, pur nan%, hit eff nan%\n", - "09_long_fromB_electrons : 0/ 0 nan% ( nan%), 0 ( nan%) clones, pur nan%, hit eff nan%\n", - "10_long_fromB_electrons_P>5GeV : 0/ 0 nan% ( nan%), 0 ( nan%) clones, pur nan%, hit eff nan%\n", - "\n", - "*** Benchmark score: 93.76\n", - "\n", - "| Categories | Efficiency | Average efficiency | % clones | Average hit purity | Average hit efficiency |\n", - "|:---------------------|:-------------|:---------------------|:-----------|:---------------------|:-------------------------|\n", - "| Velo | 92.73% | 92.84% | 0.95% | 99.48% | 94.77% |\n", - "| Long | 94.78% | 94.87% | 0.95% | 99.53% | 95.59% |\n", - "| Velo, no electrons | 92.73% | 92.84% | 0.95% | 99.48% | 94.77% |\n", - "| Velo, only electrons | nan% | nan% | nan% | nan% | nan% |\n", - "| Long, only electrons | nan% | nan% | nan% | nan% | nan% |\n", - "| Categories | # ghosts | # tracks | % ghosts |\n", - "|:-------------|-----------:|:-----------|:-----------|\n", - "| Everything | 565 | 36,219 | 1.56% |\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Plot for category velo_reconstructible_acceptance saved in scratch/output/hist1d_velo_reconstructible_acceptance-train.pdf\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "INFO:Plot for category long_reconstructible_acceptance_only_electrons saved in scratch/output/hist1d_long_reconstructible_acceptance_only_electrons-train.pdf\n" - ] - }, - { - "data": { - "text/plain": [ - "<montetracko.evaluation.trackevaluator.TrackEvaluator at 0x7f1732aeb4c0>" - ] - }, - "execution_count": 41, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 3200x2400 with 32 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 3200x2400 with 32 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "evaluate_candidates_montetracko(\n", " CONFIG,\n", @@ -3917,183 +871,16 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:------------------------------- Evaluation for val -------------------------------\n", - "INFO:1) Load dataframe of tracks, hits-particles and particles\n", - "INFO:Load tracks in scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/val.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b1240f1f9c504f2fac37b0e19abc9464", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/10 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Load truncated paths for val in scratch/velo-minbias-sim10b-xdigi-nospillover/processed/splitting.json\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b77595105f8d4c7ab31cebeef9720d2f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/10 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "30a5e820dfb74e04b3ff56bd93549f08", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/10 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:2) Matching\n", - "INFO:3) Evaluation\n", - "INFO:Report was saved in scratch/output/report-2023.05.25-11.32.09-val.txt\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TrackChecker output : 65/ 4029 1.61% ghosts\n", - "01_velo : 1968/ 2109 93.31% ( 93.40%), 19 ( 0.96%) clones, pur 99.24%, hit eff 94.52%\n", - "02_long : 1154/ 1215 94.98% ( 94.97%), 13 ( 1.11%) clones, pur 99.21%, hit eff 95.07%\n", - "03_long_P>5GeV : 783/ 814 96.19% ( 96.26%), 8 ( 1.01%) clones, pur 99.14%, hit eff 96.06%\n", - "04_long_strange : 48/ 59 81.36% ( 82.20%), 0 ( 0.00%) clones, pur 98.70%, hit eff 93.93%\n", - "05_long_strange_P>5GeV : 29/ 33 87.88% ( 78.00%), 0 ( 0.00%) clones, pur 98.24%, hit eff 96.42%\n", - "06_long_fromB : 15/ 18 83.33% ( 83.76%), 1 ( 6.25%) clones, pur 96.98%, hit eff 92.97%\n", - "07_long_fromB_P>5GeV : 12/ 14 85.71% ( 93.33%), 1 ( 7.69%) clones, pur 96.28%, hit eff 91.35%\n", - "08_long_electrons : 0/ 0 nan% ( nan%), 0 ( nan%) clones, pur nan%, hit eff nan%\n", - "09_long_fromB_electrons : 0/ 0 nan% ( nan%), 0 ( nan%) clones, pur nan%, hit eff nan%\n", - "10_long_fromB_electrons_P>5GeV : 0/ 0 nan% ( nan%), 0 ( nan%) clones, pur nan%, hit eff nan%\n", - "\n", - "*** Benchmark score: 94.06\n", - "\n", - "| Categories | Efficiency | Average efficiency | % clones | Average hit purity | Average hit efficiency |\n", - "|:---------------------|:-------------|:---------------------|:-----------|:---------------------|:-------------------------|\n", - "| Velo | 93.31% | 93.40% | 0.96% | 99.24% | 94.52% |\n", - "| Long | 94.98% | 94.97% | 1.11% | 99.21% | 95.07% |\n", - "| Velo, no electrons | 93.31% | 93.40% | 0.96% | 99.24% | 94.52% |\n", - "| Velo, only electrons | nan% | nan% | nan% | nan% | nan% |\n", - "| Long, only electrons | nan% | nan% | nan% | nan% | nan% |\n", - "| Categories | # ghosts | # tracks | % ghosts |\n", - "|:-------------|-----------:|:-----------|:-----------|\n", - "| Everything | 65 | 4,029 | 1.61% |\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Plot for category velo_reconstructible_acceptance saved in scratch/output/hist1d_velo_reconstructible_acceptance-val.pdf\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "/home/fgias/.conda/envs/etx4velo/lib/python3.10/site-packages/matplotlib/axes/_base.py:2503: UserWarning: Warning: converting a masked element to nan.\n", - " xys = np.asarray(xys)\n", - "INFO:Plot for category long_reconstructible_acceptance_only_electrons saved in scratch/output/hist1d_long_reconstructible_acceptance_only_electrons-val.pdf\n" - ] - }, - { - "data": { - "text/plain": [ - "<montetracko.evaluation.trackevaluator.TrackEvaluator at 0x7f173112f940>" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 3200x2400 with 32 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 3200x2400 with 32 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "evaluate_candidates_montetracko(\n", " CONFIG,\n", " partition=\"val\",\n", " allen_report=True,\n", " table_report=True,\n", + " plot_categories=[],\n", ")\n" ] }, @@ -4107,153 +894,41 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:-------------- Evaluation for velo-minbias-sim10b-xdigi-nospillover --------------\n", - "INFO:1) Load dataframe of tracks, hits-particles and particles\n", - "INFO:Load tracks in scratch/velo-minbias-sim10b-xdigi-nospillover/track_building_processed/test/velo-minbias-sim10b-xdigi-nospillover.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9c1699093f91492a9b6b7998c785ea36", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Load pre-processed test datasets in scratch/__test__/velo-minbias-sim10b-xdigi-nospillover.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fc500dcde87a441487da86e729696e7b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "75aae0fd7a4544bf91a5c4ebb7ff8d63", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/1000 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:2) Matching\n", - "INFO:3) Evaluation\n", - "INFO:Report was saved in scratch/output/report-2023.05.25-11.32.50-velo-minbias-sim10b-xdigi-nospillover.txt\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TrackChecker output : 7868/ 208626 3.77% ghosts\n", - "01_velo : 93076/ 100530 92.59% ( 93.62%), 954 ( 1.01%) clones, pur 99.16%, hit eff 94.86%\n", - "02_long : 54220/ 57018 95.09% ( 95.97%), 513 ( 0.94%) clones, pur 99.29%, hit eff 95.82%\n", - "03_long_P>5GeV : 35149/ 36528 96.22% ( 96.98%), 293 ( 0.83%) clones, pur 99.33%, hit eff 96.64%\n", - "04_long_strange : 2535/ 2969 85.38% ( 86.67%), 35 ( 1.36%) clones, pur 98.55%, hit eff 91.16%\n", - "05_long_strange_P>5GeV : 1241/ 1426 87.03% ( 87.77%), 12 ( 0.96%) clones, pur 98.53%, hit eff 93.32%\n", - "06_long_fromB : 79/ 85 92.94% ( 92.27%), 1 ( 1.25%) clones, pur 99.21%, hit eff 95.09%\n", - "07_long_fromB_P>5GeV : 56/ 59 94.92% ( 95.28%), 0 ( 0.00%) clones, pur 99.07%, hit eff 96.30%\n", - "08_long_electrons : 2738/ 4141 66.12% ( 68.61%), 39 ( 1.40%) clones, pur 96.10%, hit eff 81.16%\n", - "09_long_fromB_electrons : 8/ 9 88.89% ( 85.71%), 0 ( 0.00%) clones, pur 100.00%, hit eff 83.93%\n", - "10_long_fromB_electrons_P>5GeV : 6/ 7 85.71% ( 80.00%), 0 ( 0.00%) clones, pur 100.00%, hit eff 82.74%\n", - "\n", - "*** Benchmark score: 91.21\n", - "\n", - "| Categories | Efficiency | Average efficiency | % clones | Average hit purity | Average hit efficiency |\n", - "|:---------------------|:-------------|:---------------------|:-----------|:---------------------|:-------------------------|\n", - "| Velo | 87.25% | 88.64% | 1.27% | 98.87% | 92.77% |\n", - "| Long | 93.13% | 94.13% | 0.96% | 99.13% | 95.11% |\n", - "| Velo, no electrons | 92.59% | 93.62% | 1.01% | 99.16% | 94.86% |\n", - "| Velo, only electrons | 60.21% | 62.33% | 3.25% | 96.67% | 76.81% |\n", - "| Long, only electrons | 66.12% | 68.61% | 1.40% | 96.10% | 81.16% |\n", - "| Categories | # ghosts | # tracks | % ghosts |\n", - "|:-------------|:-----------|:-----------|:-----------|\n", - "| Everything | 7,868 | 208,626 | 3.77% |\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:Plot for category velo_reconstructible_acceptance saved in scratch/output/hist1d_velo_reconstructible_acceptance-velo-minbias-sim10b-xdigi-nospillover.pdf\n", - "INFO:Plot for category long_reconstructible_acceptance_only_electrons saved in scratch/output/hist1d_long_reconstructible_acceptance_only_electrons-velo-minbias-sim10b-xdigi-nospillover.pdf\n" - ] - }, - { - "data": { - "text/plain": [ - "<montetracko.evaluation.trackevaluator.TrackEvaluator at 0x7f173119c940>" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 3200x2400 with 32 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 3200x2400 with 32 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "evaluate_candidates_montetracko(\n", + "for test_dataset_name in get_required_test_dataset_names(CONFIG):\n", + " logging.info(headline(test_dataset_name))\n", + " evaluate_candidates_montetracko(\n", + " CONFIG,\n", + " partition=test_dataset_name,\n", + " allen_report=True,\n", + " table_report=True,\n", + " plot_categories=[],\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trackEvaluator = evaluate_candidates_montetracko(\n", " CONFIG,\n", - " partition=\"velo-minbias-sim10b-xdigi-nospillover\",\n", + " partition=\"velo-sim10b-nospillover\",\n", " allen_report=True,\n", " table_report=True,\n", - ")\n" + ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -4272,7 +947,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.10" }, "vscode": { "interpreter": { diff --git a/LHCb_Pipeline/pipeline_config_default.yaml b/LHCb_Pipeline/pipeline_config_default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6c437c07f411b3fa97c0c6f92a4ae5e7d713a9d --- /dev/null +++ b/LHCb_Pipeline/pipeline_config_default.yaml @@ -0,0 +1,122 @@ +common: + experiment_name: example + data_directory: /scratch/acorreia/data # where the data are saved + artifact_directory: artifacts # where the checkpoints are saved + performance_directory: output # where the plots and reports are saved + gpus: 1 + # Name of the test datasets to use (defined in `test_samples.yaml`) + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + # Can be + # - Integer: Last subdirectory that can be used (starting from `0`). `-1` for all. + # - String or list of strings: sub-directories that can be used + # - `null`: use `input_dir` directly + # - Dictionary with keys `start` and `stop` + subdirs: 10 + output_subdirectory: "preprocessed" + selection: triplets_first_selection # Selection function, defined in `Preprocessing/selecting.py` + n_events: null # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 500 # Minimal number of genuine hits + # Columns to keep in the dataframes of hits-particles and particles + # (excluding `event`, `particle_id` and `lhcbid`) + # `null` means keep everything + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 # Number of processes in parallel in the processing stage + features: ["r", "phi", "z", "plane"] # Name of the features to use + feature_means: [18., 0.0, 281.0, 7.5] # Means for normalising the features + feature_scales: [9.75, 1.82, 287.0, 12.5] # Scales for normalising the features + # List of the columns to keep in the PyTorch batches, in the dataframe of hits + kept_hits_columns: ["plane", {"un_z": "z"}] + # List of columns in the dataframe of particles that are merged to the dataframe + # of hits and stored in the PyTorch batches + kept_particles_columns: ["nhits_velo"] + n_train_events: 100 # Number of training events + n_val_events: 100 # Number of validation events + split_seed: 0 # Seed used for the splitting train-val + # How the true edges are computed + # - sortwise: sort by z + # - modulewise: sort by distance to production vertex + # - planewise: hits belonging to same particle and belonging to adjacent planes + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 # Number of hidden units / layer in the MPL + nb_layer: 4 # Number of layers + emb_dim: 4 # Embedding dimension + activation: Tanh # Action function used in the MLP + weight: 2 # Weight for positive examples + randomisation: 2 # Number of random pairs per hit + points_per_batch: 100000 # Number of query points to consider + r: 0.015 # Maximum distance for hard-mining + r_inference: 0.020 # Maximum distance for inference + knn: 50 # Maximal number of neighbours during training and inference + warmup: 8 # Start with small increasing learning rate for `warmup` epochs + margin: 0.1 # Loss for negative examples is max(0.1**2 - d²) + # Multiply the initial learning rate by ``factor`` every ``patience`` epochs + lr: 0.001 + factor: 0.7 + patience: 4 + # Available regimes + # - rp: random pairs + # - hnm: hard negative mining + # - norm: perform L2 normalisation + regime: [rp, hnm, norm] + bidir: true # Whether to use a bi-directional graph + max_epochs: 30 + + filtering: edges_at_least_3_hits + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 # Edge cut for validation + + # Model parameters + feature_indices: 4 + hidden: 256 # Number of hidden units per layer in the node encoder + # Number of layers in each MLP of the GNN + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 # = `nb_node_layers` if not specified + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 # = `nb_edge_layers` if not specified + nb_edge_classifier_layers: 6 # = `nb_edge_layers` if not specified + layernorm: True # Whether to use layer normalisation + aggregation: sum_max # Message-passing aggregation + hidden_activation: SiLU # hidden activation function + weight: 0.25 # if focal loss, `alpha`. Otherwise, weight of positive samples + warmup: 10 # Start with small increasing learning rate for `warmup` epochs + lr: 0.0002 # initial learning rate + # Multiply the learning rate by ``factor`` every ``patience`` epochs + factor: 0.7 + patience: 8 + # Existing regimes + # - weighting: use `edge_weights` as weights + # - pid: any edge belonging to the same particle is considered as a true edge + # - triplet: use the loss for triplets, with penalty term + regime: [pid] + max_epochs: 50 # Number of training epochs + gradient_clip_val: 0.5 # Gradient clipping value. Avoid exploiding gradients. + focal_loss: false # Whether to use the focal loss + bidir: true # whether to use a bi-directional graph + +track_building: + score_cut: 0.7 + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/edges-slope.yaml b/LHCb_Pipeline/pipeline_configs/edges-slope.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e20862145aa55985947f006ca8469bd0863877b9 --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/edges-slope.yaml @@ -0,0 +1,96 @@ +common: + experiment_name: edges-slope + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: track_weighting_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 1500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 3 # just r, phi, z here + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 6 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + # weight: 2 + warmup: 10 + lr: 0.002 + factor: 0.7 + patience: 8 + truth_key: pid_signal + regime: [pid] + max_epochs: 200 + gradient_clip_val: 0.5 + +track_building: + score_cut: 0.95 + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-2gnns.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-2gnns.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2813b94c79171630cfca9c032afdcacf1f4f10a2 --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-2gnns.yaml @@ -0,0 +1,136 @@ +common: + experiment_name: focal-loss-pid-2gnns + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: triplets_first_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes", "nhits_velo"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 4 # just r, phi, z here + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + weight: 0.25 + warmup: 10 + lr: 0.0002 + factor: 0.7 + patience: 8 + regime: ["pid"] + max_epochs: 50 + gradient_clip_val: 0.5 + focal_loss: true + +edge_filtering: + score_cut: 0.2 + input_subdirectory: "gnn_processed" + output_subdirectory: "edge_filtering" + +gnn2: + # Dataset parameters + input_subdirectory: "edge_filtering" + output_subdirectory: "gnn2_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 4 # just r, phi, z here + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + weight: 0.25 + warmup: 10 + lr: 0.001 + factor: 0.7 + patience: 8 + regime: ["pid"] + max_epochs: 50 + gradient_clip_val: 0.5 + focal_loss: true + +triplet_building: + input_subdirectory: "gnn_processed" + output_subdirectory: "triplet_building" + +track_building: + score_cut: 0.7 + # input_subdirectory: "gnn_processed" + input_subdirectory: "gnn2_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-bidir.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-bidir.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a752568d5aef95e5d429cf3f2b0573353f255678 --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-bidir.yaml @@ -0,0 +1,102 @@ +common: + experiment_name: focal-loss-pid-bidir + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: triplets_first_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes", "nhits_velo"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: True + shuffle_edge_direction: True + + # Model parameters + feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + weight: 0.25 + warmup: 10 + lr: 0.0002 + factor: 0.7 + patience: 8 + regime: ["pid"] + max_epochs: 50 + gradient_clip_val: 0.5 + focal_loss: true + +triplet_building: + input_subdirectory: "gnn_processed" + output_subdirectory: "triplet_building" + +track_building: + score_cut: 0.75 + # input_subdirectory: "gnn_processed" + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed.yaml new file mode 100644 index 0000000000000000000000000000000000000000..20cdef02a0838282c026619c4c3dd50a89546aaf --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed.yaml @@ -0,0 +1,101 @@ +common: + experiment_name: focal-loss-pid-fixed + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: triplets_first_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes", "nhits_velo"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + weight: 0.25 + warmup: 10 + lr: 0.0002 + factor: 0.7 + patience: 8 + regime: ["pid"] + max_epochs: 50 + gradient_clip_val: 0.5 + focal_loss: true + +triplet_building: + input_subdirectory: "gnn_processed" + output_subdirectory: "triplet_building" + +track_building: + score_cut: 0.73 + # input_subdirectory: "gnn_processed" + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid.yaml new file mode 100644 index 0000000000000000000000000000000000000000..043a6e732f3616862f9196b495b01dbdaa157729 --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid.yaml @@ -0,0 +1,104 @@ +common: + experiment_name: focal-loss-pid + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: triplets_first_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes", "nhits_velo"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + # Building + building: null + filtering: "edges_at_least_3_hits" + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + weight: 0.25 + warmup: 10 + lr: 0.0002 + factor: 0.7 + patience: 8 + regime: ["pid"] + max_epochs: 31 + gradient_clip_val: 0.5 + focal_loss: true + +triplet_building: + input_subdirectory: "gnn_processed" + output_subdirectory: "triplet_building" + +track_building: + score_cut: 0.73 + # input_subdirectory: "gnn_processed" + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..311a3c5bcff25da64cec57158b61eea03e2658dc --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/focal-loss.yaml @@ -0,0 +1,100 @@ +common: + experiment_name: focal-loss + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: track_weighting_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 4 # just r, phi, z here + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + weight: 0.25 + warmup: 10 + lr: 0.0002 + factor: 0.7 + patience: 8 + regime: [] + max_epochs: 50 + gradient_clip_val: 0.5 + focal_loss: true + +# triplet_building: +# input_subdirectory: "gnn_processed" +# output_subdirectory: "triplet_processed" + +track_building: + score_cut: 0.5 + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/track-edges.yaml b/LHCb_Pipeline/pipeline_configs/track-edges.yaml new file mode 100644 index 0000000000000000000000000000000000000000..06db0d87a155a122b413c4057303def592e04a6a --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/track-edges.yaml @@ -0,0 +1,100 @@ +common: + experiment_name: track-edges + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: track_weighting_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 1500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", "z"] + kept_particles_columns: ["n_unique_planes"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 4 # just r, phi, z here + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + # weight: 2 + warmup: 10 + lr: 0.002 + factor: 0.7 + patience: 8 + truth_key: pid_signal + regime: [] + max_epochs: 200 + gradient_clip_val: 0.5 + +triplet_building: + input_subdirectory: "gnn_processed" + output_subdirectory: "triplet_processed" + +track_building: + score_cut: 0.95 + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/track-weighting.yaml b/LHCb_Pipeline/pipeline_configs/track-weighting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8cbae47f5795cc7a56afdf27402f8a1bcc08e8a --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/track-weighting.yaml @@ -0,0 +1,96 @@ +common: + experiment_name: track-weighting + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: track_weighting_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 1500 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", "z"] + kept_particles_columns: ["n_unique_planes"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + + +gnn: + # Dataset parameters + input_subdirectory: "metric_learning_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + + # Model parameters + feature_indices: 3 # just r, phi, z here + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 10 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + # weight: 2 + warmup: 10 + lr: 0.002 + factor: 0.7 + patience: 8 + truth_key: pid_signal + regime: [pid] + max_epochs: 200 + gradient_clip_val: 0.5 + +track_building: + score_cut: 0.95 + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_configs/triplets-first.yaml b/LHCb_Pipeline/pipeline_configs/triplets-first.yaml new file mode 100644 index 0000000000000000000000000000000000000000..242207b207c1d7cf7911a79e743b91985f2abda8 --- /dev/null +++ b/LHCb_Pipeline/pipeline_configs/triplets-first.yaml @@ -0,0 +1,102 @@ +common: + experiment_name: triplets-first + data_directory: /scratch/acorreia/data + artifact_directory: artifacts + performance_directory: output # plots and reports + gpus: 1 + test_dataset_names: + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + +preprocessing: + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover + subdirs: 10 + output_subdirectory: "preprocessed" + selection: triplets_first_selection + n_events: 11000 # if `null`, default to `n_train_events + n_test_events` + num_true_hits_threshold: 1000 + hits_particles_columns: ["x", "y", "z", "plane"] + particles_columns: null + +processing: + input_subdirectory: "preprocessed" + output_subdirectory: "processed" + n_workers: 32 + features: ["r", "phi", "z", "plane"] + feature_means: [18., 0.0, 281.0, 7.5] + feature_scales: [9.75, 1.82, 287.0, 12.5] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes"] + n_train_events: 10000 + n_val_events: 500 + split_seed: 0 + true_edges_column: planewise + +metric_learning: + # Dataset parameters + input_subdirectory: "processed" + output_subdirectory: "metric_learning_processed" + + # Model parameters + feature_indices: 4 + emb_hidden: 256 + nb_layer: 6 + emb_dim: 4 + activation: Tanh + weight: 2 + randomisation: 2 + points_per_batch: 100000 + r: 0.015 + r_inference: 0.020 + knn: 50 + warmup: 8 + margin: 0.1 + lr: 0.001 + factor: 0.7 + patience: 10 + regime: [rp, hnm, norm] + bidir: False + max_epochs: 20 + +triplet_building: + input_subdirectory: "metric_learning_processed" + output_subdirectory: "triplet_processed" + +gnn: + # Dataset parameters + input_subdirectory: "triplet_processed" + output_subdirectory: "gnn_processed" + edge_cut: 0.5 + noise: True + bidir: False + n_train_events: 5000 + + # Model parameters + feature_indices: 3 # just r, phi, z here + hidden: 256 + n_graph_iters: 8 + nb_node_layers: 6 + nb_node_encoder_layers: 6 + nb_edge_layers: 6 + nb_edge_encoder_layers: 6 + nb_edge_classifier_layers: 6 + layernorm: True + aggregation: sum_max + hidden_activation: SiLU + # weight: 2 + warmup: 10 + lr: 0.002 + factor: 0.7 + patience: 8 + truth_key: pid_signal + regime: [pid, triplet] + pos_penality: 0.01 + neg_penality: 0.01 + max_epochs: 50 + # gradient_clip_val: 0.5 + +track_building: + score_cut: 0.95 + input_subdirectory: "gnn_processed" + output_subdirectory: "track_building_processed" diff --git a/LHCb_Pipeline/pipeline_config.yaml b/LHCb_Pipeline/pipeline_configs/velo-sim10b-nospillover-lot.yaml similarity index 51% rename from LHCb_Pipeline/pipeline_config.yaml rename to LHCb_Pipeline/pipeline_configs/velo-sim10b-nospillover-lot.yaml index 76dbc0bed2273153e2b99ef13feed583cc88978c..d272fc50b8c1c4e913a529a68da16a3104fdd503 100644 --- a/LHCb_Pipeline/pipeline_config.yaml +++ b/LHCb_Pipeline/pipeline_configs/velo-sim10b-nospillover-lot.yaml @@ -1,59 +1,60 @@ common: - experiment_name: velo-minbias-sim10b-xdigi-nospillover - data_directory: scratch + experiment_name: velo-sim10b-nospillover-lot + data_directory: /scratch/acorreia/data artifact_directory: artifacts - performance_directory: scratch/output # plots and reports + performance_directory: output # plots and reports gpus: 1 test_dataset_names: - - velo-minbias-sim10b-xdigi-nospillover - - velo-minbias-sim10b-xdigi-nospillover-only-long-electrons - - velo-bu2kstee-sim10aU1-xdigi + - velo-sim10b-nospillover + - velo-sim10b-nospillover-only-long-electrons + # - bu2kstee-sim10aU1-xdigi + # - smog2-xdigi preprocessing: - input_dir: scratch/datasets/minbias-sim10b-xdigi-nospillover/0 + input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover/0 output_subdirectory: "preprocessed" - selection: everything_but_electrons + selection: default_old_training_for_rta_presentation n_events: null # if `null`, default to `n_train_events + n_test_events` - num_true_hits_threshold: 2500 + num_true_hits_threshold: 0 processing: input_subdirectory: "preprocessed" output_subdirectory: "processed" n_workers: 32 - features: ["r", "phi", "z", "plane"] - feature_means: [0., 0., 0., 0.] - feature_scales: [50, 3.14159, 200, 26] - kept_hits_columns: ["plane"] - kept_particles_columns: [] - n_train_events: 90 - n_val_events: 10 + features: ["r", "phi", "z"] + feature_means: [18., 0., 281.] + feature_scales: [9.75, 1.82, 287] + kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}] + kept_particles_columns: ["n_unique_planes"] + n_train_events: 1000 + n_val_events: 1000 split_seed: 0 + true_edges_column: sortedwise metric_learning: # Dataset parameters input_subdirectory: "processed" output_subdirectory: "metric_learning_processed" - true_edges_column: modulewise_true_edges # Model parameters - feature_indices: 4 + feature_indices: 3 emb_hidden: 256 nb_layer: 4 - emb_dim: 4 + emb_dim: 3 activation: Tanh weight: 2 randomisation: 2 points_per_batch: 100000 - r: 0.035 - r_inference: 0.035 + r: 0.015 + r_inference: 0.015 knn: 50 warmup: 8 margin: 0.1 lr: 0.001 factor: 0.7 - patience: 4 + patience: 10 regime: [rp, hnm, norm] - max_epochs: 20 + max_epochs: 40 gnn: # Dataset parameters @@ -63,11 +64,14 @@ gnn: noise: True # Model parameters - feature_indices: 4 # indices in `batch.x`. If `null`, everything is taken. - hidden: 512 - n_graph_iters: 8 - nb_node_layer: 3 - nb_edge_layer: 6 + feature_indices: 3 + hidden: 256 + n_graph_iters: 12 + nb_node_layers: 4 + nb_node_encoder_layers: 4 + nb_edge_layers: 4 + nb_edge_encoder_layers: 4 + nb_edge_classifier_layers: 4 layernorm: True aggregation: sum_max hidden_activation: SiLU @@ -78,10 +82,13 @@ gnn: patience: 8 truth_key: pid_signal regime: [pid] - mask_background: True - max_epochs: 40 + max_epochs: 200 track_building: score_cut: 0.9 input_subdirectory: "gnn_processed" output_subdirectory: "track_building_processed" + +triplet_building: + input_subdirectory: "gnn_processed" + output_subdirectory: "triplet_processed" diff --git a/LHCb_Pipeline/test_samples.yaml b/LHCb_Pipeline/test_samples.yaml index 7e085267dfe8fb54119b7e8606cc8e882f88bacd..b76aeb12dc4fea9c89b8f6bba0f3ef7dd6f9ec58 100644 --- a/LHCb_Pipeline/test_samples.yaml +++ b/LHCb_Pipeline/test_samples.yaml @@ -1,17 +1,33 @@ -velo-minbias-sim10b-xdigi-nospillover: - input_dir: scratch/datasets/minbias-sim10b-xdigi-nospillover/102 +velo-sim10b-nospillover: + input_dir: /scratch/acorreia/data_validation/minbias-sim10b-xdigi-nospillover/500 selection: null n_events: 1000 - num_true_hits_threshold: 0 + num_true_hits_threshold: null -velo-minbias-sim10b-xdigi-nospillover-only-long-electrons: - input_dir: scratch/datasets/minbias-sim10b-xdigi-nospillover/102 +velo-sim10b-nospillover-only-long-electrons: + input_dir: /scratch/acorreia/data_validation/minbias-sim10b-xdigi-nospillover/500 selection: only_long_electrons n_events: 1000 - num_true_hits_threshold: 0 + num_true_hits_threshold: null -velo-bu2kstee-sim10aU1-xdigi: - input_dir: scratch/datasets/bu2kstee-sim10aU1-xdigi/198 +bu2kstee-sim10aU1-xdigi: + input_dir: /scratch/acorreia/data_validation/bu2kstee-sim10aU1-xdigi/500 selection: null n_events: 1000 - num_true_hits_threshold: 0 \ No newline at end of file + num_true_hits_threshold: null + +smog2-xdigi: + input_dir: + - /scratch/acorreia/data_validation/smog2-digi/432 + - /scratch/acorreia/data_validation/smog2-digi/433 + - /scratch/acorreia/data_validation/smog2-digi/434 + - /scratch/acorreia/data_validation/smog2-digi/435 + - /scratch/acorreia/data_validation/smog2-digi/436 + - /scratch/acorreia/data_validation/smog2-digi/437 + - /scratch/acorreia/data_validation/smog2-digi/438 + - /scratch/acorreia/data_validation/smog2-digi/439 + - /scratch/acorreia/data_validation/smog2-digi/440 + - /scratch/acorreia/data_validation/smog2-digi/441 + n_events: -1 + num_true_hits_threshold: null + new_event_id_per_file: True diff --git a/LHCb_Pipeline/utils/commonutils/config.py b/LHCb_Pipeline/utils/commonutils/config.py index 1b984a1be53bf0ec44cb0f26e00db59602b4ece8..48a77e28217db76a5f0470c2596e77752d75ad21 100644 --- a/LHCb_Pipeline/utils/commonutils/config.py +++ b/LHCb_Pipeline/utils/commonutils/config.py @@ -3,15 +3,15 @@ import os.path as op import yaml -#: List of the steps in the right order. -STEPS = [ - "preprocessing", - "processing", - "metric_learning", - "gnn", - "track_building", - "evaluation", -] +# #: List of the steps in the right order. +# STEPS = [ +# "preprocessing", +# "processing", +# "metric_learning", +# "gnn", +# "track_building", +# "evaluation", +# ] def get_data_experiment_directory(path_or_config: str | dict) -> str: @@ -42,18 +42,17 @@ def resolve_config_paths( """ if data_experiment_dir is None: data_experiment_dir = get_data_experiment_directory(configs) - for step in STEPS: - if step in configs: - for inoutput in ["input", "output"]: - if f"{inoutput}_subdirectory" in configs[step]: - assert f"{inoutput}_dir" not in configs[step], ( - f"`{inoutput}_subdirectory` and `{inoutput}_dir` as both " - f"the configuration of {step}, which might create a clash." - ) - configs[step][f"{inoutput}_dir"] = op.join( - data_experiment_dir, - configs[step].pop(f"{inoutput}_subdirectory"), - ) + for step in configs: + for inoutput in ["input", "output"]: + if f"{inoutput}_subdirectory" in configs[step]: + assert f"{inoutput}_dir" not in configs[step], ( + f"`{inoutput}_subdirectory` and `{inoutput}_dir` as both " + f"the configuration of {step}, which might create a clash." + ) + configs[step][f"{inoutput}_dir"] = op.join( + data_experiment_dir, + configs[step].pop(f"{inoutput}_subdirectory"), + ) def load_config(path_or_config: str | dict, resolve: bool = True) -> dict: diff --git a/LHCb_Pipeline/utils/commonutils/crun.py b/LHCb_Pipeline/utils/commonutils/crun.py index 74d3bf31b656c7ca1b3d2eeebcb8b3c79e8e485c..672ee1e877cd110e613702fc164303beb567ca7c 100644 --- a/LHCb_Pipeline/utils/commonutils/crun.py +++ b/LHCb_Pipeline/utils/commonutils/crun.py @@ -1,5 +1,6 @@ import typing import os +import logging class InOutFunction(typing.Protocol): @@ -16,6 +17,7 @@ def run_for_different_partitions( partitions: typing.List[str] = ["train", "val", "test"], test_dataset_names: typing.List[str] | None = None, reproduce: bool = True, + list_kwargs: typing.List[dict] | None = None, **kwargs, ): """Run a function for different dataset "partitions". @@ -37,12 +39,22 @@ def run_for_different_partitions( directory. **kwargs: keyword arguments passed to ``func`` """ - for partition in partitions: + for partition_idx, partition in enumerate(partitions): + if list_kwargs is None: + supplementary_kwargs = {} + else: + supplementary_kwargs = list_kwargs[partition_idx] + logging.info( + f"Use the following parameters for {partition}: {supplementary_kwargs}" + ) + if partition in ["train", "val"]: func( input_dir=os.path.join(input_dir, partition), output_dir=os.path.join(output_dir, partition), reproduce=reproduce, + **supplementary_kwargs, + **kwargs, ) elif partition == "test": @@ -55,6 +67,7 @@ def run_for_different_partitions( input_dir=os.path.join(input_dir, "test", test_dataset_name), output_dir=os.path.join(output_dir, "test", test_dataset_name), reproduce=reproduce, + **supplementary_kwargs, **kwargs, ) elif (test_dataset_names is not None) and (partition in test_dataset_names): @@ -62,6 +75,7 @@ def run_for_different_partitions( input_dir=os.path.join(input_dir, "test", partition), output_dir=os.path.join(output_dir, "test", partition), reproduce=reproduce, + **supplementary_kwargs, **kwargs, ) else: diff --git a/LHCb_Pipeline/utils/graphutils/__init__.py b/LHCb_Pipeline/utils/graphutils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1811ab81152ec0f34dd59cb94647359b037eb0 --- /dev/null +++ b/LHCb_Pipeline/utils/graphutils/__init__.py @@ -0,0 +1,2 @@ +"""A package that defines common utilies to handle graphs in PyTorch geometric. +""" diff --git a/LHCb_Pipeline/utils/graphutils/edgeutils.py b/LHCb_Pipeline/utils/graphutils/edgeutils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a91c70e53571c9c396a29f703eb3ec5d4b06bb5 --- /dev/null +++ b/LHCb_Pipeline/utils/graphutils/edgeutils.py @@ -0,0 +1,15 @@ +"""A module that defines utilities to handle edges exclusively. +""" +import torch + + +def sort_edge_nodes(edges: torch.Tensor, ordering_tensor: torch.Tensor) -> None: + """Sort the nodes of the edges in ascending value of a certain tensor + + Args: + edges: Two-dimensional array of edges, with shape :math:`\\left(2, n_edges)`` + ordering_tensor: Tensor of values for the nodes. The first node of an edge + is required to have a lower value that the second node. + """ + not_correctly_ordered_mask = ordering_tensor[edges[0]] > ordering_tensor[edges[1]] + edges[:, not_correctly_ordered_mask] = edges[:, not_correctly_ordered_mask].flip(0) diff --git a/LHCb_Pipeline/utils/modelutils/basemodel.py b/LHCb_Pipeline/utils/modelutils/basemodel.py index ed7770272216fb66ba0b0a0faeeb13a4c9c3270b..027e39acb8eac64687029c1a1ab030ab12926894 100644 --- a/LHCb_Pipeline/utils/modelutils/basemodel.py +++ b/LHCb_Pipeline/utils/modelutils/basemodel.py @@ -1,9 +1,13 @@ """Define a base model for GNN and Embedding, to avoid copy of functions. """ +from __future__ import annotations import typing +import logging import os import os.path as op +from tqdm.auto import tqdm +import numpy as np import torch from pytorch_lightning import LightningModule from torch_geometric.data import Data @@ -15,14 +19,38 @@ from utils.commonutils.cfeatures import get_input_features class ModelBase(LightningModule): def __init__(self, hparams): super().__init__() - + self._trainset = None + self._valset = None + self.testset = None self.save_hyperparameters(hparams) def setup(self, stage): - self.trainset = self.load_datasets(op.join(self.hparams["input_dir"], "train")) - self.valset = self.load_datasets(op.join(self.hparams["input_dir"], "val")) + self.load_partition("train") + self.load_partition("val") self.testset = None + @property + def trainset(self) -> typing.List[Data]: + if self._trainset is None: + self.load_partition(partition="train") + assert self._trainset is not None + return self._trainset + + @trainset.setter + def trainset(self, batches: typing.List[Data]): + self._trainset = batches + + @property + def valset(self) -> typing.List[Data]: + if self._valset is None: + self.load_partition(partition="val") + assert self._valset is not None + return self._valset + + @valset.setter + def valset(self, batches: typing.List[Data]): + self._valset = batches + def train_dataloader(self): if len(self.trainset) > 0: return DataLoader(self.trainset, batch_size=1, num_workers=16) @@ -41,25 +69,58 @@ class ModelBase(LightningModule): else: return None - def load_datasets(self, input_dir: str) -> typing.List[Data]: - """Load""" + def fetch_datasets( + self, + input_dir: str, + n_events: int | None = None, + shuffle: bool = False, + seed: int | None = None, + **kwargs, + ) -> typing.List[Data]: + """Get the datasets located in a given directory. + + Args: + input_dir: input directory + n_events: number of events to load + shuffle: whether to shuffle the input paths (applied before + selected the first ``n_events``) + seed: seed for the shuffling + **kwargs: Other keyword arguments passed to + :py:func:`ModelBase.fetch_dataset` + + Returns: + List of loaded PyTorch Geometric Data objects + """ all_input_paths = [ entry.path for entry in os.scandir(input_dir) if entry.is_file() ] + if shuffle: + rng = np.random.default_rng(seed=seed) + rng.shuffle(all_input_paths) + + if n_events is not None: + all_input_paths = all_input_paths[:n_events] + + logging.info(f"Load {len(all_input_paths)} files located in {input_dir}") return [ - self.load_dataset(input_path=input_path) for input_path in all_input_paths + self.fetch_dataset(input_path=input_path, **kwargs) + for input_path in tqdm(all_input_paths) ] - def load_dataset(self, input_path: str) -> Data: + def fetch_dataset( + self, input_path: str, map_location: str = "cpu", **kwargs + ) -> Data: """Load and process one PyTorch DataSet. Args: input_path: path to the PyTorch dataset + map_location: location where to load the dataset + **kwargs: Other keyword arguments passed to :py:func:`torch.load` Returns: Load PyTorch data object """ - return torch.load(input_path, map_location=torch.device("cpu")) + return torch.load(input_path, map_location=map_location, **kwargs) def load_testset_from_directory(self, input_dir: str): """Load a test dataset from a path to a directory. @@ -68,17 +129,79 @@ class ModelBase(LightningModule): input_dir: path to the directory that contains the PyTorch Geometric Data pickles files. """ - self.testset = self.load_datasets(input_dir=input_dir) + self.testset = self.fetch_datasets(input_dir=input_dir) - def load_testset(self, test_dataset_name: str): - """Load the test dataset into this model from its name. + def fetch_partition( + self, + partition: str, + n_events: int | None = None, + shuffle: bool = False, + seed: int | None = None, + **kwargs, + ) -> typing.List[Data]: + """Load a partition. Args: - test_dataset_name: name of the test dataset + partition: ``train``, ``val`` or name of the test dataset + n_events: number of events to load for this partition + shuffle: whether to shuffle the input paths (applied before + selected the first ``n_events``) + seed: seed for the shuffling + **kwargs: Other keyword arguments passed to + :py:func:`ModelBase.fetch_dataset` """ - self.load_testset_from_directory( - input_dir=op.join(self.hparams["input_dir"], "test", test_dataset_name) + if partition in ["train", "val"]: + datasets = self.fetch_datasets( + op.join(self.hparams["input_dir"], partition), + n_events=( + self.hparams.get(f"n_{partition}_events") + if n_events is None + else n_events + ), + shuffle=shuffle, + seed=seed, + **kwargs, + ) + + else: + datasets = self.fetch_datasets( + input_dir=op.join(self.hparams["input_dir"], "test", partition), + n_events=n_events, + shuffle=shuffle, + seed=seed, + **kwargs, + ) + + return datasets + + def load_partition( + self, + partition: str, + n_events: int | None = None, + shuffle: bool = False, + seed: int | None = None, + ) -> typing.List[Data]: + """Load datasets of a partition. + + Args: + partition: ``train``, ``val`` or name of the test dataset + n_events: number of events to load for this partition + shuffle: whether to shuffle the input paths (applied before + selected the first ``n_events``) + seed: seed for the shuffling + """ + datasets = self.fetch_partition( + partition=partition, + n_events=n_events, + shuffle=shuffle, + seed=seed, ) + if partition == "train": + self._trainset = datasets + elif partition == "val": + self._valset = datasets + else: + self.testset = datasets def get_input_data(self, batch: Data) -> torch.Tensor: return get_input_features( @@ -107,3 +230,46 @@ class ModelBase(LightningModule): } ] return optimizer, scheduler + + @classmethod + def get_model_from_checkpoint( + cls, + checkpoint: LightningModule | str | None, + default_checkpoint: str | None = None, + **kwargs, + ): + """Helper function to get a model at inference step. + + Args: + checkpoint: the model already loaded, or path to it + Mode: Model class + default_checkpoint: path to fall back to if ``checkpoint`` is None. + **kwargs: other parameters passed to :py:func:`Model.load_from_checkpoint` + + Return: + Loaded model + """ + if isinstance(checkpoint, cls): + model = checkpoint + elif checkpoint is None: # Default loading mode from last artifact + assert ( + default_checkpoint is not None + ), "Both `checkpoint` and `default_checkpoint` are None." + checkpoint = default_checkpoint + model = cls.load_from_checkpoint( + default_checkpoint, + **kwargs, + ) + logging.info(f"Load model from {checkpoint}.") + elif isinstance(checkpoint, str): + model = cls.load_from_checkpoint( + checkpoint_path=checkpoint, + **kwargs, + ) + logging.info(f"Load model from {checkpoint}.") + else: + raise TypeError( + f"Type of checkpoint is {type(checkpoint).__name__} " + "which is not recognised" + ) + return model diff --git a/LHCb_Pipeline/utils/modelutils/batches.py b/LHCb_Pipeline/utils/modelutils/batches.py new file mode 100644 index 0000000000000000000000000000000000000000..8c43e7cca4b46f7e3f08f88e57fef049b2cfc613 --- /dev/null +++ b/LHCb_Pipeline/utils/modelutils/batches.py @@ -0,0 +1,62 @@ +"""A module used to handle list of batches stored in model. +""" +from __future__ import annotations +import typing +import numpy as np +from torch_geometric.data import Data +from .basemodel import ModelBase + + +def get_batches(model: ModelBase, partition: str) -> typing.List[Data]: + """Get the list batches for the given model. + + Args: + model: PyTorch model inheriting from :py:class:`ModelBase` + partition: ``train``, ``val``, ``test`` (for the current already loaded + test sample) or the name of a test dataset + + Returns: + List of PyTorch Geometric data objects + + Notes: + The input directories are saved as hyperparameters in the model. This is why + it is possible to get the data input directories from a model. + """ + # Use correct batches + if partition == "train": + batches = model.trainset + elif partition == "val": + batches = model.valset + elif partition == "test": + batches = model.testset + else: + model.load_testset(test_dataset_name=partition) + batches = model.testset + + assert ( + batches is not None + ), "Error, list of batches is `None`: no batches were loaded" + return batches + + +def select_subset( + batches: typing.List[Data], n_events: int | None = None, seed: int | None = None +) -> typing.List[Data]: + """Randomly select a subset of batches. + + Args: + batches: overall list of batches + n_events: Maximal number of events to select + seed: Seed for reproducible randomness + + Returns: + List of PyTorch Data objects + """ + if n_events is not None: + n_events = int(n_events) + if n_events < len(batches): + # Randomly select a subset of ``n_events`` events + rng = np.random.default_rng(seed=seed) + indices = rng.choice(len(batches), n_events, replace=False) + batches = [batches[idx] for idx in indices] + return batches diff --git a/LHCb_Pipeline/utils/modelutils/build.py b/LHCb_Pipeline/utils/modelutils/build.py index 710ea92825ed1a01627753ab32af0e69173df083..6fc1d779bfd009e6b6216b56a7ef2537976b9851 100644 --- a/LHCb_Pipeline/utils/modelutils/build.py +++ b/LHCb_Pipeline/utils/modelutils/build.py @@ -1,45 +1,154 @@ """Define the base class to infer on data. """ import typing +from types import ModuleType import abc import os import logging +from functools import partial from tqdm.auto import tqdm +from tqdm.contrib.concurrent import process_map import torch from pytorch_lightning import LightningModule from torch_geometric.data import Data -from utils.tools.tfiles import delete_directory +from utils.tools.tfiles import delete_directory, is_directory_not_empty class BuilderBase(abc.ABC): + """Base class for looping over input files located in a directory, processing + them and saving the output in a different directory. + """ + def __init__(self) -> None: pass - def infer(self, input_dir: str, output_dir: str, reproduce: bool = True): + def infer( + self, + input_dir: str, + output_dir: str, + reproduce: bool = True, + filtering: str | None = None, + building: str | None = None, + file_names: typing.List[str] | None = None, + parallel: bool = False, + ): """Load the torch datasets located in ``input_dir``, run the model inference and save the output in ``output_dir``. + + Args: + input_dir: input directory path + output_dir: output directory path + reproduce: whether to delete the output directory if it exists, + and run again the inference + filtering: name of the function that filters the event. This would only + be applied to the train and val sets. + building: name of the function that compute columns for the event. This + would be applied to all the samples (train, val and test samples). + file_names: list of file names to run the inference on. If not specified, + the inference is run on all the datasets located in the input directory. + parallel: + Whether to run the inference in parallel. This seems quite unstable... """ # List paths to the input files - file_names = os.listdir(input_dir) + if file_names is None: + file_names = os.listdir(input_dir) assert len(file_names) > 0, f"No input files in {input_dir}" if reproduce: delete_directory(output_dir) os.makedirs(output_dir, exist_ok=True) - logging.info(f"Inference from {input_dir} to {output_dir}") + if is_directory_not_empty(output_dir): + logging.info( + f"Output folder is not empty so the inference was not run: {output_dir}" + ) + else: + logging.info(f"Inference from {input_dir} to {output_dir}") + + with torch.no_grad(): + infer_one_step_partial = partial( + self.infer_one_step, + input_dir=input_dir, + output_dir=output_dir, + building=building, + filtering=filtering, + ) + if parallel: + process_map(infer_one_step_partial, file_names, chunksize=1) + else: + for file_name in tqdm(file_names): + infer_one_step_partial(file_name=file_name) + + def infer_one_step( + self, + file_name: str, + input_dir: str, + output_dir: str, + filtering: str | typing.List[str] | None = None, + building: str | typing.List[str] | None = None, + ): + """Run the inference on a single file and save the output in another file. + + Args: + file_name: input file name + input_dir: input directory path + output_dir: output directory path + filtering: name of the function that filters the event. This would only + be applied to the train and val sets. + building: name of the function that compute columns for the event. This + would be applied to all the samples (train, val and test samples). + """ + input_path = os.path.join(input_dir, file_name) + if not os.path.exists(os.path.join(output_dir, file_name)): + batch = self.load_batch(input_path) + batch = self.process_one_step( + batch=batch, + filtering=filtering, + building=building, + ) + self.save_downstream(batch, os.path.join(output_dir, batch.event_str)) - with torch.no_grad(): - for file_name in tqdm(file_names): - input_path = os.path.join(input_dir, file_name) - if not os.path.exists(os.path.join(output_dir, file_name)): - batch = self.load_batch(input_path) - self.construct_downstream(batch) - self.save_downstream( - batch, os.path.join(output_dir, batch.event_str) + def process_one_step( + self, + batch: Data, + filtering: str | typing.List[str] | None = None, + building: str | typing.List[str] | None = None, + ) -> Data: + """Process one event. + + Args: + batch: event stored in a PyTorch Geometric data object + filtering: name of the function that filters the event. This would only + be applied to the train and val sets. + building: name of the function that compute columns for the event. This + would be applied to all the samples (train, val and test samples). + + Returns: + Processed event, first by :py:func:`BuilderBase.construct_downstream`, + then by the filtering and building functions provided as inputs. + """ + batch = self.construct_downstream(batch) + + # Apply filtering and building + for processing_step in [filtering, building]: + if processing_step is not None: + processing_fct_names = ( + [processing_step] + if isinstance(processing_step, str) + else processing_step + ) + for processing_fct_name in processing_fct_names: + processing_fct = getattr( + self._get_building_custom_module(), str(processing_fct_name) ) + batch = processing_fct(batch) + return batch + + def _get_building_custom_module(self) -> ModuleType: + """Return the module where the building and filtering functions are.""" + raise NotImplementedError() def load_batch(self, input_path: str) -> Data: """Load a PyTorch Data object from its path. @@ -47,6 +156,33 @@ class BuilderBase(abc.ABC): """ return torch.load(input_path, map_location=torch.device("cpu")) + def filter_batch(self, batch: Data) -> Data: + """Filter the batch. This should only performed in the train and val + sets. + + Args: + batch: PyTorch Data Geometric object + + Returns: + filtered batch + """ + return batch + + def build_weights(self, batch: Data) -> Data: + """Builder weights in the batch for training. + This should only be needed in the train and val sets. + + Args: + batch: PyTorch Data Geometric object + + Returns: + filtered batch + """ + return batch + + def build_features(self, batch: Data) -> Data: + return batch + @abc.abstractmethod def construct_downstream(self, batch: Data): """Run the inference on a PyTorch Data. In-place.""" @@ -59,6 +195,8 @@ class BuilderBase(abc.ABC): class ModelBuilderBase(BuilderBase): + """Base class for model inference.""" + def __init__(self, model: LightningModule) -> None: self.model = model model.eval() diff --git a/LHCb_Pipeline/utils/modelutils/checkpoint_utils.py b/LHCb_Pipeline/utils/modelutils/checkpoint_utils.py index b8b442a713ed51adb6fadc453d3d607e3b69a86c..2ccfabfa087ad569e0941f0ba83615056f1a8c95 100644 --- a/LHCb_Pipeline/utils/modelutils/checkpoint_utils.py +++ b/LHCb_Pipeline/utils/modelutils/checkpoint_utils.py @@ -30,7 +30,7 @@ def get_last_version_dir(experiment_dir: str) -> str: for version_folder_path in version_folder_paths ] if not available_versions: - raise ValueError(f"No version with `metrics.csv` found in {experiment_dir}.") + raise ValueError(f"No version with `metrics.csv` found in {experiment_dir}") last_version = sorted(available_versions)[-1] return os.path.join(os.path.join(experiment_dir, f"version_{last_version}")) @@ -114,13 +114,12 @@ def get_last_version_dir_from_config( return get_last_version_dir(experiment_dir=experiment_dir) - -def get_training_metrics(trainer: Trainer | str) -> pd.DataFrame: +def get_training_metrics(trainer: Trainer | str | typing.List[str]) -> pd.DataFrame: """Get the dataframe of the training metrics. Args: - trainer: either a PyTorch Lighting Trainer object, or the path to the metric - file to load directly. + trainer: either a PyTorch Lighting Trainer object, or the path(s) to the metric + file(s) to load directly. Returns: Dataframe of the training metrics (one row / epoch). @@ -133,17 +132,34 @@ def get_training_metrics(trainer: Trainer | str) -> pd.DataFrame: log_file = os.path.join(log_dir, "metrics.csv") elif isinstance(trainer, str): log_file = trainer + elif isinstance(trainer, (list, tuple)): + return pd.concat( + (get_training_metrics(trainer=log_file) for log_file in trainer), + axis=0, + ) else: raise TypeError( - f"`trainer` should be str or a pytorch trainer, but is " + f"`trainer` should be str, a list of str or a pytorch trainer, but is " + type(trainer).__name__ ) metrics = pd.read_csv(log_file, sep=",") - train_metrics = metrics[~metrics["train_loss"].isna()][["epoch", "train_loss"]] - train_metrics["epoch"] -= 1 - val_metrics = metrics[~metrics["val_loss"].isna()][ - ["val_loss", "eff", "pur", "current_lr", "epoch"] + + train_loss_column = "train_loss" if "train_loss" in metrics else "train_loss_epoch" + val_loss_column = "val_loss" if "val_loss" in metrics else "val_loss_epoch" + + train_metrics = metrics[~metrics[train_loss_column].isna()][ + ["epoch", train_loss_column] + ] + # train_metrics["epoch"] -= 1 + val_metrics = metrics[~metrics[val_loss_column].isna()][ + [ + column if column in metrics else column + "_epoch" + for column in ["val_loss", "eff", "pur", "current_lr", "epoch"] + ] ] metrics = pd.merge(left=train_metrics, right=val_metrics, how="inner", on="epoch") + for column in metrics.columns: + if column.endswith("_epoch"): + metrics.rename(columns={column: column[: -len("_epoch")]}, inplace=True) return metrics diff --git a/LHCb_Pipeline/utils/modelutils/evaluation.py b/LHCb_Pipeline/utils/modelutils/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..a9445a06c96ceb2d57084cf3c0e163f8d42215d9 --- /dev/null +++ b/LHCb_Pipeline/utils/modelutils/evaluation.py @@ -0,0 +1,246 @@ +"""A module that defines :py:class:`ParamExplorer`, a class that allows to vary +a parameter and check the efficiency that is obtained for this choice. +""" +import typing +import abc +import os.path as op + +from tqdm.auto import tqdm +import numpy as np +import numpy.typing as npt +import pandas as pd +from torch_geometric.data import Data +import montetracko as mt +import montetracko.lhcb as mtb +import matplotlib.pyplot as plt +from matplotlib.figure import Figure + +from Scripts.Step_6_Evaluate_Reconstruction_MonteTracko import ( + load_parquet_files, + perform_matching, +) +from utils.plotutils.plotools import save_fig +from utils.commonutils.cpaths import get_performance_directory +from .basemodel import ModelBase + + +class ParamExplorer(abc.ABC): + """A class that allow to explore the track matching performance for various choices + of a given parameter of a trained model (e.g., best efficiency as a function + of the maximal radius of the kNN) + """ + + def __init__( + self, model: ModelBase, varname: str, varlabel: str | None = None + ) -> None: + self.model = model + self.varname = str(varname) + self.varlabel = str(varlabel) if varlabel is not None else self.varname + + def load_preprocessed_dataframes( + self, + batches: typing.List[Data], + ) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: + """Load the preprocessed dataframes of hits-particles and particles associated + with the PyTorch DataSets given as input. + + Args: + batches: list of PyTorch Geometric Data objects + + Returns: + Tuple of dataframes of hits-particles and particles + """ + truncated_paths = [batch.truncated_path for batch in batches] + df_hits_particles = load_parquet_files( + truncated_paths=truncated_paths, + ending="-hits_particles", + columns=["particle_id", "hit_id"], + ) + df_particles = load_parquet_files( + truncated_paths=truncated_paths, + ending="-particles", + columns=["particle_id", "has_velo", "has_scifi", "pid", "eta"], + ) + return df_hits_particles, df_particles + + def compute_performance_metrics( + self, + values: typing.Sequence[float], + partition: str, + metric_names: typing.List[str], + categories: typing.List[mt.requirement.Category], + n_events: int | None = None, + seed: int | None = None, + **kwargs, + ) -> typing.Dict[float, typing.Dict[typing.Tuple[str, str], float]]: + """Compute the performance metrics for different values a hyperparameter. + + Args: + values: list of values for the hyperparameter of interest + partition: ``train``, ``val`` or the name of a test dataset + n_events: Maximal number of events for the evaluation + seed: Random seed for randomly selecting ``n_events`` + metric_names: List of metric names to compute + categories: list of categories to compute the performance in. + + Returns: + 3-tuple of the Matplotlib Figure and Axes, and the dictionary of + metric values for every tuple ``(value, category.name, metric_name)`` + """ + # Load PyTorch Geometric Data objects + batches = self.model.fetch_partition( + partition=partition, + n_events=n_events, + shuffle=True, + seed=seed, + map_location=self.model.device, + ) + + # Move batches to save device as model + # batches = [batch.to(model.device) for batch in batches] # type: ignore + + # Load associated pre-processed files that contains information + # used for matching + df_hits_particles, df_particles = self.load_preprocessed_dataframes( + batches=batches + ) + + dict_performance = {} + for value in (pbar := tqdm(values)): + pbar.set_description(f"Loop over {self.varname} (current value: {value})") + df_tracks = self.get_tracks(value=value, batches=batches, **kwargs) + dict_performance[value] = self.get_performance_from_tracks( + df_tracks=df_tracks, + df_hits_particles=df_hits_particles, + df_particles=df_particles, + metric_names=metric_names, + categories=categories, + ) + return dict_performance + + def get_performance_from_tracks( + self, + df_tracks: pd.DataFrame, + df_hits_particles: pd.DataFrame, + df_particles: pd.DataFrame, + metric_names: typing.List[str], + categories: typing.List[mt.requirement.Category], + ) -> typing.Dict[typing.Tuple[str, str], float]: + """Get performance dictionary for given tracks. + + Args: + df_tracks: dataframe of tracks + df_hits_particles: dataframe of hits-particles + df_particles: dataframe of particles + metric_names: List of metric names to compute + categories: list of categories to compute the performance in. + + Returns: + Dictionary that associates the 2-tuple ``(category.name, metric_name)`` + with the metric value for the given category + """ + trackEvaluator = perform_matching( + df_tracks=df_tracks, + df_hits_particles=df_hits_particles, + df_particles=df_particles, + ) + return { + (category.name, metric_name): trackEvaluator.compute_metric( + metric_name=metric_name, + category=category, + ) + for category in categories + for metric_name in metric_names + } + + @abc.abstractmethod + def get_tracks( + self, value: float, batches: typing.List[Data], **kwargs + ) -> pd.DataFrame: + """Get the dataframe of tracks""" + raise NotImplementedError() + + def plot( + self, + path_or_config: str | dict, + partition: str, + values: typing.Sequence[float], + n_events: int | None = None, + seed: int | None = None, + metric_names: typing.List[str] | None = None, + categories: typing.List[mt.requirement.Category] | None = None, + identifier: str | None = None, + **kwargs, + ) -> typing.Tuple[Figure, npt.NDArray, typing.Dict[str, typing.Dict[str, float]]]: + """Plot metrics in differences categories for different hyperparameter + values. + + Args: + path_or_config: pipeline configuration + partition: ``train``, ``val`` or the name of a test dataset + values: list of values for the hyperparameter of interest + n_events: Maximal number of events for the evaluation + seed: Random seed for randomly selecting ``n_events`` + metric_names: List of metric names to compute. If not set, + ``efficiency``, ``clone_rate`` and ``hit_efficiency_per_candidate`` + are computed and plotted. + categories: list of categories to compute the performance in. + By default, this is "Velo Without Electrons" and "Long Electrons". + identifier: Identifier for the figure name + **kwargs: Other keyword arguments passed to + :py:func:`ParamExplorer.compute_performance_metrics` + + Returns: + 3-tuple of the Matplotlib Figure and Axes, and the dictionary of + metric values for every tuple ``(value, category.name, metric_name)`` + """ + if metric_names is None: + metric_names = ["efficiency", "clone_rate", "hit_efficiency_per_candidate"] + if categories is None: + categories = [ + mtb.category.category_velo_no_electrons, + mtb.category.category_long_only_electrons, + ] + if identifier is None: + identifier = "" + + dict_performances = self.compute_performance_metrics( + values=values, + partition=partition, + metric_names=metric_names, + categories=categories, + n_events=n_events, + seed=seed, + **kwargs, + ) + + fig, axes = plt.subplots( + 1, len(metric_names), figsize=(8 * len(metric_names), 6) + ) + axes = np.atleast_1d(axes) + + for metric_idx, metric_name in enumerate(metric_names): + axes[metric_idx].set_xlabel(self.varlabel) + axes[metric_idx].set_ylabel(mt.metricsLibrary.label(metric_name)) + axes[metric_idx].grid(color="grey", alpha=0.5) + for category in categories: + metric_values = [ + dict_performances[value][category.name, metric_name] + for value in values + ] + + axes[metric_idx].plot( + values, metric_values, label=category.label, marker="." + ) + + axes[0].legend() + + performance_dir = get_performance_directory(path_or_config) + save_fig( + fig=fig, + path=op.join( + performance_dir, f"performance_given_{self.varname}{identifier}" + ), + ) + + return (fig, axes, dict_performances) diff --git a/LHCb_Pipeline/utils/modelutils/mpl.py b/LHCb_Pipeline/utils/modelutils/mlp.py similarity index 100% rename from LHCb_Pipeline/utils/modelutils/mpl.py rename to LHCb_Pipeline/utils/modelutils/mlp.py diff --git a/LHCb_Pipeline/utils/plotutils/graph.py b/LHCb_Pipeline/utils/plotutils/graph.py index e89a035cd059bdde9073ba7265503e35897d6ca3..64e1353b1d27879320ca4d64aadf55938f235d2f 100644 --- a/LHCb_Pipeline/utils/plotutils/graph.py +++ b/LHCb_Pipeline/utils/plotutils/graph.py @@ -36,8 +36,9 @@ def plot_true_graph(sample_data, path_or_config: str | dict, num_tracks=100): path_or_config=path_or_config, feature_names=["r", "phi"], ) + x = r * np.cos(phi) + y = r * np.sin(phi) - x, y = r * np.cos(phi * np.pi), r * np.sin(phi * np.pi) cmap = viridis(num_tracks) source = ColumnDataSource(dict(x=x.numpy(), y=y.numpy())) p.circle(x="x", y="y", source=source, color=cmap[0], size=1, alpha=0.1) @@ -93,7 +94,8 @@ def plot_predicted_graph(model, path_or_config: str | dict): path_or_config=path_or_config, feature_names=["r", "phi"], ) - x, y = r * np.cos(phi * np.pi), r * np.sin(phi * np.pi) + x = r * np.cos(phi) + y = r * np.sin(phi) cmap = viridis(11) source = ColumnDataSource(dict(x=x.numpy(), y=y.numpy())) p.circle(x="x", y="y", source=source, color=cmap[0], size=1, alpha=0.1) diff --git a/LHCb_Pipeline/utils/plotutils/performance.py b/LHCb_Pipeline/utils/plotutils/performance.py index a5e099e39de14200936d4446c45376609ee661bd..20c6d54e63a907441ec4e1a873019527695d7e0c 100644 --- a/LHCb_Pipeline/utils/plotutils/performance.py +++ b/LHCb_Pipeline/utils/plotutils/performance.py @@ -216,7 +216,11 @@ def plot_graph_sizes(model): plt.xlabel("Number of edges") -def plot_edge_performance(model, path_or_config: str | dict): +def plot_edge_performance( + model, path_or_config: str | dict, identifier: str | None = None +): + if identifier is None: + identifier = "" all_cuts = np.arange(0.001, 1.0, 0.02) results = {"eff": [], "pur": [], "score cut": all_cuts} model.to(device) @@ -276,7 +280,7 @@ def plot_edge_performance(model, path_or_config: str | dict): # show(row(figures)) filename = op.join( - get_performance_directory(path_or_config), "edge_performance.png" + get_performance_directory(path_or_config), f"edge_performance{identifier}.png" ) export_png(row(figures), filename=filename) display(Image(filename=filename)) diff --git a/LHCb_Pipeline/utils/plotutils/performance_mpl.py b/LHCb_Pipeline/utils/plotutils/performance_mpl.py new file mode 100644 index 0000000000000000000000000000000000000000..b630a9a8c26bc59fa917348dd1600fd2d81a8ede --- /dev/null +++ b/LHCb_Pipeline/utils/plotutils/performance_mpl.py @@ -0,0 +1,113 @@ +"""Plot general performance metrics using matplotlib only. +""" +import typing +import os.path as op + +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from matplotlib.ticker import MaxNLocator +import pandas as pd + +from utils.commonutils.cpaths import get_performance_directory +from .plotools import save_fig +from .plotconfig import partition_to_color, partition_to_label + + +def plot_metric_epochs( + metric_name: str, + metrics: pd.DataFrame, + path_or_config: str | dict | None = None, + name: str | None = None, + metric_label: str | None = None, + ax: Axes | None = None, + marker: str = ".", + **kwargs, +) -> typing.Tuple[Figure | None, Axes]: + """Plot a metric as a function of the epoch number + + Args: + metric_name: name of the metric to plot in the dataframe of ``metrics`` + metrics: dataframe of metric values computed during training. It must contain + the two columns ``train_loss`` and ``val_loss`` + name: Name of the step (e.g., ``gnn``, ``embedding``). If not given, + the plot is not saved. + metric_label: Label of the metric. Used in the y-axis + marker: Marker format used in the plot + ax: Matplotlib Axes to plot on. + **kwargs: Other arguments passed to :py:func:`matplotlib.axes.Axes.plot` + + Returns: + Figure and Axes of the plot + """ + + if ax is None: + fig, ax_ = plt.subplots(figsize=(8, 6)) + else: + fig = None + ax_ = ax + + ax_.plot( + metrics["epoch"], + metrics[metric_name], + marker=marker, + **kwargs, + ) + ax_.set_xlabel("Epoch") + ax_.set_ylabel(metric_name if metric_label is None else metric_label) + ax_.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax_.grid(color="grey", alpha=0.5) + + if name is not None and fig is not None: + assert path_or_config is not None + save_fig( + fig=fig, + path=op.join( + get_performance_directory(path_or_config), + f"{metric_name}_{name}", + ), + ) + + return fig, ax_ + + +def plot_loss( + metrics, path_or_config: str | dict, name: str | None = None +) -> typing.Tuple[Figure, Axes]: + """Plot the training and validation loss on the same plot. + + Args: + metrics: dataframe of metric values computed during training. It must contain + the two columns ``train_loss`` and ``val_loss`` + name: Name of the step (e.g., ``gnn``, ``embedding``). If not given, + the plot is not saved. + + Returns: + Figure and Axes of the plot + """ + fig, ax = plt.subplots(figsize=(8, 6)) + + for partition in ["train", "val"]: + plot_metric_epochs( + metric_name=f"{partition}_loss", + metrics=metrics, + color=partition_to_color[partition], + label=partition_to_label[partition], + ax=ax, + ) + + ax.set_xlabel("Epoch") + ax.set_ylabel("Loss") + ax.legend() + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.grid(color="grey", alpha=0.5) + + if name is not None: + save_fig( + fig=fig, + path=op.join( + get_performance_directory(path_or_config), + f"loss_{name}", + ), + ) + return fig, ax diff --git a/LHCb_Pipeline/utils/plotutils/plotconfig.py b/LHCb_Pipeline/utils/plotutils/plotconfig.py index 39438b4b08812c18c35161d7c0c85625e26a5700..42ed7d56b935c1cf59eb82a85963557ac119bc28 100644 --- a/LHCb_Pipeline/utils/plotutils/plotconfig.py +++ b/LHCb_Pipeline/utils/plotutils/plotconfig.py @@ -1,22 +1,56 @@ """A module that defines common configurations for the plots. """ - +import numpy as np import matplotlib as mpl +#: Associates a partition name with a color +partition_to_color = { + "train": "blue", + "val": "purple", + "test": "green", +} + +#: Associates a partition name with a label +partition_to_label = { + "train": "Training", + "val": "Validation", + "test": "Test", +} + + #: Associates a column name with its label for the plots column_labels = { "pt": "$p_T$ [MeV/c]", "p": "$p$ [MeV/c]", "eta": r"$\eta$", "vz": r"$ovtxz$ [mm]", + "nhits_velo": "# hits", + "n_repeated_planes": "# skipped plane", + "n_skipped_planes": "# hits in same plane", + "r_squared": "$R^2$", + "distance_to_line": "Distance to line [mm]", + "distance_to_z_axis": "Distance to the $z$-axis [mm]", + "xz_angle": "Angle to $x$-$z$ plane [Degree]", + "yz_angle": "Angle to $y$-$z$ plane [Degree]", } #: Associates a column name with its range for the plots column_ranges = { "pt": (0, 2000), "p": (0, 50000), - "eta": None, + "eta": (2.0, 5.0), "vz": (-200, 700), + "r_squared": (0.0, 0.02), + "distance_to_line": (0.0, 0.05), + "distance_to_z_axis": (0.0, 0.05), + "xz_angle": (0.0, 1.5), + "yz_angle": (0.0, 1.5), +} + +column_bins = { + "n_repeated_planes": np.arange(4) - 0.5, + "n_skipped_planes": np.arange(4) - 0.5, + "nhits_velo": np.arange(3, 19) - 0.5, } diff --git a/LHCb_Pipeline/utils/plotutils/plotools.py b/LHCb_Pipeline/utils/plotutils/plotools.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2008dba844e01394ca33ac5380167ad754fbc8 --- /dev/null +++ b/LHCb_Pipeline/utils/plotutils/plotools.py @@ -0,0 +1,26 @@ +"""Define some global utilies for plots. +""" +import typing +import os +from matplotlib.figure import Figure + + +def save_fig( + fig: Figure, path: str, exts: typing.List[str] = [".pdf", ".png"], **kwargs +): + """Save a figure. + + Args: + fig: Matplotlib figure to save + path: path where to save the figure + + """ + os.makedirs(os.path.dirname(path), exist_ok=True) + fig.tight_layout() + + path_without_ext, _ = os.path.splitext(path) + + for ext in exts: + overall_path = path_without_ext + ext + fig.savefig(path_without_ext + ext, bbox_inches="tight", **kwargs) + print("Figure was saved in", overall_path) diff --git a/Monitoring/gitlab-ci/testing_config.yaml b/Monitoring/gitlab-ci/testing_config.yaml index 11dd476f2ff49bed423418551041193a6d24ed8e..814d4ccb72ac9c83ec56b045ea8c4f8804708c5c 100644 --- a/Monitoring/gitlab-ci/testing_config.yaml +++ b/Monitoring/gitlab-ci/testing_config.yaml @@ -25,12 +25,12 @@ processing: n_train_events: 10 n_val_events: 10 split_seed: 0 + true_edges_column: modulewise metric_learning: # Dataset parameters input_subdirectory: processed output_subdirectory: metric_learning_processed - true_edges_column: modulewise_true_edges # Model parameters feature_indices: 3 # indices in `batch.x`. If `null`, everything is taken. @@ -63,8 +63,8 @@ gnn: feature_indices: 3 # indices in `batch.x`. If `null`, everything is taken. hidden: 128 n_graph_iters: 6 - nb_node_layer: 3 - nb_edge_layer: 3 + nb_node_layers: 3 + nb_edge_layers: 3 layernorm: True aggregation: sum_max hidden_activation: SiLU @@ -75,7 +75,6 @@ gnn: patience: 8 truth_key: pid_signal regime: [pid] - mask_background: True max_epochs: 1 track_building: diff --git a/exploration/check_tracks.ipynb b/exploration/check_tracks.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..879f7aceafddc1666b6b0f3408f18ae47a76c8d5 --- /dev/null +++ b/exploration/check_tracks.ipynb @@ -0,0 +1,2088 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from tqdm.auto import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from utils.commonutils.config import load_config\n", + "\n", + "#: Path to the configuration file\n", + "config_path = \"pipeline_configs/velo-sim10b-nospillover-lot.yaml\"\n", + "\n", + "config = load_config(config_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9c5ae4ce3390435bb39c16ff2f76df63", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "input_dir = os.path.join(\n", + " config[\"track_building\"][\"input_dir\"], \"test\", \"velo-sim10b-nospillover\"\n", + ")\n", + "input_paths = [entry.path for entry in os.scandir(input_dir) if entry.is_file]\n", + "\n", + "batches = []\n", + "\n", + "for input_path in tqdm(input_paths):\n", + " batches.append(\n", + " torch.load(input_path, map_location=torch.device(\"cpu\"))\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "batch = batches[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles = pd.read_parquet(batch.truncated_path + \"-hits_particles.parquet\")\n", + "particles = pd.read_parquet(batch.truncated_path + \"-particles.parquet\")\n", + "\n", + "hits_particles = hits_particles.merge(\n", + " particles[[\"particle_id\", \"nhits_velo\"]],\n", + " on=[\"particle_id\"],\n", + " how=\"left\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "edge_indices = batch.edge_index\n", + "\n", + "particle_id_edge_indices = batch.particle_id[edge_indices]\n", + "\n", + "true_edge_mask = particle_id_edge_indices[0] == particle_id_edge_indices[1]\n", + "\n", + "true_edge_indices = edge_indices[:, true_edge_mask]\n", + "true_edge_scores = batch.scores[true_edge_mask]\n", + "true_particle_ids = particle_id_edge_indices[:, true_edge_mask][0]\n", + "fake_edge_indices = edge_indices[:, ~true_edge_mask]\n", + "true_plane_edge_indices = batch.plane[edge_indices][:, true_edge_mask]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_true_edges = pd.DataFrame(\n", + " {\n", + " \"edge_idx_left\": true_edge_indices.numpy().min(axis=0),\n", + " \"edge_idx_right\": true_edge_indices.numpy().max(axis=0),\n", + " \"score\": true_edge_scores.numpy(),\n", + " \"particle_id\": true_particle_ids,\n", + " \"plane_left\": true_plane_edge_indices[0],\n", + " \"plane_right\": true_plane_edge_indices[1],\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_true_edges = df_true_edges.merge(\n", + " particles[[\"particle_id\", \"nhits_velo\"]],\n", + " on=[\"particle_id\"],\n", + " how=\"left\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>edge_idx_left</th>\n", + " <th>edge_idx_right</th>\n", + " <th>score</th>\n", + " <th>particle_id</th>\n", + " <th>plane_left</th>\n", + " <th>plane_right</th>\n", + " <th>nhits_velo</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>385</th>\n", + " <td>279</td>\n", + " <td>298</td>\n", + " <td>0.999993</td>\n", + " <td>1529</td>\n", + " <td>14</td>\n", + " <td>13</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>410</th>\n", + " <td>279</td>\n", + " <td>319</td>\n", + " <td>0.999993</td>\n", + " <td>1529</td>\n", + " <td>15</td>\n", + " <td>13</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>411</th>\n", + " <td>298</td>\n", + " <td>319</td>\n", + " <td>0.999992</td>\n", + " <td>1529</td>\n", + " <td>15</td>\n", + " <td>14</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>434</th>\n", + " <td>298</td>\n", + " <td>339</td>\n", + " <td>0.999886</td>\n", + " <td>1529</td>\n", + " <td>14</td>\n", + " <td>16</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>435</th>\n", + " <td>319</td>\n", + " <td>339</td>\n", + " <td>0.999770</td>\n", + " <td>1529</td>\n", + " <td>15</td>\n", + " <td>16</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>457</th>\n", + " <td>339</td>\n", + " <td>360</td>\n", + " <td>0.999197</td>\n", + " <td>1529</td>\n", + " <td>17</td>\n", + " <td>16</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>477</th>\n", + " <td>339</td>\n", + " <td>372</td>\n", + " <td>0.999621</td>\n", + " <td>1529</td>\n", + " <td>18</td>\n", + " <td>16</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>478</th>\n", + " <td>360</td>\n", + " <td>372</td>\n", + " <td>0.999989</td>\n", + " <td>1529</td>\n", + " <td>18</td>\n", + " <td>17</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>506</th>\n", + " <td>360</td>\n", + " <td>390</td>\n", + " <td>0.999974</td>\n", + " <td>1529</td>\n", + " <td>19</td>\n", + " <td>17</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>507</th>\n", + " <td>372</td>\n", + " <td>390</td>\n", + " <td>0.999976</td>\n", + " <td>1529</td>\n", + " <td>18</td>\n", + " <td>19</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>552</th>\n", + " <td>390</td>\n", + " <td>428</td>\n", + " <td>0.999986</td>\n", + " <td>1529</td>\n", + " <td>19</td>\n", + " <td>20</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>553</th>\n", + " <td>428</td>\n", + " <td>442</td>\n", + " <td>0.999991</td>\n", + " <td>1529</td>\n", + " <td>21</td>\n", + " <td>20</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>571</th>\n", + " <td>442</td>\n", + " <td>463</td>\n", + " <td>0.999987</td>\n", + " <td>1529</td>\n", + " <td>21</td>\n", + " <td>22</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>609</th>\n", + " <td>463</td>\n", + " <td>487</td>\n", + " <td>0.999978</td>\n", + " <td>1529</td>\n", + " <td>22</td>\n", + " <td>23</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>655</th>\n", + " <td>487</td>\n", + " <td>553</td>\n", + " <td>0.999557</td>\n", + " <td>1529</td>\n", + " <td>23</td>\n", + " <td>25</td>\n", + " <td>14</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " edge_idx_left edge_idx_right score particle_id plane_left \\\n", + "385 279 298 0.999993 1529 14 \n", + "410 279 319 0.999993 1529 15 \n", + "411 298 319 0.999992 1529 15 \n", + "434 298 339 0.999886 1529 14 \n", + "435 319 339 0.999770 1529 15 \n", + "457 339 360 0.999197 1529 17 \n", + "477 339 372 0.999621 1529 18 \n", + "478 360 372 0.999989 1529 18 \n", + "506 360 390 0.999974 1529 19 \n", + "507 372 390 0.999976 1529 18 \n", + "552 390 428 0.999986 1529 19 \n", + "553 428 442 0.999991 1529 21 \n", + "571 442 463 0.999987 1529 21 \n", + "609 463 487 0.999978 1529 22 \n", + "655 487 553 0.999557 1529 23 \n", + "\n", + " plane_right nhits_velo \n", + "385 13 14 \n", + "410 13 14 \n", + "411 14 14 \n", + "434 16 14 \n", + "435 16 14 \n", + "457 16 14 \n", + "477 16 14 \n", + "478 17 14 \n", + "506 17 14 \n", + "507 19 14 \n", + "552 20 14 \n", + "553 20 14 \n", + "571 22 14 \n", + "609 23 14 \n", + "655 25 14 " + ] + }, + "execution_count": 160, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_true_edges.query(\"particle_id == 1529 and score > 0.999\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_reco_hits(df_edges: pd.DataFrame) -> pd.DataFrame:\n", + " return pd.DataFrame(\n", + " {\n", + " \"hit_idx\": np.concatenate(\n", + " (df_true_edges[\"edge_idx_left\"], df_true_edges[\"edge_idx_right\"])\n", + " ),\n", + " \"particle_id\": np.concatenate(\n", + " (df_true_edges[\"particle_id\"], df_true_edges[\"particle_id\"])\n", + " ),\n", + " }\n", + " ).drop_duplicates()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nhits_per_particle = df_true_edges.groupby([\"particle_id\"]).size().rename(\"n_edges\")\n", + "average_score_per_particle = (\n", + " df_true_edges.groupby([\"particle_id\"])[\"score\"].mean().rename(\"avg_score\")\n", + ")\n", + "\n", + "score_cut = 0.9\n", + "\n", + "df_true_edges[\"accepted\"] = df_true_edges[\"score\"] >= score_cut\n", + "n_accepted_edges = (\n", + " df_true_edges.groupby([\"particle_id\"])[\"accepted\"].sum().rename(\"n_accepted\")\n", + ")\n", + "\n", + "n_reco_hits = (\n", + " get_reco_hits(df_true_edges)\n", + " .groupby([\"particle_id\"])[\"hit_idx\"]\n", + " .nunique()\n", + " .rename(\"n_reco_hits\")\n", + ")\n", + "n_reco_accepted_hits = (\n", + " get_reco_hits(df_true_edges.query(f\"score >= {score_cut}\"))\n", + " .groupby([\"particle_id\"])[\"hit_idx\"]\n", + " .nunique()\n", + " .rename(\"n_reco_accepted_hits\")\n", + ")\n", + "\n", + "df_particle_stats = pd.concat(\n", + " (nhits_per_particle, average_score_per_particle, n_accepted_edges, n_reco_hits, n_reco_accepted_hits),\n", + " axis=1,\n", + ").reset_index()\n", + "\n", + "df_particle_stats = df_particle_stats.merge(\n", + " particles[[\"particle_id\", \"nhits_velo\"]],\n", + " how=\"left\",\n", + " on=[\"particle_id\"],\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df_particle_stats[\"score_efficiency\"] = df_particle_stats[\"n_reco_accepted_hits\"] / df_particle_stats[\"n_reco_hits\"]\n", + "df_particle_stats[\"hit_efficiency\"] = df_particle_stats[\"n_reco_accepted_hits\"] / df_particle_stats[\"nhits_velo\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "nhits_velo\n", + "2 1.000000\n", + "3 0.958333\n", + "4 0.944444\n", + "5 1.000000\n", + "6 0.958333\n", + "7 1.000000\n", + "8 1.000000\n", + "9 1.000000\n", + "10 1.000000\n", + "11 1.000000\n", + "12 1.000000\n", + "13 1.000000\n", + "14 1.000000\n", + "15 1.000000\n", + "17 1.000000\n", + "19 1.000000\n", + "20 0.650000\n", + "Name: hit_efficiency, dtype: float64" + ] + }, + "execution_count": 152, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_particle_stats.groupby(\"nhits_velo\")[\"hit_efficiency\"].mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "fig, axes = plt.subplots(1, 2, figsize=(8, 6 * 2))\n", + "axes[0].bar(\n", + " x=df_particle_stats[\"nhits_velo\"].unique(),\n", + " height=df_particle_stats.groupby(),\n", + " width=1.0,\n", + " fill=False,\n", + " edgecolor=\"blue\",\n", + " \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>edge_idx_left</th>\n", + " <th>edge_idx_right</th>\n", + " <th>score</th>\n", + " <th>particle_id</th>\n", + " <th>plane_left</th>\n", + " <th>plane_right</th>\n", + " <th>nhits_velo</th>\n", + " <th>n_edges</th>\n", + " <th>avg_score</th>\n", + " <th>accepted</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>8</td>\n", + " <td>48</td>\n", + " <td>0.968121</td>\n", + " <td>805</td>\n", + " <td>0</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>0.955308</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>11</td>\n", + " <td>78</td>\n", + " <td>0.999993</td>\n", + " <td>368</td>\n", + " <td>0</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>6</td>\n", + " <td>0.999993</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>2</td>\n", + " <td>20</td>\n", + " <td>0.999983</td>\n", + " <td>240</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>6</td>\n", + " <td>14</td>\n", + " <td>0.806738</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>3</td>\n", + " <td>21</td>\n", + " <td>0.999993</td>\n", + " <td>1059</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>6</td>\n", + " <td>0.999993</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>5</td>\n", + " <td>22</td>\n", + " <td>0.999775</td>\n", + " <td>245</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>4</td>\n", + " <td>6</td>\n", + " <td>0.544625</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>697</th>\n", + " <td>527</td>\n", + " <td>554</td>\n", + " <td>0.999993</td>\n", + " <td>866</td>\n", + " <td>24</td>\n", + " <td>25</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>0.999993</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>698</th>\n", + " <td>528</td>\n", + " <td>555</td>\n", + " <td>0.999993</td>\n", + " <td>868</td>\n", + " <td>24</td>\n", + " <td>25</td>\n", + " <td>6</td>\n", + " <td>9</td>\n", + " <td>0.999991</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>699</th>\n", + " <td>531</td>\n", + " <td>556</td>\n", + " <td>0.999969</td>\n", + " <td>642</td>\n", + " <td>24</td>\n", + " <td>25</td>\n", + " <td>20</td>\n", + " <td>17</td>\n", + " <td>0.999379</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>700</th>\n", + " <td>534</td>\n", + " <td>544</td>\n", + " <td>0.979185</td>\n", + " <td>558</td>\n", + " <td>25</td>\n", + " <td>24</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>0.979185</td>\n", + " <td>True</td>\n", + " </tr>\n", + " <tr>\n", + " <th>701</th>\n", + " <td>535</td>\n", + " <td>559</td>\n", + " <td>0.999988</td>\n", + " <td>1286</td>\n", + " <td>25</td>\n", + " <td>24</td>\n", + " <td>8</td>\n", + " <td>14</td>\n", + " <td>0.999989</td>\n", + " <td>True</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>702 rows × 10 columns</p>\n", + "</div>" + ], + "text/plain": [ + " edge_idx_left edge_idx_right score particle_id plane_left \\\n", + "0 8 48 0.968121 805 0 \n", + "1 11 78 0.999993 368 0 \n", + "2 2 20 0.999983 240 0 \n", + "3 3 21 0.999993 1059 0 \n", + "4 5 22 0.999775 245 1 \n", + ".. ... ... ... ... ... \n", + "697 527 554 0.999993 866 24 \n", + "698 528 555 0.999993 868 24 \n", + "699 531 556 0.999969 642 24 \n", + "700 534 544 0.979185 558 25 \n", + "701 535 559 0.999988 1286 25 \n", + "\n", + " plane_right nhits_velo n_edges avg_score accepted \n", + "0 2 3 3 0.955308 True \n", + "1 3 4 6 0.999993 True \n", + "2 1 6 14 0.806738 True \n", + "3 1 4 6 0.999993 True \n", + "4 0 4 6 0.544625 True \n", + ".. ... ... ... ... ... \n", + "697 25 3 3 0.999993 True \n", + "698 25 6 9 0.999991 True \n", + "699 25 20 17 0.999379 True \n", + "700 24 3 1 0.979185 True \n", + "701 24 8 14 0.999989 True \n", + "\n", + "[702 rows x 10 columns]" + ] + }, + "execution_count": 117, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_particle_stats.groupby([\"nhits_velo\"])[\"accepted\"]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build tracks" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import typing\n", + "import scipy.sparse as sps\n", + "from torch_geometric.data import Data\n", + "\n", + "\n", + "def get_track_ids(edge_indices: torch.Tensor, n_hits: int) -> np.ndarray:\n", + " row, col = edge_indices\n", + " edge_attr = np.ones(row.size(0))\n", + "\n", + " sparse_edges = sps.coo_matrix(\n", + " (edge_attr, (row.numpy(), col.numpy())), (n_hits, n_hits)\n", + " )\n", + "\n", + " _, candidate_labels = sps.csgraph.connected_components(\n", + " sparse_edges, directed=False, return_labels=True\n", + " )\n", + " return candidate_labels\n", + "\n", + "\n", + "def update_next_labels(\n", + " labels: np.ndarray,\n", + " edge_indices: np.ndarray,\n", + " scores: np.ndarray,\n", + " score_cut: float,\n", + "):\n", + " unique_labels, n_labels = np.unique(labels, return_counts=True)\n", + " consistent_unique_labels = unique_labels[n_labels >= 3]\n", + "\n", + " hit_mask = ~np.isin(labels, consistent_unique_labels)\n", + "\n", + " n_hits = hit_mask.sum()\n", + " # Only keep edges between TWO remaining hits\n", + " edges_mask = torch.from_numpy(hit_mask[edge_indices].min(axis=0))\n", + " new_indices = np.full(fill_value=-1, shape=labels.shape[0], dtype=int)\n", + " new_indices[hit_mask] = np.arange(hit_mask.sum())\n", + "\n", + " reduced_edge_indices = new_indices[edge_indices][:, edges_mask]\n", + " reduced_scores = scores[edges_mask]\n", + "\n", + " # Build again\n", + " reduced_labels = get_track_ids(\n", + " torch.from_numpy(\n", + " reduced_edge_indices[:, reduced_scores > score_cut],\n", + " ),\n", + " n_hits=n_hits\n", + " )\n", + " reduced_labels += labels.max() + 1\n", + " labels[hit_mask] = reduced_labels\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "labels = get_track_ids(\n", + " batch.edge_index[:, batch.scores > 0.99], n_hits=batch.x.shape[0]\n", + ")\n", + "edge_indices = batch.edge_index.numpy()\n", + "scores = batch.scores.numpy()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "update_next_labels(labels, edge_indices, scores, score_cut=0.98)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([560])" + ] + }, + "execution_count": 255, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.hit_id.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(172,)" + ] + }, + "execution_count": 256, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reduced_labels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 2, 2, 5, 0, 6, 8, 8, 11, 12, 9, 9, 11, 9,\n", + " 3, 3, 13, 5, 13, 1, 14, 14, 1, 5, 18, 13, 14,\n", + " 1, 5, 19, 19, 14, 35, 36, 37, 22, 22, 43, 23, 22,\n", + " 50, 53, 62, 53, 54, 62, 54, 57, 57, 66, 65, 51, 62,\n", + " 56, 56, 63, 80, 57, 69, 74, 75, 75, 76, 102, 91, 104,\n", + " 91, 91, 109, 109, 113, 141, 131, 139, 116, 139, 131, 140, 141,\n", + " 155, 162, 170, 162, 160, 169],\n", + " [ 11, 12, 1, 6, 1, 2, 3, 8, 8, 2, 3, 9, 12,\n", + " 11, 12, 1, 13, 6, 14, 5, 6, 18, 18, 6, 18, 18,\n", + " 19, 19, 6, 13, 19, 22, 22, 22, 38, 39, 22, 44, 46,\n", + " 27, 59, 53, 65, 59, 54, 65, 25, 64, 58, 59, 60, 65,\n", + " 63, 67, 67, 68, 69, 64, 70, 64, 69, 71, 91, 103, 91,\n", + " 105, 106, 107, 111, 141, 114, 115, 115, 131, 116, 139, 155, 145,\n", + " 144, 150, 150, 151, 164, 167]])" + ] + }, + "execution_count": 236, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 84])" + ] + }, + "execution_count": 220, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reduced_edge_indices.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(560,)" + ] + }, + "execution_count": 216, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reduced_labels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "NumPy boolean array indexing assignment cannot assign 560 input values to the 172 output values where the mask is true", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[214], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# Assign the labels to the remaining hits\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m labels[hit_mask] \u001b[39m=\u001b[39m reduced_labels\n", + "\u001b[0;31mValueError\u001b[0m: NumPy boolean array indexing assignment cannot assign 560 input values to the 172 output values where the mask is true" + ] + } + ], + "source": [ + "# Assign the labels to the remaining hits\n", + "labels[hit_mask] = reduced_labels\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ True, True, False, ..., False, False, True],\n", + " [ True, True, False, ..., False, False, False]])" + ] + }, + "execution_count": 232, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hit_mask[batch.edge_index]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0, 1, 2, ..., 473, 474, 475])" + ] + }, + "execution_count": 176, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_track_id(batches[9], 0.999)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ True, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, True, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False])" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.labels" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Merge tracks into doublets" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.data import Data" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([False, True, False, ..., False, True, False])" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.y_pid" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "score_cut = 0.3\n", + "\n", + "# 1. Filter edges and hits with the given score_cut (needs to be conservative)\n", + "# Filter edges \n", + "edge_mask = batch.scores > score_cut\n", + "edge_indices = batch.edge_index[:, edge_mask]\n", + "y_pid = batch.y_pid[edge_mask]\n", + "\n", + "# Filter hits not connected to any edge\n", + "unique_edge_indices = torch.unique(edge_indices, sorted=True)\n", + "features = batch.x[unique_edge_indices]\n", + "\n", + "# Reindex edges given the filtering\n", + "n_hits = unique_edge_indices.shape[0]\n", + "mapping_new_indices = torch.full((edge_indices.max() + 1,), -1, dtype=torch.long)\n", + "mapping_new_indices[unique_edge_indices] = torch.arange(n_hits)\n", + "reindexed_edge_indices = mapping_new_indices[edge_indices]\n", + "\n", + "# 2. Compute doublet features\n", + "new_features = batch.x[reindexed_edge_indices]\n", + "doublet_features = torch.cat((new_features[0], new_features[1]), dim=-1)\n", + "\n", + "# 3. Form edges between doublets\n" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", + " -1, -1])" + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.full((edge_indices.max() + 1,), -1, dtype=torch.long)" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [], + "source": [ + "node_features = batch.x\n", + "edge_indices = batch.edge_index\n", + "edge_mask = batch.scores > 0.2\n", + "unique_edge_indices = torch.unique(edge_indices[:, edge_mask], sorted=True)\n", + "filtered_node_features = node_features[unique_edge_indices]\n", + "\n", + "# Reindex the node indices in `filtered_edge_indices`\n", + "n_filtered_nodes = unique_edge_indices.shape[0]\n", + "mapping_new_indices = torch.full(\n", + " (edge_indices.max() + 1,), -1, dtype=torch.long # type: ignore\n", + ")\n", + "mapping_new_indices[unique_edge_indices] = torch.arange(n_filtered_nodes)\n", + "reindexed_edge_indices = mapping_new_indices[edge_indices]\n", + "filtered_reindexed_edge_indices = reindexed_edge_indices[:, edge_mask]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['plane',\n", + " 'signal_true_edges',\n", + " 'truncated_path',\n", + " 'scores',\n", + " 'x',\n", + " 'y',\n", + " 'particle_id',\n", + " 'edge_index',\n", + " 'hit_id',\n", + " 'event_str',\n", + " 'y_pid']" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.keys" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(False)" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(filtered_reindexed_edge_indices == -1).any()" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "triplet_index = torch.from_numpy(\n", + " edge_to_triplet_scipy(reindexed_edge_indices.numpy())\n", + ").long()\n", + "y_triplet = y_pid[triplet_index].min(dim=0).values\n" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, False, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, False, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, False, False,\n", + " True, True, False, True, True, False, False, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, False, False, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, False, False, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, False,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, False, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, False, False, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, False, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True])" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pid" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ True, True, True, ..., True, True, True],\n", + " [False, True, True, ..., True, True, True]])" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pid[triplet_index]" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ True, True, True, ..., True, True, True],\n", + " [False, True, True, ..., True, True, True]])" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pid.numpy()[triplet_index]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 13, 13, 46, ..., 661, 639, 680],\n", + " [ 0, 1, 2, ..., 702, 704, 707]])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "edge_to_triplet_scipy(reindexed_edge_indices.numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[44],\n", + " [28]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reindexed_edge_indices[:, reindexed_edge_indices[0, :] == 44]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# Sort edges by plane\n", + "edge_planes = planes[reindexed_edge_indices]\n", + "\n", + "flip_mask = edge_planes[0] > edge_planes[1]\n", + "ordered_edge_indices = reindexed_edge_indices.clone()\n", + "ordered_edge_indices[:, flip_mask] = reindexed_edge_indices[:, flip_mask].flip(dims=[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 8, 8, 8, ..., 8, 8, 8],\n", + " [ 8, 8, 8, ..., 8, 8, 8],\n", + " [ 11, 11, 11, ..., 11, 11, 11],\n", + " ...,\n", + " [445, 445, 445, ..., 445, 445, 445],\n", + " [446, 446, 446, ..., 446, 446, 446],\n", + " [447, 447, 447, ..., 447, 447, 447]],\n", + "\n", + " [[ 44, 44, 44, ..., 44, 44, 44],\n", + " [ 45, 45, 45, ..., 45, 45, 45],\n", + " [ 74, 74, 74, ..., 74, 74, 74],\n", + " ...,\n", + " [466, 466, 466, ..., 466, 466, 466],\n", + " [455, 455, 455, ..., 455, 455, 455],\n", + " [468, 468, 468, ..., 468, 468, 468]]])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ordered_edge_indices.unsqueeze(-1).expand(-1, -1, ordered_edge_indices.shape[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 10, 31, 43, 13, 17, 18, 11, 36, 19, 36, 12, 16, 6, 25,\n", + " 10, 31, 0, 23, 42, 1, 26, 7, 2, 16, 43, 3, 20, 22,\n", + " 22, 14, 14, 19, 30, 30, 11, 11, 53, 53, 37, 37, 13, 32,\n", + " 18, 18, 34, 34, 16, 43, 39, 18, 18, 34, 34, 12, 12, 38,\n", + " 38, 40, 40, 17, 33, 33, 35, 15, 15, 27, 27, 27, 2, 2,\n", + " 2, 19, 47, 47, 47, 3, 3, 3, 20, 20, 20, 20, 20, 48,\n", + " 48, 48, 48, 48, 0, 0, 0, 1, 1, 1, 27, 26, 26, 26,\n", + " 7, 7, 7, 14, 24, 24, 46, 46, 46, 5, 5, 5, 21, 21,\n", + " 21, 22, 22, 49, 49, 50, 50, 50, 21, 21, 21, 21, 49, 49,\n", + " 49, 49, 31, 50, 50, 50, 50, 27, 27, 27, 63, 63, 63, 14,\n", + " 14, 14, 30, 30, 30, 30, 51, 54, 30, 31, 31, 31, 31, 52,\n", + " 12, 60, 39, 19, 85, 7, 2, 7, 2, 38, 103, 95, 83, 103,\n", + " 24, 59, 103, 67, 82, 47, 19, 64, 30, 60, 16, 19, 85, 57,\n", + " 2, 59, 103, 46, 38, 103, 95, 111, 1, 26, 100, 99, 43, 88,\n", + " 82, 86, 83, 103, 2, 94, 102, 64, 14, 85, 12, 103, 95, 111,\n", + " 103, 16, 39, 39, 100, 88, 127, 57, 57, 100, 99, 126, 103, 94,\n", + " 94, 102, 1, 1, 43, 67, 67, 98, 100, 100, 100, 88, 127, 124,\n", + " 100, 99, 126, 118, 114, 95, 111, 106, 106, 107, 107, 107, 108, 108,\n", + " 108, 110, 110, 113, 113, 127, 124, 146, 126, 125, 118, 111, 120, 121,\n", + " 142, 121, 103, 103, 122, 122, 122, 122, 122, 122, 104, 109, 109, 125,\n", + " 125, 148, 124, 146, 158, 106, 106, 126, 132, 128, 105, 105, 143, 142,\n", + " 135, 131, 131, 131, 131, 119, 115, 137, 137, 140, 140, 140, 140, 141,\n", + " 141, 148, 146, 158, 139, 111, 126, 132, 163, 143, 143, 143, 143, 143,\n", + " 143, 142, 135, 181, 109, 125, 144, 144, 144, 144, 158, 139, 105, 105,\n", + " 155, 129, 129, 159, 177, 148, 157, 177, 119, 163, 177, 135, 135, 115,\n", + " 133, 124, 124, 151, 170, 135, 181, 157, 157, 177, 157, 157, 155, 173,\n", + " 109, 109, 158, 158, 139, 159, 177, 155, 160, 163, 177, 161, 132, 132,\n", + " 149, 181, 151, 170, 196, 133, 152, 161, 161, 173, 182, 142, 142, 156,\n", + " 156, 197, 155, 197, 161, 171, 178, 179, 179, 154, 168, 167, 183, 196,\n", + " 169, 210, 152, 171, 151, 151, 169, 169, 149, 149, 166, 166, 197, 190,\n", + " 197, 175, 182, 174, 174, 156, 171, 172, 172, 193, 193, 196, 169, 210,\n", + " 196, 196, 166, 166, 188, 188, 190, 187, 175, 213, 183, 183, 200, 200,\n", + " 200, 200, 200, 210, 216, 185, 185, 185, 186, 186, 189, 189, 174, 174,\n", + " 187, 189, 213, 208, 191, 173, 189, 189, 172, 172, 193, 193, 172, 216,\n", + " 211, 211, 211, 188, 188, 209, 222, 213, 213, 223, 208, 189, 205, 201,\n", + " 185, 185, 202, 202, 167, 183, 183, 240, 216, 224, 241, 216, 216, 172,\n", + " 221, 203, 203, 198, 206, 206, 223, 228, 205, 214, 222, 220, 244, 220,\n", + " 244, 223, 209, 191, 219, 207, 208, 224, 224, 240, 254, 241, 238, 221,\n", + " 219, 225, 225, 212, 227, 227, 201, 214, 218, 217, 202, 215, 232, 232,\n", + " 229, 235, 233, 226, 244, 257, 273, 244, 248, 214, 228, 237, 237, 231,\n", + " 257, 237, 241, 249, 243, 249, 250, 230, 247, 267, 219, 267, 266, 254,\n", + " 242, 254, 243, 238, 267, 265, 220, 235, 251, 251, 245, 257, 273, 287,\n", + " 231, 257, 274, 259, 275, 246, 243, 261, 281, 250, 279, 280, 247, 267,\n", + " 247, 295, 242, 283, 267, 266, 295, 267, 265, 270, 242, 283, 265, 243,\n", + " 246, 266, 234, 287, 251, 251, 290, 255, 252, 271, 239, 256, 255, 285,\n", + " 273, 272, 273, 287, 302, 274, 261, 277, 275, 288, 303, 268, 269, 261,\n", + " 281, 279, 280, 246, 266, 247, 247, 295, 266, 295, 265, 270, 283, 265,\n", + " 283, 262, 287, 302, 251, 290, 270, 284, 270, 271, 286, 252, 271, 255,\n", + " 285, 255, 285, 272, 272, 287, 302, 261, 277, 288, 303, 288, 303, 277,\n", + " 264, 280, 266, 281, 295, 295, 270, 262, 265, 290, 284, 271, 286, 285,\n", + " 272, 302, 302, 303, 303, 276, 282, 277])" + ] + }, + "execution_count": 114, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ordered_edge_indices[0, ]" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8, 8, 11, ..., 445, 455, 468],\n", + " [ 44, 45, 74, ..., 466, 446, 447]])" + ] + }, + "execution_count": 109, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reindexed_edge_indices" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0, 0, 0, ..., 24, 24, 24],\n", + " [ 2, 2, 3, ..., 25, 25, 25]])" + ] + }, + "execution_count": 111, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "planes[ordered_edge_indices]" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8, 8, 11, 2, 3, 5, 4, 0, 23, 7, 6, 1, 15, 8,\n", + " 28, 28, 9, 14, 10, 13, 17, 18, 11, 36, 36, 12, 16, 6,\n", + " 25, 0, 23, 42, 1, 26, 7, 24, 2, 19, 3, 20, 22, 22,\n", + " 14, 30, 10, 31, 11, 53, 37, 13, 32, 18, 34, 16, 39, 18,\n", + " 34, 12, 38, 40, 17, 33, 35, 15, 27, 2, 19, 47, 3, 20,\n", + " 48, 0, 1, 26, 43, 7, 24, 46, 5, 21, 22, 49, 50, 5,\n", + " 21, 22, 49, 50, 15, 27, 63, 14, 30, 51, 72, 10, 31, 52,\n", + " 12, 38, 59, 75, 35, 62, 13, 32, 55, 18, 34, 56, 17, 33,\n", + " 61, 37, 54, 40, 60, 16, 39, 57, 7, 24, 46, 68, 1, 26,\n", + " 43, 67, 86, 2, 19, 47, 64, 14, 30, 51, 2, 19, 64, 12,\n", + " 38, 59, 16, 39, 57, 82, 83, 94, 94, 1, 26, 43, 67, 85,\n", + " 88, 98, 100, 100, 100, 101, 101, 102, 102, 103, 103, 103, 106, 106,\n", + " 107, 107, 107, 108, 108, 110, 112, 113, 114, 88, 99, 118, 118, 95,\n", + " 120, 121, 121, 103, 122, 122, 122, 122, 104, 109, 125, 125, 111, 126,\n", + " 127, 106, 128, 128, 105, 131, 131, 119, 115, 137, 123, 140, 140, 140,\n", + " 141, 124, 111, 126, 143, 143, 143, 109, 125, 144, 144, 144, 146, 105,\n", + " 129, 148, 119, 132, 135, 115, 133, 139, 124, 142, 126, 157, 157, 157,\n", + " 109, 158, 158, 159, 160, 163, 163, 132, 149, 150, 135, 151, 133, 152,\n", + " 161, 139, 155, 142, 156, 177, 177, 177, 178, 179, 154, 181, 168, 167,\n", + " 183, 170, 152, 171, 151, 169, 149, 166, 155, 173, 174, 156, 175, 161,\n", + " 172, 193, 196, 196, 196, 166, 188, 197, 197, 182, 167, 183, 200, 200,\n", + " 200, 169, 187, 171, 185, 186, 189, 174, 190, 175, 191, 173, 189, 172,\n", + " 193, 210, 211, 211, 211, 199, 188, 213, 213, 187, 201, 185, 202, 167,\n", + " 183, 216, 216, 216, 203, 198, 209, 206, 189, 208, 222, 222, 190, 205,\n", + " 223, 191, 207, 224, 224, 172, 225, 225, 212, 227, 227, 201, 214, 218,\n", + " 217, 202, 215, 232, 232, 233, 233, 206, 220, 236, 205, 223, 237, 237,\n", + " 209, 219, 208, 221, 240, 207, 241, 241, 229, 226, 244, 244, 214, 228,\n", + " 231, 248, 215, 249, 249, 230, 234, 219, 238, 254, 254, 220, 235, 221,\n", + " 239, 256, 257, 257, 259, 259, 245, 231, 246, 243, 250, 243, 247, 267,\n", + " 267, 267, 242, 242, 234, 251, 238, 252, 235, 255, 239, 256, 273, 273,\n", + " 274, 274, 275, 268, 269, 264, 279, 246, 261, 247, 266, 265, 283, 283,\n", + " 251, 270, 252, 271, 255, 272, 287, 287, 288, 288, 264, 277, 278, 261,\n", + " 280, 266, 281, 295, 295, 262, 265, 282, 270, 284, 271, 285, 272, 286,\n", + " 302, 302, 303, 303, 276, 290, 277, 291, 278, 292, 282, 298, 308, 309,\n", + " 296, 284, 299, 285, 300, 314, 314, 286, 301, 315, 316, 317, 317, 304,\n", + " 290, 305, 291, 306, 307, 280, 293, 322, 294, 323, 323, 298, 308, 324,\n", + " 324, 299, 311, 312, 301, 315, 329, 329, 330, 330, 313, 334, 334, 335,\n", + " 335, 318, 336, 319, 337, 321, 338, 293, 322, 339, 339, 340, 340, 341,\n", + " 341, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320,\n", + " 355, 355, 327, 356, 357, 357, 357, 358, 359, 360, 361, 313, 332, 362,\n", + " 362, 326, 363, 364, 364, 365, 365, 366, 366, 366, 367, 367, 367, 368,\n", + " 368, 369, 369, 370, 370, 370, 371, 372, 372, 372, 373, 374, 374, 374,\n", + " 332, 378, 378, 379, 379, 380, 381, 381, 381, 382, 382, 382, 383, 383,\n", + " 383, 383, 384, 384, 385, 385, 385, 386, 386, 386, 387, 387, 387, 388,\n", + " 388, 388, 389, 389, 390, 390, 390, 391, 391, 391, 392, 392, 392, 392,\n", + " 393, 393, 394, 394, 394, 395, 395, 395, 396, 396, 396, 397, 399, 401,\n", + " 401, 402, 402, 403, 403, 404, 405, 405, 406, 406, 407, 407, 408, 408,\n", + " 409, 409, 410, 410, 411, 411, 412, 412, 413, 414, 414, 415, 415, 416,\n", + " 416, 417, 417, 418, 418, 419, 419, 420, 420, 421, 421, 422, 422, 423,\n", + " 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438,\n", + " 439, 440, 441, 442, 443, 445, 446, 447]])" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reindexed_edge_indices.gather(\n", + " 0,\n", + " torch.min(edge_planes, dim=0).indices.unsqueeze(0),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 708])" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reindexed_edge_indices.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([708])" + ] + }, + "execution_count": 94, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.min(edge_planes, dim=0).indices.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8, 8, 11, ..., 445, 455, 468],\n", + " [ 8, 8, 11, ..., 445, 455, 468],\n", + " [ 8, 8, 11, ..., 445, 455, 468],\n", + " ...,\n", + " [ 8, 8, 11, ..., 445, 455, 468],\n", + " [ 44, 45, 74, ..., 466, 446, 447],\n", + " [ 44, 45, 74, ..., 466, 446, 447]])" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reindexed_edge_indices[edge_planes.argmin(axis=0), :]" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1,\n", + " 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n", + " 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0,\n", + " 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0,\n", + " 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,\n", + " 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1,\n", + " 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0,\n", + " 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1,\n", + " 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0,\n", + " 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1,\n", + " 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1,\n", + " 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1,\n", + " 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,\n", + " 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1,\n", + " 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1,\n", + " 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1,\n", + " 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,\n", + " 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1,\n", + " 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1,\n", + " 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0,\n", + " 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n", + " 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0,\n", + " 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,\n", + " 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0,\n", + " 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1,\n", + " 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0,\n", + " 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1])" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ordered_reindex_edge_indices = torch.cat(\n", + " (\n", + " \n", + " )\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", + " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,\n", + " 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,\n", + " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", + " 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,\n", + " 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n", + " 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,\n", + " 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,\n", + " 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,\n", + " 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,\n", + " 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,\n", + " 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,\n", + " 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", + " 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,\n", + " 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,\n", + " 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,\n", + " 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,\n", + " 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,\n", + " 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,\n", + " 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21,\n", + " 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,\n", + " 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22,\n", + " 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,\n", + " 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23,\n", + " 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,\n", + " 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25,\n", + " 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25,\n", + " 25, 25])" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.plane" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8, 8, 11, ..., 445, 455, 468],\n", + " [ 44, 45, 74, ..., 466, 446, 447]])" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reindexed_edge_indices" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8, 8, 11, ..., 445, 455, 468],\n", + " [ 44, 45, 74, ..., 466, 446, 447]])" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "reindexed_edge_indices" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "new_batch = Data(\n", + " x=features,\n", + " edge_index=reindexed_edge_indices,\n", + " particle_id=batch.particle_id[unique_edge_indices],\n", + " signal_true_edges=mapping_new_indices[batch.signal_true_edges],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_28393/2550153991.py:3: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n", + " from IPython.core.display import Image, display\n" + ] + }, + { + "data": { + "text/html": [ + "<div class=\"bk-root\">\n", + " <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n", + " <span id=\"2677\">Loading BokehJS ...</span>\n", + " </div>\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\nconst JS_MIME_TYPE = 'application/javascript';\n const HTML_MIME_TYPE = 'text/html';\n const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n const CLASS_NAME = 'output_bokeh rendered_html';\n\n /**\n * Render data to the DOM node\n */\n function render(props, node) {\n const script = document.createElement(\"script\");\n node.appendChild(script);\n }\n\n /**\n * Handle when an output is cleared or removed\n */\n function handleClearOutput(event, handle) {\n const cell = handle.cell;\n\n const id = cell.output_area._bokeh_element_id;\n const server_id = cell.output_area._bokeh_server_id;\n // Clean up Bokeh references\n if (id != null && id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n\n if (server_id !== undefined) {\n // Clean up Bokeh references\n const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n cell.notebook.kernel.execute(cmd_clean, {\n iopub: {\n output: function(msg) {\n const id = msg.content.text.trim();\n if (id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n }\n }\n });\n // Destroy server and session\n const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n cell.notebook.kernel.execute(cmd_destroy);\n }\n }\n\n /**\n * Handle when a new output is added\n */\n function handleAddOutput(event, handle) {\n const output_area = handle.output_area;\n const output = handle.output;\n\n // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n return\n }\n\n const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n\n if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n // store reference to embed id on output_area\n output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n }\n if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n const bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n const script_attrs = bk_div.children[0].attributes;\n for (let i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n }\n\n function register_renderer(events, OutputArea) {\n\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n const toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[toinsert.length - 1]);\n element.append(toinsert);\n return toinsert\n }\n\n /* Handle when an output is cleared or removed */\n events.on('clear_output.CodeCell', handleClearOutput);\n events.on('delete.Cell', handleClearOutput);\n\n /* Handle when a new output is added */\n events.on('output_added.OutputArea', handleAddOutput);\n\n /**\n * Register the mime type and append_mime function with output_area\n */\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n /* Is output safe? */\n safe: true,\n /* Index of renderer in `output_area.display_order` */\n index: 0\n });\n }\n\n // register the mime type if in Jupyter Notebook environment and previously unregistered\n if (root.Jupyter !== undefined) {\n const events = require('base/js/events');\n const OutputArea = require('notebook/js/outputarea').OutputArea;\n\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n }\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"<div style='background-color: #fdd'>\\n\"+\n \"<p>\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"</p>\\n\"+\n \"<ul>\\n\"+\n \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n \"<li>use INLINE resources instead, as so:</li>\\n\"+\n \"</ul>\\n\"+\n \"<code>\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"</code>\\n\"+\n \"</div>\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"2677\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"2677\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));", + "application/vnd.bokehjs_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import os.path as op\n", + "\n", + "from IPython.core.display import Image, display\n", + "import numpy as np\n", + "from bokeh.plotting import figure, show\n", + "from bokeh.models import ColumnDataSource\n", + "from bokeh.palettes import viridis\n", + "from bokeh.io import export_png, output_notebook\n", + "\n", + "output_notebook()\n", + "\n", + "def plot_graph(batch):\n", + " p = figure(\n", + " title=\"Truth graphs\", x_axis_label=\"x\", y_axis_label=\"y\", height=500, width=500\n", + " )\n", + " q = figure(\n", + " title=\"Predicted graphs\",\n", + " x_axis_label=\"x\",\n", + " y_axis_label=\"y\",\n", + " height=500,\n", + " width=500,\n", + " )\n", + "\n", + " true_edges = batch.signal_true_edges\n", + " true_unique, true_lengths = batch.particle_id[true_edges[0]].unique(\n", + " return_counts=True\n", + " )\n", + " pred_edges = batch.edge_index\n", + " particle_ids = batch.particle_id\n", + "\n", + " nr = batch.x[:, 0]\n", + " nphi = batch.x[:, 1]\n", + " r = nr * 9.75 + 18\n", + " phi = nphi * 1.82\n", + "\n", + " x = r * np.cos(phi)\n", + " y = r * np.sin(phi)\n", + " cmap = viridis(11)\n", + " source = ColumnDataSource(dict(x=x.numpy(), y=y.numpy()))\n", + " p.circle(x=\"x\", y=\"y\", source=source, color=cmap[0], size=1, alpha=0.1)\n", + " q.circle(x=\"x\", y=\"y\", source=source, color=cmap[0], size=1, alpha=0.1)\n", + "\n", + " for i, track in enumerate(true_unique[true_lengths >= 10][:10]):\n", + " # Get true track plot\n", + " track_true_edges = true_edges[:, particle_ids[true_edges[0]] == track].cpu()\n", + " X_edges, Y_edges = x[track_true_edges].numpy(), y[track_true_edges].numpy()\n", + " X = np.concatenate(X_edges)\n", + " Y = np.concatenate(Y_edges)\n", + "\n", + " p.circle(X, Y, color=cmap[i], size=5)\n", + " p.multi_line(X_edges.T.tolist(), Y_edges.T.tolist(), color=cmap[i])\n", + "\n", + " track_pred_edges = (\n", + " pred_edges[:, (particle_ids[pred_edges] == track).any(0)]\n", + " ).cpu()\n", + "\n", + " X_edges, Y_edges = x[track_pred_edges].numpy(), y[track_pred_edges].numpy()\n", + " X = np.concatenate(X_edges)\n", + " Y = np.concatenate(Y_edges)\n", + "\n", + " q.circle(X, Y, color=cmap[i], size=5)\n", + " q.multi_line(X_edges.T.tolist(), Y_edges.T.tolist(), color=cmap[i])\n", + "\n", + " show(p)\n", + " show(q)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8, 8, 11, ..., 445, 455, 468],\n", + " [ 44, 45, 74, ..., 466, 446, 447]])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reindexed_edge_indices" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8, 8, 11, ..., 531, 544, 559],\n", + " [ 47, 48, 78, ..., 556, 534, 535]])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "edge_indices" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "etx4velo_updated", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/exploration/test.ipynb b/exploration/test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2b7d742735c89a4ae72425e4cb23ce4fa111987e --- /dev/null +++ b/exploration/test.ipynb @@ -0,0 +1,333 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-05-14 04:30:22,422 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-v1exanh9', purging\n", + "2023-05-14 04:30:22,422 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-_3n17svj', purging\n", + "2023-05-14 04:30:22,423 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-pecvj_nz', purging\n", + "2023-05-14 04:30:22,423 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-8mxforuv', purging\n", + "2023-05-14 04:30:22,423 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-cfdom5ui', purging\n", + "2023-05-14 04:30:22,423 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-fijvmhzu', purging\n", + "2023-05-14 04:30:22,424 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-dwrafs7q', purging\n", + "2023-05-14 04:30:22,424 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-8avxlnms', purging\n" + ] + } + ], + "source": [ + "import os\n", + "import pandas as pd\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "import dask.dataframe as dd\n", + "from dask.distributed import Client, progress\n", + "\n", + "client = Client()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles = pd.read_parquet(\n", + " \"/scratch/acorreia/minbias-sim10b-xdigi-nospillover/0/hits_velo.parquet.lz4\"\n", + ")\n", + "particles = pd.read_parquet(\n", + " \"/scratch/acorreia/minbias-sim10b-xdigi-nospillover/0/mc_particles.parquet.lz4\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../montetracko/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from Preprocessing.particle_line_fitting import (\n", + " compute_particle_distances_to_lines_dataframe,\n", + ")\n", + "\n", + "hits_particles[\"particle_id\"] = hits_particles[\"mcid\"] + 1\n", + "particles[\"particle_id\"] = particles[\"mcid\"] + 1\n", + "\n", + "hits_particles = hits_particles.merge(\n", + " particles[[\"event\", \"particle_id\", \"has_velo\"]],\n", + " how=\"left\",\n", + " on=[\"event\", \"particle_id\"],\n", + ")\n", + "hits_particles = hits_particles[hits_particles[\"has_velo\"] == 1]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "new_distances = compute_particle_distances_to_lines_dataframe(\n", + " hits=hits_particles,\n", + " metric_names=[\"distance_to_line\", \"distance_to_z_axis\"],\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "hits_particles = hits_particles.merge(\n", + " new_distances, how=\"left\", on=[\"event\", \"particle_id\"]\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "particles = particles.merge(\n", + " new_distances, how=\"left\", on=[\"event\", \"particle_id\"]\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 0.000002\n", + "1 0.010684\n", + "2 0.155721\n", + "3 0.032444\n", + "4 0.034152\n", + " ... \n", + "5345106 0.025152\n", + "5345107 0.005537\n", + "5345108 0.855370\n", + "5345109 NaN\n", + "5345110 5.439347\n", + "Name: distance_to_line, Length: 5345111, dtype: float64" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "particles[\"distance_to_line\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'Abundance')" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "ax.hist(particles[\"distance_to_line\"], range=(0., 1.))\n", + "ax.set_xlabel(\"Distance to a straight line fitted to it\")\n", + "ax.set_ylabel(\"Abundance\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, 'Abundance')" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 800x600 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "ax.hist(particles[\"distance_to_z_axis\"], range=(0., 1.))\n", + "ax.set_xlabel(\"Distance to the z-axis\")\n", + "ax.set_ylabel(\"Abundance\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "\n", + "def plot_xy_graph(\n", + " df_hits_particles: pd.DataFrame,\n", + " seed: int | None = None,\n", + "):\n", + " fig, ax = plt.subplots(figsize=(12, 12))\n", + "\n", + " ax.axhline(y=0.0, color=\"k\", linewidth=0.5)\n", + " ax.axvline(x=0.0, color=\"k\", linewidth=0.5)\n", + "\n", + " n_lines = 20\n", + " event_ids = df_hits_particles[\"event\"].unique()\n", + "\n", + " rng = np.random.default_rng(seed=seed)\n", + " rng.shuffle(event_ids)\n", + "\n", + " for idx, (_, hits_particle) in enumerate(\n", + " df_hits_particles[\n", + " df_hits_particles[\"event\"].isin(event_ids[:10])\n", + " & (\n", + " (df_hits_particles[\"distance_to_z_axis\"] < 0.5)\n", + " & (df_hits_particles[\"distance_to_z_axis\"] > 0.2)\n", + " )\n", + " ][[\"event\", \"particle_id\", \"x\", \"y\", \"z\", \"plane\"]].groupby(\n", + " by=[\"event\", \"particle_id\"]\n", + " )\n", + " ):\n", + " hit_coordinates = hits_particle.sort_values(by=\"plane\")[[\"x\", \"y\"]]\n", + " ax.plot(\n", + " hit_coordinates[\"x\"],\n", + " hit_coordinates[\"y\"],\n", + " linestyle=\"-\",\n", + " markersize=5.0,\n", + " marker=\"o\",\n", + " )\n", + " if idx > n_lines:\n", + " break\n", + "\n", + " ax.set_xlim(-50.0, 50.0)\n", + " ax.set_ylim(-50.0, 50.0)\n", + " ax.grid(color=\"grey\", alpha=0.5)\n", + "\n", + " return fig, ax\n" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(<Figure size 1200x1200 with 1 Axes>, <Axes: >)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 1200x1200 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_xy_graph(hits_particles)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/montetracko b/montetracko index 15ca87819f350195715152e9f6e4c73ce8654f5d..e033e8a998cfa59df42f4e39c068e6ce4671ffae 160000 --- a/montetracko +++ b/montetracko @@ -1 +1 @@ -Subproject commit 15ca87819f350195715152e9f6e4c73ce8654f5d +Subproject commit e033e8a998cfa59df42f4e39c068e6ce4671ffae