diff --git a/changelog.md b/changelog.md
index 0cc4b01007adcbfdc7d0696bbda76292c5771dcb..3c368f1c7d673a70eab0b27d2b323b6ddee7501a 100644
--- a/changelog.md
+++ b/changelog.md
@@ -4,6 +4,7 @@
 
 ### Latest
 
+- Adding support for dataset-specific class labels in input var plots [!623](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/623)
 - Apply naming scheme for WP and nEpochs [!621](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/621)
 - Adding correct naming scheme for train config sections [!617](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/617)
 
diff --git a/docs/plotting/plotting_inputs.md b/docs/plotting/plotting_inputs.md
index ceec1f07402ab4aadb664eaf112247dfd6a1553d..cbe3c7ce9df62b243e25eb9180e9f71f4e7a06e6 100644
--- a/docs/plotting/plotting_inputs.md
+++ b/docs/plotting/plotting_inputs.md
@@ -30,7 +30,7 @@ The number of tracks per jet can be plotted for all different files. This can be
 
 ??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
     ```yaml
-    §§§examples/plotting_input_vars.yaml:91:108§§§
+    §§§examples/plotting_input_vars.yaml:95:112§§§
     ```
 
 | Options | Data Type | Necessary/Optional | Explanation |
@@ -51,7 +51,7 @@ To plot the track input variables, the following options are used.
 
 ??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
     ```yaml
-    §§§examples/plotting_input_vars.yaml:110:144§§§
+    §§§examples/plotting_input_vars.yaml:114:148§§§
     ```
 
 
@@ -73,7 +73,7 @@ To plot the jet input variables, the following options are used.
 
 ??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
     ```yaml
-    §§§examples/plotting_input_vars.yaml:16:89§§§
+    §§§examples/plotting_input_vars.yaml:16:93§§§
     ```
 
 | Options | Data Type | Necessary/Optional | Explanation |
