diff --git a/LHCb_Pipeline/Embedding/build_embedding.py b/LHCb_Pipeline/Embedding/build_embedding.py
index ac2958573daa14924c704d5f209b88014406a54c..60e1779c259f443499fd50cde34345dd336e16ce 100644
--- a/LHCb_Pipeline/Embedding/build_embedding.py
+++ b/LHCb_Pipeline/Embedding/build_embedding.py
@@ -8,9 +8,7 @@ from utils.modelutils.build import ModelBuilderBase
 from utils.commonutils.config import load_config
 from utils.graphutils.edgeutils import sort_edge_nodes
 
-from . import building_custom
-
-device = "cuda" if torch.cuda.is_available() else "cpu"
+from . import process_custom
 
 
 def get_radius_from_config(path_or_config: str | dict) -> float:
@@ -90,64 +88,12 @@ class EmbeddingInferenceBuilder(ModelBuilderBase):
             unique_indices = torch.unique(unique_inverse)
             y_cluster = y_cluster[unique_indices]
 
-        assert e_spatial.shape[1] == torch.unique(e_spatial, dim=1).shape[1]
-
         batch["edge_index"] = e_spatial
         batch["y"] = y_cluster
         return batch
 
     def _get_building_custom_module(self) -> ModuleType:
-        return building_custom
-
-    # def filter_batch(self, batch: Data) -> Data:
-    #     # at least 3 hits to be classified as valid edges
-    #     # not_reconstructible_mask = batch.n_unique_planes < 3
-    #     not_reconstructible_mask = batch.nhits_velo < 3
-
-    #     # Apply this transformation to `y` as well
-    #     not_reconstructible_edge_mask = (
-    #         not_reconstructible_mask[batch.edge_index].min(dim=0).values
-    #     )
-
-    #     batch.y[not_reconstructible_edge_mask] = False
-
-    #     # Classify the hits as fake (so that the edges are also classified like so)
-    #     batch.particle_id[not_reconstructible_mask] = 0
-
-    #     # Remove edges in same `plane` and `z`
-    #     edge_index_plane = batch.plane[batch.edge_index]
-    #     no_self_edge_mask = edge_index_plane[0] != edge_index_plane[1]
-    #     batch.edge_index = batch.edge_index[:, no_self_edge_mask]
-    #     batch.y = batch.y[no_self_edge_mask]
-
-    #     return batch
-
-    # def build_features(self, batch: Data) -> Data:
-    #     # norm_x = batch.un_x / 14.5
-    #     # norm_y = batch.un_y / 14.5
-    #     # xe = batch.un_x[batch.edge_index]
-    #     # ye = batch.un_y[batch.edge_index]
-    #     # ze = batch.un_z[batch.edge_index]
-
-    #     # slopes_yz = (ye[1] - ye[0]) / (ze[1] - ze[0]) / 0.17
-    #     # slopes_xz = (xe[1] - xe[0]) / (ze[1] - ze[0]) / 0.17
-
-    #     # assert not torch.isnan(slopes_yz).any()
-    #     # assert not torch.isnan(slopes_xz).any()
-    #     # # Modify batch definition
-    #     # batch.x = torch.stack((norm_x, norm_y, batch["x"][:, 2]), dim=1).float()
-    #     # batch.edge_features = torch.stack((slopes_xz, slopes_yz), dim=1).float()
-    #     return batch
-
-    # def build_weights(self, batch: Data) -> Data:
-    #     node_weights = 7.0 / batch.n_unique_planes
-    #     node_weights = torch.nan_to_num(node_weights, nan=1.0)
-    #     edge_weights = torch.mean(node_weights[batch.edge_index], dim=0)
-    #     batch.edge_weights = (
-    #         edge_weights * batch.edge_index.shape[1] / edge_weights.sum()
-    #     )
-    #     assert not torch.isnan(batch.edge_weights).any(), str(batch.edge_weights)
-    #     return batch
+        return process_custom
 
     def get_performance(self, batch: Data, r_max: float, k_max: int):
         with torch.no_grad():
diff --git a/LHCb_Pipeline/Embedding/embedding_base.py b/LHCb_Pipeline/Embedding/embedding_base.py
index 27d4ecdce353aed095585dfde47ee7b9ed6272fc..186021f452002e4a6d603b153d263bffa8c9634e 100644
--- a/LHCb_Pipeline/Embedding/embedding_base.py
+++ b/LHCb_Pipeline/Embedding/embedding_base.py
@@ -64,7 +64,12 @@ class EmbeddingBase(ModelBase):
     def append_hnm_pairs(self, e_spatial, query, query_indices, spatial):
         if "low_purity" in self.hparams["regime"]:
             knn_edges = build_edges(
-                query, spatial, query_indices, self.hparams["r"], 500
+                query,
+                spatial,
+                query_indices,
+                self.hparams["r"],
+                500,
+                device=self.device,
             )
             knn_edges = knn_edges[
                 :,
@@ -80,14 +85,15 @@ class EmbeddingBase(ModelBase):
                 query_indices,
                 self.hparams["r"],
                 self.hparams["knn"],
+                device=self.device,
             )
 
         e_spatial = torch.cat(
-            [
+            (
                 e_spatial.to(device),
                 knn_edges.to(device),
-            ],
-            axis=-1,
+            ),
+            dim=-1,
         )
 
         return e_spatial