diff --git a/examples/plotting_input_vars.yaml b/examples/plotting_input_vars.yaml
index 9d63278beeddde863f127d8a2c93cafdcb7849d9..87ee955fbbca5c4bb01d5adc17d9a3df99ec06e6 100644
--- a/examples/plotting_input_vars.yaml
+++ b/examples/plotting_input_vars.yaml
@@ -23,6 +23,10 @@ jets_input_vars:
     R22:
       files: <path_palce_holder>/user.alfroch.410470.btagTraining.e6337_s3126_r12305_r12253_r12305_p4441.EMPFlow_loose.2021-04-20-T171733-R21211_output.h5/*
       label: "R22 Loose"
+      # If you want to specify the `class_labels` per dataset you can add it here
+      # If you don't specify anything here, the overall defined `class_labels` will be
+      # used
+      # class_labels: ["bjets", "cjets", "ujets"]
   plot_settings:
     <<: *default_plot_settings
   class_labels: ["bjets", "cjets", "ujets"]
diff --git a/umami/input_vars_tools/__init__.py b/umami/input_vars_tools/__init__.py
index 2469dc070249f869c6e100a5a5e58776629d814b..1314762b0be9fd1459aa7706f59df3e597362d79 100644
--- a/umami/input_vars_tools/__init__.py
+++ b/umami/input_vars_tools/__init__.py
@@ -1,6 +1,7 @@
 # flake8: noqa
 # pylint: skip-file
 from umami.input_vars_tools.plotting_functions import (
+    get_datasets_configuration,
     plot_input_vars_jets,
     plot_input_vars_trks,
     plot_n_tracks_per_jet,
diff --git a/umami/input_vars_tools/plotting_functions.py b/umami/input_vars_tools/plotting_functions.py
index a989c72baeae6bebb39575f54370e8072070b2e4..ce77901354601defcc384635314380ca77e0ea69 100644
--- a/umami/input_vars_tools/plotting_functions.py
+++ b/umami/input_vars_tools/plotting_functions.py
@@ -11,6 +11,56 @@ from umami.plotting_tools.utils import translate_binning
 from umami.preprocessing_tools import get_variable_dict
 
 
+def get_datasets_configuration(plotting_config: dict, tracks: bool = False):
+    """Helper function to transform dict that stores the configuration of the different
+    datasets into lists of certain parameters.
+
+    Parameters
+    ----------
+    plotting_config : dict
+        Plotting configuration
+    tracks : bool, optional
+        Bool if the function should look for the `tracks_name` variable in the dataset
+        configurations.
+
+    Returns
+    -------
+    filepath_list : list
+        List with the filepaths of all the datasets.
+    labels_list : list
+        List with the 'dataset label' of each dataset.
+    class_labels_list : list
+        List with the class labels for each dataset. If no dataset-specific class labels
+        are provided, the globally defined class labels are used.
+    tracks_name_list : list
+        List with the track names of the datasets. Only returned if `tracks` is True.
+    """
+
+    filepath_list = []
+    labels_list = []
+    class_labels_list = []
+    tracks_name_list = []
+
+    datasets_config = plotting_config["Datasets_to_plot"]
+
+    for dataset_name in datasets_config:
+        if not datasets_config[dataset_name]["files"] is None:
+            filepath_list.append(datasets_config[dataset_name]["files"])
+            labels_list.append(datasets_config[dataset_name]["label"])
+            # check if this dataset has a specific list of class labels
+            class_labels_list.append(
+                datasets_config[dataset_name]["class_labels"]
+                if "class_labels" in datasets_config[dataset_name]
+                else plotting_config["class_labels"]
+            )
+            if tracks:
+                tracks_name_list.append(datasets_config[dataset_name]["tracks_name"])
+
+    if tracks:
+        return filepath_list, labels_list, class_labels_list, tracks_name_list
+    return filepath_list, labels_list, class_labels_list
+
+
 def check_kwargs_for_ylabel_and_n_ratio_panel(
     kwargs: dict,
     fallback_ylabel: str,
@@ -54,9 +104,9 @@ def check_kwargs_for_ylabel_and_n_ratio_panel(
 def plot_n_tracks_per_jet(
     datasets_filepaths: list,
     datasets_labels: list,
+    datasets_class_labels: list,
     datasets_track_names: list,
     n_jets: int,
-    class_labels: list,
     output_directory: str = "input_vars_trks",
     plot_type: str = "pdf",
     track_origin: str = "All",
@@ -73,12 +123,14 @@ def plot_n_tracks_per_jet(
         List of filepaths to the files.
     datasets_labels : list
         Label of the dataset for the legend.
+    datasets_class_labels : list
+        List with dataset-specific class labels, e.g. [["ujets", "cjets"], ["cjets"]]
+        to plot light-jets and c-jets for the first but only c-jets for the second
+        dataset
     datasets_track_names : list
         List with the track names of the files.
     n_jets : int
         Number of jets to use.
-    class_labels : list
-        List of classes that are to be plotted.
     output_directory : str
         Name of the output directory. Only the dir name not path!
     plot_type: str, optional
@@ -104,8 +156,8 @@ def plot_n_tracks_per_jet(
     flavour_label_dict = {}
 
     # Iterate over the different dataset filepaths and labels defined in the config
-    for (filepath, label, tracks_name) in zip(
-        datasets_filepaths, datasets_labels, datasets_track_names
+    for (filepath, label, tracks_name, class_labels) in zip(
+        datasets_filepaths, datasets_labels, datasets_track_names, datasets_class_labels
     ):
         loaded_trks, loaded_flavour_labels = udt.LoadTrksFromFile(
             filepath=filepath,
@@ -138,8 +190,8 @@ def plot_n_tracks_per_jet(
     # Store the means of the n_tracks distributions to print them at the end
     n_tracks_means = {label: {} for label in datasets_labels}
     # Iterate over datasets
-    for dataset_number, (label, linestyle) in enumerate(
-        zip(datasets_labels, linestyles[: len(datasets_labels)])
+    for dataset_number, (label, linestyle, class_labels) in enumerate(
+        zip(datasets_labels, linestyles[: len(datasets_labels)], datasets_class_labels)
     ):
         # Sort after given variable
         trks = np.asarray(trks_dict[label])
@@ -180,11 +232,11 @@ def plot_n_tracks_per_jet(
 def plot_input_vars_trks(
     datasets_filepaths: list,
     datasets_labels: list,
+    datasets_class_labels: list,
     datasets_track_names: list,
     var_dict: dict,
     n_jets: int,
     binning: dict,
-    class_labels: list,
     sorting_variable: str = "ptfrac",
     n_leading: list = None,
     output_directory: str = "input_vars_trks",
@@ -203,6 +255,10 @@ def plot_input_vars_trks(
         List of filepaths to the files.
     datasets_labels : list
         Label of the dataset for the legend.
+    datasets_class_labels : list
+        List with dataset-specific class labels, e.g. [["ujets", "cjets"], ["cjets"]]
+        to plot light-jets and c-jets for the first but only c-jets for the second
+        dataset
     datasets_track_names : list
         List with the track names of the files.
     var_dict : dict
@@ -211,8 +267,6 @@ def plot_input_vars_trks(
         Number of jets to use for plotting.
     binning : dict
         Decide which binning is used.
-    class_labels : list
-        List of class_labels which are to be plotted.
     sorting_variable : str
         Variable which is used for sorting.
     n_leading : list
@@ -252,9 +306,10 @@ def plot_input_vars_trks(
     flavour_label_dict = {}
 
     # Iterate over the different dataset filepaths and labels defined in the config
-    for filepath, label, tracks_name in zip(
+    for filepath, label, class_labels, tracks_name in zip(
         datasets_filepaths,
         datasets_labels,
+        datasets_class_labels,
         datasets_track_names,
     ):
 
@@ -359,8 +414,12 @@ def plot_input_vars_trks(
                     )
 
                 # Iterate over datasets
-                for dataset_number, (label, linestyle) in enumerate(
-                    zip(datasets_labels, linestyles[: len(datasets_labels)])
+                for dataset_number, (label, linestyle, class_labels) in enumerate(
+                    zip(
+                        datasets_labels,
+                        linestyles[: len(datasets_labels)],
+                        datasets_class_labels,
+                    )
                 ):
                     # Sort after given variable
                     sorting = np.argsort(-1 * trks_dict[label][sorting_variable])
@@ -450,10 +509,10 @@ def plot_input_vars_trks(
 def plot_input_vars_jets(
     datasets_filepaths: list,
     datasets_labels: list,
+    datasets_class_labels: list,
     var_dict: dict,
     n_jets: int,
     binning: dict,
-    class_labels: list,
     special_param_jets: dict = None,
     output_directory: str = "input_vars_jets",
     plot_type: str = "pdf",
@@ -470,14 +529,16 @@ def plot_input_vars_jets(
         List of filepaths to the files.
     datasets_labels : list
         Label of the dataset for the legend.
+    datasets_class_labels : list
+        List with dataset-specific class labels, e.g. [["ujets", "cjets"], ["cjets"]]
+        to plot light-jets and c-jets for the first but only c-jets for the second
+        dataset
     var_dict : dict
         Variable dict where all variables of the files are saved.
     n_jets : int
         Number of jets to use for plotting.
     binning : dict
         Decide which binning is used.
-    class_labels : list
-        List of class_labels which are to be plotted.
     special_param_jets : dict
         Dict with special x-axis cuts for the given variable.
     output_directory : str
@@ -508,7 +569,9 @@ def plot_input_vars_jets(
     flavour_label_dict = {}
 
     # Iterate over the different dataset filepaths and labels defined in the config
-    for (filepath, label) in zip(datasets_filepaths, datasets_labels):
+    for (filepath, label, class_labels) in zip(
+        datasets_filepaths, datasets_labels, datasets_class_labels
+    ):
         # Get the tracks and the labels from the file/files
         jets, flavour_labels = udt.LoadJetsFromFile(
             filepath=filepath,
@@ -554,8 +617,12 @@ def plot_input_vars_jets(
             logger.info("Plotting %s ...", var)
 
             # Iterate over datasets
-            for dataset_number, (label, linestyle) in enumerate(
-                zip(datasets_labels, linestyles[: len(datasets_labels)])
+            for dataset_number, (label, linestyle, class_labels) in enumerate(
+                zip(
+                    datasets_labels,
+                    linestyles[: len(datasets_labels)],
+                    datasets_class_labels,
+                )
             ):
                 # Get variable and the labels of the jets
                 jets_var = jets_dict[label][var]
diff --git a/umami/plot_input_variables.py b/umami/plot_input_variables.py
index c0cf806a8047d035e5289f9605aae82ff5128330..379133d41fe4689aac73232ca8b85808959f1ae5 100644
--- a/umami/plot_input_variables.py
+++ b/umami/plot_input_variables.py
@@ -83,9 +83,13 @@ def plot_trks_variables(plot_config, plot_type):
         plotting_config["plot_settings"] = translate_kwargs(
             plotting_config["plot_settings"]
         )
-        filepath_list = []
-        labels_list = []
-        tracks_list = []
+
+        (
+            filepath_list,
+            labels_list,
+            class_labels_list,
+            tracks_name_list,
+        ) = uit.get_datasets_configuration(plotting_config, tracks=True)
 
         # Default to no selection based on track_origin
         trk_origins = ["All"]
@@ -94,28 +98,13 @@ def plot_trks_variables(plot_config, plot_type):
         if "track_origins" in plotting_config:
             trk_origins = plotting_config["track_origins"]
 
-        for model_name, _ in plotting_config["Datasets_to_plot"].items():
-            if (
-                not plotting_config["Datasets_to_plot"][f"{model_name}"]["files"]
-                is None
-            ):
-                filepath_list.append(
-                    plotting_config["Datasets_to_plot"][f"{model_name}"]["files"]
-                )
-                labels_list.append(
-                    plotting_config["Datasets_to_plot"][f"{model_name}"]["label"]
-                )
-                tracks_list.append(
-                    plotting_config["Datasets_to_plot"][f"{model_name}"]["tracks_name"]
-                )
-
         for trk_origin in trk_origins:
             if ("nTracks" in plotting_config) and (plotting_config["nTracks"] is True):
                 uit.plot_n_tracks_per_jet(
                     datasets_filepaths=filepath_list,
                     datasets_labels=labels_list,
-                    datasets_track_names=tracks_list,
-                    class_labels=plotting_config["class_labels"],
+                    datasets_class_labels=class_labels_list,
+                    datasets_track_names=tracks_name_list,
                     n_jets=int(plot_config["Eval_parameters"]["n_jets"]),
                     output_directory=plotting_config["folder_to_save"]
                     if plotting_config["folder_to_save"]
@@ -129,8 +118,8 @@ def plot_trks_variables(plot_config, plot_type):
                 uit.plot_input_vars_trks(
                     datasets_filepaths=filepath_list,
                     datasets_labels=labels_list,
-                    datasets_track_names=tracks_list,
-                    class_labels=plotting_config["class_labels"],
+                    datasets_class_labels=class_labels_list,
+                    datasets_track_names=tracks_name_list,
                     var_dict=plot_config["Eval_parameters"]["var_dict"],
                     n_jets=int(plot_config["Eval_parameters"]["n_jets"]),
                     binning=plotting_config["binning"],
@@ -167,25 +156,17 @@ def plot_jets_variables(plot_config, plot_type):
         plotting_config["plot_settings"] = translate_kwargs(
             plotting_config["plot_settings"]
         )
-        filepath_list = []
-        labels_list = []
-
-        for model_name, _ in plotting_config["Datasets_to_plot"].items():
-            if (
-                not plotting_config["Datasets_to_plot"][f"{model_name}"]["files"]
-                is None
-            ):
-                filepath_list.append(
-                    plotting_config["Datasets_to_plot"][f"{model_name}"]["files"]
-                )
-                labels_list.append(
-                    plotting_config["Datasets_to_plot"][f"{model_name}"]["label"]
-                )
+
+        (  # pylint: disable=unbalanced-tuple-unpacking
+            filepath_list,
+            labels_list,
+            class_labels_list,
+        ) = uit.get_datasets_configuration(plotting_config)
 
         uit.plot_input_vars_jets(
             datasets_filepaths=filepath_list,
             datasets_labels=labels_list,
-            class_labels=plotting_config["class_labels"],
+            datasets_class_labels=class_labels_list,
             var_dict=plot_config["Eval_parameters"]["var_dict"],
             n_jets=int(plot_config["Eval_parameters"]["n_jets"]),
             binning=plotting_config["binning"],
diff --git a/umami/tests/unit/input_vars_tools/test_input_vars_tools.py b/umami/tests/unit/input_vars_tools/test_input_vars_tools.py
index 005fe577ec13860351dd6c4a96e71ad478d47c47..e2ceeecdd89c75a95a226cd31c9ec6a57e00fa71 100644
--- a/umami/tests/unit/input_vars_tools/test_input_vars_tools.py
+++ b/umami/tests/unit/input_vars_tools/test_input_vars_tools.py
@@ -10,6 +10,7 @@ from matplotlib.testing.compare import compare_images
 
 from umami.configuration import logger, set_log_level
 from umami.input_vars_tools.plotting_functions import (
+    get_datasets_configuration,
     plot_input_vars_jets,
     plot_input_vars_trks,
     plot_n_tracks_per_jet,
@@ -19,6 +20,81 @@ from umami.tools import yaml_loader
 set_log_level(logger, "DEBUG")
 
 
+class HelperFunction_TestCase(unittest.TestCase):
+    """Test class for helper functions."""
+
+    def setUp(self):
+        self.plotting_config = {
+            "class_labels": ["ujets", "cjets", "bjets"],
+            "Datasets_to_plot": {
+                "ds_1": {
+                    "files": "dummy_path_1",
+                    "label": "dummy_label_1",
+                },
+                "ds_2": {
+                    "files": "dummy_path_2",
+                    "label": "dummy_label_2",
+                },
+            },
+        }
+
+    def test_get_datasets_configuration_all_default(self):
+        """Test the helper function for the default case (same class labels for both
+        datasets)"""
+        exp_filepath_list = ["dummy_path_1", "dummy_path_2"]
+        exp_labels_list = ["dummy_label_1", "dummy_label_2"]
+        exp_class_labels_list = [
+            ["ujets", "cjets", "bjets"],
+            ["ujets", "cjets", "bjets"],
+        ]
+        (  # pylint: disable=unbalanced-tuple-unpacking
+            filepath_list,
+            labels_list,
+            class_labels_list,
+        ) = get_datasets_configuration(self.plotting_config)
+
+        with self.subTest():
+            self.assertEqual(exp_filepath_list, filepath_list)
+        with self.subTest():
+            self.assertEqual(exp_labels_list, labels_list)
+        with self.subTest():
+            self.assertEqual(exp_class_labels_list, class_labels_list)
+
+    def test_get_datasets_configuration_specific_class_labels(self):
+        """Test the helper function for the case of specifying specific class labels
+        for one of the datasets"""
+
+        # modify the config for this test
+        plotting_config = self.plotting_config
+        plotting_config["Datasets_to_plot"]["ds_2"]["class_labels"] = ["bjets"]
+        plotting_config["Datasets_to_plot"]["ds_1"]["tracks_name"] = "tracks_loose"
+        plotting_config["Datasets_to_plot"]["ds_2"]["tracks_name"] = "tracks"
+
+        # define expected outcome
+        exp_filepath_list = ["dummy_path_1", "dummy_path_2"]
+        exp_labels_list = ["dummy_label_1", "dummy_label_2"]
+        exp_class_labels_list = [
+            ["ujets", "cjets", "bjets"],
+            ["bjets"],
+        ]
+        exp_tracks_name_list = ["tracks_loose", "tracks"]
+        (
+            filepath_list,
+            labels_list,
+            class_labels_list,
+            tracks_name_list,
+        ) = get_datasets_configuration(self.plotting_config, tracks=True)
+
+        with self.subTest():
+            self.assertEqual(exp_filepath_list, filepath_list)
+        with self.subTest():
+            self.assertEqual(exp_labels_list, labels_list)
+        with self.subTest():
+            self.assertEqual(exp_class_labels_list, class_labels_list)
+        with self.subTest():
+            self.assertEqual(exp_tracks_name_list, tracks_name_list)
+
+
 class JetPlotting_TestCase(unittest.TestCase):
     """Test class for jet plotting functions."""
 
@@ -70,6 +146,7 @@ class JetPlotting_TestCase(unittest.TestCase):
         """Test jet input variable plots with wrong type."""
         plotting_config = self.plot_config["jets_input_vars"]
         filepath_list = [self.r21_test_file]
+        class_labels_list = [["ujets", "cjets"]]
         labels_list = ["R21 Test"]
 
         # Change type in plotting_config to string to produce error
@@ -79,7 +156,7 @@ class JetPlotting_TestCase(unittest.TestCase):
             plot_input_vars_jets(
                 datasets_filepaths=filepath_list,
                 datasets_labels=labels_list,
-                class_labels=plotting_config["class_labels"],
+                datasets_class_labels=class_labels_list,
                 var_dict=self.plot_config["Eval_parameters"]["var_dict"],
                 n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
                 binning=plotting_config["binning"],
@@ -95,11 +172,12 @@ class JetPlotting_TestCase(unittest.TestCase):
         plotting_config = self.plot_config["jets_input_vars"]
         filepath_list = [self.r21_test_file]
         labels_list = ["R21 Test"]
+        class_labels_list = [["bjets", "cjets", "ujets", "taujets"]]
 
         plot_input_vars_jets(
             datasets_filepaths=filepath_list,
             datasets_labels=labels_list,
-            class_labels=plotting_config["class_labels"],
+            datasets_class_labels=class_labels_list,
             var_dict=self.plot_config["Eval_parameters"]["var_dict"],
             n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
             binning=plotting_config["binning"],
@@ -138,6 +216,7 @@ class JetPlotting_TestCase(unittest.TestCase):
         plotting_config = self.plot_config["jets_input_vars"]
         filepath_list = [self.r21_test_file]
         labels_list = ["R21 Test"]
+        class_labels_list = [["ujets", "cjets"]]
 
         # Change type in plotting_config to string to produce error
         plotting_config["binning"]["SV1_NGTinSvx"] = "test"
@@ -146,7 +225,7 @@ class JetPlotting_TestCase(unittest.TestCase):
             plot_input_vars_jets(
                 datasets_filepaths=filepath_list,
                 datasets_labels=labels_list,
-                class_labels=plotting_config["class_labels"],
+                datasets_class_labels=class_labels_list,
                 var_dict=self.plot_config["Eval_parameters"]["var_dict"],
                 n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
                 binning=plotting_config["binning"],
@@ -160,12 +239,16 @@ class JetPlotting_TestCase(unittest.TestCase):
         """Test jet input variable plot with comparison."""
         plotting_config = self.plot_config["jets_input_vars"]
         filepath_list = [self.r21_test_file, self.r22_test_file]
+        class_labels_list = [
+            ["bjets", "cjets", "ujets", "taujets"],
+            ["bjets", "cjets", "ujets", "taujets"],
+        ]
         labels_list = ["R21 Test", "R22 Test"]
 
         plot_input_vars_jets(
             datasets_filepaths=filepath_list,
             datasets_labels=labels_list,
-            class_labels=plotting_config["class_labels"],
+            datasets_class_labels=class_labels_list,
             var_dict=self.plot_config["Eval_parameters"]["var_dict"],
             n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
             binning=plotting_config["binning"],
@@ -202,6 +285,7 @@ class JetPlotting_TestCase(unittest.TestCase):
         """Test track input variable plots with wrong type."""
         plotting_config = self.plot_config["Tracks_Test"]
         filepath_list = [self.r21_test_file]
+        class_labels_list = [["bjets", "cjets", "ujets"]]
         tracks_name_list = ["tracks"]
         labels_list = ["R21 Test"]
 
@@ -213,7 +297,7 @@ class JetPlotting_TestCase(unittest.TestCase):
                 datasets_filepaths=filepath_list,
                 datasets_labels=labels_list,
                 datasets_track_names=tracks_name_list,
-                class_labels=plotting_config["class_labels"],
+                datasets_class_labels=class_labels_list,
                 var_dict=self.plot_config["Eval_parameters"]["var_dict"],
                 n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
                 binning=plotting_config["binning"],
@@ -226,14 +310,15 @@ class JetPlotting_TestCase(unittest.TestCase):
         """Test track input variables with wrong type."""
         plotting_config = self.plot_config["Tracks_Test"]
         filepath_list = [self.r21_test_file]
+        class_labels_list = [["bjets", "cjets", "ujets"]]
         tracks_name_list = ["tracks"]
         labels_list = ["R21 Test"]
 
         plot_input_vars_trks(
             datasets_filepaths=filepath_list,
             datasets_labels=labels_list,
+            datasets_class_labels=class_labels_list,
             datasets_track_names=tracks_name_list,
-            class_labels=plotting_config["class_labels"],
             var_dict=self.plot_config["Eval_parameters"]["var_dict"],
             n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
             binning=plotting_config["binning"],
@@ -317,6 +402,7 @@ class JetPlotting_TestCase(unittest.TestCase):
         """Test track input variables with comparison and wrong type."""
         plotting_config = self.plot_config["Tracks_Test"]
         filepath_list = [self.r21_test_file]
+        class_labels_list = [["bjets", "cjets", "ujets"]]
         tracks_name_list = ["tracks"]
         labels_list = ["R21 Test"]
 
@@ -328,7 +414,7 @@ class JetPlotting_TestCase(unittest.TestCase):
                 datasets_filepaths=filepath_list,
                 datasets_labels=labels_list,
                 datasets_track_names=tracks_name_list,
-                class_labels=plotting_config["class_labels"],
+                datasets_class_labels=class_labels_list,
                 var_dict=self.plot_config["Eval_parameters"]["var_dict"],
                 n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
                 binning=plotting_config["binning"],
@@ -341,14 +427,15 @@ class JetPlotting_TestCase(unittest.TestCase):
         """Test track variable plots without normalisation."""
         plotting_config = self.plot_config["tracks_test_not_normalised"]
         filepath_list = [self.r21_test_file, self.r22_test_file]
+        class_labels_list = [["bjets", "cjets", "ujets"], ["bjets", "cjets", "ujets"]]
         tracks_name_list = ["tracks", "tracks_loose"]
         labels_list = ["R21 Test", "R22 Test"]
 
         plot_input_vars_trks(
             datasets_filepaths=filepath_list,
             datasets_labels=labels_list,
+            datasets_class_labels=class_labels_list,
             datasets_track_names=tracks_name_list,
-            class_labels=plotting_config["class_labels"],
             var_dict=self.plot_config["Eval_parameters"]["var_dict"],
             n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
             binning=plotting_config["binning"],
@@ -374,6 +461,7 @@ class JetPlotting_TestCase(unittest.TestCase):
         """Test plotting track input variables with comparison."""
         plotting_config = self.plot_config["Tracks_Test"]
         filepath_list = [self.r21_test_file, self.r22_test_file]
+        class_labels_list = [["bjets", "cjets", "ujets"], ["bjets", "cjets", "ujets"]]
         tracks_name_list = ["tracks", "tracks_loose"]
         labels_list = ["R21 Test", "R22 Test"]
 
@@ -381,7 +469,7 @@ class JetPlotting_TestCase(unittest.TestCase):
             datasets_filepaths=filepath_list,
             datasets_labels=labels_list,
             datasets_track_names=tracks_name_list,
-            class_labels=plotting_config["class_labels"],
+            datasets_class_labels=class_labels_list,
             var_dict=self.plot_config["Eval_parameters"]["var_dict"],
             n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
             binning=plotting_config["binning"],
@@ -466,13 +554,14 @@ class JetPlotting_TestCase(unittest.TestCase):
         plotting_config = self.plot_config["nTracks_Test"]
         filepath_list = [self.r21_test_file, self.r22_test_file]
         tracks_name_list = ["tracks", "tracks_loose"]
+        class_labels_list = [["bjets", "cjets", "ujets"], ["bjets", "cjets", "ujets"]]
         labels_list = ["R21 Test", "R22 Test"]
 
         plot_n_tracks_per_jet(
             datasets_filepaths=filepath_list,
             datasets_labels=labels_list,
+            datasets_class_labels=class_labels_list,
             datasets_track_names=tracks_name_list,
-            class_labels=plotting_config["class_labels"],
             n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]),
             output_directory=f"{self.actual_plots_dir}",
             # output_directory=f"{self.expected_plots_dir}",