@@ -101,18 +107,18 @@ class EmbeddingBase(ModelBase):
         random_pairs = torch.stack([query_indices[indices_src], indices_dest])
 
         e_spatial = torch.cat(
-            [e_spatial.to(device), random_pairs.to(device)],
-            axis=-1,
+            (e_spatial.to(device), random_pairs.to(device)),
+            dim=-1,
         )
         return e_spatial
 
     def get_true_pairs(self, e_spatial, y_cluster, new_weights, e_bidir):
         e_spatial = torch.cat(
-            [
+            (
                 e_spatial.to(self.device),
                 e_bidir,
-            ],
-            axis=-1,
+            ),
+            dim=-1,
         )
         y_cluster = torch.cat(
             [y_cluster.int(), torch.ones(e_bidir.shape[1], device=self.device)]
@@ -137,7 +143,11 @@ class EmbeddingBase(ModelBase):
         return hinge, d
 
     def get_truth(self, batch, e_spatial, e_bidir):
-        e_spatial, y_cluster = graph_intersection(e_spatial, e_bidir)
+        e_spatial, y_cluster = graph_intersection(
+            e_spatial,
+            e_bidir,
+            device=self.device,
+        )
 
         return e_spatial, y_cluster
 
@@ -238,7 +248,12 @@ class EmbeddingBase(ModelBase):
 
         # Build whole KNN graph
         e_spatial = build_edges(
-            spatial, spatial, indices=None, r_max=knn_radius, k_max=knn_num
+            spatial,
+            spatial,
+            indices=None,
+            r_max=knn_radius,
+            k_max=knn_num,
+            device=self.device,
         )
         if not self.bidir:
             sort_edge_nodes(e_spatial, batch.un_z)
diff --git a/LHCb_Pipeline/Embedding/embedding_plots.py b/LHCb_Pipeline/Embedding/embedding_plots.py
index 51e914b83efea1ed59a15527ecf4e3fa66971cd9..144e44c1e5853c312af6bb90db1723524786bfd1 100644
--- a/LHCb_Pipeline/Embedding/embedding_plots.py
+++ b/LHCb_Pipeline/Embedding/embedding_plots.py
@@ -4,6 +4,7 @@ import typing
 import os.path as op
 
 import numpy as np
+import numpy.typing as npt
 from uncertainties import unumpy as unp
 import matplotlib.pyplot as plt
 from matplotlib.figure import Figure
@@ -30,7 +31,6 @@ def plot_embedding_performance_given_radius_knn_max(
     show_err: bool = True,
 ) -> typing.Tuple[
     typing.Dict[str, typing.Tuple[Figure, Axes]],
-    typing.Tuple[Figure, Axes],
     typing.Tuple[
         np.ndarray, typing.Dict[str, unp.matrix], typing.Dict[str, unp.matrix]
     ],
@@ -121,7 +121,7 @@ def plot_best_performances_radius(
     knn_max: int | None = None,
     n_events: int | None = None,
     seed: int | None = None,
-) -> typing.Tuple[Figure, Axes, typing.Dict[str, typing.Dict[str, float]]]:
+) -> typing.Tuple[Figure, npt.NDArray, typing.Dict[str, typing.Dict[str, float]]]:
    
    embeddingRadiusExplorer = EmbeddingRadiusExplorer(model=model)
    config = load_config(path_or_config=path_or_config)
@@ -132,5 +132,5 @@ def plot_best_performances_radius(
        n_events=n_events,
        seed=seed,
        knn_max=knn_max,
-       building=config["metric_learning"].get("building"),
+       processing=config["metric_learning"].get("processing"),
    )
diff --git a/LHCb_Pipeline/Embedding/embedding_validation.py b/LHCb_Pipeline/Embedding/embedding_validation.py
index fc338ba74a853d905c5274e69efba2039ebb73a3..1ca7ea417cb12b5b7a12c8f4ac4b15fcf4391dac 100644
--- a/LHCb_Pipeline/Embedding/embedding_validation.py
+++ b/LHCb_Pipeline/Embedding/embedding_validation.py
@@ -164,7 +164,7 @@ class EmbeddingRadiusExplorer(ParamExplorer):
         value: float,
         batches: typing.List[Data],
         knn_max: int | None = None,
-        building: str | None = None,
+        processing: str | typing.List[str] | None = None,
     ) -> pd.DataFrame:
         # Run embedding inference
         embeddingInferenceBuilder = EmbeddingInferenceBuilder(
@@ -175,7 +175,7 @@ class EmbeddingRadiusExplorer(ParamExplorer):
         batches = [
             embeddingInferenceBuilder.process_one_step(
                 batch=batch.clone(),
-                building=building,
+                processing=processing,
             )
             for batch in tqdm(batches, desc="Graph Building")
         ]
diff --git a/LHCb_Pipeline/Embedding/graphutils.py b/LHCb_Pipeline/Embedding/graphutils.py
index aaa6d8975cdeab9087dbade9503d8a500ff8a783..b9ac48e8573fec1d159e26c78ee8b1ccd4938edb 100644
--- a/LHCb_Pipeline/Embedding/graphutils.py
+++ b/LHCb_Pipeline/Embedding/graphutils.py
@@ -1,8 +1,8 @@
+from __future__ import annotations
 import torch
 import scipy as sp
 import numpy as np
 
-
 # Ideally, we would be using FRNN and the GPU.
 # But in the case of a user not having a GPU, or not having FRNN, we import FAISS as the
 # nearest neighbor library
@@ -17,9 +17,9 @@ except ImportError:
     FRNN_AVAILABLE = False
 
 if torch.cuda.is_available():
-    device = "cuda"
+    default_device = "cuda"
 else:
-    device = "cpu"
+    default_device = "cpu"
     FRNN_AVAILABLE = False
 
 FRNN_AVAILABLE = False
@@ -34,8 +34,17 @@ def get_edge_subset(edges, mask_where, inverse_mask):
 
 
 def graph_intersection(
-    pred_graph, truth_graph, using_weights=False, weights_bidir=None
+    pred_graph,
+    truth_graph,
+    using_weights=False,
+    weights_bidir=None,
+    device: str | torch.device | None = None,
 ):
+    if device is None:
+        device = default_device
+    elif str(device) == "cuda:0":
+        device = "cuda"
+
     if pred_graph.numel() > 0:
         pred_graph_max = pred_graph.max().item()
     else:
@@ -110,13 +119,23 @@ def graph_intersection_pytorch(pred_graph, truth_graph):
 
 
 def build_edges(
-    query, database, indices=None, r_max=1.0, k_max=10, return_indices=False
+    query,
+    database,
+    indices=None,
+    r_max=1.0,
+    k_max=10,
+    return_indices=False,
+    device: str | torch.device | None = None,
 ):
     """
     NOTE: These KNN/FRNN algorithms return the distances**2.
     Therefore we need to be careful when comparing them to the target distances
     (r_val, r_test), and to the margin parameter (which is L1 distance)
     """
+    if device is None:
+        device = default_device
+    elif str(device) == "cuda:0":
+        device = "cuda"
 
     if FRNN_AVAILABLE:
         Dsq, I, nn, grid = frnn.frnn_grid_points(
@@ -139,13 +158,15 @@ def build_edges(
         edge_list = torch.stack([ind[positive_idxs], I[positive_idxs]]).long()
 
     else:
-        if device == "cuda":
+        if str(device) == "cuda":
             res = faiss.StandardGpuResources()
             Dsq, I = faiss.knn_gpu(res=res, xq=query, xb=database, k=k_max)
-        elif device == "cpu":
+        elif str(device) == "cpu":
             index = faiss.IndexFlatL2(database.shape[1])
             index.add(database)
             Dsq, I = index.search(query, k_max)
+        else:
+            raise ValueError(f"Device {device} is not recognised.")
 
         ind = torch.Tensor.repeat(
             torch.arange(I.shape[0], device=device), (I.shape[1], 1), 1
diff --git a/LHCb_Pipeline/Embedding/building_custom.py b/LHCb_Pipeline/Embedding/process_custom.py
similarity index 96%
rename from LHCb_Pipeline/Embedding/building_custom.py
rename to LHCb_Pipeline/Embedding/process_custom.py
index 2b0fb9b5905496327468ea61f72a8028d34ba4ec..1baa9cf56310b71f19cfbea5a49166a4728d4bce 100644
--- a/LHCb_Pipeline/Embedding/building_custom.py
+++ b/LHCb_Pipeline/Embedding/process_custom.py
@@ -19,7 +19,9 @@ def edges_at_least_3_hits(batch: Data) -> Data:
     # Classify the hits as fake (so that the edges are also classified like so)
     batch.particle_id[not_reconstructible_mask] = 0
 
-    # Remove edges in same `plane` and `z`
+    return batch
+
+def remove_edges_in_same_plane(batch: Data) -> Data:
     edge_index_plane = batch.plane[batch.edge_index]
     no_self_edge_mask = edge_index_plane[0] != edge_index_plane[1]
     batch["edge_index"] = batch.edge_index[:, no_self_edge_mask]
diff --git a/LHCb_Pipeline/GNN/gnn_base.py b/LHCb_Pipeline/GNN/gnn_base.py
index 6b167b123405bad2f078fa06388249be5cf76302..9a47457d5977d126042fe0f354ba47e166976fee 100644
--- a/LHCb_Pipeline/GNN/gnn_base.py
+++ b/LHCb_Pipeline/GNN/gnn_base.py
@@ -6,6 +6,7 @@ import torch
 from torch_geometric.data import Data
 
 from utils.modelutils.basemodel import ModelBase
+from utils.loaderutils.dataiterator import LazyDatasetBase
 
 
 def compute_edge_labels(
@@ -36,21 +37,14 @@ def compute_edge_labels(
     )
 
 
-class GNNBase(ModelBase):
-    @property
-    def bidir(self) -> bool:
-        return self.hparams.get("bidir", True)
-
-    def fetch_dataset(self, input_path: str, **kwargs) -> Data:
-        """Load and process one PyTorch DataSet.
-
-        Args:
-            input_path: path to the PyTorch dataset
+class GNNLazyDataset(LazyDatasetBase):
+    def __init__(self, bidir: bool, shuffle_edge_direction: bool, **kwargs):
+        super().__init__(**kwargs)
+        self.bidir = bool(bidir)
+        self.shuffle_edge_direction = bool(shuffle_edge_direction)
 
-        Returns:
-            PyTorch DataSet
-        """
-        loaded_event = super(GNNBase, self).fetch_dataset(
+    def fetch_dataset(self, input_path: str, **kwargs):
+        loaded_event = super(GNNLazyDataset, self).fetch_dataset(
             input_path=input_path, **kwargs
         )
 
@@ -61,7 +55,7 @@ class GNNBase(ModelBase):
                 particle_ids=loaded_event.particle_id,
             )
 
-        if self.hparams.get("shuffle_edge_direction", False):
+        if self.shuffle_edge_direction:
             assert self.bidir, (
                 "It was required to shuffle the edge directions, even though "
                 "the graph is not bidirectional. This is odd."
@@ -78,6 +72,35 @@ class GNNBase(ModelBase):
 
         return loaded_event
 
+
+class GNNBase(ModelBase):
+    def get_lazy_dataset(self, *args, **kwargs) -> GNNLazyDataset:
+        """Get the GNN lazy dataset object.
+
+        Args:
+            input_dir: input directory
+            n_events: number of events to load
+            shuffle: whether to shuffle the input paths (applied before
+                selected the first ``n_events``)
+            seed: seed for the shuffling
+            **kwargs: Other keyword arguments passed to the
+                :py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` constructor.
+
+        Returns:
+            :py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` object
+        """
+        return GNNLazyDataset(
+            *args,
+            bidir=self.bidir,
+            shuffle_edge_direction=self.hparams.get("shuffle_edge_direction", False),
+            **kwargs,
+        )
+
+    @property
+    def bidir(self) -> bool:
+        """Whether the graph is bidirectional"""
+        return self.hparams.get("bidir", True)
+
     def handle_bidirectional(self, edge_sample, truth_sample):
         if self.bidir:
             edge_sample = torch.cat([edge_sample, edge_sample.flip(0)], dim=-1)
@@ -117,6 +140,9 @@ class GNNBase(ModelBase):
                 else (~truth).sum() / truth.sum()
             )
 
+        if not output.shape:
+            output = output.reshape((-1,))
+
         # Compute weighted loss
         if self.hparams.get("focal_loss", False):
             from torchvision.ops import sigmoid_focal_loss
@@ -157,7 +183,7 @@ class GNNBase(ModelBase):
 
     def common_training_validation_step(
         self, batch: Data
-    ) -> typing.Tuple[torch.Tensor, torch.Tensor, float]:
+    ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         """Perform the inference and loss computation step that is common
         to the training and validation step.
 
@@ -183,7 +209,7 @@ class GNNBase(ModelBase):
             "train_loss",
             loss,
             on_epoch=True,
-            on_step=False,
+            on_step=self.hparams.get("on_step", False),
             batch_size=output.shape[0],
             prog_bar=True,
         )
@@ -217,13 +243,13 @@ class GNNBase(ModelBase):
                 "current_lr": current_lr,
             },
             on_epoch=True,
-            on_step=False,
+            on_step=self.hparams.get("on_step", False),
             batch_size=preds.shape[0],
         )
 
     def shared_evaluation(
         self, batch: Data, batch_idx: int, log: bool = False
-    ) -> typing.Dict[str, float]:
+    ) -> typing.Dict[str, torch.Tensor]:
         output, truth, loss = self.common_training_validation_step(batch=batch)
         # Edge filter performance
         score = torch.sigmoid(output)
@@ -263,6 +289,9 @@ class GNNBase(ModelBase):
             lr_scale = min(1.0, float(self.current_epoch + 1) / self.hparams["warmup"])
             for pg in optimizer.param_groups:
                 pg["lr"] = lr_scale * self.hparams["lr"]
+        else:
+            for pg in optimizer.param_groups:
+                pg["lr"] = self.lr_schedulers().get_last_lr()[0]
 
         # update params
         optimizer.step(closure=optimizer_closure)
diff --git a/LHCb_Pipeline/Preprocessing/inputloader.py b/LHCb_Pipeline/Preprocessing/inputloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..7885756ccf43642da26db1a1eaaa03b955b24fa3
--- /dev/null
+++ b/LHCb_Pipeline/Preprocessing/inputloader.py
@@ -0,0 +1,218 @@
+"""A module that defines the input loader that allow to loop over events scattered
+in different parquet or CSV files.
+"""
+import typing
+import logging
+import os
+from functools import reduce
+
+import numpy as np
+import numpy.typing as npt
+import pandas as pd
+from pandas.core.groupby.generic import DataFrameGroupBy
+
+
+def get_indirs(
+    input_dir: str | None = None,
+    subdirs: int | str | typing.List[str] | typing.Dict[str, int] | None = None,
+):
+    """Get the input directories that can be used as input of the preprocessing.
+
+    Args:
+        input_dir: A single input directory if ``subdirs`` is ``None``,
+            or the main directory where sub-directories are
+        subdirs:
+        
+            * If ``subdirs`` is None, there is a single input directory, ``input_dir``
+            * If ``subdirs`` is a string or a list of strings, they specify \
+            the sub-directories with respect to ``input_dir``. If ``input_dir`` \
+            is ``None``, then they are the (list of) input directories directly, which \
+            can be useful if the input directories are not at the same location \
+            (even though it is discouraged)
+            * If ``subdirs`` is an integer, it corresponds to the the name of the last \
+            sub-directory to consider (i.e., from 0 to ``subdirs``). If ``subdirs`` \
+            is ``-1``, all the sub-directories are considered as input.
+            * If ``subdirs`` is a dictionary, the keys ``start`` and ``stop`` specify \
+            the first and last sub-directories to consider as input.
+    
+    Returns:
+        List of input directories that can be considered.
+    """
+    if input_dir is None:
+        if isinstance(subdirs, str):
+            return [subdirs]
+        elif isinstance(subdirs, list):
+            return [str(subdir) for subdir in subdirs]
+        else:
+            raise TypeError(
+                "`input_dir` is `None` but `subdirs` is neither a string nor "
+                "a list of strings, so the input directories of the preprocessing "
+                "cannot be determined."
+            )
+    else:
+        # Get the list of all the sub-directories inside ``input_dir``
+
+        # Filter this list according to ``subdirs``
+        if subdirs is None:
+            return [input_dir]
+        elif isinstance(subdirs, (int, dict)):
+            available_subdirs = sorted(
+                [
+                    int(file_or_dir.name)
+                    for file_or_dir in os.scandir(input_dir)
+                    if file_or_dir.is_dir()
+                ]
+            )
+            if subdirs == -1:
+                final_subdirs = available_subdirs
+            else:
+                if isinstance(subdirs, int):
+                    start = 0
+                    stop = subdirs
+                else:  # dict
+                    start = subdirs.get("start", 0)
+                    stop = subdirs["stop"]
+
+                assert (
+                    stop >= start
+                ), f"`start` ({start}) is strictly higher than `stop ({stop})"
+                final_subdirs = [
+                    subdir
+                    for subdir in available_subdirs
+                    if subdir >= start and subdir <= stop
+                ]
+        elif isinstance(subdirs, str):
+            final_subdirs = [subdirs]
+        elif isinstance(subdirs, list):
+            final_subdirs = subdirs
+        else:
+            raise ValueError(
+                f"`input_dir` is not `None` and `subdirs` is `{subdirs}`, which are "
+                "not valid inputs."
+            )
+
+        return [os.path.join(input_dir, str(subdir)) for subdir in final_subdirs]
+
+
+class DataFrameLoader:
+    """Iterator over events scattered in various CSV or parquet files.
+
+    Attributes:
+        indirs: Input directories where the dataframes are
+        event_key: dataframe column that identifies an event
+    """
+
+    def __init__(
+        self,
+        input_dir: str | None = None,
+        subdirs: int | str | typing.List[str] | typing.Dict[str, int] | None = None,
+        event_key: str = "event",
+    ) -> None:
+        """
+        
+        Args:
+            input_dir: A single input directory if ``subdirs`` is ``None``,
+                or the main directory where sub-directories are
+            subdirs:
+            
+                * If ``subdirs`` is None, there is a single input directory, \
+                ``input_dir``
+                * If ``subdirs`` is a string or a list of strings, they specify \
+                the sub-directories with respect to ``input_dir``. If ``input_dir`` \
+                is ``None``, then they are the (list of) input directories directly, \
+                which can be useful if the input directories are not at the same \
+                location (even though it is discouraged)
+                * If ``subdirs`` is an integer, it corresponds to the the name of the \
+                last sub-directory to consider (i.e., from 0 to ``subdirs``). \
+                If ``subdirs`` is ``-1``, all the sub-directories are considered as \
+                input.
+                * If ``subdirs`` is a dictionary, the keys ``start`` and ``stop`` \
+                specify the first and last sub-directories to consider as input.
+
+            event_key: dataframe column that identifies an event
+        """
+        self.indirs = get_indirs(input_dir=input_dir, subdirs=subdirs)
+
+        if len(self.indirs) == 0:
+            raise ValueError("No input directories.")
+
+        logging.info("Input directories:")
+        for indir in self.indirs:
+            logging.info(f"- {indir}")
+        self.event_key = str(event_key)
+
+    def register_load(
+        self, func: typing.Callable[[str], typing.List[pd.DataFrame]]
+    ) -> None:
+        """Register the function that load the dataframe(s).
+
+        Args:
+            func: Function that takes as input a directory and returns a list
+                of dataframes.
+        """
+        self._load = func
+
+    @property
+    def current_subdir(self) -> str:
+        """Current sub-directory that contains the dataframes to load."""
+        assert self._indir_idx is not None
+        return self.indirs[self._indir_idx]
+
+    @property
+    def current_event_id(self) -> int:
+        """Current event ID of the current dataframe(s)."""
+        assert self._event_idx is not None
+        return self._event_ids[self._event_idx]
+
+    def __iter__(self) -> "DataFrameLoader":
+        """Set up the iterator over the events"""
+        self._indir_idx: int | None = None
+        self._event_idx: int | None = None
+        self._grouped_by_dataframes: typing.List[DataFrameGroupBy] = []
+
+        return self
+
+    def load_dataframes(self):
+        """Load the current dataframes and group by events."""
+        dataframes = self._load(self.current_subdir)
+        self._grouped_by_dataframes = [
+            dataframe.groupby(by=self.event_key) for dataframe in dataframes
+        ]
+
+        self._event_ids = reduce(
+            np.intersect1d, [dataframe["event"].unique() for dataframe in dataframes]
+        )
+
+    def update_idx(self):
+        if self._event_idx is None or self._indir_idx is None:
+            # if first iteraction
+            assert self._event_idx is None and self._indir_idx is None
+            self._event_idx = 0
+            self._indir_idx = 0
+        elif self._event_idx + 1 >= len(self._event_ids):
+            # if no more events in current dataframe
+            self._event_idx = 0
+            self._indir_idx += 1
+        else:
+            self._event_idx += 1
+
+    @property
+    def no_more_dataframe(self) -> bool:
+        """Whether there is no remaining dataframe to load."""
+        assert self._indir_idx is not None
+        return self._indir_idx >= len(self.indirs)
+
+    def __next__(self) -> typing.Tuple[typing.List[pd.DataFrame], int]:
+        self.update_idx()
+
+        # Load dataframes
+        if self._event_idx == 0:
+            if self.no_more_dataframe:
+                raise StopIteration()
+            else:
+                self.load_dataframes()
+
+        return [
+            dataframe.get_group(self.current_event_id)
+            for dataframe in self._grouped_by_dataframes
+        ], self.current_event_id
diff --git a/LHCb_Pipeline/Preprocessing/preprocessing.py b/LHCb_Pipeline/Preprocessing/preprocessing.py
index e20f6b8a4be37a9a39ffac247520d21e949034e0..5436bbfca6f06f18ab83b8ce5835e25a4828962c 100644
--- a/LHCb_Pipeline/Preprocessing/preprocessing.py
+++ b/LHCb_Pipeline/Preprocessing/preprocessing.py
@@ -2,11 +2,14 @@ from __future__ import annotations
 import typing
 import os
 import logging
+from functools import partial
+
 from tqdm.auto import tqdm
 import numpy as np
 import pandas as pd
 
-from . import selecting
+from . import process_custom
+from .inputloader import DataFrameLoader
 
 
 def cast_boolean_columns(particles: pd.DataFrame):
@@ -140,141 +143,43 @@ def enough_true_hits(
         return True
 
 
-def load_and_filter_dataframes(
-    indir: str,
-    particles_columns: typing.List[str] | None = None,
-    hits_particles_columns: typing.List[str] | None = None,
-    selection: str | None = None,
-    **kwargs,
+def apply_custom_processing(
+    hits_particles: pd.DataFrame,
+    particles: pd.DataFrame,
+    processing: str | typing.Sequence[str] | None = None,
 ) -> typing.Tuple[pd.DataFrame, pd.DataFrame]:
-    """Load and filter the dataframes of hits-particles and particles.
+    """Apply custom processing to the dataframe of hits-particles and particles.
+    The custom processing functions are defined in :py:mod:`.process_custom`.
 
     Args:
-        indir: directory where the dataframes are saved
-        particles_columns: columns to load for the dataframe of particles
-        hits_particles_columns: columns to load for the dataframe of hits
-            and the hits-particles association information
-        selection: function to use to filter the candidates.
-            The latter is defined in the :py:mod:`.selecting` module.
-        **kwargs: other keyword arguments passed to the function that load the files
-
-    Returns:
-        Filtered dataframes of hits-particles and of particles
-    """
-
-    # Load dataframes
-    logging.info(f"Load dataframes in {indir}")
-    hits_particles, particles = load_dataframes(
-        indir=indir,
-        particles_columns=particles_columns,
-        hits_particles_columns=hits_particles_columns,
-        **kwargs,
-    )
-
-    # Add truth particle information to the dataframe of hits
-    if selection:
-        logging.info(f"Apply selection `{selection}`")
-        selection_function: selecting.SelectionFunction = getattr(selecting, selection)
-        hits_particles, particles = selection_function(
-            hits_particles=hits_particles,
-            particles=particles,
-        )
-
-    # Define `n_unique_planes`
-    # We'll train with all the hits for the training
-    # And cut before the GNN
-    n_unique_planes = (
-        hits_particles.groupby(["event", "particle_id"])["plane"]
-        .nunique()
-        .rename("n_unique_planes")
-    )
-    particles = particles.merge(
-        n_unique_planes, how="left", on=["event", "particle_id"]
-    ).fillna(0)
-
-    return hits_particles, particles
-
-
-def get_indirs(
-    input_dir: str | None = None,
-    subdirs: int | str | typing.List[str] | typing.Dict[str, int] | None = None,
-):
-    """Get the input directories that can be used as input of the preprocessing.
+        hits_particles: dataframe of hits-particles
+        particles: dataframe of particles
+        processing: Name(s) of the processing function(s) to apply to the dataframes.
+            The processing functions as defined in :py:mod:`.process_custom`
 
-    Args:
-        input_dir: A single input directory if ``subdirs`` is ``None``,
-            or the main directory where sub-directories are
-        subdirs:
-        
-        * If ``subdirs`` is None, there is a single input directory, ``input_dir``
-        * If ``subdirs`` is a string or a list of strings, they specify \
-        the sub-directories with respect to ``input_dir``. If ``input_dir`` \
-        is ``None``, then they are the (list of) input directories directly, which \
-        can be useful if the input directories are not at the same location \
-        (even though it is discouraged)
-        * If ``subdirs`` is an integer, it corresponds to the the name of the last \
-        sub-directory to consider (i.e., from 0 to ``subdirs``). If ``subdirs`` \
-        is ``-1``, all the sub-directories are considered as input.
-        * If ``subdirs`` is a dictionary, the keys ``start`` and ``stop`` specify \
-        the first and last sub-directories to consider as input.
-    
     Returns:
-        List of input directories that can be considered.
+        Processed dataframe of hits-particles and particles
     """
-    if input_dir is None:
-        if isinstance(subdirs, str):
-            return [subdirs]
-        elif isinstance(subdirs, list):
-            return [str(subdir) for subdir in subdirs]
-        else:
-            raise TypeError(
-                "`input_dir` is `None` but `subdirs` is neither a string nor "
-                "a list of strings, so the input directories of the preprocessing "
-                "cannot be determined."
-            )
+    if processing is None:
+        return hits_particles, particles
     else:
-        # Get the list of all the sub-directories inside ``input_dir``
-
-        # Filter this list according to ``subdirs``
-        if subdirs is None:
-            return [input_dir]
-        elif isinstance(subdirs, (int, dict)):
-            available_subdirs = sorted(
-                [
-                    int(file_or_dir.name)
-                    for file_or_dir in os.scandir(input_dir)
-                    if file_or_dir.is_dir()
-                ]
-            )
-            if subdirs == -1:
-                final_subdirs = available_subdirs
-            else:
-                if isinstance(subdirs, int):
-                    start = 0
-                    stop = subdirs
-                else:  # dict
-                    start = subdirs.get("start", 0)
-                    stop = subdirs["stop"]
-
-                assert (
-                    stop >= start
-                ), f"`start` ({start}) is strictly higher than `stop ({stop})"
-                final_subdirs = [
-                    subdir
-                    for subdir in available_subdirs
-                    if subdir >= start and subdir <= stop
-                ]
-        elif isinstance(subdirs, str):
-            final_subdirs = [subdirs]
-        elif isinstance(subdirs, list):
-            final_subdirs = subdirs
+        # Get name of processing functions
+        if isinstance(processing, str):
+            processing_fct_names = [processing]
         else:
-            raise ValueError(
-                f"`input_dir` is not `None` and `subdirs` is `{subdirs}`, which are "
-                "not valid inputs."
+            processing_fct_names = [
+                str(processing_fct_name) for processing_fct_name in processing
+            ]
+
+        # Apply processing
+        for processing_fct_name in processing_fct_names:
+            logging.info(f"Apply `{processing_fct_name}`")
+            processing_fct: process_custom.SelectionFunction = getattr(
+                process_custom, processing_fct_name
             )
+            hits_particles, particles = processing_fct(hits_particles, particles)
 
-        return [os.path.join(input_dir, str(subdir)) for subdir in final_subdirs]
+        return hits_particles, particles
 
 
 def preprocess(
@@ -282,7 +187,7 @@ def preprocess(
     output_dir: str,
     subdirs: int | str | typing.List[str] | None = None,
     n_events: int = -1,
-    selection: str | None = None,
+    processing: str | typing.List[str] | None = None,
     num_true_hits_threshold: int | None = None,
     hits_particles_columns: typing.List[str] | None = None,
     particles_columns: typing.List[str] | None = None,
@@ -295,83 +200,86 @@ def preprocess(
     os.makedirs(output_dir, exist_ok=True)
     logging.info(f"Preprocessing: output will be written in {output_dir}")
 
-    indirs = get_indirs(input_dir=input_dir, subdirs=subdirs)
-    if len(indirs) == 0:
-        raise ValueError("No input directories.")
-    logging.info("Input directories:")
-    for indir in indirs:
-        logging.info(f"- {indir}")
+    def load_and_process_dataframes_reduced(indir: str) -> typing.List[pd.DataFrame]:
+        """Load the dataframes of hits-particles and particles located in ``indir``,
+        and apply the custom processing functions defined by ``processing``.
+        
+        Args:
+            indir: input directory where the dataframes of hits-particles and particles
+                are saved
+        
+        Returns:
+            List of two dataframes: the dataframe of hits-particles and the dataframe
+                of particles
+
+        """
+        logging.info(f"Load dataframes in {indir}")
+        hits_particles, particles = load_dataframes(
+            indir=indir,
+            particles_columns=particles_columns,
+            hits_particles_columns=particles_columns,
+        )
+
+        hits_particles, particles = apply_custom_processing(
+            hits_particles=hits_particles,
+            particles=particles,
+            processing=processing,
+        )
+        return [hits_particles, particles]
+
+    dataFrameLoader = DataFrameLoader(input_dir=input_dir, subdirs=subdirs)
+    dataFrameLoader.register_load(load_and_process_dataframes_reduced)
 
     n_output_saved = 0  # Count the number of events outputted
-    event_idx = 0
 
     n_required_events = np.inf if n_events == -1 else n_events
-    left_indirs = indirs
+
     logging.info(f"Number of events to produce: {n_required_events}")
     with tqdm(total=n_required_events) as pbar:
-        while n_output_saved < n_required_events and left_indirs:
-            hits_particles, particles = load_and_filter_dataframes(
-                indir=left_indirs[0],
-                hits_particles_columns=hits_particles_columns,
-                particles_columns=particles_columns,
-                selection=selection,
-            )
-            left_indirs = left_indirs[1:]
-
-            event_ids_in_df_hits_particles = hits_particles["event"].unique()
-            event_ids_df_particles = particles["event"].unique()
-            event_ids = np.intersect1d(
-                event_ids_in_df_hits_particles, event_ids_df_particles
-            )
+        for (event_hits_particles, event_particles), event_id in iter(dataFrameLoader):
+            # Stop when required # events processed
+            if n_output_saved >= n_required_events:
+                break
+
+            #: String representation of the event ID
+            event_id_str = str(event_id).zfill(9)
+
+            no_hits = event_hits_particles.shape[0] == 0
+
+            if not no_hits and (  # skip events with no hits
+                (num_true_hits_threshold is None)
+                or enough_true_hits(
+                    event_hits_particles=event_hits_particles,
+                    num_true_hits_threshold=num_true_hits_threshold,
+                    event_id_str=event_id_str,
+                    num_events=n_output_saved,
+                    required_num_events=n_events,
+                )
+            ):
+                # Select subset of columns
+                if hits_particles_columns is None:
+                    hits_particles_csv = event_hits_particles
+                else:
+                    hits_particles_csv = event_hits_particles[
+                        ["particle_id", "hit_id"] + hits_particles_columns
+                    ]
+                if particles_columns is None:
+                    particles_csv = event_particles
+                else:
+                    particles_csv = event_particles[
+                        ["particle_id"] + particles_columns
+                    ]
+
+                # Save
+                hits_particles_csv.to_parquet(
+                    f"{output_dir}/event{event_id_str}-hits_particles.parquet",
+                )
+                particles_csv.to_parquet(
+                    f"{output_dir}/event{event_id_str}-particles.parquet",
+                )
+                n_output_saved += 1
+                pbar.update()
 
-            # Loop over the events in the dataframe of hits-particles
-            grouped_df_hits_particles = hits_particles.groupby("event")
-            grouped_df_particles = particles.groupby("event")
-            for event_id in event_ids:
-                if n_output_saved >= n_required_events:
-                    break
-                event_hits_particles = grouped_df_hits_particles.get_group(event_id)
-                event_particles = grouped_df_particles.get_group(event_id)
-
-                #: String representation of the event ID
-                event_id_str = str(event_id).zfill(9)
-
-                no_hits = event_hits_particles.shape[0] == 0
-
-                if not no_hits and (
-                    (num_true_hits_threshold is None)
-                    or enough_true_hits(
-                        event_hits_particles=event_hits_particles,
-                        num_true_hits_threshold=num_true_hits_threshold,
-                        event_id_str=event_id_str,
-                        num_events=n_output_saved,
-                        required_num_events=n_events,
-                    )
-                ):
-                    # Save subset of columns
-                    if hits_particles_columns is None:
-                        hits_particles_csv = event_hits_particles
-                    else:
-                        hits_particles_csv = event_hits_particles[
-                            ["particle_id", "hit_id"] + hits_particles_columns
-                        ]
-                    if particles_columns is None:
-                        particles_csv = event_particles
-                    else:
-                        particles_csv = event_particles[
-                            ["particle_id"] + particles_columns
-                        ]
-
-                    # Save
-                    hits_particles_csv.to_parquet(
-                        f"{output_dir}/event{event_id_str}-hits_particles.parquet",
-                    )
-                    particles_csv.to_parquet(
-                        f"{output_dir}/event{event_id_str}-particles.parquet",
-                    )
-                    n_output_saved += 1
-                    pbar.update()
-                event_idx += 1
         pbar.close()
 
     pd.set_option("chained_assignment", "warn")  # re-enable chained-assignment warning
diff --git a/LHCb_Pipeline/Preprocessing/selecting.py b/LHCb_Pipeline/Preprocessing/process_custom.py
similarity index 92%
rename from LHCb_Pipeline/Preprocessing/selecting.py
rename to LHCb_Pipeline/Preprocessing/process_custom.py
index 9a0000c0d6a4528b55facbfa9ca0f752a3bac5e6..c138258da37958616f427142a65e60af0d5e0731 100644
--- a/LHCb_Pipeline/Preprocessing/selecting.py
+++ b/LHCb_Pipeline/Preprocessing/process_custom.py
@@ -209,3 +209,23 @@ def triplets_first_selection(
     assert not hits_particles.isna().any().any()
 
     return hits_particles, particles
+
+
+def compute_n_unique_planes(
+    hits_particles: pd.DataFrame, particles: pd.DataFrame
+) -> typing.Tuple[pd.DataFrame, pd.DataFrame]:
+    """Compute number of unique planes for each particle.
+    """
+    
+    # We'll train with all the hits for the training
+    # And cut before the GNN
+    n_unique_planes = (
+        hits_particles.groupby(["event", "particle_id"])["plane"]
+        .nunique()
+        .rename("n_unique_planes")
+    )
+    particles = particles.merge(
+        n_unique_planes, how="left", on=["event", "particle_id"]
+    ).fillna(0)
+
+    return hits_particles, particles
diff --git a/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py b/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py
index 67f8d9a30f3bb05fe151a4e118fee8a27613661c..f5ed7bc7dd7ddde0e45e1899d78ad05b4f553c30 100644
--- a/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py
+++ b/LHCb_Pipeline/Scripts/Step_2_Run_Metric_Learning.py
@@ -24,6 +24,7 @@ def train(
     checkpoint: LayerlessEmbedding | str | None = None,
     reproduce: bool = True,
     override_hparams: bool = False,
+    use_gpu: bool = True,
     **kwargs,
 ):
     """Run the inference of the metric learning stage.
@@ -44,6 +45,7 @@ def train(
         reproduce: whether to delete an existing folder
         override_hparams: whether to override the hyparameters of the model
             that is loaded, with the ones in the YAML configuration
+        use_gpu: whether to use the GPU (if available)
         **kwargs: Other keyword arguments passed to the
             :py:func:`PyTorch.LightingModel.load_from_checkpoint` class method
     """
@@ -55,7 +57,7 @@ def train(
 
     logging.info(headline("a) Loading trained model"))
 
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
 
     if override_hparams:
         kwargs = {"hparams": metric_learning_configs, **kwargs}
@@ -79,8 +81,8 @@ def train(
     else:
         radius = metric_learning_configs["r_test"]
 
-    building = metric_learning_configs.pop("building", None)
-    filtering = metric_learning_configs.pop("filtering", None)
+    test_processing = metric_learning_configs.pop("test_processing", None)
+    training_processing = metric_learning_configs.pop("training_processing", None)
 
     logging.info(f"Use radius {radius}")
     graph_builder = EmbeddingInferenceBuilder(
@@ -97,9 +99,9 @@ def train(
         test_dataset_names=get_required_test_dataset_names(all_configs),
         reproduce=reproduce,
         list_kwargs=[
-            dict(building=building, filtering=filtering)
+            dict(processing=training_processing)
             if partition in ["train", "val"]
-            else dict(building=building)
+            else dict(processing=test_processing)
             for partition in partitions
         ],
     )
diff --git a/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py b/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py
index 7874c7a2a7fd45265a81331c791332b306c4b4a4..47d348a942180af257046e60432b68397f128d55 100644
--- a/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py
+++ b/LHCb_Pipeline/Scripts/Step_5_Build_Track_Candidates.py
@@ -47,13 +47,16 @@ class TrackBuilder(BuilderBase):
             ]
         else:
             double_edge_indices = edge_indices
-
-        labels = torch.from_numpy(
-            get_track_ids(
-                edge_indices=double_edge_indices,
-                n_hits=batch.x.shape[0],
-            )
-        ).long()
+            
+        if double_edge_indices.nelement():
+            labels = torch.from_numpy(
+                get_track_ids(
+                    edge_indices=double_edge_indices,
+                    n_hits=batch.x.shape[0],
+                )
+            ).long()
+        else:
+            labels = torch.arange(batch.x.shape[0])
         batch.labels = labels
 
         return batch
diff --git a/LHCb_Pipeline/Scripts/Step_7_Compare_With_Allen.py b/LHCb_Pipeline/Scripts/Step_7_Compare_With_Allen.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66307ae88951116a7db6cc302a9bc12a123a766
--- /dev/null
+++ b/LHCb_Pipeline/Scripts/Step_7_Compare_With_Allen.py
@@ -0,0 +1,217 @@
+from __future__ import annotations
+import typing
+import os.path as op
+
+import numpy as np
+import matplotlib.pyplot as plt
+import montetracko as mt
+import montetracko.lhcb as mtb
+
+from Scripts.Step_6_Evaluate_Reconstruction_MonteTracko import (
+    load_parquet_files,
+    perform_matching,
+    perform_evaluation,
+)
+from utils.commonutils.ctests import get_preprocessed_test_dataset_dir
+from utils.plotutils import plotconfig
+from utils.plotutils.plotools import save_fig
+from utils.commonutils.config import load_config
+
+
+def plot_histograms(
+    trackEvaluator1: mt.TrackEvaluator,
+    trackEvaluator2: mt.TrackEvaluator,
+    label1: str,
+    label2: str,
+    color1: str,
+    color2: str,
+    columns: typing.List[str],
+    metric_names: typing.List[str],
+    column_labels: typing.Optional[typing.Dict[str, str]] = None,
+    bins: typing.Optional[
+        int | typing.Sequence[float] | str | typing.Dict[str, typing.Any]
+    ] = 50,
+    column_ranges: typing.Optional[typing.Dict[str, typing.Tuple[float, float]]] = None,
+    average: typing.Optional[typing.List[str]] = None,
+    category: typing.Optional[mt.requirement.Category] = None,
+    hide_repetitive_labels: bool = True,
+    **kwargs,
+):
+    fig, axes = plt.subplots(
+        nrows=len(metric_names),
+        ncols=len(columns),
+        figsize=(8 * len(columns), 6 * len(metric_names)),
+    )
+    axes = np.atleast_2d(axes)
+    axes_histogram = np.empty_like(axes)
+    for idx_col, column in enumerate(columns):
+        edges = None
+        for idx_metric, metric_name in enumerate(metric_names):
+            if edges is not None:
+                bins_metric = edges
+            elif isinstance(bins, dict) and column in bins:
+                bins_metric = bins[column]
+            else:
+                bins_metric = 20
+
+            (
+                particle_histogram,
+                array_metric_values,
+                edges,
+            ) = trackEvaluator1.compute_histogram(
+                column=column,
+                metric_name=metric_name,
+                bins=bins_metric,
+                range=column_ranges.get(column),
+                average=average,
+                category=category,
+            )
+            axes_histogram[idx_metric][idx_col] = trackEvaluator1._plot_histogram(
+                column=column,
+                metric_name=metric_name,
+                array_metric_values=array_metric_values,
+                edges=edges,
+                column_label=column_labels.get(column, column.replace("_", r"\_")),
+                histogram=particle_histogram,
+                ax=axes[idx_metric][idx_col],
+                label=label1,
+                color=color1,
+                **kwargs,
+            )
+            _, array_metric_values2, edges = trackEvaluator2.compute_histogram(
+                column=column,
+                metric_name=metric_name,
+                bins=edges,
+                range=column_ranges.get(column),
+                average=average,
+                category=category,
+            )
+            trackEvaluator2._plot_histogram(
+                column=column,
+                metric_name=metric_name,
+                array_metric_values=array_metric_values2,
+                edges=edges,
+                column_label=column_labels.get(column, column.replace("_", r"\_")),
+                ax=axes[idx_metric][idx_col],
+                label=label2,
+                color=color2,
+                **kwargs,
+            )
+            if idx_metric == 0 and idx_col == 0:
+                axes[idx_metric][idx_col].legend()
+
+    for idx_metric, metric_name in enumerate(metric_names):
+        ymins, ymaxs = [], []
+        for idx_col, column in enumerate(columns):
+            current_ymin, current_ymax = axes[idx_metric][idx_col].get_ylim()
+            ymins.append(current_ymin)
+            ymaxs.append(current_ymax)
+
+        ymin, ymax = min(ymins), max(ymaxs)
+        if ymin > 0:
+            ymin = 0
+        for idx_col, column in enumerate(columns):
+            axes[idx_metric][idx_col].set_ylim(ymin, ymax)
+
+    if hide_repetitive_labels:
+        for idx_metric, metric_name in enumerate(metric_names):
+            for idx_col, column in enumerate(columns):
+                if idx_metric != len(metric_names) - 1:
+                    axes[idx_metric][idx_col].tick_params(
+                        axis="x",
+                        labelbottom=False,
+                    )
+                    axes[idx_metric][idx_col].xaxis.label.set_visible(False)
+
+                if idx_col != 0:
+                    axes[idx_metric][idx_col].tick_params(
+                        axis="y",
+                        labelleft=False,
+                    )
+                    axes[idx_metric][idx_col].yaxis.label.set_visible(False)
+                if idx_col != len(columns) - 1:
+                    axes_histogram[idx_metric][idx_col].yaxis.label.set_visible(False)
+    return fig, axes, axes_histogram
+
+
+def evaluate_allen(event_ids, path_or_config: str | dict) -> mt.TrackEvaluator:
+    config = load_config(path_or_config)
+    trackhandler = mt.TrackHandler.from_padded_csv(
+        paths=[
+            op.join("/scratch/acorreia/tracks_allen_test", f"{event_id}.csv")
+            for event_id in event_ids
+        ],
+        padding_value=0,
+        skip_header=True,
+    )
+    preprocessed_input_dir = get_preprocessed_test_dataset_dir(
+        test_dataset_name="velo-sim10b-nospillover",
+        path_or_config=config,
+    )
+    truncated_paths = [
+        op.join(preprocessed_input_dir, "event" + str(event_id).zfill(9))
+        for event_id in event_ids
+    ]
+
+    df_hits_particles = load_parquet_files(
+        truncated_paths=truncated_paths,
+        ending="-hits_particles",
+        columns=["particle_id", "hit_id", "plane", "x", "y", "z"],
+    )
+    df_particles = load_parquet_files(
+        truncated_paths=truncated_paths, ending="-particles"
+    )
+    trackEvaluator_Allen = perform_matching(
+        df_tracks=trackhandler.dataframe,
+        df_hits_particles=df_hits_particles,
+        df_particles=df_particles,
+    )
+    perform_evaluation(
+        trackEvaluator_Allen,
+        output_dir=op.join(
+            config["common"]["performance_directory"],
+            config["common"]["experiment_name"],
+            "allen",
+        ),
+    )
+
+    return trackEvaluator_Allen
+
+
+def compare_allen_vs_etx4velo(
+    path_or_config: str | dict,
+    trackEvaluator: mt.TrackEvaluator,
+    trackEvaluator_Allen: mt.TrackEvaluator,
+):
+    config = load_config(path_or_config)
+
+    metric_names = ["efficiency", "hit_efficiency_per_candidate", "clone_rate"]
+    columns = ["nhits_velo", "vz", "pt"]
+
+    for category in [
+        mtb.category.category_velo_no_electrons,
+        mtb.category.category_long_only_electrons,
+    ]:
+        fig, _, _ = plot_histograms(
+            trackEvaluator,
+            trackEvaluator_Allen,
+            "etx4velo",
+            "Allen",
+            color1="blue",
+            color2="green",
+            columns=columns,
+            metric_names=metric_names,
+            column_ranges=plotconfig.column_ranges,
+            column_labels=plotconfig.column_labels,
+            bins=plotconfig.column_bins,
+            category=category,
+        )
+
+        save_fig(
+            fig,
+            op.join(
+                config["common"]["performance_directory"],
+                config["common"]["experiment_name"],
+                f"etx4velo_vs_allen_{category.name}",
+            ),
+        )
diff --git a/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed-20000.ipynb b/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed-20000.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..72b8637396cd9963ff076ee7934a51d5c090b670
--- /dev/null
+++ b/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed-20000.ipynb
@@ -0,0 +1,3305 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 0. Setup"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Imports"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div class=\"bk-root\">\n",
+       "        <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
+       "        <span id=\"1002\">Loading BokehJS ...</span>\n",
+       "    </div>\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/javascript": [
+       "(function(root) {\n",
+       "  function now() {\n",
+       "    return new Date();\n",
+       "  }\n",
+       "\n",
+       "  const force = true;\n",
+       "\n",
+       "  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n",
+       "    root._bokeh_onload_callbacks = [];\n",
+       "    root._bokeh_is_loading = undefined;\n",
+       "  }\n",
+       "\n",
+       "const JS_MIME_TYPE = 'application/javascript';\n",
+       "  const HTML_MIME_TYPE = 'text/html';\n",
+       "  const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n",
+       "  const CLASS_NAME = 'output_bokeh rendered_html';\n",
+       "\n",
+       "  /**\n",
+       "   * Render data to the DOM node\n",
+       "   */\n",
+       "  function render(props, node) {\n",
+       "    const script = document.createElement(\"script\");\n",
+       "    node.appendChild(script);\n",
+       "  }\n",
+       "\n",
+       "  /**\n",
+       "   * Handle when an output is cleared or removed\n",
+       "   */\n",
+       "  function handleClearOutput(event, handle) {\n",
+       "    const cell = handle.cell;\n",
+       "\n",
+       "    const id = cell.output_area._bokeh_element_id;\n",
+       "    const server_id = cell.output_area._bokeh_server_id;\n",
+       "    // Clean up Bokeh references\n",
+       "    if (id != null && id in Bokeh.index) {\n",
+       "      Bokeh.index[id].model.document.clear();\n",
+       "      delete Bokeh.index[id];\n",
+       "    }\n",
+       "\n",
+       "    if (server_id !== undefined) {\n",
+       "      // Clean up Bokeh references\n",
+       "      const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n",
+       "      cell.notebook.kernel.execute(cmd_clean, {\n",
+       "        iopub: {\n",
+       "          output: function(msg) {\n",
+       "            const id = msg.content.text.trim();\n",
+       "            if (id in Bokeh.index) {\n",
+       "              Bokeh.index[id].model.document.clear();\n",
+       "              delete Bokeh.index[id];\n",
+       "            }\n",
+       "          }\n",
+       "        }\n",
+       "      });\n",
+       "      // Destroy server and session\n",
+       "      const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n",
+       "      cell.notebook.kernel.execute(cmd_destroy);\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  /**\n",
+       "   * Handle when a new output is added\n",
+       "   */\n",
+       "  function handleAddOutput(event, handle) {\n",
+       "    const output_area = handle.output_area;\n",
+       "    const output = handle.output;\n",
+       "\n",
+       "    // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n",
+       "    if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n",
+       "      return\n",
+       "    }\n",
+       "\n",
+       "    const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n",
+       "\n",
+       "    if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n",
+       "      toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n",
+       "      // store reference to embed id on output_area\n",
+       "      output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n",
+       "    }\n",
+       "    if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n",
+       "      const bk_div = document.createElement(\"div\");\n",
+       "      bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n",
+       "      const script_attrs = bk_div.children[0].attributes;\n",
+       "      for (let i = 0; i < script_attrs.length; i++) {\n",
+       "        toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n",
+       "        toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n",
+       "      }\n",
+       "      // store reference to server id on output_area\n",
+       "      output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  function register_renderer(events, OutputArea) {\n",
+       "\n",
+       "    function append_mime(data, metadata, element) {\n",
+       "      // create a DOM node to render to\n",
+       "      const toinsert = this.create_output_subarea(\n",
+       "        metadata,\n",
+       "        CLASS_NAME,\n",
+       "        EXEC_MIME_TYPE\n",
+       "      );\n",
+       "      this.keyboard_manager.register_events(toinsert);\n",
+       "      // Render to node\n",
+       "      const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n",
+       "      render(props, toinsert[toinsert.length - 1]);\n",
+       "      element.append(toinsert);\n",
+       "      return toinsert\n",
+       "    }\n",
+       "\n",
+       "    /* Handle when an output is cleared or removed */\n",
+       "    events.on('clear_output.CodeCell', handleClearOutput);\n",
+       "    events.on('delete.Cell', handleClearOutput);\n",
+       "\n",
+       "    /* Handle when a new output is added */\n",
+       "    events.on('output_added.OutputArea', handleAddOutput);\n",
+       "\n",
+       "    /**\n",
+       "     * Register the mime type and append_mime function with output_area\n",
+       "     */\n",
+       "    OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n",
+       "      /* Is output safe? */\n",
+       "      safe: true,\n",
+       "      /* Index of renderer in `output_area.display_order` */\n",
+       "      index: 0\n",
+       "    });\n",
+       "  }\n",
+       "\n",
+       "  // register the mime type if in Jupyter Notebook environment and previously unregistered\n",
+       "  if (root.Jupyter !== undefined) {\n",
+       "    const events = require('base/js/events');\n",
+       "    const OutputArea = require('notebook/js/outputarea').OutputArea;\n",
+       "\n",
+       "    if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n",
+       "      register_renderer(events, OutputArea);\n",
+       "    }\n",
+       "  }\n",
+       "  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n",
+       "    root._bokeh_timeout = Date.now() + 5000;\n",
+       "    root._bokeh_failed_load = false;\n",
+       "  }\n",
+       "\n",
+       "  const NB_LOAD_WARNING = {'data': {'text/html':\n",
+       "     \"<div style='background-color: #fdd'>\\n\"+\n",
+       "     \"<p>\\n\"+\n",
+       "     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
+       "     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
+       "     \"</p>\\n\"+\n",
+       "     \"<ul>\\n\"+\n",
+       "     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
+       "     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
+       "     \"</ul>\\n\"+\n",
+       "     \"<code>\\n\"+\n",
+       "     \"from bokeh.resources import INLINE\\n\"+\n",
+       "     \"output_notebook(resources=INLINE)\\n\"+\n",
+       "     \"</code>\\n\"+\n",
+       "     \"</div>\"}};\n",
+       "\n",
+       "  function display_loaded() {\n",
+       "    const el = document.getElementById(\"1002\");\n",
+       "    if (el != null) {\n",
+       "      el.textContent = \"BokehJS is loading...\";\n",
+       "    }\n",
+       "    if (root.Bokeh !== undefined) {\n",
+       "      if (el != null) {\n",
+       "        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n",
+       "      }\n",
+       "    } else if (Date.now() < root._bokeh_timeout) {\n",
+       "      setTimeout(display_loaded, 100)\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  function run_callbacks() {\n",
+       "    try {\n",
+       "      root._bokeh_onload_callbacks.forEach(function(callback) {\n",
+       "        if (callback != null)\n",
+       "          callback();\n",
+       "      });\n",
+       "    } finally {\n",
+       "      delete root._bokeh_onload_callbacks\n",
+       "    }\n",
+       "    console.debug(\"Bokeh: all callbacks have finished\");\n",
+       "  }\n",
+       "\n",
+       "  function load_libs(css_urls, js_urls, callback) {\n",
+       "    if (css_urls == null) css_urls = [];\n",
+       "    if (js_urls == null) js_urls = [];\n",
+       "\n",
+       "    root._bokeh_onload_callbacks.push(callback);\n",
+       "    if (root._bokeh_is_loading > 0) {\n",
+       "      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
+       "      return null;\n",
+       "    }\n",
+       "    if (js_urls == null || js_urls.length === 0) {\n",
+       "      run_callbacks();\n",
+       "      return null;\n",
+       "    }\n",
+       "    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
+       "    root._bokeh_is_loading = css_urls.length + js_urls.length;\n",
+       "\n",
+       "    function on_load() {\n",
+       "      root._bokeh_is_loading--;\n",
+       "      if (root._bokeh_is_loading === 0) {\n",
+       "        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n",
+       "        run_callbacks()\n",
+       "      }\n",
+       "    }\n",
+       "\n",
+       "    function on_error(url) {\n",
+       "      console.error(\"failed to load \" + url);\n",
+       "    }\n",
+       "\n",
+       "    for (let i = 0; i < css_urls.length; i++) {\n",
+       "      const url = css_urls[i];\n",
+       "      const element = document.createElement(\"link\");\n",
+       "      element.onload = on_load;\n",
+       "      element.onerror = on_error.bind(null, url);\n",
+       "      element.rel = \"stylesheet\";\n",
+       "      element.type = \"text/css\";\n",
+       "      element.href = url;\n",
+       "      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n",
+       "      document.body.appendChild(element);\n",
+       "    }\n",
+       "\n",
+       "    for (let i = 0; i < js_urls.length; i++) {\n",
+       "      const url = js_urls[i];\n",
+       "      const element = document.createElement('script');\n",
+       "      element.onload = on_load;\n",
+       "      element.onerror = on_error.bind(null, url);\n",
+       "      element.async = false;\n",
+       "      element.src = url;\n",
+       "      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
+       "      document.head.appendChild(element);\n",
+       "    }\n",
+       "  };\n",
+       "\n",
+       "  function inject_raw_css(css) {\n",
+       "    const element = document.createElement(\"style\");\n",
+       "    element.appendChild(document.createTextNode(css));\n",
+       "    document.body.appendChild(element);\n",
+       "  }\n",
+       "\n",
+       "  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n",
+       "  const css_urls = [];\n",
+       "\n",
+       "  const inline_js = [    function(Bokeh) {\n",
+       "      Bokeh.set_log_level(\"info\");\n",
+       "    },\n",
+       "function(Bokeh) {\n",
+       "    }\n",
+       "  ];\n",
+       "\n",
+       "  function run_inline_js() {\n",
+       "    if (root.Bokeh !== undefined || force === true) {\n",
+       "          for (let i = 0; i < inline_js.length; i++) {\n",
+       "      inline_js[i].call(root, root.Bokeh);\n",
+       "    }\n",
+       "if (force === true) {\n",
+       "        display_loaded();\n",
+       "      }} else if (Date.now() < root._bokeh_timeout) {\n",
+       "      setTimeout(run_inline_js, 100);\n",
+       "    } else if (!root._bokeh_failed_load) {\n",
+       "      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
+       "      root._bokeh_failed_load = true;\n",
+       "    } else if (force !== true) {\n",
+       "      const cell = $(document.getElementById(\"1002\")).parents('.cell').data().cell;\n",
+       "      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  if (root._bokeh_is_loading === 0) {\n",
+       "    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
+       "    run_inline_js();\n",
+       "  } else {\n",
+       "    load_libs(css_urls, js_urls, function() {\n",
+       "      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n",
+       "      run_inline_js();\n",
+       "    });\n",
+       "  }\n",
+       "}(window));"
+      ],
+      "application/vnd.bokehjs_load.v0+json": "(function(root) {\n  function now() {\n    return new Date();\n  }\n\n  const force = true;\n\n  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n    root._bokeh_onload_callbacks = [];\n    root._bokeh_is_loading = undefined;\n  }\n\n\n  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n    root._bokeh_timeout = Date.now() + 5000;\n    root._bokeh_failed_load = false;\n  }\n\n  const NB_LOAD_WARNING = {'data': {'text/html':\n     \"<div style='background-color: #fdd'>\\n\"+\n     \"<p>\\n\"+\n     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n     \"</p>\\n\"+\n     \"<ul>\\n\"+\n     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n     \"</ul>\\n\"+\n     \"<code>\\n\"+\n     \"from bokeh.resources import INLINE\\n\"+\n     \"output_notebook(resources=INLINE)\\n\"+\n     \"</code>\\n\"+\n     \"</div>\"}};\n\n  function display_loaded() {\n    const el = document.getElementById(\"1002\");\n    if (el != null) {\n      el.textContent = \"BokehJS is loading...\";\n    }\n    if (root.Bokeh !== undefined) {\n      if (el != null) {\n        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n      }\n    } else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(display_loaded, 100)\n    }\n  }\n\n  function run_callbacks() {\n    try {\n      root._bokeh_onload_callbacks.forEach(function(callback) {\n        if (callback != null)\n          callback();\n      });\n    } finally {\n      delete root._bokeh_onload_callbacks\n    }\n    console.debug(\"Bokeh: all callbacks have finished\");\n  }\n\n  function load_libs(css_urls, js_urls, callback) {\n    if (css_urls == null) css_urls = [];\n    if (js_urls == null) js_urls = [];\n\n    root._bokeh_onload_callbacks.push(callback);\n    if (root._bokeh_is_loading > 0) {\n      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n      return null;\n    }\n    if (js_urls == null || js_urls.length === 0) {\n      run_callbacks();\n      return null;\n    }\n    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n    root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n    function on_load() {\n      root._bokeh_is_loading--;\n      if (root._bokeh_is_loading === 0) {\n        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n        run_callbacks()\n      }\n    }\n\n    function on_error(url) {\n      console.error(\"failed to load \" + url);\n    }\n\n    for (let i = 0; i < css_urls.length; i++) {\n      const url = css_urls[i];\n      const element = document.createElement(\"link\");\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.rel = \"stylesheet\";\n      element.type = \"text/css\";\n      element.href = url;\n      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n      document.body.appendChild(element);\n    }\n\n    for (let i = 0; i < js_urls.length; i++) {\n      const url = js_urls[i];\n      const element = document.createElement('script');\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.async = false;\n      element.src = url;\n      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n      document.head.appendChild(element);\n    }\n  };\n\n  function inject_raw_css(css) {\n    const element = document.createElement(\"style\");\n    element.appendChild(document.createTextNode(css));\n    document.body.appendChild(element);\n  }\n\n  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n  const css_urls = [];\n\n  const inline_js = [    function(Bokeh) {\n      Bokeh.set_log_level(\"info\");\n    },\nfunction(Bokeh) {\n    }\n  ];\n\n  function run_inline_js() {\n    if (root.Bokeh !== undefined || force === true) {\n          for (let i = 0; i < inline_js.length; i++) {\n      inline_js[i].call(root, root.Bokeh);\n    }\nif (force === true) {\n        display_loaded();\n      }} else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(run_inline_js, 100);\n    } else if (!root._bokeh_failed_load) {\n      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n      root._bokeh_failed_load = true;\n    } else if (force !== true) {\n      const cell = $(document.getElementById(\"1002\")).parents('.cell').data().cell;\n      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n    }\n  }\n\n  if (root._bokeh_is_loading === 0) {\n    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n    run_inline_js();\n  } else {\n    load_libs(css_urls, js_urls, function() {\n      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n      run_inline_js();\n    });\n  }\n}(window));"
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div class=\"bk-root\">\n",
+       "        <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
+       "        <span id=\"1003\">Loading BokehJS ...</span>\n",
+       "    </div>\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/javascript": [
+       "(function(root) {\n",
+       "  function now() {\n",
+       "    return new Date();\n",
+       "  }\n",
+       "\n",
+       "  const force = true;\n",
+       "\n",
+       "  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n",
+       "    root._bokeh_onload_callbacks = [];\n",
+       "    root._bokeh_is_loading = undefined;\n",
+       "  }\n",
+       "\n",
+       "const JS_MIME_TYPE = 'application/javascript';\n",
+       "  const HTML_MIME_TYPE = 'text/html';\n",
+       "  const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n",
+       "  const CLASS_NAME = 'output_bokeh rendered_html';\n",
+       "\n",
+       "  /**\n",
+       "   * Render data to the DOM node\n",
+       "   */\n",
+       "  function render(props, node) {\n",
+       "    const script = document.createElement(\"script\");\n",
+       "    node.appendChild(script);\n",
+       "  }\n",
+       "\n",
+       "  /**\n",
+       "   * Handle when an output is cleared or removed\n",
+       "   */\n",
+       "  function handleClearOutput(event, handle) {\n",
+       "    const cell = handle.cell;\n",
+       "\n",
+       "    const id = cell.output_area._bokeh_element_id;\n",
+       "    const server_id = cell.output_area._bokeh_server_id;\n",
+       "    // Clean up Bokeh references\n",
+       "    if (id != null && id in Bokeh.index) {\n",
+       "      Bokeh.index[id].model.document.clear();\n",
+       "      delete Bokeh.index[id];\n",
+       "    }\n",
+       "\n",
+       "    if (server_id !== undefined) {\n",
+       "      // Clean up Bokeh references\n",
+       "      const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n",
+       "      cell.notebook.kernel.execute(cmd_clean, {\n",
+       "        iopub: {\n",
+       "          output: function(msg) {\n",
+       "            const id = msg.content.text.trim();\n",
+       "            if (id in Bokeh.index) {\n",
+       "              Bokeh.index[id].model.document.clear();\n",
+       "              delete Bokeh.index[id];\n",
+       "            }\n",
+       "          }\n",
+       "        }\n",
+       "      });\n",
+       "      // Destroy server and session\n",
+       "      const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n",
+       "      cell.notebook.kernel.execute(cmd_destroy);\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  /**\n",
+       "   * Handle when a new output is added\n",
+       "   */\n",
+       "  function handleAddOutput(event, handle) {\n",
+       "    const output_area = handle.output_area;\n",
+       "    const output = handle.output;\n",
+       "\n",
+       "    // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n",
+       "    if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n",
+       "      return\n",
+       "    }\n",
+       "\n",
+       "    const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n",
+       "\n",
+       "    if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n",
+       "      toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n",
+       "      // store reference to embed id on output_area\n",
+       "      output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n",
+       "    }\n",
+       "    if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n",
+       "      const bk_div = document.createElement(\"div\");\n",
+       "      bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n",
+       "      const script_attrs = bk_div.children[0].attributes;\n",
+       "      for (let i = 0; i < script_attrs.length; i++) {\n",
+       "        toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n",
+       "        toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n",
+       "      }\n",
+       "      // store reference to server id on output_area\n",
+       "      output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  function register_renderer(events, OutputArea) {\n",
+       "\n",
+       "    function append_mime(data, metadata, element) {\n",
+       "      // create a DOM node to render to\n",
+       "      const toinsert = this.create_output_subarea(\n",
+       "        metadata,\n",
+       "        CLASS_NAME,\n",
+       "        EXEC_MIME_TYPE\n",
+       "      );\n",
+       "      this.keyboard_manager.register_events(toinsert);\n",
+       "      // Render to node\n",
+       "      const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n",
+       "      render(props, toinsert[toinsert.length - 1]);\n",
+       "      element.append(toinsert);\n",
+       "      return toinsert\n",
+       "    }\n",
+       "\n",
+       "    /* Handle when an output is cleared or removed */\n",
+       "    events.on('clear_output.CodeCell', handleClearOutput);\n",
+       "    events.on('delete.Cell', handleClearOutput);\n",
+       "\n",
+       "    /* Handle when a new output is added */\n",
+       "    events.on('output_added.OutputArea', handleAddOutput);\n",
+       "\n",
+       "    /**\n",
+       "     * Register the mime type and append_mime function with output_area\n",
+       "     */\n",
+       "    OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n",
+       "      /* Is output safe? */\n",
+       "      safe: true,\n",
+       "      /* Index of renderer in `output_area.display_order` */\n",
+       "      index: 0\n",
+       "    });\n",
+       "  }\n",
+       "\n",
+       "  // register the mime type if in Jupyter Notebook environment and previously unregistered\n",
+       "  if (root.Jupyter !== undefined) {\n",
+       "    const events = require('base/js/events');\n",
+       "    const OutputArea = require('notebook/js/outputarea').OutputArea;\n",
+       "\n",
+       "    if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n",
+       "      register_renderer(events, OutputArea);\n",
+       "    }\n",
+       "  }\n",
+       "  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n",
+       "    root._bokeh_timeout = Date.now() + 5000;\n",
+       "    root._bokeh_failed_load = false;\n",
+       "  }\n",
+       "\n",
+       "  const NB_LOAD_WARNING = {'data': {'text/html':\n",
+       "     \"<div style='background-color: #fdd'>\\n\"+\n",
+       "     \"<p>\\n\"+\n",
+       "     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
+       "     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
+       "     \"</p>\\n\"+\n",
+       "     \"<ul>\\n\"+\n",
+       "     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
+       "     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
+       "     \"</ul>\\n\"+\n",
+       "     \"<code>\\n\"+\n",
+       "     \"from bokeh.resources import INLINE\\n\"+\n",
+       "     \"output_notebook(resources=INLINE)\\n\"+\n",
+       "     \"</code>\\n\"+\n",
+       "     \"</div>\"}};\n",
+       "\n",
+       "  function display_loaded() {\n",
+       "    const el = document.getElementById(\"1003\");\n",
+       "    if (el != null) {\n",
+       "      el.textContent = \"BokehJS is loading...\";\n",
+       "    }\n",
+       "    if (root.Bokeh !== undefined) {\n",
+       "      if (el != null) {\n",
+       "        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n",
+       "      }\n",
+       "    } else if (Date.now() < root._bokeh_timeout) {\n",
+       "      setTimeout(display_loaded, 100)\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  function run_callbacks() {\n",
+       "    try {\n",
+       "      root._bokeh_onload_callbacks.forEach(function(callback) {\n",
+       "        if (callback != null)\n",
+       "          callback();\n",
+       "      });\n",
+       "    } finally {\n",
+       "      delete root._bokeh_onload_callbacks\n",
+       "    }\n",
+       "    console.debug(\"Bokeh: all callbacks have finished\");\n",
+       "  }\n",
+       "\n",
+       "  function load_libs(css_urls, js_urls, callback) {\n",
+       "    if (css_urls == null) css_urls = [];\n",
+       "    if (js_urls == null) js_urls = [];\n",
+       "\n",
+       "    root._bokeh_onload_callbacks.push(callback);\n",
+       "    if (root._bokeh_is_loading > 0) {\n",
+       "      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
+       "      return null;\n",
+       "    }\n",
+       "    if (js_urls == null || js_urls.length === 0) {\n",
+       "      run_callbacks();\n",
+       "      return null;\n",
+       "    }\n",
+       "    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
+       "    root._bokeh_is_loading = css_urls.length + js_urls.length;\n",
+       "\n",
+       "    function on_load() {\n",
+       "      root._bokeh_is_loading--;\n",
+       "      if (root._bokeh_is_loading === 0) {\n",
+       "        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n",
+       "        run_callbacks()\n",
+       "      }\n",
+       "    }\n",
+       "\n",
+       "    function on_error(url) {\n",
+       "      console.error(\"failed to load \" + url);\n",
+       "    }\n",
+       "\n",
+       "    for (let i = 0; i < css_urls.length; i++) {\n",
+       "      const url = css_urls[i];\n",
+       "      const element = document.createElement(\"link\");\n",
+       "      element.onload = on_load;\n",
+       "      element.onerror = on_error.bind(null, url);\n",
+       "      element.rel = \"stylesheet\";\n",
+       "      element.type = \"text/css\";\n",
+       "      element.href = url;\n",
+       "      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n",
+       "      document.body.appendChild(element);\n",
+       "    }\n",
+       "\n",
+       "    for (let i = 0; i < js_urls.length; i++) {\n",
+       "      const url = js_urls[i];\n",
+       "      const element = document.createElement('script');\n",
+       "      element.onload = on_load;\n",
+       "      element.onerror = on_error.bind(null, url);\n",
+       "      element.async = false;\n",
+       "      element.src = url;\n",
+       "      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
+       "      document.head.appendChild(element);\n",
+       "    }\n",
+       "  };\n",
+       "\n",
+       "  function inject_raw_css(css) {\n",
+       "    const element = document.createElement(\"style\");\n",
+       "    element.appendChild(document.createTextNode(css));\n",
+       "    document.body.appendChild(element);\n",
+       "  }\n",
+       "\n",
+       "  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n",
+       "  const css_urls = [];\n",
+       "\n",
+       "  const inline_js = [    function(Bokeh) {\n",
+       "      Bokeh.set_log_level(\"info\");\n",
+       "    },\n",
+       "function(Bokeh) {\n",
+       "    }\n",
+       "  ];\n",
+       "\n",
+       "  function run_inline_js() {\n",
+       "    if (root.Bokeh !== undefined || force === true) {\n",
+       "          for (let i = 0; i < inline_js.length; i++) {\n",
+       "      inline_js[i].call(root, root.Bokeh);\n",
+       "    }\n",
+       "if (force === true) {\n",
+       "        display_loaded();\n",
+       "      }} else if (Date.now() < root._bokeh_timeout) {\n",
+       "      setTimeout(run_inline_js, 100);\n",
+       "    } else if (!root._bokeh_failed_load) {\n",
+       "      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
+       "      root._bokeh_failed_load = true;\n",
+       "    } else if (force !== true) {\n",
+       "      const cell = $(document.getElementById(\"1003\")).parents('.cell').data().cell;\n",
+       "      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
+       "    }\n",
+       "  }\n",
+       "\n",
+       "  if (root._bokeh_is_loading === 0) {\n",
+       "    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
+       "    run_inline_js();\n",
+       "  } else {\n",
+       "    load_libs(css_urls, js_urls, function() {\n",
+       "      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n",
+       "      run_inline_js();\n",
+       "    });\n",
+       "  }\n",
+       "}(window));"
+      ],
+      "application/vnd.bokehjs_load.v0+json": "(function(root) {\n  function now() {\n    return new Date();\n  }\n\n  const force = true;\n\n  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n    root._bokeh_onload_callbacks = [];\n    root._bokeh_is_loading = undefined;\n  }\n\n\n  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n    root._bokeh_timeout = Date.now() + 5000;\n    root._bokeh_failed_load = false;\n  }\n\n  const NB_LOAD_WARNING = {'data': {'text/html':\n     \"<div style='background-color: #fdd'>\\n\"+\n     \"<p>\\n\"+\n     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n     \"</p>\\n\"+\n     \"<ul>\\n\"+\n     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n     \"</ul>\\n\"+\n     \"<code>\\n\"+\n     \"from bokeh.resources import INLINE\\n\"+\n     \"output_notebook(resources=INLINE)\\n\"+\n     \"</code>\\n\"+\n     \"</div>\"}};\n\n  function display_loaded() {\n    const el = document.getElementById(\"1003\");\n    if (el != null) {\n      el.textContent = \"BokehJS is loading...\";\n    }\n    if (root.Bokeh !== undefined) {\n      if (el != null) {\n        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n      }\n    } else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(display_loaded, 100)\n    }\n  }\n\n  function run_callbacks() {\n    try {\n      root._bokeh_onload_callbacks.forEach(function(callback) {\n        if (callback != null)\n          callback();\n      });\n    } finally {\n      delete root._bokeh_onload_callbacks\n    }\n    console.debug(\"Bokeh: all callbacks have finished\");\n  }\n\n  function load_libs(css_urls, js_urls, callback) {\n    if (css_urls == null) css_urls = [];\n    if (js_urls == null) js_urls = [];\n\n    root._bokeh_onload_callbacks.push(callback);\n    if (root._bokeh_is_loading > 0) {\n      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n      return null;\n    }\n    if (js_urls == null || js_urls.length === 0) {\n      run_callbacks();\n      return null;\n    }\n    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n    root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n    function on_load() {\n      root._bokeh_is_loading--;\n      if (root._bokeh_is_loading === 0) {\n        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n        run_callbacks()\n      }\n    }\n\n    function on_error(url) {\n      console.error(\"failed to load \" + url);\n    }\n\n    for (let i = 0; i < css_urls.length; i++) {\n      const url = css_urls[i];\n      const element = document.createElement(\"link\");\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.rel = \"stylesheet\";\n      element.type = \"text/css\";\n      element.href = url;\n      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n      document.body.appendChild(element);\n    }\n\n    for (let i = 0; i < js_urls.length; i++) {\n      const url = js_urls[i];\n      const element = document.createElement('script');\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.async = false;\n      element.src = url;\n      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n      document.head.appendChild(element);\n    }\n  };\n\n  function inject_raw_css(css) {\n    const element = document.createElement(\"style\");\n    element.appendChild(document.createTextNode(css));\n    document.body.appendChild(element);\n  }\n\n  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n  const css_urls = [];\n\n  const inline_js = [    function(Bokeh) {\n      Bokeh.set_log_level(\"info\");\n    },\nfunction(Bokeh) {\n    }\n  ];\n\n  function run_inline_js() {\n    if (root.Bokeh !== undefined || force === true) {\n          for (let i = 0; i < inline_js.length; i++) {\n      inline_js[i].call(root, root.Bokeh);\n    }\nif (force === true) {\n        display_loaded();\n      }} else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(run_inline_js, 100);\n    } else if (!root._bokeh_failed_load) {\n      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n      root._bokeh_failed_load = true;\n    } else if (force !== true) {\n      const cell = $(document.getElementById(\"1003\")).parents('.cell').data().cell;\n      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n    }\n  }\n\n  if (root._bokeh_is_loading === 0) {\n    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n    run_inline_js();\n  } else {\n    load_libs(css_urls, js_urls, function() {\n      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n      run_inline_js();\n    });\n  }\n}(window));"
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "\n",
+    "from __future__ import annotations\n",
+    "import typing\n",
+    "import logging\n",
+    "import os\n",
+    "import sys\n",
+    "import warnings\n",
+    "\n",
+    "sys.path.append('../montetracko')\n",
+    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
+    "\n",
+    "import yaml\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "\n",
+    "from Preprocessing.run_preprocessing import run_preprocessing_test_dataset\n",
+    "from Preprocessing.run_preprocessing import run_preprocessing\n",
+    "from Processing.run_processing import run_processing_from_config\n",
+    "\n",
+    "from Scripts.Step_1_Train_Metric_Learning import train as train_metric_learning\n",
+    "from Scripts.Step_2_Run_Metric_Learning import train as run_metric_learning_inference\n",
+    "from Scripts.Step_3_Train_GNN import train as train_gnn\n",
+    "from Scripts.Step_4_Run_GNN import train as run_gnn_inference\n",
+    "from Scripts.Step_5_Build_Track_Candidates import train as build_track_candidates\n",
+    "from Scripts.Step_6_Evaluate_Reconstruction_MonteTracko import (\n",
+    "    evaluate as evaluate_candidates_montetracko\n",
+    ")\n",
+    "\n",
+    "from utils.plotutils import graph as graphplot\n",
+    "from utils.plotutils import performance as perfplot\n",
+    "from utils.plotutils import performance_mpl as perfplot_mpl\n",
+    "from utils.commonutils.ctests import get_required_test_dataset_names\n",
+    "from utils.commonutils.config import load_config\n",
+    "from utils.modelutils import checkpoint_utils\n",
+    "from utils.scriptutils.loghandler import headline\n",
+    "\n",
+    "from utils.plotutils.plotconfig import configure_matplotlib\n",
+    "\n",
+    "configure_matplotlib()\n",
+    "\n",
+    "warnings.filterwarnings(\n",
+    "    \"ignore\", message=(\n",
+    "        \"TypedStorage is deprecated. It will be removed in the future and \"\n",
+    "        \"UntypedStorage will be the only storage class. This should only matter to you \"\n",
+    "        \"if you are using storages directly.\"\n",
+    "    )\n",
+    ")\n",
+    "\n",
+    "CONFIG = 'pipeline_configs/focal-loss-pid-fixed-20000.yaml'\n",
+    "\n",
+    "run_training: bool = True\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Pipeline configurations\n",
+    "\n",
+    "The configurations for the entire pipeline are defined under pipeline_config.yml."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Download data\n",
+    "Uncomment if you do not already have the data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# path = 'data/input/0'\n",
+    "# os.makedirs(path, exist_ok=True)\n",
+    "# ! xrdcp -r root://eoslhcb.cern.ch//eos/lhcb/user/a/anthonyc/tracking/data/csv/v2/minbias-sim10b-xdigi/0 data/input  --parallel 4"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Telegram notification bot"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "import requests\n",
+    "import json\n",
+    "\n",
+    "# from datetime import datetime\n",
+    "\n",
+    "# def send_telegram_message(message: str,\n",
+    "#                           chat_id: str,\n",
+    "#                           api_key: str,\n",
+    "#                          ):\n",
+    "#     responses = {}\n",
+    "\n",
+    "#     url = f'https://api.telegram.org/bot{api_key}/sendMessage?chat_id={chat_id}&text={message}'\n",
+    "    \n",
+    "#     response = requests.post(url)\n",
+    "    \n",
+    "#     return response\n",
+    "\n",
+    "def send_telegram_message(*args, **kwargs):\n",
+    "    pass"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "chat_id = \"5027012918\"\n",
+    "api_key = \"6268687426:AAE1P7WQofCBuQPiYZlYaKU-p1GNn6OvAxM\"\n",
+    "\n",
+    "send_telegram_message(\"======================\", chat_id, api_key)\n",
+    "\n",
+    "send_telegram_message(\"Starting.\", chat_id, api_key)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Preprocess the test samples"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for required_test_dataset_name in get_required_test_dataset_names(CONFIG):\n",
+    "    run_preprocessing_test_dataset(\n",
+    "        test_dataset_name=required_test_dataset_name,\n",
+    "        path_or_config=CONFIG,\n",
+    "        path_or_config_test=\"test_samples.yaml\",\n",
+    "        reproduce=False,\n",
+    "    )\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Preprocessing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Preprocessing: output will be written in /scratch/acorreia/data/focal-loss-pid-fixed-20000/preprocessed\n",
+      "INFO:Input directories:\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/10\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/11\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/12\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/13\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/14\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/15\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/16\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/17\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/18\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/19\n",
+      "INFO:- /scratch/acorreia/minbias-sim10b-xdigi-nospillover/20\n",
+      "INFO:Number of events to produce: 21000\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "75e456dbc8714b0ab3c97d39c060575f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/21000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load dataframes in /scratch/acorreia/minbias-sim10b-xdigi-nospillover/10\n",
+      "INFO:Apply selection `triplets_first_selection`\n",
+      "INFO:Compute distance to line (that might take some time)\n",
+      "INFO:Load dataframes in /scratch/acorreia/minbias-sim10b-xdigi-nospillover/11\n",
+      "INFO:Apply selection `triplets_first_selection`\n",
+      "INFO:Compute distance to line (that might take some time)\n",
+      "INFO:Load dataframes in /scratch/acorreia/minbias-sim10b-xdigi-nospillover/12\n",
+      "INFO:Apply selection `triplets_first_selection`\n",
+      "INFO:Compute distance to line (that might take some time)\n",
+      "INFO:Load dataframes in /scratch/acorreia/minbias-sim10b-xdigi-nospillover/13\n",
+      "INFO:Apply selection `triplets_first_selection`\n",
+      "INFO:Compute distance to line (that might take some time)\n"
+     ]
+    }
+   ],
+   "source": [
+    "run_preprocessing(CONFIG, reproduce=False)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Processing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Input directory: /scratch/acorreia/data/focal-loss-pid-fixed-20000/preprocessed\n",
+      "INFO:Writing outputs to /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/train\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "394333f02c3445edb109e9ca67311228",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/20000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Writing outputs to /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/val\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b659e22a7a074332821bbeccca00d09a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Splitting was saved in /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/splitting.json.\n"
+     ]
+    }
+   ],
+   "source": [
+    "run_processing_from_config(CONFIG, reproduce=False)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Input directory: /scratch/acorreia/data/__test__/velo-sim10b-nospillover\n",
+      "INFO:Writing outputs to /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ec290d5663144d78b18046d15351fa6c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Input directory: /scratch/acorreia/data/__test__/velo-sim10b-nospillover-only-long-electrons\n",
+      "INFO:Writing outputs to /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/test/velo-sim10b-nospillover-only-long-electrons\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "74507c79e2d3448a95da1edac7c462e4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "for required_test_dataset_name in get_required_test_dataset_names(CONFIG):\n",
+    "    run_processing_from_config(\n",
+    "        test_dataset_name=required_test_dataset_name,\n",
+    "        path_or_config=CONFIG,\n",
+    "        reproduce=False,\n",
+    "    )\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 1. Train Metric Learning"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## What it does\n",
+    "Broadly speaking, the first stage of our pipeline is embedding the space points on to graphs, in a way that is efficient, i.e. we miss as few points on a graph as possible. We train a MLP to transform the input feature vector of each space point $\\mathbf{u}_i$ into an N-dimensional latent space $\\mathbf{v}_i$. The graph is then constructed by connecting the space points whose Euclidean distance between the latent space points $$d_{ij} = \\left| \\mathbf{v}_i - \\mathbf{v}_j \\right| < r_{embedding}$$"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Training data\n",
+    "Let us take a look at the data before training. In this example pipeline, we have preprocessed the TrackML data into a more convenient form. We calculated directional information and summary statistics from the charge deposited in each spacepoints, and append them to its cyclidrical coordinates. Let us load an example data file and inspect the content."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from Embedding.embedding_base import get_example_data\n",
+    "example_data_df, example_data_pyg = get_example_data(CONFIG)\n",
+    "example_data_df\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "graphplot.plot_true_graph(example_data_pyg, CONFIG, num_tracks=50)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Train metric learning model\n",
+    "\n",
+    "Finally we come to model training. By default, we train the MLP for 30 epochs, which takes approximately 15 minutes on an NVidia V100. Feel free to adjust the epoch number in pipeline_config.yml"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "! nvcc --version"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "! nvidia-smi"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "if run_training:\n",
+    "    send_telegram_message('Started metric learning training.', chat_id, api_key)\n",
+    "\n",
+    "    metric_learning_trainer, metric_learning_model = train_metric_learning(CONFIG)\n",
+    "\n",
+    "    send_telegram_message('Finished metric learning training.', chat_id, api_key)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In the case you want to continue the training of a certain network, you may\n",
+    "use the code below."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### From checkpoint"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "embedding_metric_path='artifacts/metric_learning/focal-loss-pid-fixed-20000/version_0/metrics.csv'\n",
+      "embedding_artifact_path='artifacts/metric_learning/focal-loss-pid-fixed-20000/version_0/checkpoints/epoch=16-step=170000.ckpt'\n"
+     ]
+    }
+   ],
+   "source": [
+    "from Embedding.models.layerless_embedding import LayerlessEmbedding\n",
+    "\n",
+    "embedding_version_dir = checkpoint_utils.get_last_version_dir_from_config(\n",
+    "    step=\"metric_learning\", path_or_config=CONFIG\n",
+    ")\n",
+    "embedding_metric_path = os.path.join(embedding_version_dir, \"metrics.csv\")\n",
+    "embedding_artifact_path = checkpoint_utils.get_last_artifact(\n",
+    "    version_dir=embedding_version_dir\n",
+    ")\n",
+    "print(f\"{embedding_metric_path=}\")\n",
+    "print(f\"{embedding_artifact_path=}\")\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "To continue the training of the network, you may use the code below"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from pytorch_lightning.loggers import CSVLogger\n",
+    "from pytorch_lightning import Trainer\n",
+    "\n",
+    "\n",
+    "def continue_embedding_training(\n",
+    "    path_or_config: str | dict,\n",
+    ") -> typing.Tuple[Trainer, LayerlessEmbedding]:\n",
+    "    config = load_config(path_or_config=path_or_config)\n",
+    "\n",
+    "    metric_learning_model = LayerlessEmbedding.load_from_checkpoint(\n",
+    "        embedding_artifact_path\n",
+    "    )  # you may change `metric_learning_model`\n",
+    "\n",
+    "    save_directory = os.path.abspath(\n",
+    "        os.path.join(config[\"common\"][\"artifact_directory\"], \"metric_learning\")\n",
+    "    )\n",
+    "\n",
+    "    logger = CSVLogger(save_directory, name=config[\"common\"][\"experiment_name\"])\n",
+    "\n",
+    "    metric_learning_trainer = Trainer(\n",
+    "        accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
+    "        devices=1,\n",
+    "        max_epochs=40,  # you may increase the number of epochs\n",
+    "        logger=logger,\n",
+    "        # callbacks=[EarlyStopping(monitor=\"train_loss\", mode=\"min\")]\n",
+    "    )\n",
+    "\n",
+    "    metric_learning_trainer.fit(metric_learning_model)\n",
+    "    return metric_learning_trainer, metric_learning_model\n",
+    "\n",
+    "\n",
+    "# metric_learning_trainer, metric_learning_model = continue_embedding_training(CONFIG)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "metric_learning_model = LayerlessEmbedding.load_from_checkpoint(\n",
+    "    embedding_artifact_path,\n",
+    "    # map_location=\"cpu\",\n",
+    "    # If importing model from another experiment\n",
+    "    hparams=load_config(CONFIG)[\"metric_learning\"],\n",
+    ")\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Plot training metrics"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can examine how the training went. This is stored in a simple dataframe:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# embedding_metrics = checkpoint_utils.get_training_metrics(metric_learning_trainer) \n",
+    "\n",
+    "embedding_metrics = checkpoint_utils.get_training_metrics(embedding_metric_path)\n",
+    "\n",
+    "embedding_metrics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "perfplot_mpl.plot_loss(embedding_metrics, CONFIG, \"metric_learning\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "perfplot_mpl.plot_metric_epochs(\n",
+    "    \"eff\",\n",
+    "    embedding_metrics,\n",
+    "    CONFIG,\n",
+    "    \"metric_learning\",\n",
+    "    \"Edge Efficiency\",\n",
+    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
+    ")\n",
+    "perfplot_mpl.plot_metric_epochs(\n",
+    "    \"pur\",\n",
+    "    embedding_metrics,\n",
+    "    CONFIG,\n",
+    "    \"metric_learning\",\n",
+    "    \"Edge Purity\",\n",
+    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
+    ")\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaluate model performance on sample test data\n",
+    "\n",
+    "Here we evaluate the model performance on one sample test data. We look at how the efficiency and purity change with the embedding radius."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true,
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "from Embedding import embedding_plots\n",
+    "\n",
+    "embedding_plots.plot_embedding_performance_given_radius_knn_max(\n",
+    "    model=metric_learning_model,\n",
+    "    path_or_config=CONFIG,\n",
+    "    radius=np.linspace(0.01, 0.05, 10),\n",
+    "    n_events=20,\n",
+    "    partitions=[\"train\", \"val\", \"velo-sim10b-nospillover\"],\n",
+    ");\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Plot example truth and predicted graphs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "metric_learning_model.load_partition(\"velo-sim10b-nospillover\")\n",
+    "graphplot.plot_predicted_graph(metric_learning_model, CONFIG)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from Embedding.embedding_plots import plot_best_performances_radius\n",
+    "plot_best_performances_radius(\n",
+    "    model=metric_learning_model,\n",
+    "    path_or_config=CONFIG,\n",
+    "    partition=\"velo-sim10b-nospillover\",\n",
+    "    # list_radius=[0.015, 0.020],\n",
+    "    list_radius=[0.015, 0.020, 0.025, 0.030, 0.035, 0.040],\n",
+    "    n_events=200,\n",
+    "    seed=0,\n",
+    ");\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Track lengths"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "metric_learning_model.load_partition(\"velo-sim10b-nospillover\")\n",
+    "perfplot.plot_track_lengths(metric_learning_model, CONFIG)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "metric_learning_model.setup(stage=\"fit\")  # load train and val datasets\n",
+    "\n",
+    "perfplot.plot_graph_sizes(metric_learning_model)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 2. Construct graphs from metric learning inference\n",
+    "\n",
+    "This step performs model inference on the entire input datasets (train, validation and test), to obtain input graphs to the graph neural network. Optionally, we also clear the directory."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:------------- Step 2: Constructing graphs from metric learning model -------------\n",
+      "INFO:---------------------------- a) Loading trained model ----------------------------\n",
+      "INFO:Load model from artifacts/metric_learning/focal-loss-pid-fixed-20000/version_0/checkpoints/epoch=16-step=170000.ckpt.\n",
+      "INFO:----------------------------- b) Running inferencing -----------------------------\n",
+      "INFO:Use radius 0.02\n",
+      "INFO:Use the following parameters for train: {'building': None, 'filtering': 'edges_at_least_3_hits'}\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/train to /scratch/acorreia/data/focal-loss-pid-fixed-20000/metric_learning_processed/train\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b65c3f994ca64de7adf2be3ac3bb3648",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/20000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Use the following parameters for val: {'building': None, 'filtering': 'edges_at_least_3_hits'}\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/val to /scratch/acorreia/data/focal-loss-pid-fixed-20000/metric_learning_processed/val\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "140c65ea95ec476097d211a1d699d119",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Use the following parameters for test: {'building': None}\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/test/velo-sim10b-nospillover to /scratch/acorreia/data/focal-loss-pid-fixed-20000/metric_learning_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a2b001a1fd1c48ea9a26195c66558502",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed-20000/processed/test/velo-sim10b-nospillover-only-long-electrons to /scratch/acorreia/data/focal-loss-pid-fixed-20000/metric_learning_processed/test/velo-sim10b-nospillover-only-long-electrons\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "54c0b8ab8880445d8284a5f95e281342",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "graph_builder = run_metric_learning_inference(\n",
+    "    CONFIG,\n",
+    "    checkpoint=embedding_artifact_path,\n",
+    "    # checkpoint=metric_learning_model,  # here directly use the model\n",
+    "    reproduce=False,\n",
+    "    use_gpu=False,\n",
+    ")\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 3. Train graph neural networks"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We have a set of graphs constructed. We now train a GNN to classify edges as either \"true\" (belonging to the same track) or \"false\" (not belonging to the same track). We train for 30 epochs, which should take around 10 minutes on a V100 GPU. Your mileage may vary."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if run_training:\n",
+    "    send_telegram_message('Started GNN training.', chat_id, api_key)\n",
+    "    with warnings.catch_warnings():\n",
+    "        warnings.filterwarnings(\n",
+    "            \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n",
+    "        )\n",
+    "        gnn_trainer, gnn_model = train_gnn(CONFIG)\n",
+    "\n",
+    "    send_telegram_message('Finished GNN training.', chat_id, api_key)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### From checkpoint"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "gnn_metric_path='artifacts/gnn/focal-loss-pid-fixed/version_0/metrics.csv'\n",
+      "gnn_artifact_path='artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints/epoch=49-step=500000.ckpt'\n"
+     ]
+    }
+   ],
+   "source": [
+    "from utils.modelutils.checkpoint_utils import (\n",
+    "    get_last_version_dir_from_config,\n",
+    "    get_last_artifact,\n",
+    ")\n",
+    "from GNN.models.interaction_gnn import InteractionGNN\n",
+    "\n",
+    "gnn_version_dir = get_last_version_dir_from_config(step=\"gnn\", path_or_config=\"pipeline_configs/focal-loss-pid-fixed.yaml\")\n",
+    "gnn_metric_path = os.path.join(gnn_version_dir, \"metrics.csv\")\n",
+    "gnn_artifact_path = get_last_artifact(version_dir=gnn_version_dir)\n",
+    "print(f\"{gnn_metric_path=}\")\n",
+    "print(f\"{gnn_artifact_path=}\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "GPU available: True (cuda), used: True\n",
+      "TPU available: False, using: 0 TPU cores\n",
+      "IPU available: False, using: 0 IPUs\n",
+      "HPU available: False, using: 0 HPUs\n",
+      "Missing logger folder: /home/acorreia/Documents/tracking/etx4velo/LHCb_Pipeline/artifacts/gnn/focal-loss-pid-fixed-20000\n",
+      "INFO:Load 20000 files located in /scratch/acorreia/data/focal-loss-pid-fixed-20000/metric_learning_processed/train\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "79a91c643fde4cc68d866661047f98b1",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/20000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load 1000 files located in /scratch/acorreia/data/focal-loss-pid-fixed-20000/metric_learning_processed/val\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5775ab029e5b461cb56baf1cee6dbcb9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Restoring states from the checkpoint path at artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints/epoch=49-step=500000.ckpt\n",
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:337: UserWarning: The dirpath has changed from '/home/acorreia/Documents/tracking/etx4velo/LHCb_Pipeline/artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints' to '/home/acorreia/Documents/tracking/etx4velo/LHCb_Pipeline/artifacts/gnn/focal-loss-pid-fixed-20000/version_0/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.\n",
+      "  warnings.warn(\n",
+      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+      "\n",
+      "  | Name                   | Type       | Params\n",
+      "------------------------------------------------------\n",
+      "0 | node_encoder           | Sequential | 332 K \n",
+      "1 | edge_encoder           | Sequential | 462 K \n",
+      "2 | edge_network           | Sequential | 793 K \n",
+      "3 | node_network           | Sequential | 659 K \n",
+      "4 | output_edge_classifier | Sequential | 529 K \n",
+      "------------------------------------------------------\n",
+      "2.8 M     Trainable params\n",
+      "0         Non-trainable params\n",
+      "2.8 M     Total params\n",
+      "11.111    Total estimated model params size (MB)\n",
+      "Restored all states from the checkpoint at artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints/epoch=49-step=500000.ckpt\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Sanity Checking: 0it [00:00, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "035609d787c1434ea4acaa49bab5ec5a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Training: 0it [00:00, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from pytorch_lightning import Trainer\n",
+    "from pytorch_lightning.loggers import CSVLogger\n",
+    "\n",
+    "\n",
+    "def continue_gnn_training(\n",
+    "    path_or_config: str | dict,\n",
+    ") -> typing.Tuple[Trainer, InteractionGNN]:\n",
+    "    config = load_config(path_or_config=path_or_config)\n",
+    "\n",
+    "    gnn_model = InteractionGNN.load_from_checkpoint(\n",
+    "        gnn_artifact_path, hparams=config[\"gnn\"]\n",
+    "    )  # you may change `gnn_model`\n",
+    "\n",
+    "    save_directory = os.path.abspath(\n",
+    "        os.path.join(config[\"common\"][\"artifact_directory\"], \"gnn\")\n",
+    "    )\n",
+    "\n",
+    "    logger = CSVLogger(save_directory, name=config[\"common\"][\"experiment_name\"])\n",
+    "\n",
+    "    gnn_trainer = Trainer(\n",
+    "        accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
+    "        devices=1,\n",
+    "        max_epochs=150,  # you may increase the number of epochs\n",
+    "        logger=logger,\n",
+    "        # callbacks=[EarlyStopping(monitor=\"val_loss\", mode=\"min\")]\n",
+    "    )\n",
+    "\n",
+    "    with warnings.catch_warnings():\n",
+    "        warnings.filterwarnings(\n",
+    "            \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n",
+    "        )\n",
+    "        gnn_trainer.fit(gnn_model, ckpt_path=gnn_artifact_path)\n",
+    "    return gnn_trainer, gnn_model\n",
+    "\n",
+    "gnn_trainer, gnn_model = continue_gnn_training(CONFIG)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "gnn_model = InteractionGNN.load_from_checkpoint(\n",
+    "    gnn_artifact_path,\n",
+    "    # map_location=\"cpu\",\n",
+    "    # hparams=load_config(CONFIG)[\"gnn\"],\n",
+    ")\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Plot training metrics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>epoch</th>\n",
+       "      <th>train_loss</th>\n",
+       "      <th>val_loss</th>\n",
+       "      <th>eff</th>\n",
+       "      <th>pur</th>\n",
+       "      <th>current_lr</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>0</td>\n",
+       "      <td>0.015745</td>\n",
+       "      <td>0.009409</td>\n",
+       "      <td>0.882107</td>\n",
+       "      <td>0.979640</td>\n",
+       "      <td>0.000020</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>1</td>\n",
+       "      <td>0.007946</td>\n",
+       "      <td>0.006557</td>\n",
+       "      <td>0.922102</td>\n",
+       "      <td>0.986181</td>\n",
+       "      <td>0.000040</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>2</td>\n",
+       "      <td>0.006280</td>\n",
+       "      <td>0.005664</td>\n",
+       "      <td>0.927634</td>\n",
+       "      <td>0.990096</td>\n",
+       "      <td>0.000060</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>3</td>\n",
+       "      <td>0.005555</td>\n",
+       "      <td>0.005122</td>\n",
+       "      <td>0.937812</td>\n",
+       "      <td>0.990240</td>\n",
+       "      <td>0.000080</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>4</td>\n",
+       "      <td>0.005132</td>\n",
+       "      <td>0.005172</td>\n",
+       "      <td>0.932919</td>\n",
+       "      <td>0.991395</td>\n",
+       "      <td>0.000100</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>5</td>\n",
+       "      <td>0.004851</td>\n",
+       "      <td>0.005073</td>\n",
+       "      <td>0.937282</td>\n",
+       "      <td>0.990720</td>\n",
+       "      <td>0.000120</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>6</td>\n",
+       "      <td>0.004646</td>\n",
+       "      <td>0.005248</td>\n",
+       "      <td>0.930109</td>\n",
+       "      <td>0.991742</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>7</td>\n",
+       "      <td>0.004510</td>\n",
+       "      <td>0.004563</td>\n",
+       "      <td>0.940789</td>\n",
+       "      <td>0.992526</td>\n",
+       "      <td>0.000112</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>8</td>\n",
+       "      <td>0.004373</td>\n",
+       "      <td>0.004509</td>\n",
+       "      <td>0.946273</td>\n",
+       "      <td>0.991391</td>\n",
+       "      <td>0.000180</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>9</td>\n",
+       "      <td>0.004298</td>\n",
+       "      <td>0.004374</td>\n",
+       "      <td>0.947959</td>\n",
+       "      <td>0.991694</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>10</th>\n",
+       "      <td>10</td>\n",
+       "      <td>0.004109</td>\n",
+       "      <td>0.004050</td>\n",
+       "      <td>0.953197</td>\n",
+       "      <td>0.992015</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>11</th>\n",
+       "      <td>11</td>\n",
+       "      <td>0.003983</td>\n",
+       "      <td>0.004233</td>\n",
+       "      <td>0.954432</td>\n",
+       "      <td>0.990489</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>12</th>\n",
+       "      <td>12</td>\n",
+       "      <td>0.003879</td>\n",
+       "      <td>0.003788</td>\n",
+       "      <td>0.958103</td>\n",
+       "      <td>0.991964</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>13</th>\n",
+       "      <td>13</td>\n",
+       "      <td>0.003778</td>\n",
+       "      <td>0.003780</td>\n",
+       "      <td>0.960245</td>\n",
+       "      <td>0.991730</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>14</th>\n",
+       "      <td>14</td>\n",
+       "      <td>0.003703</td>\n",
+       "      <td>0.003744</td>\n",
+       "      <td>0.959127</td>\n",
+       "      <td>0.992023</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>15</th>\n",
+       "      <td>15</td>\n",
+       "      <td>0.003641</td>\n",
+       "      <td>0.003704</td>\n",
+       "      <td>0.959411</td>\n",
+       "      <td>0.992099</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>16</th>\n",
+       "      <td>16</td>\n",
+       "      <td>0.003305</td>\n",
+       "      <td>0.003534</td>\n",
+       "      <td>0.961065</td>\n",
+       "      <td>0.992371</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>17</th>\n",
+       "      <td>17</td>\n",
+       "      <td>0.003253</td>\n",
+       "      <td>0.003618</td>\n",
+       "      <td>0.963001</td>\n",
+       "      <td>0.991412</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>18</th>\n",
+       "      <td>18</td>\n",
+       "      <td>0.003225</td>\n",
+       "      <td>0.003503</td>\n",
+       "      <td>0.963811</td>\n",
+       "      <td>0.991958</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>19</th>\n",
+       "      <td>19</td>\n",
+       "      <td>0.003224</td>\n",
+       "      <td>0.003467</td>\n",
+       "      <td>0.964852</td>\n",
+       "      <td>0.991678</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>20</th>\n",
+       "      <td>20</td>\n",
+       "      <td>0.003200</td>\n",
+       "      <td>0.003380</td>\n",
+       "      <td>0.964631</td>\n",
+       "      <td>0.992555</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>21</th>\n",
+       "      <td>21</td>\n",
+       "      <td>0.003180</td>\n",
+       "      <td>0.003469</td>\n",
+       "      <td>0.965098</td>\n",
+       "      <td>0.992115</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>22</th>\n",
+       "      <td>22</td>\n",
+       "      <td>0.003160</td>\n",
+       "      <td>0.003374</td>\n",
+       "      <td>0.964880</td>\n",
+       "      <td>0.992387</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>23</th>\n",
+       "      <td>23</td>\n",
+       "      <td>0.003152</td>\n",
+       "      <td>0.003522</td>\n",
+       "      <td>0.965240</td>\n",
+       "      <td>0.991543</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>24</th>\n",
+       "      <td>24</td>\n",
+       "      <td>0.002912</td>\n",
+       "      <td>0.003308</td>\n",
+       "      <td>0.968241</td>\n",
+       "      <td>0.992033</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>25</th>\n",
+       "      <td>25</td>\n",
+       "      <td>0.002879</td>\n",
+       "      <td>0.003358</td>\n",
+       "      <td>0.968242</td>\n",
+       "      <td>0.991781</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>26</th>\n",
+       "      <td>26</td>\n",
+       "      <td>0.002879</td>\n",
+       "      <td>0.003365</td>\n",
+       "      <td>0.968507</td>\n",
+       "      <td>0.991775</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>27</th>\n",
+       "      <td>27</td>\n",
+       "      <td>0.002853</td>\n",
+       "      <td>0.003240</td>\n",
+       "      <td>0.969288</td>\n",
+       "      <td>0.992256</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>28</th>\n",
+       "      <td>28</td>\n",
+       "      <td>0.002858</td>\n",
+       "      <td>0.003385</td>\n",
+       "      <td>0.968127</td>\n",
+       "      <td>0.991958</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>29</th>\n",
+       "      <td>29</td>\n",
+       "      <td>0.002849</td>\n",
+       "      <td>0.003389</td>\n",
+       "      <td>0.969099</td>\n",
+       "      <td>0.991858</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>30</th>\n",
+       "      <td>30</td>\n",
+       "      <td>0.002855</td>\n",
+       "      <td>0.003413</td>\n",
+       "      <td>0.966039</td>\n",
+       "      <td>0.992158</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>31</th>\n",
+       "      <td>31</td>\n",
+       "      <td>0.002830</td>\n",
+       "      <td>0.003342</td>\n",
+       "      <td>0.969118</td>\n",
+       "      <td>0.991828</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>32</th>\n",
+       "      <td>32</td>\n",
+       "      <td>0.002659</td>\n",
+       "      <td>0.003288</td>\n",
+       "      <td>0.969825</td>\n",
+       "      <td>0.992203</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>33</th>\n",
+       "      <td>33</td>\n",
+       "      <td>0.002611</td>\n",
+       "      <td>0.002948</td>\n",
+       "      <td>0.972177</td>\n",
+       "      <td>0.993180</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>34</th>\n",
+       "      <td>34</td>\n",
+       "      <td>0.002425</td>\n",
+       "      <td>0.002649</td>\n",
+       "      <td>0.974883</td>\n",
+       "      <td>0.994110</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>35</th>\n",
+       "      <td>35</td>\n",
+       "      <td>0.002253</td>\n",
+       "      <td>0.002432</td>\n",
+       "      <td>0.977438</td>\n",
+       "      <td>0.994554</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>36</th>\n",
+       "      <td>36</td>\n",
+       "      <td>0.002139</td>\n",
+       "      <td>0.002351</td>\n",
+       "      <td>0.977666</td>\n",
+       "      <td>0.994845</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>37</th>\n",
+       "      <td>37</td>\n",
+       "      <td>0.002048</td>\n",
+       "      <td>0.002227</td>\n",
+       "      <td>0.978469</td>\n",
+       "      <td>0.995250</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>38</th>\n",
+       "      <td>38</td>\n",
+       "      <td>0.001969</td>\n",
+       "      <td>0.002170</td>\n",
+       "      <td>0.979098</td>\n",
+       "      <td>0.995544</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>39</th>\n",
+       "      <td>39</td>\n",
+       "      <td>0.001897</td>\n",
+       "      <td>0.001969</td>\n",
+       "      <td>0.981175</td>\n",
+       "      <td>0.995911</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>40</th>\n",
+       "      <td>40</td>\n",
+       "      <td>0.001624</td>\n",
+       "      <td>0.001698</td>\n",
+       "      <td>0.982833</td>\n",
+       "      <td>0.996831</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>41</th>\n",
+       "      <td>41</td>\n",
+       "      <td>0.001523</td>\n",
+       "      <td>0.001631</td>\n",
+       "      <td>0.983602</td>\n",
+       "      <td>0.996945</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>42</th>\n",
+       "      <td>42</td>\n",
+       "      <td>0.001450</td>\n",
+       "      <td>0.001593</td>\n",
+       "      <td>0.984215</td>\n",
+       "      <td>0.997049</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>43</th>\n",
+       "      <td>43</td>\n",
+       "      <td>0.001411</td>\n",
+       "      <td>0.001537</td>\n",
+       "      <td>0.984988</td>\n",
+       "      <td>0.997071</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>44</th>\n",
+       "      <td>44</td>\n",
+       "      <td>0.001374</td>\n",
+       "      <td>0.001585</td>\n",
+       "      <td>0.984472</td>\n",
+       "      <td>0.997054</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>45</th>\n",
+       "      <td>45</td>\n",
+       "      <td>0.001344</td>\n",
+       "      <td>0.001564</td>\n",
+       "      <td>0.984906</td>\n",
+       "      <td>0.997020</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>46</th>\n",
+       "      <td>46</td>\n",
+       "      <td>0.001334</td>\n",
+       "      <td>0.001556</td>\n",
+       "      <td>0.984871</td>\n",
+       "      <td>0.996946</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>47</th>\n",
+       "      <td>47</td>\n",
+       "      <td>0.001305</td>\n",
+       "      <td>0.001551</td>\n",
+       "      <td>0.984850</td>\n",
+       "      <td>0.997180</td>\n",
+       "      <td>0.000034</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>48</th>\n",
+       "      <td>48</td>\n",
+       "      <td>0.001215</td>\n",
+       "      <td>0.001518</td>\n",
+       "      <td>0.985659</td>\n",
+       "      <td>0.997226</td>\n",
+       "      <td>0.000034</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>49</th>\n",
+       "      <td>49</td>\n",
+       "      <td>0.001197</td>\n",
+       "      <td>0.001530</td>\n",
+       "      <td>0.985763</td>\n",
+       "      <td>0.997204</td>\n",
+       "      <td>0.000034</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "    epoch  train_loss  val_loss       eff       pur  current_lr\n",
+       "0       0    0.015745  0.009409  0.882107  0.979640    0.000020\n",
+       "1       1    0.007946  0.006557  0.922102  0.986181    0.000040\n",
+       "2       2    0.006280  0.005664  0.927634  0.990096    0.000060\n",
+       "3       3    0.005555  0.005122  0.937812  0.990240    0.000080\n",
+       "4       4    0.005132  0.005172  0.932919  0.991395    0.000100\n",
+       "5       5    0.004851  0.005073  0.937282  0.990720    0.000120\n",
+       "6       6    0.004646  0.005248  0.930109  0.991742    0.000140\n",
+       "7       7    0.004510  0.004563  0.940789  0.992526    0.000112\n",
+       "8       8    0.004373  0.004509  0.946273  0.991391    0.000180\n",
+       "9       9    0.004298  0.004374  0.947959  0.991694    0.000200\n",
+       "10     10    0.004109  0.004050  0.953197  0.992015    0.000200\n",
+       "11     11    0.003983  0.004233  0.954432  0.990489    0.000200\n",
+       "12     12    0.003879  0.003788  0.958103  0.991964    0.000200\n",
+       "13     13    0.003778  0.003780  0.960245  0.991730    0.000200\n",
+       "14     14    0.003703  0.003744  0.959127  0.992023    0.000200\n",
+       "15     15    0.003641  0.003704  0.959411  0.992099    0.000140\n",
+       "16     16    0.003305  0.003534  0.961065  0.992371    0.000140\n",
+       "17     17    0.003253  0.003618  0.963001  0.991412    0.000140\n",
+       "18     18    0.003225  0.003503  0.963811  0.991958    0.000140\n",
+       "19     19    0.003224  0.003467  0.964852  0.991678    0.000140\n",
+       "20     20    0.003200  0.003380  0.964631  0.992555    0.000140\n",
+       "21     21    0.003180  0.003469  0.965098  0.992115    0.000140\n",
+       "22     22    0.003160  0.003374  0.964880  0.992387    0.000140\n",
+       "23     23    0.003152  0.003522  0.965240  0.991543    0.000098\n",
+       "24     24    0.002912  0.003308  0.968241  0.992033    0.000098\n",
+       "25     25    0.002879  0.003358  0.968242  0.991781    0.000098\n",
+       "26     26    0.002879  0.003365  0.968507  0.991775    0.000098\n",
+       "27     27    0.002853  0.003240  0.969288  0.992256    0.000098\n",
+       "28     28    0.002858  0.003385  0.968127  0.991958    0.000098\n",
+       "29     29    0.002849  0.003389  0.969099  0.991858    0.000098\n",
+       "30     30    0.002855  0.003413  0.966039  0.992158    0.000098\n",
+       "31     31    0.002830  0.003342  0.969118  0.991828    0.000069\n",
+       "32     32    0.002659  0.003288  0.969825  0.992203    0.000069\n",
+       "33     33    0.002611  0.002948  0.972177  0.993180    0.000069\n",
+       "34     34    0.002425  0.002649  0.974883  0.994110    0.000069\n",
+       "35     35    0.002253  0.002432  0.977438  0.994554    0.000069\n",
+       "36     36    0.002139  0.002351  0.977666  0.994845    0.000069\n",
+       "37     37    0.002048  0.002227  0.978469  0.995250    0.000069\n",
+       "38     38    0.001969  0.002170  0.979098  0.995544    0.000069\n",
+       "39     39    0.001897  0.001969  0.981175  0.995911    0.000048\n",
+       "40     40    0.001624  0.001698  0.982833  0.996831    0.000048\n",
+       "41     41    0.001523  0.001631  0.983602  0.996945    0.000048\n",
+       "42     42    0.001450  0.001593  0.984215  0.997049    0.000048\n",
+       "43     43    0.001411  0.001537  0.984988  0.997071    0.000048\n",
+       "44     44    0.001374  0.001585  0.984472  0.997054    0.000048\n",
+       "45     45    0.001344  0.001564  0.984906  0.997020    0.000048\n",
+       "46     46    0.001334  0.001556  0.984871  0.996946    0.000048\n",
+       "47     47    0.001305  0.001551  0.984850  0.997180    0.000034\n",
+       "48     48    0.001215  0.001518  0.985659  0.997226    0.000034\n",
+       "49     49    0.001197  0.001530  0.985763  0.997204    0.000034"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# gnn_metrics = checkpoint_utils.get_training_metrics(gnn_trainer)\n",
+    "\n",
+    "gnn_metrics = checkpoint_utils.get_training_metrics(gnn_metric_path)\n",
+    "\n",
+    "gnn_metrics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Figure was saved in output/focal-loss-pid-fixed/loss_gnn.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/loss_gnn.png\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(<Figure size 800x600 with 1 Axes>, <Axes: xlabel='Epoch', ylabel='Loss'>)"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 800x600 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "perfplot_mpl.plot_loss(gnn_metrics, CONFIG, \"gnn\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Figure was saved in output/focal-loss-pid-fixed/eff_gnn.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/eff_gnn.png\n",
+      "Figure was saved in output/focal-loss-pid-fixed/pur_gnn.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/pur_gnn.png\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(<Figure size 800x600 with 1 Axes>,\n",
+       " <Axes: xlabel='Epoch', ylabel='Edge Purity'>)"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 800x600 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 800x600 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "perfplot_mpl.plot_metric_epochs(\n",
+    "    \"eff\",\n",
+    "    gnn_metrics,\n",
+    "    CONFIG,\n",
+    "    \"gnn\",\n",
+    "    \"Edge Efficiency\",\n",
+    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
+    ")\n",
+    "perfplot_mpl.plot_metric_epochs(\n",
+    "    \"pur\",\n",
+    "    gnn_metrics,\n",
+    "    CONFIG,\n",
+    "    \"gnn\",\n",
+    "    \"Edge Purity\",\n",
+    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
+    ")\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaluate model performance on sample test data\n",
+    "\n",
+    "Here we evaluate the model performace on one sample test data. We look at how the efficiency and purity change with the embedding radius."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load 1000 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "0b5c1718765a4379b0a98e86f0baacde",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n",
+      "WARNING:Unable to obtain driver using Selenium Manager: /scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/selenium/webdriver/common/linux/selenium-manager is missing.  Please open an issue on https://github.com/SeleniumHQ/selenium/issues\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<IPython.core.display.Image object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "gnn_model.load_partition(\"velo-sim10b-nospillover\")\n",
+    "perfplot.plot_edge_performance(gnn_model, CONFIG)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load 200 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "91ddae17207d4b9a9e133df76e7887f6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c737875249e04c03b5c38e3ec9955499",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ac377279aafc4b5180f4639ca6d039e9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f85018c85c654d548916880638dee82e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/16 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "27babb13915a4937bb5ecb72296ee5eb",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7f6a4f1b645b4747b13c741e234ed7bc",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fe292d18a81c464a97ed34ef443e2fa6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4135aaeedd4c4ce3ab377fab149dba5a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "8f2216787d9d4aeaa02b8b9bdfefaa40",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "6992b033ce614752878da7593a75cfff",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ed9f3c809e7c4e51b44a358ec82b78b7",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "896ef3db87ee47f2a44d9737076232c2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4b6e0862e10a4d3895142114ea69d609",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2c942e6fc1354e4e8c29448b4849bcb9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "81517adcee3f45f2aabf4180ccf00a19",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5f6e952c17a0420aac2e2e011bdfb016",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7aef953e9dd34c40b680254ae973cace",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9bd250618ba44ce5a74a9776fb8580b2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "47d644dc6dbf4c9cbcb40c56ada0ed92",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b4989466e64246258d9c54882fb24ba5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Figure was saved in output/focal-loss-pid-fixed/performance_given_score_cut.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/performance_given_score_cut.png\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 2400x600 with 3 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from GNN.gnn_plots import plot_best_performances_score_cut\n",
+    "\n",
+    "_, _, performances_for_various_score_cuts = plot_best_performances_score_cut(\n",
+    "    model=gnn_model,\n",
+    "    path_or_config=CONFIG,\n",
+    "    partition=\"velo-sim10b-nospillover\",\n",
+    "    score_cuts=[\n",
+    "        0.1,\n",
+    "        0.2,\n",
+    "        0.3,\n",
+    "        0.4,\n",
+    "        0.42,\n",
+    "        0.43,\n",
+    "        0.44,\n",
+    "        0.45,\n",
+    "        0.46,\n",
+    "        0.47,\n",
+    "        0.48,\n",
+    "        0.5,\n",
+    "        0.6,\n",
+    "        0.7,\n",
+    "        0.8,\n",
+    "        0.9,\n",
+    "    ],\n",
+    "    n_events=200,\n",
+    "    seed=0,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 4. GNN inference "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:--------------------- Step 4: Scoring graph edges using GNN  ---------------------\n",
+      "INFO:---------------------------- a) Loading trained model ----------------------------\n",
+      "INFO:Load model from artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints/epoch=49-step=500000.ckpt.\n",
+      "INFO:Load model from artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints/epoch=49-step=500000.ckpt.\n",
+      "INFO:----------------------------- b) Running inferencing -----------------------------\n",
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover to /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a6cb9afd5cb9438e8880192981766da0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n",
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover-only-long-electrons`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover-only-long-electrons to /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover-only-long-electrons\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a1af6c66f0574263ba085bcdea898a7a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "run_gnn_inference(CONFIG, partitions=[\"test\"], checkpoint=gnn_artifact_path)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 5. Build track candidates from GNN"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:-----------  Step 5: Building track candidates from the scored graph  -----------\n",
+      "INFO:Score cut: 0.45\n",
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover to /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "15ad24e1d6f1444b9d3242e715a0dc9e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover-only-long-electrons`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover-only-long-electrons to /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover-only-long-electrons\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e865558d4d084522b9a35a2f76ee8a29",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "build_track_candidates(CONFIG, partitions=[\"test\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 6. Evaluate track candidates on the same data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "evaluate_candidates_montetracko(\n",
+    "    CONFIG,\n",
+    "    partition=\"train\",\n",
+    "    allen_report=True,\n",
+    "    table_report=True,\n",
+    ")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "evaluate_candidates_montetracko(\n",
+    "    CONFIG,\n",
+    "    partition=\"val\",\n",
+    "    allen_report=True,\n",
+    "    table_report=True,\n",
+    "    plot_categories=[],\n",
+    ")\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# 7. Evaluate track candidates on unseen data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:---------------------------- velo-sim10b-nospillover ----------------------------\n",
+      "INFO:--------------------- Evaluation for velo-sim10b-nospillover ---------------------\n",
+      "INFO:1) Load dataframe of tracks, hits-particles and particles\n",
+      "INFO:Load tracks in /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "093d1a9e10864656acccf65296bb9b51",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load pre-processed test datasets in /scratch/acorreia/data/__test__/velo-sim10b-nospillover.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "abc8329945a54b44a5747708f1256c12",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3b22d2cf13c24237bb0c6b187f60bc29",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Compute plat stats\n",
+      "INFO:2) Matching\n",
+      "INFO:3) Evaluation\n",
+      "INFO:Report was saved in output/focal-loss-pid-fixed/report-2023.06.11-12.04.53-velo-sim10b-nospillover.txt\n",
+      "INFO:------------------ velo-sim10b-nospillover-only-long-electrons ------------------\n",
+      "INFO:----------- Evaluation for velo-sim10b-nospillover-only-long-electrons -----------\n",
+      "INFO:1) Load dataframe of tracks, hits-particles and particles\n",
+      "INFO:Load tracks in /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover-only-long-electrons.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "TrackChecker output                               :      2020/   243242   0.83% ghosts\n",
+      "01_velo                                           :    101318/   104345  97.10% ( 97.24%),       837 (  0.82%) clones, pur  99.86%, hit eff  97.73%\n",
+      "02_long                                           :     58330/    59167  98.59% ( 98.62%),       457 (  0.78%) clones, pur  99.93%, hit eff  98.31%\n",
+      "03_long_P>5GeV                                    :     37865/    38150  99.25% ( 99.26%),       231 (  0.61%) clones, pur  99.93%, hit eff  98.85%\n",
+      "04_long_strange                                   :      2971/     3142  94.56% ( 94.88%),        38 (  1.26%) clones, pur  99.64%, hit eff  95.25%\n",
+      "05_long_strange_P>5GeV                            :      1443/     1521  94.87% ( 94.91%),         7 (  0.48%) clones, pur  99.58%, hit eff  97.68%\n",
+      "06_long_fromB                                     :       120/      120 100.00% (100.00%),         2 (  1.64%) clones, pur  99.80%, hit eff  98.13%\n",
+      "07_long_fromB_P>5GeV                              :        87/       87 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  99.81%\n",
+      "08_long_electrons                                 :      3459/     4198  82.40% ( 82.85%),        74 (  2.09%) clones, pur  98.84%, hit eff  87.95%\n",
+      "09_long_fromB_electrons                           :        10/       10 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  93.33%\n",
+      "10_long_fromB_electrons_P>5GeV                    :         7/        7 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  96.43%\n",
+      "\n",
+      "| Categories           | Efficiency   | Average efficiency   | % clones   | Average hit purity   | Average hit efficiency   |\n",
+      "|:---------------------|:-------------|:---------------------|:-----------|:---------------------|:-------------------------|\n",
+      "| Velo                 | 93.82%       | 94.10%               | 1.06%      | 99.75%               | 96.26%                   |\n",
+      "| Long                 | 97.51%       | 97.58%               | 0.85%      | 99.87%               | 97.72%                   |\n",
+      "| Velo, no electrons   | 97.10%       | 97.24%               | 0.82%      | 99.86%               | 97.73%                   |\n",
+      "| Velo, only electrons | 76.99%       | 77.37%               | 2.57%      | 99.03%               | 86.85%                   |\n",
+      "| Long, only electrons | 82.40%       | 82.85%               | 2.09%      | 98.84%               | 87.95%                   |\n",
+      "| Categories   | # ghosts   | # tracks   | % ghosts   |\n",
+      "|:-------------|:-----------|:-----------|:-----------|\n",
+      "| Everything   | 2,020      | 243,242    | 0.83%      |\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c118aca80cae4b62831251402e288928",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load pre-processed test datasets in /scratch/acorreia/data/__test__/velo-sim10b-nospillover-only-long-electrons.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3bffc5f3235146cc89114053b6cc801c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9eae617c6d194998adb1a1469ab68796",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Compute plat stats\n",
+      "INFO:2) Matching\n",
+      "INFO:3) Evaluation\n",
+      "INFO:Report was saved in output/focal-loss-pid-fixed/report-2023.06.11-12.05.13-velo-sim10b-nospillover-only-long-electrons.txt\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "TrackChecker output                               :        79/     4432   1.78% ghosts\n",
+      "01_velo                                           :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "02_long                                           :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "03_long_P>5GeV                                    :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "04_long_strange                                   :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "05_long_strange_P>5GeV                            :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "06_long_fromB                                     :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "07_long_fromB_P>5GeV                              :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "08_long_electrons                                 :      4286/     4670  91.78% ( 93.18%),        37 (  0.86%) clones, pur  99.29%, hit eff  95.56%\n",
+      "09_long_fromB_electrons                           :        10/       10 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  97.50%\n",
+      "10_long_fromB_electrons_P>5GeV                    :         7/        7 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  96.43%\n",
+      "\n",
+      "| Categories           | Efficiency   | Average efficiency   | % clones   | Average hit purity   | Average hit efficiency   |\n",
+      "|:---------------------|:-------------|:---------------------|:-----------|:---------------------|:-------------------------|\n",
+      "| Velo                 | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Long                 | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Velo, no electrons   | nan%         | nan%                 | nan%       | nan%                 | nan%                     |\n",
+      "| Velo, only electrons | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Long, only electrons | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Categories   |   # ghosts | # tracks   | % ghosts   |\n",
+      "|:-------------|-----------:|:-----------|:-----------|\n",
+      "| Everything   |         79 | 4,432      | 1.78%      |\n"
+     ]
+    }
+   ],
+   "source": [
+    "for test_dataset_name in get_required_test_dataset_names(CONFIG):\n",
+    "    logging.info(headline(test_dataset_name))\n",
+    "    evaluate_candidates_montetracko(\n",
+    "        CONFIG,\n",
+    "        partition=test_dataset_name,\n",
+    "        allen_report=True,\n",
+    "        table_report=True,\n",
+    "        plot_categories=[],\n",
+    "    )\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trackEvaluator = evaluate_candidates_montetracko(\n",
+    "    CONFIG,\n",
+    "    partition=\"velo-sim10b-nospillover\",\n",
+    "    allen_report=True,\n",
+    "    table_report=True,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.10"
+  },
+  "vscode": {
+   "interpreter": {
+    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed.ipynb b/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed.ipynb
index b10c97baf089dfc18fc58a1033e24d1c678e67bf..a3a82d58e8daa3d5ecff77d40e2c00747f7222b0 100644
--- a/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed.ipynb
+++ b/LHCb_Pipeline/full_pipeline-focal-loss-pid-fixed.ipynb
@@ -1,6 +1,7 @@
 {
  "cells": [
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -8,6 +9,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -16,608 +18,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/html": [
-       "<div class=\"bk-root\">\n",
-       "        <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
-       "        <span id=\"1002\">Loading BokehJS ...</span>\n",
-       "    </div>\n"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/javascript": [
-       "(function(root) {\n",
-       "  function now() {\n",
-       "    return new Date();\n",
-       "  }\n",
-       "\n",
-       "  const force = true;\n",
-       "\n",
-       "  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n",
-       "    root._bokeh_onload_callbacks = [];\n",
-       "    root._bokeh_is_loading = undefined;\n",
-       "  }\n",
-       "\n",
-       "const JS_MIME_TYPE = 'application/javascript';\n",
-       "  const HTML_MIME_TYPE = 'text/html';\n",
-       "  const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n",
-       "  const CLASS_NAME = 'output_bokeh rendered_html';\n",
-       "\n",
-       "  /**\n",
-       "   * Render data to the DOM node\n",
-       "   */\n",
-       "  function render(props, node) {\n",
-       "    const script = document.createElement(\"script\");\n",
-       "    node.appendChild(script);\n",
-       "  }\n",
-       "\n",
-       "  /**\n",
-       "   * Handle when an output is cleared or removed\n",
-       "   */\n",
-       "  function handleClearOutput(event, handle) {\n",
-       "    const cell = handle.cell;\n",
-       "\n",
-       "    const id = cell.output_area._bokeh_element_id;\n",
-       "    const server_id = cell.output_area._bokeh_server_id;\n",
-       "    // Clean up Bokeh references\n",
-       "    if (id != null && id in Bokeh.index) {\n",
-       "      Bokeh.index[id].model.document.clear();\n",
-       "      delete Bokeh.index[id];\n",
-       "    }\n",
-       "\n",
-       "    if (server_id !== undefined) {\n",
-       "      // Clean up Bokeh references\n",
-       "      const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n",
-       "      cell.notebook.kernel.execute(cmd_clean, {\n",
-       "        iopub: {\n",
-       "          output: function(msg) {\n",
-       "            const id = msg.content.text.trim();\n",
-       "            if (id in Bokeh.index) {\n",
-       "              Bokeh.index[id].model.document.clear();\n",
-       "              delete Bokeh.index[id];\n",
-       "            }\n",
-       "          }\n",
-       "        }\n",
-       "      });\n",
-       "      // Destroy server and session\n",
-       "      const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n",
-       "      cell.notebook.kernel.execute(cmd_destroy);\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  /**\n",
-       "   * Handle when a new output is added\n",
-       "   */\n",
-       "  function handleAddOutput(event, handle) {\n",
-       "    const output_area = handle.output_area;\n",
-       "    const output = handle.output;\n",
-       "\n",
-       "    // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n",
-       "    if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n",
-       "      return\n",
-       "    }\n",
-       "\n",
-       "    const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n",
-       "\n",
-       "    if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n",
-       "      toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n",
-       "      // store reference to embed id on output_area\n",
-       "      output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n",
-       "    }\n",
-       "    if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n",
-       "      const bk_div = document.createElement(\"div\");\n",
-       "      bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n",
-       "      const script_attrs = bk_div.children[0].attributes;\n",
-       "      for (let i = 0; i < script_attrs.length; i++) {\n",
-       "        toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n",
-       "        toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n",
-       "      }\n",
-       "      // store reference to server id on output_area\n",
-       "      output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  function register_renderer(events, OutputArea) {\n",
-       "\n",
-       "    function append_mime(data, metadata, element) {\n",
-       "      // create a DOM node to render to\n",
-       "      const toinsert = this.create_output_subarea(\n",
-       "        metadata,\n",
-       "        CLASS_NAME,\n",
-       "        EXEC_MIME_TYPE\n",
-       "      );\n",
-       "      this.keyboard_manager.register_events(toinsert);\n",
-       "      // Render to node\n",
-       "      const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n",
-       "      render(props, toinsert[toinsert.length - 1]);\n",
-       "      element.append(toinsert);\n",
-       "      return toinsert\n",
-       "    }\n",
-       "\n",
-       "    /* Handle when an output is cleared or removed */\n",
-       "    events.on('clear_output.CodeCell', handleClearOutput);\n",
-       "    events.on('delete.Cell', handleClearOutput);\n",
-       "\n",
-       "    /* Handle when a new output is added */\n",
-       "    events.on('output_added.OutputArea', handleAddOutput);\n",
-       "\n",
-       "    /**\n",
-       "     * Register the mime type and append_mime function with output_area\n",
-       "     */\n",
-       "    OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n",
-       "      /* Is output safe? */\n",
-       "      safe: true,\n",
-       "      /* Index of renderer in `output_area.display_order` */\n",
-       "      index: 0\n",
-       "    });\n",
-       "  }\n",
-       "\n",
-       "  // register the mime type if in Jupyter Notebook environment and previously unregistered\n",
-       "  if (root.Jupyter !== undefined) {\n",
-       "    const events = require('base/js/events');\n",
-       "    const OutputArea = require('notebook/js/outputarea').OutputArea;\n",
-       "\n",
-       "    if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n",
-       "      register_renderer(events, OutputArea);\n",
-       "    }\n",
-       "  }\n",
-       "  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n",
-       "    root._bokeh_timeout = Date.now() + 5000;\n",
-       "    root._bokeh_failed_load = false;\n",
-       "  }\n",
-       "\n",
-       "  const NB_LOAD_WARNING = {'data': {'text/html':\n",
-       "     \"<div style='background-color: #fdd'>\\n\"+\n",
-       "     \"<p>\\n\"+\n",
-       "     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
-       "     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
-       "     \"</p>\\n\"+\n",
-       "     \"<ul>\\n\"+\n",
-       "     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
-       "     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
-       "     \"</ul>\\n\"+\n",
-       "     \"<code>\\n\"+\n",
-       "     \"from bokeh.resources import INLINE\\n\"+\n",
-       "     \"output_notebook(resources=INLINE)\\n\"+\n",
-       "     \"</code>\\n\"+\n",
-       "     \"</div>\"}};\n",
-       "\n",
-       "  function display_loaded() {\n",
-       "    const el = document.getElementById(\"1002\");\n",
-       "    if (el != null) {\n",
-       "      el.textContent = \"BokehJS is loading...\";\n",
-       "    }\n",
-       "    if (root.Bokeh !== undefined) {\n",
-       "      if (el != null) {\n",
-       "        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n",
-       "      }\n",
-       "    } else if (Date.now() < root._bokeh_timeout) {\n",
-       "      setTimeout(display_loaded, 100)\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  function run_callbacks() {\n",
-       "    try {\n",
-       "      root._bokeh_onload_callbacks.forEach(function(callback) {\n",
-       "        if (callback != null)\n",
-       "          callback();\n",
-       "      });\n",
-       "    } finally {\n",
-       "      delete root._bokeh_onload_callbacks\n",
-       "    }\n",
-       "    console.debug(\"Bokeh: all callbacks have finished\");\n",
-       "  }\n",
-       "\n",
-       "  function load_libs(css_urls, js_urls, callback) {\n",
-       "    if (css_urls == null) css_urls = [];\n",
-       "    if (js_urls == null) js_urls = [];\n",
-       "\n",
-       "    root._bokeh_onload_callbacks.push(callback);\n",
-       "    if (root._bokeh_is_loading > 0) {\n",
-       "      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
-       "      return null;\n",
-       "    }\n",
-       "    if (js_urls == null || js_urls.length === 0) {\n",
-       "      run_callbacks();\n",
-       "      return null;\n",
-       "    }\n",
-       "    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
-       "    root._bokeh_is_loading = css_urls.length + js_urls.length;\n",
-       "\n",
-       "    function on_load() {\n",
-       "      root._bokeh_is_loading--;\n",
-       "      if (root._bokeh_is_loading === 0) {\n",
-       "        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n",
-       "        run_callbacks()\n",
-       "      }\n",
-       "    }\n",
-       "\n",
-       "    function on_error(url) {\n",
-       "      console.error(\"failed to load \" + url);\n",
-       "    }\n",
-       "\n",
-       "    for (let i = 0; i < css_urls.length; i++) {\n",
-       "      const url = css_urls[i];\n",
-       "      const element = document.createElement(\"link\");\n",
-       "      element.onload = on_load;\n",
-       "      element.onerror = on_error.bind(null, url);\n",
-       "      element.rel = \"stylesheet\";\n",
-       "      element.type = \"text/css\";\n",
-       "      element.href = url;\n",
-       "      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n",
-       "      document.body.appendChild(element);\n",
-       "    }\n",
-       "\n",
-       "    for (let i = 0; i < js_urls.length; i++) {\n",
-       "      const url = js_urls[i];\n",
-       "      const element = document.createElement('script');\n",
-       "      element.onload = on_load;\n",
-       "      element.onerror = on_error.bind(null, url);\n",
-       "      element.async = false;\n",
-       "      element.src = url;\n",
-       "      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
-       "      document.head.appendChild(element);\n",
-       "    }\n",
-       "  };\n",
-       "\n",
-       "  function inject_raw_css(css) {\n",
-       "    const element = document.createElement(\"style\");\n",
-       "    element.appendChild(document.createTextNode(css));\n",
-       "    document.body.appendChild(element);\n",
-       "  }\n",
-       "\n",
-       "  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n",
-       "  const css_urls = [];\n",
-       "\n",
-       "  const inline_js = [    function(Bokeh) {\n",
-       "      Bokeh.set_log_level(\"info\");\n",
-       "    },\n",
-       "function(Bokeh) {\n",
-       "    }\n",
-       "  ];\n",
-       "\n",
-       "  function run_inline_js() {\n",
-       "    if (root.Bokeh !== undefined || force === true) {\n",
-       "          for (let i = 0; i < inline_js.length; i++) {\n",
-       "      inline_js[i].call(root, root.Bokeh);\n",
-       "    }\n",
-       "if (force === true) {\n",
-       "        display_loaded();\n",
-       "      }} else if (Date.now() < root._bokeh_timeout) {\n",
-       "      setTimeout(run_inline_js, 100);\n",
-       "    } else if (!root._bokeh_failed_load) {\n",
-       "      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
-       "      root._bokeh_failed_load = true;\n",
-       "    } else if (force !== true) {\n",
-       "      const cell = $(document.getElementById(\"1002\")).parents('.cell').data().cell;\n",
-       "      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  if (root._bokeh_is_loading === 0) {\n",
-       "    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
-       "    run_inline_js();\n",
-       "  } else {\n",
-       "    load_libs(css_urls, js_urls, function() {\n",
-       "      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n",
-       "      run_inline_js();\n",
-       "    });\n",
-       "  }\n",
-       "}(window));"
-      ],
-      "application/vnd.bokehjs_load.v0+json": "(function(root) {\n  function now() {\n    return new Date();\n  }\n\n  const force = true;\n\n  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n    root._bokeh_onload_callbacks = [];\n    root._bokeh_is_loading = undefined;\n  }\n\n\n  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n    root._bokeh_timeout = Date.now() + 5000;\n    root._bokeh_failed_load = false;\n  }\n\n  const NB_LOAD_WARNING = {'data': {'text/html':\n     \"<div style='background-color: #fdd'>\\n\"+\n     \"<p>\\n\"+\n     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n     \"</p>\\n\"+\n     \"<ul>\\n\"+\n     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n     \"</ul>\\n\"+\n     \"<code>\\n\"+\n     \"from bokeh.resources import INLINE\\n\"+\n     \"output_notebook(resources=INLINE)\\n\"+\n     \"</code>\\n\"+\n     \"</div>\"}};\n\n  function display_loaded() {\n    const el = document.getElementById(\"1002\");\n    if (el != null) {\n      el.textContent = \"BokehJS is loading...\";\n    }\n    if (root.Bokeh !== undefined) {\n      if (el != null) {\n        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n      }\n    } else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(display_loaded, 100)\n    }\n  }\n\n  function run_callbacks() {\n    try {\n      root._bokeh_onload_callbacks.forEach(function(callback) {\n        if (callback != null)\n          callback();\n      });\n    } finally {\n      delete root._bokeh_onload_callbacks\n    }\n    console.debug(\"Bokeh: all callbacks have finished\");\n  }\n\n  function load_libs(css_urls, js_urls, callback) {\n    if (css_urls == null) css_urls = [];\n    if (js_urls == null) js_urls = [];\n\n    root._bokeh_onload_callbacks.push(callback);\n    if (root._bokeh_is_loading > 0) {\n      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n      return null;\n    }\n    if (js_urls == null || js_urls.length === 0) {\n      run_callbacks();\n      return null;\n    }\n    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n    root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n    function on_load() {\n      root._bokeh_is_loading--;\n      if (root._bokeh_is_loading === 0) {\n        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n        run_callbacks()\n      }\n    }\n\n    function on_error(url) {\n      console.error(\"failed to load \" + url);\n    }\n\n    for (let i = 0; i < css_urls.length; i++) {\n      const url = css_urls[i];\n      const element = document.createElement(\"link\");\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.rel = \"stylesheet\";\n      element.type = \"text/css\";\n      element.href = url;\n      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n      document.body.appendChild(element);\n    }\n\n    for (let i = 0; i < js_urls.length; i++) {\n      const url = js_urls[i];\n      const element = document.createElement('script');\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.async = false;\n      element.src = url;\n      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n      document.head.appendChild(element);\n    }\n  };\n\n  function inject_raw_css(css) {\n    const element = document.createElement(\"style\");\n    element.appendChild(document.createTextNode(css));\n    document.body.appendChild(element);\n  }\n\n  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n  const css_urls = [];\n\n  const inline_js = [    function(Bokeh) {\n      Bokeh.set_log_level(\"info\");\n    },\nfunction(Bokeh) {\n    }\n  ];\n\n  function run_inline_js() {\n    if (root.Bokeh !== undefined || force === true) {\n          for (let i = 0; i < inline_js.length; i++) {\n      inline_js[i].call(root, root.Bokeh);\n    }\nif (force === true) {\n        display_loaded();\n      }} else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(run_inline_js, 100);\n    } else if (!root._bokeh_failed_load) {\n      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n      root._bokeh_failed_load = true;\n    } else if (force !== true) {\n      const cell = $(document.getElementById(\"1002\")).parents('.cell').data().cell;\n      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n    }\n  }\n\n  if (root._bokeh_is_loading === 0) {\n    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n    run_inline_js();\n  } else {\n    load_libs(css_urls, js_urls, function() {\n      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n      run_inline_js();\n    });\n  }\n}(window));"
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "<div class=\"bk-root\">\n",
-       "        <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
-       "        <span id=\"1003\">Loading BokehJS ...</span>\n",
-       "    </div>\n"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/javascript": [
-       "(function(root) {\n",
-       "  function now() {\n",
-       "    return new Date();\n",
-       "  }\n",
-       "\n",
-       "  const force = true;\n",
-       "\n",
-       "  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n",
-       "    root._bokeh_onload_callbacks = [];\n",
-       "    root._bokeh_is_loading = undefined;\n",
-       "  }\n",
-       "\n",
-       "const JS_MIME_TYPE = 'application/javascript';\n",
-       "  const HTML_MIME_TYPE = 'text/html';\n",
-       "  const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n",
-       "  const CLASS_NAME = 'output_bokeh rendered_html';\n",
-       "\n",
-       "  /**\n",
-       "   * Render data to the DOM node\n",
-       "   */\n",
-       "  function render(props, node) {\n",
-       "    const script = document.createElement(\"script\");\n",
-       "    node.appendChild(script);\n",
-       "  }\n",
-       "\n",
-       "  /**\n",
-       "   * Handle when an output is cleared or removed\n",
-       "   */\n",
-       "  function handleClearOutput(event, handle) {\n",
-       "    const cell = handle.cell;\n",
-       "\n",
-       "    const id = cell.output_area._bokeh_element_id;\n",
-       "    const server_id = cell.output_area._bokeh_server_id;\n",
-       "    // Clean up Bokeh references\n",
-       "    if (id != null && id in Bokeh.index) {\n",
-       "      Bokeh.index[id].model.document.clear();\n",
-       "      delete Bokeh.index[id];\n",
-       "    }\n",
-       "\n",
-       "    if (server_id !== undefined) {\n",
-       "      // Clean up Bokeh references\n",
-       "      const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n",
-       "      cell.notebook.kernel.execute(cmd_clean, {\n",
-       "        iopub: {\n",
-       "          output: function(msg) {\n",
-       "            const id = msg.content.text.trim();\n",
-       "            if (id in Bokeh.index) {\n",
-       "              Bokeh.index[id].model.document.clear();\n",
-       "              delete Bokeh.index[id];\n",
-       "            }\n",
-       "          }\n",
-       "        }\n",
-       "      });\n",
-       "      // Destroy server and session\n",
-       "      const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n",
-       "      cell.notebook.kernel.execute(cmd_destroy);\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  /**\n",
-       "   * Handle when a new output is added\n",
-       "   */\n",
-       "  function handleAddOutput(event, handle) {\n",
-       "    const output_area = handle.output_area;\n",
-       "    const output = handle.output;\n",
-       "\n",
-       "    // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n",
-       "    if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n",
-       "      return\n",
-       "    }\n",
-       "\n",
-       "    const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n",
-       "\n",
-       "    if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n",
-       "      toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n",
-       "      // store reference to embed id on output_area\n",
-       "      output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n",
-       "    }\n",
-       "    if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n",
-       "      const bk_div = document.createElement(\"div\");\n",
-       "      bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n",
-       "      const script_attrs = bk_div.children[0].attributes;\n",
-       "      for (let i = 0; i < script_attrs.length; i++) {\n",
-       "        toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n",
-       "        toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n",
-       "      }\n",
-       "      // store reference to server id on output_area\n",
-       "      output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  function register_renderer(events, OutputArea) {\n",
-       "\n",
-       "    function append_mime(data, metadata, element) {\n",
-       "      // create a DOM node to render to\n",
-       "      const toinsert = this.create_output_subarea(\n",
-       "        metadata,\n",
-       "        CLASS_NAME,\n",
-       "        EXEC_MIME_TYPE\n",
-       "      );\n",
-       "      this.keyboard_manager.register_events(toinsert);\n",
-       "      // Render to node\n",
-       "      const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n",
-       "      render(props, toinsert[toinsert.length - 1]);\n",
-       "      element.append(toinsert);\n",
-       "      return toinsert\n",
-       "    }\n",
-       "\n",
-       "    /* Handle when an output is cleared or removed */\n",
-       "    events.on('clear_output.CodeCell', handleClearOutput);\n",
-       "    events.on('delete.Cell', handleClearOutput);\n",
-       "\n",
-       "    /* Handle when a new output is added */\n",
-       "    events.on('output_added.OutputArea', handleAddOutput);\n",
-       "\n",
-       "    /**\n",
-       "     * Register the mime type and append_mime function with output_area\n",
-       "     */\n",
-       "    OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n",
-       "      /* Is output safe? */\n",
-       "      safe: true,\n",
-       "      /* Index of renderer in `output_area.display_order` */\n",
-       "      index: 0\n",
-       "    });\n",
-       "  }\n",
-       "\n",
-       "  // register the mime type if in Jupyter Notebook environment and previously unregistered\n",
-       "  if (root.Jupyter !== undefined) {\n",
-       "    const events = require('base/js/events');\n",
-       "    const OutputArea = require('notebook/js/outputarea').OutputArea;\n",
-       "\n",
-       "    if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n",
-       "      register_renderer(events, OutputArea);\n",
-       "    }\n",
-       "  }\n",
-       "  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n",
-       "    root._bokeh_timeout = Date.now() + 5000;\n",
-       "    root._bokeh_failed_load = false;\n",
-       "  }\n",
-       "\n",
-       "  const NB_LOAD_WARNING = {'data': {'text/html':\n",
-       "     \"<div style='background-color: #fdd'>\\n\"+\n",
-       "     \"<p>\\n\"+\n",
-       "     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
-       "     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
-       "     \"</p>\\n\"+\n",
-       "     \"<ul>\\n\"+\n",
-       "     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
-       "     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
-       "     \"</ul>\\n\"+\n",
-       "     \"<code>\\n\"+\n",
-       "     \"from bokeh.resources import INLINE\\n\"+\n",
-       "     \"output_notebook(resources=INLINE)\\n\"+\n",
-       "     \"</code>\\n\"+\n",
-       "     \"</div>\"}};\n",
-       "\n",
-       "  function display_loaded() {\n",
-       "    const el = document.getElementById(\"1003\");\n",
-       "    if (el != null) {\n",
-       "      el.textContent = \"BokehJS is loading...\";\n",
-       "    }\n",
-       "    if (root.Bokeh !== undefined) {\n",
-       "      if (el != null) {\n",
-       "        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n",
-       "      }\n",
-       "    } else if (Date.now() < root._bokeh_timeout) {\n",
-       "      setTimeout(display_loaded, 100)\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  function run_callbacks() {\n",
-       "    try {\n",
-       "      root._bokeh_onload_callbacks.forEach(function(callback) {\n",
-       "        if (callback != null)\n",
-       "          callback();\n",
-       "      });\n",
-       "    } finally {\n",
-       "      delete root._bokeh_onload_callbacks\n",
-       "    }\n",
-       "    console.debug(\"Bokeh: all callbacks have finished\");\n",
-       "  }\n",
-       "\n",
-       "  function load_libs(css_urls, js_urls, callback) {\n",
-       "    if (css_urls == null) css_urls = [];\n",
-       "    if (js_urls == null) js_urls = [];\n",
-       "\n",
-       "    root._bokeh_onload_callbacks.push(callback);\n",
-       "    if (root._bokeh_is_loading > 0) {\n",
-       "      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
-       "      return null;\n",
-       "    }\n",
-       "    if (js_urls == null || js_urls.length === 0) {\n",
-       "      run_callbacks();\n",
-       "      return null;\n",
-       "    }\n",
-       "    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
-       "    root._bokeh_is_loading = css_urls.length + js_urls.length;\n",
-       "\n",
-       "    function on_load() {\n",
-       "      root._bokeh_is_loading--;\n",
-       "      if (root._bokeh_is_loading === 0) {\n",
-       "        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n",
-       "        run_callbacks()\n",
-       "      }\n",
-       "    }\n",
-       "\n",
-       "    function on_error(url) {\n",
-       "      console.error(\"failed to load \" + url);\n",
-       "    }\n",
-       "\n",
-       "    for (let i = 0; i < css_urls.length; i++) {\n",
-       "      const url = css_urls[i];\n",
-       "      const element = document.createElement(\"link\");\n",
-       "      element.onload = on_load;\n",
-       "      element.onerror = on_error.bind(null, url);\n",
-       "      element.rel = \"stylesheet\";\n",
-       "      element.type = \"text/css\";\n",
-       "      element.href = url;\n",
-       "      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n",
-       "      document.body.appendChild(element);\n",
-       "    }\n",
-       "\n",
-       "    for (let i = 0; i < js_urls.length; i++) {\n",
-       "      const url = js_urls[i];\n",
-       "      const element = document.createElement('script');\n",
-       "      element.onload = on_load;\n",
-       "      element.onerror = on_error.bind(null, url);\n",
-       "      element.async = false;\n",
-       "      element.src = url;\n",
-       "      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
-       "      document.head.appendChild(element);\n",
-       "    }\n",
-       "  };\n",
-       "\n",
-       "  function inject_raw_css(css) {\n",
-       "    const element = document.createElement(\"style\");\n",
-       "    element.appendChild(document.createTextNode(css));\n",
-       "    document.body.appendChild(element);\n",
-       "  }\n",
-       "\n",
-       "  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n",
-       "  const css_urls = [];\n",
-       "\n",
-       "  const inline_js = [    function(Bokeh) {\n",
-       "      Bokeh.set_log_level(\"info\");\n",
-       "    },\n",
-       "function(Bokeh) {\n",
-       "    }\n",
-       "  ];\n",
-       "\n",
-       "  function run_inline_js() {\n",
-       "    if (root.Bokeh !== undefined || force === true) {\n",
-       "          for (let i = 0; i < inline_js.length; i++) {\n",
-       "      inline_js[i].call(root, root.Bokeh);\n",
-       "    }\n",
-       "if (force === true) {\n",
-       "        display_loaded();\n",
-       "      }} else if (Date.now() < root._bokeh_timeout) {\n",
-       "      setTimeout(run_inline_js, 100);\n",
-       "    } else if (!root._bokeh_failed_load) {\n",
-       "      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
-       "      root._bokeh_failed_load = true;\n",
-       "    } else if (force !== true) {\n",
-       "      const cell = $(document.getElementById(\"1003\")).parents('.cell').data().cell;\n",
-       "      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
-       "    }\n",
-       "  }\n",
-       "\n",
-       "  if (root._bokeh_is_loading === 0) {\n",
-       "    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
-       "    run_inline_js();\n",
-       "  } else {\n",
-       "    load_libs(css_urls, js_urls, function() {\n",
-       "      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n",
-       "      run_inline_js();\n",
-       "    });\n",
-       "  }\n",
-       "}(window));"
-      ],
-      "application/vnd.bokehjs_load.v0+json": "(function(root) {\n  function now() {\n    return new Date();\n  }\n\n  const force = true;\n\n  if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n    root._bokeh_onload_callbacks = [];\n    root._bokeh_is_loading = undefined;\n  }\n\n\n  if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n    root._bokeh_timeout = Date.now() + 5000;\n    root._bokeh_failed_load = false;\n  }\n\n  const NB_LOAD_WARNING = {'data': {'text/html':\n     \"<div style='background-color: #fdd'>\\n\"+\n     \"<p>\\n\"+\n     \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n     \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n     \"</p>\\n\"+\n     \"<ul>\\n\"+\n     \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n     \"<li>use INLINE resources instead, as so:</li>\\n\"+\n     \"</ul>\\n\"+\n     \"<code>\\n\"+\n     \"from bokeh.resources import INLINE\\n\"+\n     \"output_notebook(resources=INLINE)\\n\"+\n     \"</code>\\n\"+\n     \"</div>\"}};\n\n  function display_loaded() {\n    const el = document.getElementById(\"1003\");\n    if (el != null) {\n      el.textContent = \"BokehJS is loading...\";\n    }\n    if (root.Bokeh !== undefined) {\n      if (el != null) {\n        el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n      }\n    } else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(display_loaded, 100)\n    }\n  }\n\n  function run_callbacks() {\n    try {\n      root._bokeh_onload_callbacks.forEach(function(callback) {\n        if (callback != null)\n          callback();\n      });\n    } finally {\n      delete root._bokeh_onload_callbacks\n    }\n    console.debug(\"Bokeh: all callbacks have finished\");\n  }\n\n  function load_libs(css_urls, js_urls, callback) {\n    if (css_urls == null) css_urls = [];\n    if (js_urls == null) js_urls = [];\n\n    root._bokeh_onload_callbacks.push(callback);\n    if (root._bokeh_is_loading > 0) {\n      console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n      return null;\n    }\n    if (js_urls == null || js_urls.length === 0) {\n      run_callbacks();\n      return null;\n    }\n    console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n    root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n    function on_load() {\n      root._bokeh_is_loading--;\n      if (root._bokeh_is_loading === 0) {\n        console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n        run_callbacks()\n      }\n    }\n\n    function on_error(url) {\n      console.error(\"failed to load \" + url);\n    }\n\n    for (let i = 0; i < css_urls.length; i++) {\n      const url = css_urls[i];\n      const element = document.createElement(\"link\");\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.rel = \"stylesheet\";\n      element.type = \"text/css\";\n      element.href = url;\n      console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n      document.body.appendChild(element);\n    }\n\n    for (let i = 0; i < js_urls.length; i++) {\n      const url = js_urls[i];\n      const element = document.createElement('script');\n      element.onload = on_load;\n      element.onerror = on_error.bind(null, url);\n      element.async = false;\n      element.src = url;\n      console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n      document.head.appendChild(element);\n    }\n  };\n\n  function inject_raw_css(css) {\n    const element = document.createElement(\"style\");\n    element.appendChild(document.createTextNode(css));\n    document.body.appendChild(element);\n  }\n\n  const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n  const css_urls = [];\n\n  const inline_js = [    function(Bokeh) {\n      Bokeh.set_log_level(\"info\");\n    },\nfunction(Bokeh) {\n    }\n  ];\n\n  function run_inline_js() {\n    if (root.Bokeh !== undefined || force === true) {\n          for (let i = 0; i < inline_js.length; i++) {\n      inline_js[i].call(root, root.Bokeh);\n    }\nif (force === true) {\n        display_loaded();\n      }} else if (Date.now() < root._bokeh_timeout) {\n      setTimeout(run_inline_js, 100);\n    } else if (!root._bokeh_failed_load) {\n      console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n      root._bokeh_failed_load = true;\n    } else if (force !== true) {\n      const cell = $(document.getElementById(\"1003\")).parents('.cell').data().cell;\n      cell.output_area.append_execute_result(NB_LOAD_WARNING)\n    }\n  }\n\n  if (root._bokeh_is_loading === 0) {\n    console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n    run_inline_js();\n  } else {\n    load_libs(css_urls, js_urls, function() {\n      console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n      run_inline_js();\n    });\n  }\n}(window));"
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "%load_ext autoreload\n",
     "%autoreload 2\n",
@@ -675,6 +78,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -684,6 +88,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -693,7 +98,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -703,6 +108,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -711,7 +117,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -738,7 +144,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -751,6 +157,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -773,6 +180,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -789,6 +197,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -819,6 +228,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -826,6 +236,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -834,6 +245,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -862,6 +274,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -907,6 +320,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -915,6 +329,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -923,18 +338,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "embedding_metric_path='artifacts/metric_learning/focal-loss-pid-fixed/version_0/metrics.csv'\n",
-      "embedding_artifact_path='artifacts/metric_learning/focal-loss-pid-fixed/version_0/checkpoints/epoch=16-step=170000.ckpt'\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "from Embedding.layerless_embedding import LayerlessEmbedding\n",
     "\n",
@@ -950,6 +356,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1011,6 +418,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1018,6 +426,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1073,6 +482,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1101,6 +511,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1138,6 +549,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1166,6 +578,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1189,6 +602,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1196,6 +610,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1206,32 +621,801 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
+   "outputs": [],
+   "source": [
+    "if run_training:\n",
+    "    send_telegram_message('Started GNN training.', chat_id, api_key)\n",
+    "    with warnings.catch_warnings():\n",
+    "        warnings.filterwarnings(\n",
+    "            \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n",
+    "        )\n",
+    "        gnn_trainer, gnn_model = train_gnn(CONFIG)\n",
+    "\n",
+    "    send_telegram_message('Finished GNN training.', chat_id, api_key)\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### From checkpoint"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from utils.modelutils.checkpoint_utils import (\n",
+    "    get_last_version_dir_from_config,\n",
+    "    get_last_artifact,\n",
+    ")\n",
+    "from GNN.models.interaction_gnn import InteractionGNN\n",
+    "\n",
+    "gnn_version_dir = get_last_version_dir_from_config(step=\"gnn\", path_or_config=CONFIG)\n",
+    "gnn_metric_path = os.path.join(gnn_version_dir, \"metrics.csv\")\n",
+    "gnn_artifact_path = get_last_artifact(version_dir=gnn_version_dir)\n",
+    "print(f\"{gnn_metric_path=}\")\n",
+    "print(f\"{gnn_artifact_path=}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from pytorch_lightning import Trainer\n",
+    "from pytorch_lightning.loggers import CSVLogger\n",
+    "\n",
+    "\n",
+    "def continue_gnn_training(\n",
+    "    path_or_config: str | dict,\n",
+    ") -> typing.Tuple[Trainer, InteractionGNN]:\n",
+    "    config = load_config(path_or_config=path_or_config)\n",
+    "\n",
+    "    gnn_model = InteractionGNN.load_from_checkpoint(\n",
+    "        gnn_artifact_path, hparams=config[\"gnn\"]\n",
+    "    )  # you may change `gnn_model`\n",
+    "\n",
+    "    save_directory = os.path.abspath(\n",
+    "        os.path.join(config[\"common\"][\"artifact_directory\"], \"gnn\")\n",
+    "    )\n",
+    "\n",
+    "    logger = CSVLogger(save_directory, name=config[\"common\"][\"experiment_name\"])\n",
+    "\n",
+    "    gnn_trainer = Trainer(\n",
+    "        accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
+    "        devices=1,\n",
+    "        max_epochs=150,  # you may increase the number of epochs\n",
+    "        logger=logger,\n",
+    "        # callbacks=[EarlyStopping(monitor=\"val_loss\", mode=\"min\")]\n",
+    "    )\n",
+    "\n",
+    "    with warnings.catch_warnings():\n",
+    "        warnings.filterwarnings(\n",
+    "            \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n",
+    "        )\n",
+    "        gnn_trainer.fit(gnn_model, ckpt_path=gnn_artifact_path)\n",
+    "    return gnn_trainer, gnn_model\n",
+    "\n",
+    "# gnn_trainer, gnn_model = continue_gnn_training(CONFIG)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "gnn_model = InteractionGNN.load_from_checkpoint(\n",
+    "    gnn_artifact_path,\n",
+    "    # map_location=\"cpu\",\n",
+    "    # hparams=load_config(CONFIG)[\"gnn\"],\n",
+    ")\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Plot training metrics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
    "outputs": [
     {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "INFO:-------------------------  Step 3: Running GNN training  -------------------------\n",
-      "INFO:----------------------------- a) Initialising model -----------------------------\n",
-      "INFO:------------------------------ b) Running training ------------------------------\n",
-      "Missing logger folder: /home/acorreia/Documents/tracking/etx4velo/LHCb_Pipeline/artifacts/gnn/focal-loss-pid-fixed\n",
-      "INFO:Save hyperparameters, metrics and artifacts in /home/acorreia/Documents/tracking/etx4velo/LHCb_Pipeline/artifacts/gnn/focal-loss-pid-fixed/version_0\n",
-      "GPU available: True (cuda), used: True\n",
-      "TPU available: False, using: 0 TPU cores\n",
-      "IPU available: False, using: 0 IPUs\n",
-      "HPU available: False, using: 0 HPUs\n",
-      "INFO:Load 10000 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/train\n"
-     ]
-    },
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>epoch</th>\n",
+       "      <th>train_loss</th>\n",
+       "      <th>val_loss</th>\n",
+       "      <th>eff</th>\n",
+       "      <th>pur</th>\n",
+       "      <th>current_lr</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>0</td>\n",
+       "      <td>0.015745</td>\n",
+       "      <td>0.009409</td>\n",
+       "      <td>0.882107</td>\n",
+       "      <td>0.979640</td>\n",
+       "      <td>0.000020</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>1</td>\n",
+       "      <td>0.007946</td>\n",
+       "      <td>0.006557</td>\n",
+       "      <td>0.922102</td>\n",
+       "      <td>0.986181</td>\n",
+       "      <td>0.000040</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>2</td>\n",
+       "      <td>0.006280</td>\n",
+       "      <td>0.005664</td>\n",
+       "      <td>0.927634</td>\n",
+       "      <td>0.990096</td>\n",
+       "      <td>0.000060</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>3</td>\n",
+       "      <td>0.005555</td>\n",
+       "      <td>0.005122</td>\n",
+       "      <td>0.937812</td>\n",
+       "      <td>0.990240</td>\n",
+       "      <td>0.000080</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>4</td>\n",
+       "      <td>0.005132</td>\n",
+       "      <td>0.005172</td>\n",
+       "      <td>0.932919</td>\n",
+       "      <td>0.991395</td>\n",
+       "      <td>0.000100</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>5</th>\n",
+       "      <td>5</td>\n",
+       "      <td>0.004851</td>\n",
+       "      <td>0.005073</td>\n",
+       "      <td>0.937282</td>\n",
+       "      <td>0.990720</td>\n",
+       "      <td>0.000120</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>6</th>\n",
+       "      <td>6</td>\n",
+       "      <td>0.004646</td>\n",
+       "      <td>0.005248</td>\n",
+       "      <td>0.930109</td>\n",
+       "      <td>0.991742</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>7</th>\n",
+       "      <td>7</td>\n",
+       "      <td>0.004510</td>\n",
+       "      <td>0.004563</td>\n",
+       "      <td>0.940789</td>\n",
+       "      <td>0.992526</td>\n",
+       "      <td>0.000112</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>8</th>\n",
+       "      <td>8</td>\n",
+       "      <td>0.004373</td>\n",
+       "      <td>0.004509</td>\n",
+       "      <td>0.946273</td>\n",
+       "      <td>0.991391</td>\n",
+       "      <td>0.000180</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>9</th>\n",
+       "      <td>9</td>\n",
+       "      <td>0.004298</td>\n",
+       "      <td>0.004374</td>\n",
+       "      <td>0.947959</td>\n",
+       "      <td>0.991694</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>10</th>\n",
+       "      <td>10</td>\n",
+       "      <td>0.004109</td>\n",
+       "      <td>0.004050</td>\n",
+       "      <td>0.953197</td>\n",
+       "      <td>0.992015</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>11</th>\n",
+       "      <td>11</td>\n",
+       "      <td>0.003983</td>\n",
+       "      <td>0.004233</td>\n",
+       "      <td>0.954432</td>\n",
+       "      <td>0.990489</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>12</th>\n",
+       "      <td>12</td>\n",
+       "      <td>0.003879</td>\n",
+       "      <td>0.003788</td>\n",
+       "      <td>0.958103</td>\n",
+       "      <td>0.991964</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>13</th>\n",
+       "      <td>13</td>\n",
+       "      <td>0.003778</td>\n",
+       "      <td>0.003780</td>\n",
+       "      <td>0.960245</td>\n",
+       "      <td>0.991730</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>14</th>\n",
+       "      <td>14</td>\n",
+       "      <td>0.003703</td>\n",
+       "      <td>0.003744</td>\n",
+       "      <td>0.959127</td>\n",
+       "      <td>0.992023</td>\n",
+       "      <td>0.000200</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>15</th>\n",
+       "      <td>15</td>\n",
+       "      <td>0.003641</td>\n",
+       "      <td>0.003704</td>\n",
+       "      <td>0.959411</td>\n",
+       "      <td>0.992099</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>16</th>\n",
+       "      <td>16</td>\n",
+       "      <td>0.003305</td>\n",
+       "      <td>0.003534</td>\n",
+       "      <td>0.961065</td>\n",
+       "      <td>0.992371</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>17</th>\n",
+       "      <td>17</td>\n",
+       "      <td>0.003253</td>\n",
+       "      <td>0.003618</td>\n",
+       "      <td>0.963001</td>\n",
+       "      <td>0.991412</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>18</th>\n",
+       "      <td>18</td>\n",
+       "      <td>0.003225</td>\n",
+       "      <td>0.003503</td>\n",
+       "      <td>0.963811</td>\n",
+       "      <td>0.991958</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>19</th>\n",
+       "      <td>19</td>\n",
+       "      <td>0.003224</td>\n",
+       "      <td>0.003467</td>\n",
+       "      <td>0.964852</td>\n",
+       "      <td>0.991678</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>20</th>\n",
+       "      <td>20</td>\n",
+       "      <td>0.003200</td>\n",
+       "      <td>0.003380</td>\n",
+       "      <td>0.964631</td>\n",
+       "      <td>0.992555</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>21</th>\n",
+       "      <td>21</td>\n",
+       "      <td>0.003180</td>\n",
+       "      <td>0.003469</td>\n",
+       "      <td>0.965098</td>\n",
+       "      <td>0.992115</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>22</th>\n",
+       "      <td>22</td>\n",
+       "      <td>0.003160</td>\n",
+       "      <td>0.003374</td>\n",
+       "      <td>0.964880</td>\n",
+       "      <td>0.992387</td>\n",
+       "      <td>0.000140</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>23</th>\n",
+       "      <td>23</td>\n",
+       "      <td>0.003152</td>\n",
+       "      <td>0.003522</td>\n",
+       "      <td>0.965240</td>\n",
+       "      <td>0.991543</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>24</th>\n",
+       "      <td>24</td>\n",
+       "      <td>0.002912</td>\n",
+       "      <td>0.003308</td>\n",
+       "      <td>0.968241</td>\n",
+       "      <td>0.992033</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>25</th>\n",
+       "      <td>25</td>\n",
+       "      <td>0.002879</td>\n",
+       "      <td>0.003358</td>\n",
+       "      <td>0.968242</td>\n",
+       "      <td>0.991781</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>26</th>\n",
+       "      <td>26</td>\n",
+       "      <td>0.002879</td>\n",
+       "      <td>0.003365</td>\n",
+       "      <td>0.968507</td>\n",
+       "      <td>0.991775</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>27</th>\n",
+       "      <td>27</td>\n",
+       "      <td>0.002853</td>\n",
+       "      <td>0.003240</td>\n",
+       "      <td>0.969288</td>\n",
+       "      <td>0.992256</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>28</th>\n",
+       "      <td>28</td>\n",
+       "      <td>0.002858</td>\n",
+       "      <td>0.003385</td>\n",
+       "      <td>0.968127</td>\n",
+       "      <td>0.991958</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>29</th>\n",
+       "      <td>29</td>\n",
+       "      <td>0.002849</td>\n",
+       "      <td>0.003389</td>\n",
+       "      <td>0.969099</td>\n",
+       "      <td>0.991858</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>30</th>\n",
+       "      <td>30</td>\n",
+       "      <td>0.002855</td>\n",
+       "      <td>0.003413</td>\n",
+       "      <td>0.966039</td>\n",
+       "      <td>0.992158</td>\n",
+       "      <td>0.000098</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>31</th>\n",
+       "      <td>31</td>\n",
+       "      <td>0.002830</td>\n",
+       "      <td>0.003342</td>\n",
+       "      <td>0.969118</td>\n",
+       "      <td>0.991828</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>32</th>\n",
+       "      <td>32</td>\n",
+       "      <td>0.002659</td>\n",
+       "      <td>0.003288</td>\n",
+       "      <td>0.969825</td>\n",
+       "      <td>0.992203</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>33</th>\n",
+       "      <td>33</td>\n",
+       "      <td>0.002611</td>\n",
+       "      <td>0.002948</td>\n",
+       "      <td>0.972177</td>\n",
+       "      <td>0.993180</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>34</th>\n",
+       "      <td>34</td>\n",
+       "      <td>0.002425</td>\n",
+       "      <td>0.002649</td>\n",
+       "      <td>0.974883</td>\n",
+       "      <td>0.994110</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>35</th>\n",
+       "      <td>35</td>\n",
+       "      <td>0.002253</td>\n",
+       "      <td>0.002432</td>\n",
+       "      <td>0.977438</td>\n",
+       "      <td>0.994554</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>36</th>\n",
+       "      <td>36</td>\n",
+       "      <td>0.002139</td>\n",
+       "      <td>0.002351</td>\n",
+       "      <td>0.977666</td>\n",
+       "      <td>0.994845</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>37</th>\n",
+       "      <td>37</td>\n",
+       "      <td>0.002048</td>\n",
+       "      <td>0.002227</td>\n",
+       "      <td>0.978469</td>\n",
+       "      <td>0.995250</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>38</th>\n",
+       "      <td>38</td>\n",
+       "      <td>0.001969</td>\n",
+       "      <td>0.002170</td>\n",
+       "      <td>0.979098</td>\n",
+       "      <td>0.995544</td>\n",
+       "      <td>0.000069</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>39</th>\n",
+       "      <td>39</td>\n",
+       "      <td>0.001897</td>\n",
+       "      <td>0.001969</td>\n",
+       "      <td>0.981175</td>\n",
+       "      <td>0.995911</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>40</th>\n",
+       "      <td>40</td>\n",
+       "      <td>0.001624</td>\n",
+       "      <td>0.001698</td>\n",
+       "      <td>0.982833</td>\n",
+       "      <td>0.996831</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>41</th>\n",
+       "      <td>41</td>\n",
+       "      <td>0.001523</td>\n",
+       "      <td>0.001631</td>\n",
+       "      <td>0.983602</td>\n",
+       "      <td>0.996945</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>42</th>\n",
+       "      <td>42</td>\n",
+       "      <td>0.001450</td>\n",
+       "      <td>0.001593</td>\n",
+       "      <td>0.984215</td>\n",
+       "      <td>0.997049</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>43</th>\n",
+       "      <td>43</td>\n",
+       "      <td>0.001411</td>\n",
+       "      <td>0.001537</td>\n",
+       "      <td>0.984988</td>\n",
+       "      <td>0.997071</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>44</th>\n",
+       "      <td>44</td>\n",
+       "      <td>0.001374</td>\n",
+       "      <td>0.001585</td>\n",
+       "      <td>0.984472</td>\n",
+       "      <td>0.997054</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>45</th>\n",
+       "      <td>45</td>\n",
+       "      <td>0.001344</td>\n",
+       "      <td>0.001564</td>\n",
+       "      <td>0.984906</td>\n",
+       "      <td>0.997020</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>46</th>\n",
+       "      <td>46</td>\n",
+       "      <td>0.001334</td>\n",
+       "      <td>0.001556</td>\n",
+       "      <td>0.984871</td>\n",
+       "      <td>0.996946</td>\n",
+       "      <td>0.000048</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>47</th>\n",
+       "      <td>47</td>\n",
+       "      <td>0.001305</td>\n",
+       "      <td>0.001551</td>\n",
+       "      <td>0.984850</td>\n",
+       "      <td>0.997180</td>\n",
+       "      <td>0.000034</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>48</th>\n",
+       "      <td>48</td>\n",
+       "      <td>0.001215</td>\n",
+       "      <td>0.001518</td>\n",
+       "      <td>0.985659</td>\n",
+       "      <td>0.997226</td>\n",
+       "      <td>0.000034</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>49</th>\n",
+       "      <td>49</td>\n",
+       "      <td>0.001197</td>\n",
+       "      <td>0.001530</td>\n",
+       "      <td>0.985763</td>\n",
+       "      <td>0.997204</td>\n",
+       "      <td>0.000034</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "    epoch  train_loss  val_loss       eff       pur  current_lr\n",
+       "0       0    0.015745  0.009409  0.882107  0.979640    0.000020\n",
+       "1       1    0.007946  0.006557  0.922102  0.986181    0.000040\n",
+       "2       2    0.006280  0.005664  0.927634  0.990096    0.000060\n",
+       "3       3    0.005555  0.005122  0.937812  0.990240    0.000080\n",
+       "4       4    0.005132  0.005172  0.932919  0.991395    0.000100\n",
+       "5       5    0.004851  0.005073  0.937282  0.990720    0.000120\n",
+       "6       6    0.004646  0.005248  0.930109  0.991742    0.000140\n",
+       "7       7    0.004510  0.004563  0.940789  0.992526    0.000112\n",
+       "8       8    0.004373  0.004509  0.946273  0.991391    0.000180\n",
+       "9       9    0.004298  0.004374  0.947959  0.991694    0.000200\n",
+       "10     10    0.004109  0.004050  0.953197  0.992015    0.000200\n",
+       "11     11    0.003983  0.004233  0.954432  0.990489    0.000200\n",
+       "12     12    0.003879  0.003788  0.958103  0.991964    0.000200\n",
+       "13     13    0.003778  0.003780  0.960245  0.991730    0.000200\n",
+       "14     14    0.003703  0.003744  0.959127  0.992023    0.000200\n",
+       "15     15    0.003641  0.003704  0.959411  0.992099    0.000140\n",
+       "16     16    0.003305  0.003534  0.961065  0.992371    0.000140\n",
+       "17     17    0.003253  0.003618  0.963001  0.991412    0.000140\n",
+       "18     18    0.003225  0.003503  0.963811  0.991958    0.000140\n",
+       "19     19    0.003224  0.003467  0.964852  0.991678    0.000140\n",
+       "20     20    0.003200  0.003380  0.964631  0.992555    0.000140\n",
+       "21     21    0.003180  0.003469  0.965098  0.992115    0.000140\n",
+       "22     22    0.003160  0.003374  0.964880  0.992387    0.000140\n",
+       "23     23    0.003152  0.003522  0.965240  0.991543    0.000098\n",
+       "24     24    0.002912  0.003308  0.968241  0.992033    0.000098\n",
+       "25     25    0.002879  0.003358  0.968242  0.991781    0.000098\n",
+       "26     26    0.002879  0.003365  0.968507  0.991775    0.000098\n",
+       "27     27    0.002853  0.003240  0.969288  0.992256    0.000098\n",
+       "28     28    0.002858  0.003385  0.968127  0.991958    0.000098\n",
+       "29     29    0.002849  0.003389  0.969099  0.991858    0.000098\n",
+       "30     30    0.002855  0.003413  0.966039  0.992158    0.000098\n",
+       "31     31    0.002830  0.003342  0.969118  0.991828    0.000069\n",
+       "32     32    0.002659  0.003288  0.969825  0.992203    0.000069\n",
+       "33     33    0.002611  0.002948  0.972177  0.993180    0.000069\n",
+       "34     34    0.002425  0.002649  0.974883  0.994110    0.000069\n",
+       "35     35    0.002253  0.002432  0.977438  0.994554    0.000069\n",
+       "36     36    0.002139  0.002351  0.977666  0.994845    0.000069\n",
+       "37     37    0.002048  0.002227  0.978469  0.995250    0.000069\n",
+       "38     38    0.001969  0.002170  0.979098  0.995544    0.000069\n",
+       "39     39    0.001897  0.001969  0.981175  0.995911    0.000048\n",
+       "40     40    0.001624  0.001698  0.982833  0.996831    0.000048\n",
+       "41     41    0.001523  0.001631  0.983602  0.996945    0.000048\n",
+       "42     42    0.001450  0.001593  0.984215  0.997049    0.000048\n",
+       "43     43    0.001411  0.001537  0.984988  0.997071    0.000048\n",
+       "44     44    0.001374  0.001585  0.984472  0.997054    0.000048\n",
+       "45     45    0.001344  0.001564  0.984906  0.997020    0.000048\n",
+       "46     46    0.001334  0.001556  0.984871  0.996946    0.000048\n",
+       "47     47    0.001305  0.001551  0.984850  0.997180    0.000034\n",
+       "48     48    0.001215  0.001518  0.985659  0.997226    0.000034\n",
+       "49     49    0.001197  0.001530  0.985763  0.997204    0.000034"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# gnn_metrics = checkpoint_utils.get_training_metrics(gnn_trainer)\n",
+    "\n",
+    "gnn_metrics = checkpoint_utils.get_training_metrics(gnn_metric_path)\n",
+    "\n",
+    "gnn_metrics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Figure was saved in output/focal-loss-pid-fixed/loss_gnn.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/loss_gnn.png\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(<Figure size 800x600 with 1 Axes>, <Axes: xlabel='Epoch', ylabel='Loss'>)"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 800x600 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "perfplot_mpl.plot_loss(gnn_metrics, CONFIG, \"gnn\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Figure was saved in output/focal-loss-pid-fixed/eff_gnn.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/eff_gnn.png\n",
+      "Figure was saved in output/focal-loss-pid-fixed/pur_gnn.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/pur_gnn.png\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(<Figure size 800x600 with 1 Axes>,\n",
+       " <Axes: xlabel='Epoch', ylabel='Edge Purity'>)"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 800x600 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 800x600 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "perfplot_mpl.plot_metric_epochs(\n",
+    "    \"eff\",\n",
+    "    gnn_metrics,\n",
+    "    CONFIG,\n",
+    "    \"gnn\",\n",
+    "    \"Edge Efficiency\",\n",
+    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
+    ")\n",
+    "perfplot_mpl.plot_metric_epochs(\n",
+    "    \"pur\",\n",
+    "    gnn_metrics,\n",
+    "    CONFIG,\n",
+    "    \"gnn\",\n",
+    "    \"Edge Purity\",\n",
+    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
+    ")\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaluate model performance on sample test data\n",
+    "\n",
+    "Here we evaluate the model performace on one sample test data. We look at how the efficiency and purity change with the embedding radius."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load 1000 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "bbbe40dfdf7143e085e7a813d5787d6d",
+       "model_id": "0b5c1718765a4379b0a98e86f0baacde",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "  0%|          | 0/10000 [00:00<?, ?it/s]"
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
@@ -1241,18 +1425,104 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "INFO:Load 500 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/val\n"
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n",
+      "WARNING:Unable to obtain driver using Selenium Manager: /scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/selenium/webdriver/common/linux/selenium-manager is missing.  Please open an issue on https://github.com/SeleniumHQ/selenium/issues\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<IPython.core.display.Image object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "gnn_model.load_partition(\"velo-sim10b-nospillover\")\n",
+    "perfplot.plot_edge_performance(gnn_model, CONFIG)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load 200 files located in /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover\n"
      ]
     },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "9edd7c25454547c6a782b3302caa67ad",
+       "model_id": "91ddae17207d4b9a9e133df76e7887f6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c737875249e04c03b5c38e3ec9955499",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "ac377279aafc4b5180f4639ca6d039e9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "f85018c85c654d548916880638dee82e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/16 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "27babb13915a4937bb5ecb72296ee5eb",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "  0%|          | 0/500 [00:00<?, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
@@ -1262,314 +1532,391 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
-      "\n",
-      "  | Name                   | Type       | Params\n",
-      "------------------------------------------------------\n",
-      "0 | node_encoder           | Sequential | 332 K \n",
-      "1 | edge_encoder           | Sequential | 462 K \n",
-      "2 | edge_network           | Sequential | 793 K \n",
-      "3 | node_network           | Sequential | 659 K \n",
-      "4 | output_edge_classifier | Sequential | 529 K \n",
-      "------------------------------------------------------\n",
-      "2.8 M     Trainable params\n",
-      "0         Non-trainable params\n",
-      "2.8 M     Total params\n",
-      "11.111    Total estimated model params size (MB)\n"
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
      ]
     },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
+       "model_id": "7f6a4f1b645b4747b13c741e234ed7bc",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Sanity Checking: 0it [00:00, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "5ee6357167424736846d37949bf55de0",
+       "model_id": "fe292d18a81c464a97ed34ef443e2fa6",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Training: 0it [00:00, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
+       "model_id": "4135aaeedd4c4ce3ab377fab149dba5a",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Validation: 0it [00:00, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
+       "model_id": "8f2216787d9d4aeaa02b8b9bdfefaa40",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Validation: 0it [00:00, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
+       "model_id": "6992b033ce614752878da7593a75cfff",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Validation: 0it [00:00, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
+       "model_id": "ed9f3c809e7c4e51b44a358ec82b78b7",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Validation: 0it [00:00, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
+       "model_id": "896ef3db87ee47f2a44d9737076232c2",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Validation: 0it [00:00, ?it/s]"
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "if run_training:\n",
-    "    send_telegram_message('Started GNN training.', chat_id, api_key)\n",
-    "    with warnings.catch_warnings():\n",
-    "        warnings.filterwarnings(\n",
-    "            \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n",
-    "        )\n",
-    "        gnn_trainer, gnn_model = train_gnn(CONFIG)\n",
-    "\n",
-    "    send_telegram_message('Finished GNN training.', chat_id, api_key)\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "#### From checkpoint"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from utils.modelutils.checkpoint_utils import (\n",
-    "    get_last_version_dir_from_config,\n",
-    "    get_last_artifact,\n",
-    ")\n",
-    "from GNN.interaction_gnn import InteractionGNN\n",
-    "\n",
-    "gnn_version_dir = get_last_version_dir_from_config(step=\"gnn\", path_or_config=CONFIG)\n",
-    "gnn_metric_path = os.path.join(gnn_version_dir, \"metrics.csv\")\n",
-    "gnn_artifact_path = get_last_artifact(version_dir=gnn_version_dir)\n",
-    "print(f\"{gnn_metric_path=}\")\n",
-    "print(f\"{gnn_artifact_path=}\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from pytorch_lightning import Trainer\n",
-    "from pytorch_lightning.loggers import CSVLogger\n",
-    "\n",
-    "\n",
-    "def continue_gnn_training(\n",
-    "    path_or_config: str | dict,\n",
-    ") -> typing.Tuple[Trainer, InteractionGNN]:\n",
-    "    config = load_config(path_or_config=path_or_config)\n",
-    "\n",
-    "    gnn_model = InteractionGNN.load_from_checkpoint(\n",
-    "        gnn_artifact_path, **config[\"gnn\"]\n",
-    "    )  # you may change `gnn_model`\n",
-    "\n",
-    "    save_directory = os.path.abspath(\n",
-    "        os.path.join(config[\"common\"][\"artifact_directory\"], \"gnn\")\n",
-    "    )\n",
-    "\n",
-    "    logger = CSVLogger(save_directory, name=config[\"common\"][\"experiment_name\"])\n",
-    "\n",
-    "    gnn_trainer = Trainer(\n",
-    "        accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n",
-    "        devices=1,\n",
-    "        max_epochs=50,  # you may increase the number of epochs\n",
-    "        logger=logger,\n",
-    "        # callbacks=[EarlyStopping(monitor=\"val_loss\", mode=\"min\")]\n",
-    "    )\n",
-    "\n",
-    "    with warnings.catch_warnings():\n",
-    "        warnings.filterwarnings(\n",
-    "            \"ignore\", message=\"None of the inputs have requires_grad=True.\"\n",
-    "        )\n",
-    "        gnn_trainer.fit(gnn_model, ckpt_path=gnn_artifact_path)\n",
-    "    return gnn_trainer, gnn_model\n",
-    "\n",
-    "# gnn_trainer, gnn_model = continue_gnn_training(CONFIG)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "gnn_model = InteractionGNN.load_from_checkpoint(\n",
-    "    gnn_artifact_path,\n",
-    "    # map_location=\"cpu\",\n",
-    "    # hparams=load_config(CONFIG)[\"gnn\"],\n",
-    ")\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Plot training metrics"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# gnn_metrics = checkpoint_utils.get_training_metrics(gnn_trainer)\n",
-    "\n",
-    "gnn_metrics = checkpoint_utils.get_training_metrics(gnn_metric_path)\n",
-    "\n",
-    "gnn_metrics"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "perfplot_mpl.plot_loss(gnn_metrics, CONFIG, \"gnn\")\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "perfplot_mpl.plot_metric_epochs(\n",
-    "    \"eff\",\n",
-    "    gnn_metrics,\n",
-    "    CONFIG,\n",
-    "    \"gnn\",\n",
-    "    \"Edge Efficiency\",\n",
-    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
-    ")\n",
-    "perfplot_mpl.plot_metric_epochs(\n",
-    "    \"pur\",\n",
-    "    gnn_metrics,\n",
-    "    CONFIG,\n",
-    "    \"gnn\",\n",
-    "    \"Edge Purity\",\n",
-    "    color=perfplot_mpl.partition_to_color[\"val\"],\n",
-    ")\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Evaluate model performance on sample test data\n",
-    "\n",
-    "Here we evaluate the model performace on one sample test data. We look at how the efficiency and purity change with the embedding radius."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "gnn_model.load_partition(\"velo-sim10b-nospillover\")\n",
-    "perfplot.plot_edge_performance(gnn_model, CONFIG)\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "performances_for_various_score_cuts"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "4b6e0862e10a4d3895142114ea69d609",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2c942e6fc1354e4e8c29448b4849bcb9",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "81517adcee3f45f2aabf4180ccf00a19",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "5f6e952c17a0420aac2e2e011bdfb016",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7aef953e9dd34c40b680254ae973cace",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9bd250618ba44ce5a74a9776fb8580b2",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "47d644dc6dbf4c9cbcb40c56ada0ed92",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "b4989466e64246258d9c54882fb24ba5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "GNN inference:   0%|          | 0/200 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Figure was saved in output/focal-loss-pid-fixed/performance_given_score_cut.pdf\n",
+      "Figure was saved in output/focal-loss-pid-fixed/performance_given_score_cut.png\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 2400x600 with 3 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "from GNN.gnn_plots import plot_best_performances_score_cut\n",
+    "\n",
     "_, _, performances_for_various_score_cuts = plot_best_performances_score_cut(\n",
     "    model=gnn_model,\n",
     "    path_or_config=CONFIG,\n",
     "    partition=\"velo-sim10b-nospillover\",\n",
-    "    score_cuts=[0.4, 0.5, 0.6, 0.65, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.85, 0.9, 0.95],\n",
+    "    score_cuts=[\n",
+    "        0.1,\n",
+    "        0.2,\n",
+    "        0.3,\n",
+    "        0.4,\n",
+    "        0.42,\n",
+    "        0.43,\n",
+    "        0.44,\n",
+    "        0.45,\n",
+    "        0.46,\n",
+    "        0.47,\n",
+    "        0.48,\n",
+    "        0.5,\n",
+    "        0.6,\n",
+    "        0.7,\n",
+    "        0.8,\n",
+    "        0.9,\n",
+    "    ],\n",
     "    n_events=200,\n",
     "    seed=0,\n",
-    ")\n"
+    ")"
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1578,14 +1925,67 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 16,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:--------------------- Step 4: Scoring graph edges using GNN  ---------------------\n",
+      "INFO:---------------------------- a) Loading trained model ----------------------------\n",
+      "INFO:Load model from artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints/epoch=49-step=500000.ckpt.\n",
+      "INFO:Load model from artifacts/gnn/focal-loss-pid-fixed/version_0/checkpoints/epoch=49-step=500000.ckpt.\n",
+      "INFO:----------------------------- b) Running inferencing -----------------------------\n",
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover to /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a6cb9afd5cb9438e8880192981766da0",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/scratch/acorreia/mambaforge/envs/etx4velo_updated/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
+      "  warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n",
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover-only-long-electrons`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/metric_learning_processed/test/velo-sim10b-nospillover-only-long-electrons to /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover-only-long-electrons\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a1af6c66f0574263ba085bcdea898a7a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
-    "run_gnn_inference(CONFIG, checkpoint=gnn_model)"
+    "run_gnn_inference(CONFIG, partitions=[\"test\"], checkpoint=gnn_artifact_path)"
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1594,14 +1994,62 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 17,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:-----------  Step 5: Building track candidates from the scored graph  -----------\n",
+      "INFO:Score cut: 0.45\n",
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover to /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "15ad24e1d6f1444b9d3242e715a0dc9e",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Remove directory `/scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover-only-long-electrons`.\n",
+      "INFO:Inference from /scratch/acorreia/data/focal-loss-pid-fixed/gnn_processed/test/velo-sim10b-nospillover-only-long-electrons to /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover-only-long-electrons\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e865558d4d084522b9a35a2f76ee8a29",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
-    "build_track_candidates(CONFIG)"
+    "build_track_candidates(CONFIG, partitions=[\"test\"])"
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1640,6 +2088,7 @@
    ]
   },
   {
+   "attachments": {},
    "cell_type": "markdown",
    "metadata": {},
    "source": [
@@ -1648,9 +2097,198 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 18,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:---------------------------- velo-sim10b-nospillover ----------------------------\n",
+      "INFO:--------------------- Evaluation for velo-sim10b-nospillover ---------------------\n",
+      "INFO:1) Load dataframe of tracks, hits-particles and particles\n",
+      "INFO:Load tracks in /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "093d1a9e10864656acccf65296bb9b51",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load pre-processed test datasets in /scratch/acorreia/data/__test__/velo-sim10b-nospillover.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "abc8329945a54b44a5747708f1256c12",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3b22d2cf13c24237bb0c6b187f60bc29",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Compute plat stats\n",
+      "INFO:2) Matching\n",
+      "INFO:3) Evaluation\n",
+      "INFO:Report was saved in output/focal-loss-pid-fixed/report-2023.06.11-12.04.53-velo-sim10b-nospillover.txt\n",
+      "INFO:------------------ velo-sim10b-nospillover-only-long-electrons ------------------\n",
+      "INFO:----------- Evaluation for velo-sim10b-nospillover-only-long-electrons -----------\n",
+      "INFO:1) Load dataframe of tracks, hits-particles and particles\n",
+      "INFO:Load tracks in /scratch/acorreia/data/focal-loss-pid-fixed/track_building_processed/test/velo-sim10b-nospillover-only-long-electrons.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "TrackChecker output                               :      2020/   243242   0.83% ghosts\n",
+      "01_velo                                           :    101318/   104345  97.10% ( 97.24%),       837 (  0.82%) clones, pur  99.86%, hit eff  97.73%\n",
+      "02_long                                           :     58330/    59167  98.59% ( 98.62%),       457 (  0.78%) clones, pur  99.93%, hit eff  98.31%\n",
+      "03_long_P>5GeV                                    :     37865/    38150  99.25% ( 99.26%),       231 (  0.61%) clones, pur  99.93%, hit eff  98.85%\n",
+      "04_long_strange                                   :      2971/     3142  94.56% ( 94.88%),        38 (  1.26%) clones, pur  99.64%, hit eff  95.25%\n",
+      "05_long_strange_P>5GeV                            :      1443/     1521  94.87% ( 94.91%),         7 (  0.48%) clones, pur  99.58%, hit eff  97.68%\n",
+      "06_long_fromB                                     :       120/      120 100.00% (100.00%),         2 (  1.64%) clones, pur  99.80%, hit eff  98.13%\n",
+      "07_long_fromB_P>5GeV                              :        87/       87 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  99.81%\n",
+      "08_long_electrons                                 :      3459/     4198  82.40% ( 82.85%),        74 (  2.09%) clones, pur  98.84%, hit eff  87.95%\n",
+      "09_long_fromB_electrons                           :        10/       10 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  93.33%\n",
+      "10_long_fromB_electrons_P>5GeV                    :         7/        7 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  96.43%\n",
+      "\n",
+      "| Categories           | Efficiency   | Average efficiency   | % clones   | Average hit purity   | Average hit efficiency   |\n",
+      "|:---------------------|:-------------|:---------------------|:-----------|:---------------------|:-------------------------|\n",
+      "| Velo                 | 93.82%       | 94.10%               | 1.06%      | 99.75%               | 96.26%                   |\n",
+      "| Long                 | 97.51%       | 97.58%               | 0.85%      | 99.87%               | 97.72%                   |\n",
+      "| Velo, no electrons   | 97.10%       | 97.24%               | 0.82%      | 99.86%               | 97.73%                   |\n",
+      "| Velo, only electrons | 76.99%       | 77.37%               | 2.57%      | 99.03%               | 86.85%                   |\n",
+      "| Long, only electrons | 82.40%       | 82.85%               | 2.09%      | 98.84%               | 87.95%                   |\n",
+      "| Categories   | # ghosts   | # tracks   | % ghosts   |\n",
+      "|:-------------|:-----------|:-----------|:-----------|\n",
+      "| Everything   | 2,020      | 243,242    | 0.83%      |\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c118aca80cae4b62831251402e288928",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Load pre-processed test datasets in /scratch/acorreia/data/__test__/velo-sim10b-nospillover-only-long-electrons.\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3bffc5f3235146cc89114053b6cc801c",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9eae617c6d194998adb1a1469ab68796",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/1000 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "INFO:Compute plat stats\n",
+      "INFO:2) Matching\n",
+      "INFO:3) Evaluation\n",
+      "INFO:Report was saved in output/focal-loss-pid-fixed/report-2023.06.11-12.05.13-velo-sim10b-nospillover-only-long-electrons.txt\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "TrackChecker output                               :        79/     4432   1.78% ghosts\n",
+      "01_velo                                           :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "02_long                                           :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "03_long_P>5GeV                                    :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "04_long_strange                                   :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "05_long_strange_P>5GeV                            :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "06_long_fromB                                     :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "07_long_fromB_P>5GeV                              :         0/        0    nan% (   nan%),         0 (   nan%) clones, pur    nan%, hit eff    nan%\n",
+      "08_long_electrons                                 :      4286/     4670  91.78% ( 93.18%),        37 (  0.86%) clones, pur  99.29%, hit eff  95.56%\n",
+      "09_long_fromB_electrons                           :        10/       10 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  97.50%\n",
+      "10_long_fromB_electrons_P>5GeV                    :         7/        7 100.00% (100.00%),         0 (  0.00%) clones, pur 100.00%, hit eff  96.43%\n",
+      "\n",
+      "| Categories           | Efficiency   | Average efficiency   | % clones   | Average hit purity   | Average hit efficiency   |\n",
+      "|:---------------------|:-------------|:---------------------|:-----------|:---------------------|:-------------------------|\n",
+      "| Velo                 | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Long                 | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Velo, no electrons   | nan%         | nan%                 | nan%       | nan%                 | nan%                     |\n",
+      "| Velo, only electrons | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Long, only electrons | 91.78%       | 93.18%               | 0.86%      | 99.29%               | 95.56%                   |\n",
+      "| Categories   |   # ghosts | # tracks   | % ghosts   |\n",
+      "|:-------------|-----------:|:-----------|:-----------|\n",
+      "| Everything   |         79 | 4,432      | 1.78%      |\n"
+     ]
+    }
+   ],
    "source": [
     "for test_dataset_name in get_required_test_dataset_names(CONFIG):\n",
     "    logging.info(headline(test_dataset_name))\n",
diff --git a/LHCb_Pipeline/pipeline_config_default.yaml b/LHCb_Pipeline/pipeline_config_default.yaml
index a6c437c07f411b3fa97c0c6f92a4ae5e7d713a9d..d3ad84032f1368abefda7844377dc4c0350a49db 100644
--- a/LHCb_Pipeline/pipeline_config_default.yaml
+++ b/LHCb_Pipeline/pipeline_config_default.yaml
@@ -18,7 +18,7 @@ preprocessing:
   # - Dictionary with keys `start` and `stop`
   subdirs: 10
   output_subdirectory: "preprocessed"
-  selection: triplets_first_selection # Selection function, defined in `Preprocessing/selecting.py`
+  processing: triplets_first_selection # Processing function(s), defined in `Preprocessing/processing.py`
   n_events: null # if `null`, default to `n_train_events + n_test_events`
   num_true_hits_threshold: 500 # Minimal number of genuine hits
   # Columns to keep in the dataframes of hits-particles and particles
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-bidir.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-bidir.yaml
index a752568d5aef95e5d429cf3f2b0573353f255678..254fc165594089f276f9a05f5d0ec0dce3e3a70a 100644
--- a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-bidir.yaml
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-bidir.yaml
@@ -59,6 +59,10 @@ metric_learning:
   bidir: False
   max_epochs: 20
 
+  # Building
+  test_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+  training_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+
 
 gnn:
   # Dataset parameters
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-100000.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-100000.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d235fa6f79988aae3870d0fdf0cf91ba400a290b
--- /dev/null
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-100000.yaml
@@ -0,0 +1,105 @@
+common:
+  experiment_name: focal-loss-pid-fixed-100000
+  data_directory: /scratch/acorreia/data
+  artifact_directory: artifacts
+  performance_directory: output # plots and reports
+  gpus: 1
+  test_dataset_names:
+  - velo-sim10b-nospillover
+  - velo-sim10b-nospillover-only-long-electrons
+  # - bu2kstee-sim10aU1-xdigi
+
+preprocessing:
+  input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover
+  subdirs: {"start": 15, "stop": 80}
+  output_subdirectory: "preprocessed"
+  selection: triplets_first_selection
+  n_events: null # if `null`, default to `n_train_events + n_test_events`
+  num_true_hits_threshold: 500
+  hits_particles_columns: ["x", "y", "z", "plane"]
+  particles_columns: null
+
+processing:
+  input_subdirectory: "preprocessed"
+  output_subdirectory: "processed"
+  n_workers: 32
+  features: ["r", "phi", "z", "plane"]
+  feature_means: [18., 0.0, 281.0, 7.5]
+  feature_scales: [9.75, 1.82, 287.0, 12.5]
+  kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}]
+  kept_particles_columns: ["n_unique_planes", "nhits_velo"]
+  n_train_events: 100000
+  n_val_events: 1000
+  split_seed: 0
+  true_edges_column: planewise
+
+metric_learning:
+  # Dataset parameters
+  input_subdirectory: "processed"
+  output_subdirectory: "metric_learning_processed"
+
+  # Model parameters
+  feature_indices: 4
+  emb_hidden: 256
+  nb_layer: 6
+  emb_dim: 4
+  activation: Tanh
+  weight: 2
+  randomisation: 2
+  points_per_batch: 100000
+  r: 0.015
+  r_inference: 0.020
+  knn: 50
+  warmup: 8
+  margin: 0.1
+  lr: 0.001
+  factor: 0.7
+  patience: 10
+  regime: [rp, hnm, norm]
+  bidir: False
+  max_epochs: 20
+
+  # Building
+  test_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+  training_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+
+
+gnn:
+  # Dataset parameters
+  input_subdirectory: "metric_learning_processed"
+  output_subdirectory: "gnn_processed"
+  edge_cut: 0.5
+  noise: True
+  bidir: False
+
+  # Model parameters
+  feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate
+  hidden: 256
+  n_graph_iters: 8
+  nb_node_layers: 6
+  nb_node_encoder_layers: 6
+  nb_edge_layers: 10
+  nb_edge_encoder_layers: 6
+  nb_edge_classifier_layers: 6
+  layernorm: True
+  aggregation: sum_max
+  hidden_activation: SiLU
+  weight: 0.25
+  warmup: 10
+  lr: 0.0002
+  factor: 0.7
+  patience: 8
+  regime: ["pid"]
+  max_epochs: 50
+  gradient_clip_val: 0.5
+  focal_loss: true
+
+triplet_building:
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "triplet_building"
+
+track_building:
+  score_cut: 0.44
+  # input_subdirectory: "gnn_processed"
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "track_building_processed"
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-20000.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-20000.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bbb483bdbac86ddacc459277c7165a4b61ca327b
--- /dev/null
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-20000.yaml
@@ -0,0 +1,105 @@
+common:
+  experiment_name: focal-loss-pid-fixed-20000
+  data_directory: /scratch/acorreia/data
+  artifact_directory: artifacts
+  performance_directory: output # plots and reports
+  gpus: 1
+  test_dataset_names:
+  - velo-sim10b-nospillover
+  - velo-sim10b-nospillover-only-long-electrons
+  # - bu2kstee-sim10aU1-xdigi
+
+preprocessing:
+  input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover
+  subdirs: {"start": 10, "stop": 20}
+  output_subdirectory: "preprocessed"
+  selection: triplets_first_selection
+  n_events: null # if `null`, default to `n_train_events + n_test_events`
+  num_true_hits_threshold: 500
+  hits_particles_columns: ["x", "y", "z", "plane"]
+  particles_columns: null
+
+processing:
+  input_subdirectory: "preprocessed"
+  output_subdirectory: "processed"
+  n_workers: 32
+  features: ["r", "phi", "z", "plane"]
+  feature_means: [18., 0.0, 281.0, 7.5]
+  feature_scales: [9.75, 1.82, 287.0, 12.5]
+  kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}]
+  kept_particles_columns: ["n_unique_planes", "nhits_velo"]
+  n_train_events: 20000
+  n_val_events: 1000
+  split_seed: 0
+  true_edges_column: planewise
+
+metric_learning:
+  # Dataset parameters
+  input_subdirectory: "processed"
+  output_subdirectory: "metric_learning_processed"
+
+  # Model parameters
+  feature_indices: 4
+  emb_hidden: 256
+  nb_layer: 6
+  emb_dim: 4
+  activation: Tanh
+  weight: 2
+  randomisation: 2
+  points_per_batch: 100000
+  r: 0.015
+  r_inference: 0.020
+  knn: 50
+  warmup: 8
+  margin: 0.1
+  lr: 0.001
+  factor: 0.7
+  patience: 10
+  regime: [rp, hnm, norm]
+  bidir: False
+  max_epochs: 20
+
+  # Building
+  test_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+  training_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+
+
+gnn:
+  # Dataset parameters
+  input_subdirectory: "metric_learning_processed"
+  output_subdirectory: "gnn_processed"
+  edge_cut: 0.5
+  noise: True
+  bidir: False
+
+  # Model parameters
+  feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate
+  hidden: 256
+  n_graph_iters: 8
+  nb_node_layers: 6
+  nb_node_encoder_layers: 6
+  nb_edge_layers: 10
+  nb_edge_encoder_layers: 6
+  nb_edge_classifier_layers: 6
+  layernorm: True
+  aggregation: sum_max
+  hidden_activation: SiLU
+  weight: 0.25
+  warmup: 10
+  lr: 0.0002
+  factor: 0.7
+  patience: 8
+  regime: ["pid"]
+  max_epochs: 50
+  gradient_clip_val: 0.5
+  focal_loss: true
+
+triplet_building:
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "triplet_building"
+
+track_building:
+  score_cut: 0.44
+  # input_subdirectory: "gnn_processed"
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "track_building_processed"
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-250000.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-250000.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..14f85df36d0d992fa93c0d73a7e83b022d14cae1
--- /dev/null
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-250000.yaml
@@ -0,0 +1,108 @@
+common:
+  experiment_name: focal-loss-pid-fixed-250000
+  data_directory: /scratch/acorreia/data
+  artifact_directory: artifacts
+  performance_directory: output # plots and reports
+  gpus: 1
+  test_dataset_names:
+  - velo-sim10b-nospillover_choice
+  - velo-sim10b-nospillover
+  # - velo-sim10b-nospillover-only-long-electrons
+  # - bu2kstee-sim10aU1-xdigi
+
+preprocessing:
+  input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover
+  subdirs: {"start": 30, "stop": 80}
+  output_subdirectory: "preprocessed"
+  selection: triplets_first_selection
+  n_events: null # if `null`, default to `n_train_events + n_test_events`
+  num_true_hits_threshold: 500
+  hits_particles_columns: ["x", "y", "z", "plane"]
+  particles_columns: null
+
+processing:
+  input_subdirectory: "preprocessed"
+  output_subdirectory: "processed"
+  n_workers: 32
+  features: ["r", "phi", "z", "plane"]
+  feature_means: [18., 0.0, 281.0, 7.5]
+  feature_scales: [9.75, 1.82, 287.0, 12.5]
+  kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}]
+  kept_particles_columns: ["n_unique_planes", "nhits_velo"]
+  n_train_events: 220000
+  n_val_events: 1000
+  split_seed: 0
+  true_edges_column: planewise
+
+metric_learning:
+  # Dataset parameters
+  input_subdirectory: "processed"
+  output_subdirectory: "metric_learning_processed"
+
+  # Model parameters
+  feature_indices: 4
+  emb_hidden: 256
+  nb_layer: 6
+  emb_dim: 4
+  activation: Tanh
+  weight: 2
+  randomisation: 2
+  points_per_batch: 100000
+  r: 0.015
+  r_inference: 0.020
+  knn: 50
+  warmup: 8
+  margin: 0.1
+  lr: 0.001
+  factor: 0.7
+  patience: 10
+  regime: [rp, hnm, norm]
+  bidir: False
+  max_epochs: 20
+
+  # Building
+  test_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+  training_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+
+
+gnn:
+  # Dataset parameters
+  input_subdirectory: "metric_learning_processed"
+  output_subdirectory: "gnn_processed"
+  edge_cut: 0.5
+  noise: True
+  bidir: False
+  n_train_events: 1000
+  lazy: true
+
+  # Model parameters
+  feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate
+  hidden: 256
+  n_graph_iters: 8
+  nb_node_layers: 6
+  nb_node_encoder_layers: 6
+  nb_edge_layers: 10
+  nb_edge_encoder_layers: 6
+  nb_edge_classifier_layers: 6
+  layernorm: True
+  aggregation: sum_max
+  hidden_activation: SiLU
+  weight: 0.25
+  warmup: 10
+  lr: 0.0002
+  factor: 0.7
+  patience: 8
+  regime: ["pid"]
+  max_epochs: 50
+  gradient_clip_val: 0.5
+  focal_loss: true
+
+triplet_building:
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "triplet_building"
+
+track_building:
+  score_cut: 0.44
+  # input_subdirectory: "gnn_processed"
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "track_building_processed"
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-80000-2.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-80000-2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d43c8d542ec3d100b00dc2109ffd0a2b1075427
--- /dev/null
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-80000-2.yaml
@@ -0,0 +1,106 @@
+common:
+  experiment_name: focal-loss-pid-fixed-80000-2
+  data_directory: /scratch/acorreia/data
+  artifact_directory: artifacts
+  performance_directory: output # plots and reports
+  gpus: 1
+  test_dataset_names:
+  - velo-sim10b-nospillover_choice
+  - velo-sim10b-nospillover
+  # - velo-sim10b-nospillover-only-long-electrons
+  # - bu2kstee-sim10aU1-xdigi
+
+preprocessing:
+  input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover
+  subdirs: {"start": 30, "stop": 50}
+  output_subdirectory: "preprocessed"
+  selection: triplets_first_selection
+  n_events: null # if `null`, default to `n_train_events + n_test_events`
+  num_true_hits_threshold: 500
+  hits_particles_columns: ["x", "y", "z", "plane"]
+  particles_columns: null
+
+processing:
+  input_subdirectory: "preprocessed"
+  output_subdirectory: "processed"
+  n_workers: 32
+  features: ["r", "phi", "z", "plane"]
+  feature_means: [18., 0.0, 281.0, 7.5]
+  feature_scales: [9.75, 1.82, 287.0, 12.5]
+  kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}]
+  kept_particles_columns: ["n_unique_planes", "nhits_velo"]
+  n_train_events: 80000
+  n_val_events: 1000
+  split_seed: 0
+  true_edges_column: planewise
+
+metric_learning:
+  # Dataset parameters
+  input_subdirectory: "processed"
+  output_subdirectory: "metric_learning_processed"
+
+  # Model parameters
+  feature_indices: 4
+  emb_hidden: 256
+  nb_layer: 6
+  emb_dim: 4
+  activation: Tanh
+  weight: 2
+  randomisation: 2
+  points_per_batch: 100000
+  r: 0.015
+  r_inference: 0.020
+  knn: 50
+  warmup: 8
+  margin: 0.1
+  lr: 0.001
+  factor: 0.7
+  patience: 10
+  regime: [rp, hnm, norm]
+  bidir: False
+  max_epochs: 20
+
+  # Building
+  test_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+  training_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+
+
+gnn:
+  # Dataset parameters
+  input_subdirectory: "metric_learning_processed"
+  output_subdirectory: "gnn_processed"
+  edge_cut: 0.5
+  noise: True
+  bidir: False
+
+  # Model parameters
+  feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate
+  hidden: 256
+  n_graph_iters: 8
+  nb_node_layers: 6
+  nb_node_encoder_layers: 6
+  nb_edge_layers: 10
+  nb_edge_encoder_layers: 6
+  nb_edge_classifier_layers: 6
+  layernorm: True
+  aggregation: sum_max
+  hidden_activation: SiLU
+  weight: 0.25
+  warmup: 10
+  lr: 0.0002
+  factor: 0.7
+  patience: 8
+  regime: ["pid"]
+  max_epochs: 50
+  gradient_clip_val: 0.5
+  focal_loss: true
+
+triplet_building:
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "triplet_building"
+
+track_building:
+  score_cut: 0.43
+  # input_subdirectory: "gnn_processed"
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "track_building_processed"
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-80000.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-80000.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26b33b4425ce78eb32b4a1faa7235e6ef84088d4
--- /dev/null
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed-80000.yaml
@@ -0,0 +1,106 @@
+common:
+  experiment_name: focal-loss-pid-fixed-80000
+  data_directory: /scratch/acorreia/data
+  artifact_directory: artifacts
+  performance_directory: output # plots and reports
+  gpus: 1
+  test_dataset_names:
+  - velo-sim10b-nospillover_choice
+  - velo-sim10b-nospillover
+  # - velo-sim10b-nospillover-only-long-electrons
+  # - bu2kstee-sim10aU1-xdigi
+
+preprocessing:
+  input_dir: /scratch/acorreia/minbias-sim10b-xdigi-nospillover
+  subdirs: {"start": 15, "stop": 50}
+  output_subdirectory: "preprocessed"
+  selection: triplets_first_selection
+  n_events: null # if `null`, default to `n_train_events + n_test_events`
+  num_true_hits_threshold: 500
+  hits_particles_columns: ["x", "y", "z", "plane"]
+  particles_columns: null
+
+processing:
+  input_subdirectory: "preprocessed"
+  output_subdirectory: "processed"
+  n_workers: 32
+  features: ["r", "phi", "z", "plane"]
+  feature_means: [18., 0.0, 281.0, 7.5]
+  feature_scales: [9.75, 1.82, 287.0, 12.5]
+  kept_hits_columns: ["plane", {"un_x": "x"}, {"un_y": "y"}, {"un_z": "z"}]
+  kept_particles_columns: ["n_unique_planes", "nhits_velo"]
+  n_train_events: 80000
+  n_val_events: 1000
+  split_seed: 0
+  true_edges_column: planewise
+
+metric_learning:
+  # Dataset parameters
+  input_subdirectory: "processed"
+  output_subdirectory: "metric_learning_processed"
+
+  # Model parameters
+  feature_indices: 4
+  emb_hidden: 256
+  nb_layer: 6
+  emb_dim: 4
+  activation: Tanh
+  weight: 2
+  randomisation: 2
+  points_per_batch: 100000
+  r: 0.015
+  r_inference: 0.020
+  knn: 50
+  warmup: 8
+  margin: 0.1
+  lr: 0.001
+  factor: 0.7
+  patience: 10
+  regime: [rp, hnm, norm]
+  bidir: False
+  max_epochs: 20
+
+  # Building
+  test_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+  training_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+
+
+gnn:
+  # Dataset parameters
+  input_subdirectory: "metric_learning_processed"
+  output_subdirectory: "gnn_processed"
+  edge_cut: 0.5
+  noise: True
+  bidir: False
+
+  # Model parameters
+  feature_indices: 4 # mmh I'm actually using the plane number, which is not deliberate
+  hidden: 256
+  n_graph_iters: 8
+  nb_node_layers: 6
+  nb_node_encoder_layers: 6
+  nb_edge_layers: 10
+  nb_edge_encoder_layers: 6
+  nb_edge_classifier_layers: 6
+  layernorm: True
+  aggregation: sum_max
+  hidden_activation: SiLU
+  weight: 0.25
+  warmup: 10
+  lr: 0.0002
+  factor: 0.7
+  patience: 8
+  regime: ["pid"]
+  max_epochs: 50
+  gradient_clip_val: 0.5
+  focal_loss: true
+
+triplet_building:
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "triplet_building"
+
+track_building:
+  score_cut: 0.44
+  # input_subdirectory: "gnn_processed"
+  input_subdirectory: "gnn_processed"
+  output_subdirectory: "track_building_processed"
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed.yaml
index 20cdef02a0838282c026619c4c3dd50a89546aaf..dfbc6ee5466f02aded8c8eb44e5294c7a7bf32b1 100644
--- a/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed.yaml
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid-fixed.yaml
@@ -59,6 +59,9 @@ metric_learning:
   bidir: False
   max_epochs: 20
 
+  # Building
+  test_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
+  training_processing: ["edges_at_least_3_hits", "remove_edges_in_same_plane"]
 
 gnn:
   # Dataset parameters
@@ -95,7 +98,7 @@ triplet_building:
   output_subdirectory: "triplet_building"
 
 track_building:
-  score_cut: 0.73
+  score_cut: 0.45
   # input_subdirectory: "gnn_processed"
   input_subdirectory: "gnn_processed"
   output_subdirectory: "track_building_processed"
diff --git a/LHCb_Pipeline/pipeline_configs/focal-loss-pid.yaml b/LHCb_Pipeline/pipeline_configs/focal-loss-pid.yaml
index 043a6e732f3616862f9196b495b01dbdaa157729..8d7b5aa96fa13f2b285d94fb9fe81908e3b7a104 100644
--- a/LHCb_Pipeline/pipeline_configs/focal-loss-pid.yaml
+++ b/LHCb_Pipeline/pipeline_configs/focal-loss-pid.yaml
@@ -60,7 +60,7 @@ metric_learning:
   max_epochs: 20
 
   # Building
-  building: null
+  train_val_processing: null
   filtering: "edges_at_least_3_hits"
 
 gnn:
diff --git a/LHCb_Pipeline/test_samples.yaml b/LHCb_Pipeline/test_samples.yaml
index b76aeb12dc4fea9c89b8f6bba0f3ef7dd6f9ec58..425c03da8081257f41fe15109c52dd176b8ebe41 100644
--- a/LHCb_Pipeline/test_samples.yaml
+++ b/LHCb_Pipeline/test_samples.yaml
@@ -4,6 +4,12 @@ velo-sim10b-nospillover:
   n_events: 1000
   num_true_hits_threshold: null
 
+velo-sim10b-nospillover_choice:
+  input_dir: /scratch/acorreia/data_validation/minbias-sim10b-xdigi-nospillover/498
+  selection: null
+  n_events: 1000
+  num_true_hits_threshold: null
+
 velo-sim10b-nospillover-only-long-electrons:
   input_dir: /scratch/acorreia/data_validation/minbias-sim10b-xdigi-nospillover/500
   selection: only_long_electrons
diff --git a/LHCb_Pipeline/utils/loaderutils/__init__.py b/LHCb_Pipeline/utils/loaderutils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..911fec669a0ea23515d9664500cd48479bb52a6b
--- /dev/null
+++ b/LHCb_Pipeline/utils/loaderutils/__init__.py
@@ -0,0 +1,2 @@
+"""A package that contains utilities to load files for training, test and validation.
+"""
diff --git a/LHCb_Pipeline/utils/loaderutils/dataiterator.py b/LHCb_Pipeline/utils/loaderutils/dataiterator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f1a7eac5d6707bb8af846955f8201b1d5e1f618
--- /dev/null
+++ b/LHCb_Pipeline/utils/loaderutils/dataiterator.py
@@ -0,0 +1,61 @@
+"""Implement a general data loader that does not load all the data into
+memory, in order to deal with large datasets.
+"""
+from __future__ import annotations
+import torch
+from torch.utils.data import Dataset
+from torch_geometric.data import Data
+
+from .pathandling import get_input_paths
+
+
+class LazyDatasetBase(Dataset):
+    def __init__(
+        self,
+        input_dir: str,
+        n_events: int | None = None,
+        shuffle: bool = False,
+        seed: int | None = None,
+        **kwargs,
+    ):
+        self.input_dir = str(input_dir)
+
+        self.input_paths = get_input_paths(
+            input_dir=self.input_dir,
+            n_events=n_events,
+            shuffle=shuffle,
+            seed=seed,
+        )
+        self.fetch_dataset_kwargs = kwargs
+
+    def __len__(self) -> int:
+        """Number of input files"""
+        return len(self.input_paths)
+
+    def fetch_dataset(self, input_path: str, map_location: str = "cpu", **kwargs):
+        """Load and process one PyTorch DataSet.
+
+        Args:
+            input_path: path to the PyTorch dataset
+            map_location: location where to load the dataset
+            **kwargs: Other keyword arguments passed to :py:func:`torch.load`
+
+        Returns:
+            Load PyTorch data object
+        """
+        fetch_dataset_kwargs = self.fetch_dataset_kwargs.copy()
+        map_location_kwargs = fetch_dataset_kwargs.pop("map_location", None)
+
+        return torch.load(
+            input_path,
+            map_location=(
+                map_location_kwargs if map_location_kwargs is not None else map_location
+            ),
+            **fetch_dataset_kwargs,
+            **kwargs,
+        )
+
+    def __getitem__(self, idx: int) -> Data:
+        input_path = self.input_paths[idx]
+        dataset = self.fetch_dataset(input_path=input_path)
+        return dataset
diff --git a/LHCb_Pipeline/utils/loaderutils/pathandling.py b/LHCb_Pipeline/utils/loaderutils/pathandling.py
new file mode 100644
index 0000000000000000000000000000000000000000..b74f26648d5f5027b0ba3132f761552806a25877
--- /dev/null
+++ b/LHCb_Pipeline/utils/loaderutils/pathandling.py
@@ -0,0 +1,36 @@
+"""Utilies to handles datasets without loading them.
+"""
+import typing
+import os
+import numpy as np
+
+
+def get_input_paths(
+    input_dir: str,
+    n_events: int | None = None,
+    shuffle: bool = False,
+    seed: int | None = None,
+) -> typing.List[str]:
+    """Get the paths of the datasets located in a given directory.
+
+    Args:
+        input_dir: input directory
+        n_events: number of events to load
+        shuffle: whether to shuffle the input paths (applied before
+            selected the first ``n_events``)
+        seed: seed for the shuffling
+        **kwargs: Other keyword arguments passed to
+            :py:func:`ModelBase.fetch_dataset`
+
+    Returns:
+        List of paths to the PyTorch Data objects
+    """
+    all_input_paths = [entry.path for entry in os.scandir(input_dir) if entry.is_file()]
+    if shuffle:
+        rng = np.random.default_rng(seed=seed)
+        rng.shuffle(all_input_paths)
+
+    if n_events is not None:
+        all_input_paths = all_input_paths[:n_events]
+
+    return all_input_paths
diff --git a/LHCb_Pipeline/utils/modelutils/basemodel.py b/LHCb_Pipeline/utils/modelutils/basemodel.py
index 027e39acb8eac64687029c1a1ab030ab12926894..6f1de1c80568cf54cbd7793ceeac358b40e9fba1 100644
--- a/LHCb_Pipeline/utils/modelutils/basemodel.py
+++ b/LHCb_Pipeline/utils/modelutils/basemodel.py
@@ -7,17 +7,20 @@ import os
 import os.path as op
 from tqdm.auto import tqdm
 
-import numpy as np
 import torch
 from pytorch_lightning import LightningModule
 from torch_geometric.data import Data
 from torch_geometric.loader import DataLoader
 
 from utils.commonutils.cfeatures import get_input_features
+from utils.loaderutils.dataiterator import LazyDatasetBase
 
 
 class ModelBase(LightningModule):
-    def __init__(self, hparams):
+    def __init__(
+        self,
+        hparams,
+    ):
         super().__init__()
         self._trainset = None
         self._valset = None
@@ -30,7 +33,14 @@ class ModelBase(LightningModule):
         self.testset = None
 
     @property
-    def trainset(self) -> typing.List[Data]:
+    def lazy(self) -> bool:
+        """Whether to load the training set and val set into memory only when
+        needed.
+        """
+        return self.hparams.get("lazy", False)
+
+    @property
+    def trainset(self) -> typing.List[Data] | LazyDatasetBase:
         if self._trainset is None:
             self.load_partition(partition="train")
         assert self._trainset is not None
@@ -69,15 +79,15 @@ class ModelBase(LightningModule):
         else:
             return None
 
-    def fetch_datasets(
+    def get_lazy_dataset(
         self,
         input_dir: str,
         n_events: int | None = None,
         shuffle: bool = False,
         seed: int | None = None,
         **kwargs,
-    ) -> typing.List[Data]:
-        """Get the datasets located in a given directory.
+    ) -> LazyDatasetBase:
+        """Get the lazy dataset object.
 
         Args:
             input_dir: input directory
@@ -85,74 +95,75 @@ class ModelBase(LightningModule):
             shuffle: whether to shuffle the input paths (applied before
                 selected the first ``n_events``)
             seed: seed for the shuffling
-            **kwargs: Other keyword arguments passed to
-                :py:func:`ModelBase.fetch_dataset`
+            **kwargs: Other keyword arguments passed to the
+                :py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` constructor.
 
         Returns:
-            List of loaded PyTorch Geometric Data objects
+            :py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` object
         """
-        all_input_paths = [
-            entry.path for entry in os.scandir(input_dir) if entry.is_file()
-        ]
-        if shuffle:
-            rng = np.random.default_rng(seed=seed)
-            rng.shuffle(all_input_paths)
-
-        if n_events is not None:
-            all_input_paths = all_input_paths[:n_events]
-
-        logging.info(f"Load {len(all_input_paths)} files located in {input_dir}")
-        return [
-            self.fetch_dataset(input_path=input_path, **kwargs)
-            for input_path in tqdm(all_input_paths)
-        ]
+        return LazyDatasetBase(
+            input_dir=input_dir,
+            n_events=n_events,
+            shuffle=shuffle,
+            seed=seed,
+            **kwargs,
+        )
 
-    def fetch_dataset(
-        self, input_path: str, map_location: str = "cpu", **kwargs
-    ) -> Data:
-        """Load and process one PyTorch DataSet.
+    def fetch_datasets(self, lazy_dataset: LazyDatasetBase) -> typing.List[Data]:
+        """Get the datasets located in a given directory.
 
         Args:
-            input_path: path to the PyTorch dataset
-            map_location: location where to load the dataset
-            **kwargs: Other keyword arguments passed to :py:func:`torch.load`
+            input_dir: input directory
+            n_events: number of events to load
+            shuffle: whether to shuffle the input paths (applied before
+                selected the first ``n_events``)
+            seed: seed for the shuffling
+            **kwargs: Other keyword arguments passed to
+                :py:func:`ModelBase.get_lazy_dataset`
 
         Returns:
-            Load PyTorch data object
+            List of loaded PyTorch Geometric Data objects
         """
-        return torch.load(input_path, map_location=map_location, **kwargs)
+        logging.info(
+            f"Load {len(lazy_dataset)} files located in {lazy_dataset.input_dir}"
+        )
+        return [event for event in tqdm(iter(lazy_dataset), total=len(lazy_dataset))]
 
-    def load_testset_from_directory(self, input_dir: str):
+    def load_testset_from_directory(self, input_dir: str, **kwargs):
         """Load a test dataset from a path to a directory.
 
         Args:
             input_dir: path to the directory that contains the PyTorch Geometric Data
                 pickles files.
         """
-        self.testset = self.fetch_datasets(input_dir=input_dir)
+        lazy_dataset = self.get_lazy_dataset(input_dir=input_dir, **kwargs)
+        self.testset = self.fetch_datasets(lazy_dataset=lazy_dataset)
 
-    def fetch_partition(
+    def get_lazy_dataset_partition(
         self,
         partition: str,
         n_events: int | None = None,
         shuffle: bool = False,
         seed: int | None = None,
         **kwargs,
-    ) -> typing.List[Data]:
-        """Load a partition.
+    ) -> LazyDatasetBase:
+        """Get the lazy dataset of a partition.
 
         Args:
             partition: ``train``, ``val`` or name of the test dataset
-            n_events: number of events to load for this partition
+            n_events: number of events to load
             shuffle: whether to shuffle the input paths (applied before
                 selected the first ``n_events``)
             seed: seed for the shuffling
             **kwargs: Other keyword arguments passed to
-                :py:func:`ModelBase.fetch_dataset`
+                :py:func:`ModelBase.get_lazy_dataset`
+
+        Returns:
+            Lazy dataset of the ``partition``
         """
         if partition in ["train", "val"]:
-            datasets = self.fetch_datasets(
-                op.join(self.hparams["input_dir"], partition),
+            lazy_dataset = self.get_lazy_dataset(
+                input_dir=op.join(self.hparams["input_dir"], partition),
                 n_events=(
                     self.hparams.get(f"n_{partition}_events")
                     if n_events is None
@@ -164,7 +175,7 @@ class ModelBase(LightningModule):
             )
 
         else:
-            datasets = self.fetch_datasets(
+            lazy_dataset = self.get_lazy_dataset(
                 input_dir=op.join(self.hparams["input_dir"], "test", partition),
                 n_events=n_events,
                 shuffle=shuffle,
@@ -172,7 +183,38 @@ class ModelBase(LightningModule):
                 **kwargs,
             )
 
-        return datasets
+        return lazy_dataset
+
+    def fetch_partition(
+        self,
+        partition: str,
+        n_events: int | None = None,
+        shuffle: bool = False,
+        seed: int | None = None,
+        **kwargs,
+    ) -> typing.List[Data] | LazyDatasetBase:
+        """Load a partition.
+
+        Args:
+            partition: ``train``, ``val`` or name of the test dataset
+            n_events: number of events to load for this partition
+            shuffle: whether to shuffle the input paths (applied before
+                selected the first ``n_events``)
+            seed: seed for the shuffling
+            **kwargs: Other keyword arguments passed to
+                :py:func:`ModelBase.fetch_dataset`
+        """
+        lazy_dataset = self.get_lazy_dataset_partition(
+            partition=partition,
+            n_events=n_events,
+            shuffle=shuffle,
+            seed=seed,
+            **kwargs,
+        )
+        if partition == "train" and self.lazy:
+            return lazy_dataset
+        else:
+            return self.fetch_datasets(lazy_dataset=lazy_dataset)
 
     def load_partition(
         self,
@@ -180,7 +222,7 @@ class ModelBase(LightningModule):
         n_events: int | None = None,
         shuffle: bool = False,
         seed: int | None = None,
-    ) -> typing.List[Data]:
+    ):
         """Load datasets of a partition.
 
         Args:
diff --git a/LHCb_Pipeline/utils/modelutils/build.py b/LHCb_Pipeline/utils/modelutils/build.py
index 6fc1d779bfd009e6b6216b56a7ef2537976b9851..807f2813906b9ab4aad60a8299b4d0bcddf5a1e9 100644
--- a/LHCb_Pipeline/utils/modelutils/build.py
+++ b/LHCb_Pipeline/utils/modelutils/build.py
@@ -29,8 +29,7 @@ class BuilderBase(abc.ABC):
         input_dir: str,
         output_dir: str,
         reproduce: bool = True,
-        filtering: str | None = None,
-        building: str | None = None,
+        processing: str | typing.List[str] | None = None,
         file_names: typing.List[str] | None = None,
         parallel: bool = False,
     ):
@@ -42,10 +41,8 @@ class BuilderBase(abc.ABC):
             output_dir: output directory path
             reproduce: whether to delete the output directory if it exists,
                 and run again the inference
-            filtering: name of the function that filters the event. This would only
-                be applied to the train and val sets.
-            building: name of the function that compute columns for the event. This
-                would be applied to all the samples (train, val and test samples).
+            processing: name(s) of supplementary function(s) that process the event.
+                after :py:func:`ModelBase.construct_downstream`.
             file_names: list of file names to run the inference on. If not specified,
                 the inference is run on all the datasets located in the input directory.
             parallel:
@@ -72,8 +69,7 @@ class BuilderBase(abc.ABC):
                     self.infer_one_step,
                     input_dir=input_dir,
                     output_dir=output_dir,
-                    building=building,
-                    filtering=filtering,
+                    processing=processing,
                 )
                 if parallel:
                     process_map(infer_one_step_partial, file_names, chunksize=1)
@@ -86,8 +82,7 @@ class BuilderBase(abc.ABC):
         file_name: str,
         input_dir: str,
         output_dir: str,
-        filtering: str | typing.List[str] | None = None,
-        building: str | typing.List[str] | None = None,
+        processing: str | typing.List[str] | None = None,
     ):
         """Run the inference on a single file and save the output in another file.
 
@@ -95,35 +90,29 @@ class BuilderBase(abc.ABC):
             file_name: input file name
             input_dir: input directory path
             output_dir: output directory path
-            filtering: name of the function that filters the event. This would only
-                be applied to the train and val sets.
-            building: name of the function that compute columns for the event. This
-                would be applied to all the samples (train, val and test samples).
+            processing: name(s) of supplementary function(s) that process the event.
+                after :py:func:`ModelBase.construct_downstream`.
         """
         input_path = os.path.join(input_dir, file_name)
         if not os.path.exists(os.path.join(output_dir, file_name)):
             batch = self.load_batch(input_path)
             batch = self.process_one_step(
                 batch=batch,
-                filtering=filtering,
-                building=building,
+                processing=processing,
             )
             self.save_downstream(batch, os.path.join(output_dir, batch.event_str))
 
     def process_one_step(
         self,
         batch: Data,
-        filtering: str | typing.List[str] | None = None,
-        building: str | typing.List[str] | None = None,
+        processing: str | typing.List[str] | None = None,
     ) -> Data:
         """Process one event.
 
         Args:
             batch: event stored in a PyTorch Geometric data object
-            filtering: name of the function that filters the event. This would only
-                be applied to the train and val sets.
-            building: name of the function that compute columns for the event. This
-                would be applied to all the samples (train, val and test samples).
+            processing: name(s) of supplementary function(s) that process the event.
+                after :py:func:`ModelBase.construct_downstream`.
 
         Returns:
             Processed event, first by :py:func:`BuilderBase.construct_downstream`,
@@ -131,19 +120,16 @@ class BuilderBase(abc.ABC):
         """
         batch = self.construct_downstream(batch)
 
-        # Apply filtering and building
-        for processing_step in [filtering, building]:
-            if processing_step is not None:
-                processing_fct_names = (
-                    [processing_step]
-                    if isinstance(processing_step, str)
-                    else processing_step
+        if processing is not None:
+            # Apply processing functions (building or filtering)
+            processing_fct_names = (
+                [processing] if isinstance(processing, str) else processing
+            )
+            for processing_fct_name in processing_fct_names:
+                processing_fct = getattr(
+                    self._get_building_custom_module(), str(processing_fct_name)
                 )
-                for processing_fct_name in processing_fct_names:
-                    processing_fct = getattr(
-                        self._get_building_custom_module(), str(processing_fct_name)
-                    )
-                    batch = processing_fct(batch)
+                batch = processing_fct(batch)
         return batch
 
     def _get_building_custom_module(self) -> ModuleType:
diff --git a/montetracko b/montetracko
index e033e8a998cfa59df42f4e39c068e6ce4671ffae..d9c733560e13ef6075d16b28f682f4bc7bd37add 160000
--- a/montetracko
+++ b/montetracko
@@ -1 +1 @@
-Subproject commit e033e8a998cfa59df42f4e39c068e6ce4671ffae
+Subproject commit d9c733560e13ef6075d16b28f682f4bc7bd37add