diff --git a/changelog.md b/changelog.md index 0cc4b01007adcbfdc7d0696bbda76292c5771dcb..3c368f1c7d673a70eab0b27d2b323b6ddee7501a 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ ### Latest +- Adding support for dataset-specific class labels in input var plots [!623](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/623) - Apply naming scheme for WP and nEpochs [!621](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/621) - Adding correct naming scheme for train config sections [!617](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/617) diff --git a/docs/plotting/plotting_inputs.md b/docs/plotting/plotting_inputs.md index ceec1f07402ab4aadb664eaf112247dfd6a1553d..cbe3c7ce9df62b243e25eb9180e9f71f4e7a06e6 100644 --- a/docs/plotting/plotting_inputs.md +++ b/docs/plotting/plotting_inputs.md @@ -30,7 +30,7 @@ The number of tracks per jet can be plotted for all different files. This can be ??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)" ```yaml - §§§examples/plotting_input_vars.yaml:91:108§§§ + §§§examples/plotting_input_vars.yaml:95:112§§§ ``` | Options | Data Type | Necessary/Optional | Explanation | @@ -51,7 +51,7 @@ To plot the track input variables, the following options are used. ??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)" ```yaml - §§§examples/plotting_input_vars.yaml:110:144§§§ + §§§examples/plotting_input_vars.yaml:114:148§§§ ``` @@ -73,7 +73,7 @@ To plot the jet input variables, the following options are used. ??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)" ```yaml - §§§examples/plotting_input_vars.yaml:16:89§§§ + §§§examples/plotting_input_vars.yaml:16:93§§§ ``` | Options | Data Type | Necessary/Optional | Explanation | diff --git a/examples/plotting_input_vars.yaml b/examples/plotting_input_vars.yaml index 9d63278beeddde863f127d8a2c93cafdcb7849d9..87ee955fbbca5c4bb01d5adc17d9a3df99ec06e6 100644 --- a/examples/plotting_input_vars.yaml +++ b/examples/plotting_input_vars.yaml @@ -23,6 +23,10 @@ jets_input_vars: R22: files: <path_palce_holder>/user.alfroch.410470.btagTraining.e6337_s3126_r12305_r12253_r12305_p4441.EMPFlow_loose.2021-04-20-T171733-R21211_output.h5/* label: "R22 Loose" + # If you want to specify the `class_labels` per dataset you can add it here + # If you don't specify anything here, the overall defined `class_labels` will be + # used + # class_labels: ["bjets", "cjets", "ujets"] plot_settings: <<: *default_plot_settings class_labels: ["bjets", "cjets", "ujets"] diff --git a/umami/input_vars_tools/__init__.py b/umami/input_vars_tools/__init__.py index 2469dc070249f869c6e100a5a5e58776629d814b..1314762b0be9fd1459aa7706f59df3e597362d79 100644 --- a/umami/input_vars_tools/__init__.py +++ b/umami/input_vars_tools/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa # pylint: skip-file from umami.input_vars_tools.plotting_functions import ( + get_datasets_configuration, plot_input_vars_jets, plot_input_vars_trks, plot_n_tracks_per_jet, diff --git a/umami/input_vars_tools/plotting_functions.py b/umami/input_vars_tools/plotting_functions.py index a989c72baeae6bebb39575f54370e8072070b2e4..ce77901354601defcc384635314380ca77e0ea69 100644 --- a/umami/input_vars_tools/plotting_functions.py +++ b/umami/input_vars_tools/plotting_functions.py @@ -11,6 +11,56 @@ from umami.plotting_tools.utils import translate_binning from umami.preprocessing_tools import get_variable_dict +def get_datasets_configuration(plotting_config: dict, tracks: bool = False): + """Helper function to transform dict that stores the configuration of the different + datasets into lists of certain parameters. + + Parameters + ---------- + plotting_config : dict + Plotting configuration + tracks : bool, optional + Bool if the function should look for the `tracks_name` variable in the dataset + configurations. + + Returns + ------- + filepath_list : list + List with the filepaths of all the datasets. + labels_list : list + List with the 'dataset label' of each dataset. + class_labels_list : list + List with the class labels for each dataset. If no dataset-specific class labels + are provided, the globally defined class labels are used. + tracks_name_list : list + List with the track names of the datasets. Only returned if `tracks` is True. + """ + + filepath_list = [] + labels_list = [] + class_labels_list = [] + tracks_name_list = [] + + datasets_config = plotting_config["Datasets_to_plot"] + + for dataset_name in datasets_config: + if not datasets_config[dataset_name]["files"] is None: + filepath_list.append(datasets_config[dataset_name]["files"]) + labels_list.append(datasets_config[dataset_name]["label"]) + # check if this dataset has a specific list of class labels + class_labels_list.append( + datasets_config[dataset_name]["class_labels"] + if "class_labels" in datasets_config[dataset_name] + else plotting_config["class_labels"] + ) + if tracks: + tracks_name_list.append(datasets_config[dataset_name]["tracks_name"]) + + if tracks: + return filepath_list, labels_list, class_labels_list, tracks_name_list + return filepath_list, labels_list, class_labels_list + + def check_kwargs_for_ylabel_and_n_ratio_panel( kwargs: dict, fallback_ylabel: str, @@ -54,9 +104,9 @@ def check_kwargs_for_ylabel_and_n_ratio_panel( def plot_n_tracks_per_jet( datasets_filepaths: list, datasets_labels: list, + datasets_class_labels: list, datasets_track_names: list, n_jets: int, - class_labels: list, output_directory: str = "input_vars_trks", plot_type: str = "pdf", track_origin: str = "All", @@ -73,12 +123,14 @@ def plot_n_tracks_per_jet( List of filepaths to the files. datasets_labels : list Label of the dataset for the legend. + datasets_class_labels : list + List with dataset-specific class labels, e.g. [["ujets", "cjets"], ["cjets"]] + to plot light-jets and c-jets for the first but only c-jets for the second + dataset datasets_track_names : list List with the track names of the files. n_jets : int Number of jets to use. - class_labels : list - List of classes that are to be plotted. output_directory : str Name of the output directory. Only the dir name not path! plot_type: str, optional @@ -104,8 +156,8 @@ def plot_n_tracks_per_jet( flavour_label_dict = {} # Iterate over the different dataset filepaths and labels defined in the config - for (filepath, label, tracks_name) in zip( - datasets_filepaths, datasets_labels, datasets_track_names + for (filepath, label, tracks_name, class_labels) in zip( + datasets_filepaths, datasets_labels, datasets_track_names, datasets_class_labels ): loaded_trks, loaded_flavour_labels = udt.LoadTrksFromFile( filepath=filepath, @@ -138,8 +190,8 @@ def plot_n_tracks_per_jet( # Store the means of the n_tracks distributions to print them at the end n_tracks_means = {label: {} for label in datasets_labels} # Iterate over datasets - for dataset_number, (label, linestyle) in enumerate( - zip(datasets_labels, linestyles[: len(datasets_labels)]) + for dataset_number, (label, linestyle, class_labels) in enumerate( + zip(datasets_labels, linestyles[: len(datasets_labels)], datasets_class_labels) ): # Sort after given variable trks = np.asarray(trks_dict[label]) @@ -180,11 +232,11 @@ def plot_n_tracks_per_jet( def plot_input_vars_trks( datasets_filepaths: list, datasets_labels: list, + datasets_class_labels: list, datasets_track_names: list, var_dict: dict, n_jets: int, binning: dict, - class_labels: list, sorting_variable: str = "ptfrac", n_leading: list = None, output_directory: str = "input_vars_trks", @@ -203,6 +255,10 @@ def plot_input_vars_trks( List of filepaths to the files. datasets_labels : list Label of the dataset for the legend. + datasets_class_labels : list + List with dataset-specific class labels, e.g. [["ujets", "cjets"], ["cjets"]] + to plot light-jets and c-jets for the first but only c-jets for the second + dataset datasets_track_names : list List with the track names of the files. var_dict : dict @@ -211,8 +267,6 @@ def plot_input_vars_trks( Number of jets to use for plotting. binning : dict Decide which binning is used. - class_labels : list - List of class_labels which are to be plotted. sorting_variable : str Variable which is used for sorting. n_leading : list @@ -252,9 +306,10 @@ def plot_input_vars_trks( flavour_label_dict = {} # Iterate over the different dataset filepaths and labels defined in the config - for filepath, label, tracks_name in zip( + for filepath, label, class_labels, tracks_name in zip( datasets_filepaths, datasets_labels, + datasets_class_labels, datasets_track_names, ): @@ -359,8 +414,12 @@ def plot_input_vars_trks( ) # Iterate over datasets - for dataset_number, (label, linestyle) in enumerate( - zip(datasets_labels, linestyles[: len(datasets_labels)]) + for dataset_number, (label, linestyle, class_labels) in enumerate( + zip( + datasets_labels, + linestyles[: len(datasets_labels)], + datasets_class_labels, + ) ): # Sort after given variable sorting = np.argsort(-1 * trks_dict[label][sorting_variable]) @@ -450,10 +509,10 @@ def plot_input_vars_trks( def plot_input_vars_jets( datasets_filepaths: list, datasets_labels: list, + datasets_class_labels: list, var_dict: dict, n_jets: int, binning: dict, - class_labels: list, special_param_jets: dict = None, output_directory: str = "input_vars_jets", plot_type: str = "pdf", @@ -470,14 +529,16 @@ def plot_input_vars_jets( List of filepaths to the files. datasets_labels : list Label of the dataset for the legend. + datasets_class_labels : list + List with dataset-specific class labels, e.g. [["ujets", "cjets"], ["cjets"]] + to plot light-jets and c-jets for the first but only c-jets for the second + dataset var_dict : dict Variable dict where all variables of the files are saved. n_jets : int Number of jets to use for plotting. binning : dict Decide which binning is used. - class_labels : list - List of class_labels which are to be plotted. special_param_jets : dict Dict with special x-axis cuts for the given variable. output_directory : str @@ -508,7 +569,9 @@ def plot_input_vars_jets( flavour_label_dict = {} # Iterate over the different dataset filepaths and labels defined in the config - for (filepath, label) in zip(datasets_filepaths, datasets_labels): + for (filepath, label, class_labels) in zip( + datasets_filepaths, datasets_labels, datasets_class_labels + ): # Get the tracks and the labels from the file/files jets, flavour_labels = udt.LoadJetsFromFile( filepath=filepath, @@ -554,8 +617,12 @@ def plot_input_vars_jets( logger.info("Plotting %s ...", var) # Iterate over datasets - for dataset_number, (label, linestyle) in enumerate( - zip(datasets_labels, linestyles[: len(datasets_labels)]) + for dataset_number, (label, linestyle, class_labels) in enumerate( + zip( + datasets_labels, + linestyles[: len(datasets_labels)], + datasets_class_labels, + ) ): # Get variable and the labels of the jets jets_var = jets_dict[label][var] diff --git a/umami/plot_input_variables.py b/umami/plot_input_variables.py index c0cf806a8047d035e5289f9605aae82ff5128330..379133d41fe4689aac73232ca8b85808959f1ae5 100644 --- a/umami/plot_input_variables.py +++ b/umami/plot_input_variables.py @@ -83,9 +83,13 @@ def plot_trks_variables(plot_config, plot_type): plotting_config["plot_settings"] = translate_kwargs( plotting_config["plot_settings"] ) - filepath_list = [] - labels_list = [] - tracks_list = [] + + ( + filepath_list, + labels_list, + class_labels_list, + tracks_name_list, + ) = uit.get_datasets_configuration(plotting_config, tracks=True) # Default to no selection based on track_origin trk_origins = ["All"] @@ -94,28 +98,13 @@ def plot_trks_variables(plot_config, plot_type): if "track_origins" in plotting_config: trk_origins = plotting_config["track_origins"] - for model_name, _ in plotting_config["Datasets_to_plot"].items(): - if ( - not plotting_config["Datasets_to_plot"][f"{model_name}"]["files"] - is None - ): - filepath_list.append( - plotting_config["Datasets_to_plot"][f"{model_name}"]["files"] - ) - labels_list.append( - plotting_config["Datasets_to_plot"][f"{model_name}"]["label"] - ) - tracks_list.append( - plotting_config["Datasets_to_plot"][f"{model_name}"]["tracks_name"] - ) - for trk_origin in trk_origins: if ("nTracks" in plotting_config) and (plotting_config["nTracks"] is True): uit.plot_n_tracks_per_jet( datasets_filepaths=filepath_list, datasets_labels=labels_list, - datasets_track_names=tracks_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, + datasets_track_names=tracks_name_list, n_jets=int(plot_config["Eval_parameters"]["n_jets"]), output_directory=plotting_config["folder_to_save"] if plotting_config["folder_to_save"] @@ -129,8 +118,8 @@ def plot_trks_variables(plot_config, plot_type): uit.plot_input_vars_trks( datasets_filepaths=filepath_list, datasets_labels=labels_list, - datasets_track_names=tracks_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, + datasets_track_names=tracks_name_list, var_dict=plot_config["Eval_parameters"]["var_dict"], n_jets=int(plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -167,25 +156,17 @@ def plot_jets_variables(plot_config, plot_type): plotting_config["plot_settings"] = translate_kwargs( plotting_config["plot_settings"] ) - filepath_list = [] - labels_list = [] - - for model_name, _ in plotting_config["Datasets_to_plot"].items(): - if ( - not plotting_config["Datasets_to_plot"][f"{model_name}"]["files"] - is None - ): - filepath_list.append( - plotting_config["Datasets_to_plot"][f"{model_name}"]["files"] - ) - labels_list.append( - plotting_config["Datasets_to_plot"][f"{model_name}"]["label"] - ) + + ( # pylint: disable=unbalanced-tuple-unpacking + filepath_list, + labels_list, + class_labels_list, + ) = uit.get_datasets_configuration(plotting_config) uit.plot_input_vars_jets( datasets_filepaths=filepath_list, datasets_labels=labels_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=plot_config["Eval_parameters"]["var_dict"], n_jets=int(plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], diff --git a/umami/tests/unit/input_vars_tools/test_input_vars_tools.py b/umami/tests/unit/input_vars_tools/test_input_vars_tools.py index 005fe577ec13860351dd6c4a96e71ad478d47c47..e2ceeecdd89c75a95a226cd31c9ec6a57e00fa71 100644 --- a/umami/tests/unit/input_vars_tools/test_input_vars_tools.py +++ b/umami/tests/unit/input_vars_tools/test_input_vars_tools.py @@ -10,6 +10,7 @@ from matplotlib.testing.compare import compare_images from umami.configuration import logger, set_log_level from umami.input_vars_tools.plotting_functions import ( + get_datasets_configuration, plot_input_vars_jets, plot_input_vars_trks, plot_n_tracks_per_jet, @@ -19,6 +20,81 @@ from umami.tools import yaml_loader set_log_level(logger, "DEBUG") +class HelperFunction_TestCase(unittest.TestCase): + """Test class for helper functions.""" + + def setUp(self): + self.plotting_config = { + "class_labels": ["ujets", "cjets", "bjets"], + "Datasets_to_plot": { + "ds_1": { + "files": "dummy_path_1", + "label": "dummy_label_1", + }, + "ds_2": { + "files": "dummy_path_2", + "label": "dummy_label_2", + }, + }, + } + + def test_get_datasets_configuration_all_default(self): + """Test the helper function for the default case (same class labels for both + datasets)""" + exp_filepath_list = ["dummy_path_1", "dummy_path_2"] + exp_labels_list = ["dummy_label_1", "dummy_label_2"] + exp_class_labels_list = [ + ["ujets", "cjets", "bjets"], + ["ujets", "cjets", "bjets"], + ] + ( # pylint: disable=unbalanced-tuple-unpacking + filepath_list, + labels_list, + class_labels_list, + ) = get_datasets_configuration(self.plotting_config) + + with self.subTest(): + self.assertEqual(exp_filepath_list, filepath_list) + with self.subTest(): + self.assertEqual(exp_labels_list, labels_list) + with self.subTest(): + self.assertEqual(exp_class_labels_list, class_labels_list) + + def test_get_datasets_configuration_specific_class_labels(self): + """Test the helper function for the case of specifying specific class labels + for one of the datasets""" + + # modify the config for this test + plotting_config = self.plotting_config + plotting_config["Datasets_to_plot"]["ds_2"]["class_labels"] = ["bjets"] + plotting_config["Datasets_to_plot"]["ds_1"]["tracks_name"] = "tracks_loose" + plotting_config["Datasets_to_plot"]["ds_2"]["tracks_name"] = "tracks" + + # define expected outcome + exp_filepath_list = ["dummy_path_1", "dummy_path_2"] + exp_labels_list = ["dummy_label_1", "dummy_label_2"] + exp_class_labels_list = [ + ["ujets", "cjets", "bjets"], + ["bjets"], + ] + exp_tracks_name_list = ["tracks_loose", "tracks"] + ( + filepath_list, + labels_list, + class_labels_list, + tracks_name_list, + ) = get_datasets_configuration(self.plotting_config, tracks=True) + + with self.subTest(): + self.assertEqual(exp_filepath_list, filepath_list) + with self.subTest(): + self.assertEqual(exp_labels_list, labels_list) + with self.subTest(): + self.assertEqual(exp_class_labels_list, class_labels_list) + with self.subTest(): + self.assertEqual(exp_tracks_name_list, tracks_name_list) + + class JetPlotting_TestCase(unittest.TestCase): """Test class for jet plotting functions.""" @@ -70,6 +146,7 @@ class JetPlotting_TestCase(unittest.TestCase): """Test jet input variable plots with wrong type.""" plotting_config = self.plot_config["jets_input_vars"] filepath_list = [self.r21_test_file] + class_labels_list = [["ujets", "cjets"]] labels_list = ["R21 Test"] # Change type in plotting_config to string to produce error @@ -79,7 +156,7 @@ class JetPlotting_TestCase(unittest.TestCase): plot_input_vars_jets( datasets_filepaths=filepath_list, datasets_labels=labels_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -95,11 +172,12 @@ class JetPlotting_TestCase(unittest.TestCase): plotting_config = self.plot_config["jets_input_vars"] filepath_list = [self.r21_test_file] labels_list = ["R21 Test"] + class_labels_list = [["bjets", "cjets", "ujets", "taujets"]] plot_input_vars_jets( datasets_filepaths=filepath_list, datasets_labels=labels_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -138,6 +216,7 @@ class JetPlotting_TestCase(unittest.TestCase): plotting_config = self.plot_config["jets_input_vars"] filepath_list = [self.r21_test_file] labels_list = ["R21 Test"] + class_labels_list = [["ujets", "cjets"]] # Change type in plotting_config to string to produce error plotting_config["binning"]["SV1_NGTinSvx"] = "test" @@ -146,7 +225,7 @@ class JetPlotting_TestCase(unittest.TestCase): plot_input_vars_jets( datasets_filepaths=filepath_list, datasets_labels=labels_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -160,12 +239,16 @@ class JetPlotting_TestCase(unittest.TestCase): """Test jet input variable plot with comparison.""" plotting_config = self.plot_config["jets_input_vars"] filepath_list = [self.r21_test_file, self.r22_test_file] + class_labels_list = [ + ["bjets", "cjets", "ujets", "taujets"], + ["bjets", "cjets", "ujets", "taujets"], + ] labels_list = ["R21 Test", "R22 Test"] plot_input_vars_jets( datasets_filepaths=filepath_list, datasets_labels=labels_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -202,6 +285,7 @@ class JetPlotting_TestCase(unittest.TestCase): """Test track input variable plots with wrong type.""" plotting_config = self.plot_config["Tracks_Test"] filepath_list = [self.r21_test_file] + class_labels_list = [["bjets", "cjets", "ujets"]] tracks_name_list = ["tracks"] labels_list = ["R21 Test"] @@ -213,7 +297,7 @@ class JetPlotting_TestCase(unittest.TestCase): datasets_filepaths=filepath_list, datasets_labels=labels_list, datasets_track_names=tracks_name_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -226,14 +310,15 @@ class JetPlotting_TestCase(unittest.TestCase): """Test track input variables with wrong type.""" plotting_config = self.plot_config["Tracks_Test"] filepath_list = [self.r21_test_file] + class_labels_list = [["bjets", "cjets", "ujets"]] tracks_name_list = ["tracks"] labels_list = ["R21 Test"] plot_input_vars_trks( datasets_filepaths=filepath_list, datasets_labels=labels_list, + datasets_class_labels=class_labels_list, datasets_track_names=tracks_name_list, - class_labels=plotting_config["class_labels"], var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -317,6 +402,7 @@ class JetPlotting_TestCase(unittest.TestCase): """Test track input variables with comparison and wrong type.""" plotting_config = self.plot_config["Tracks_Test"] filepath_list = [self.r21_test_file] + class_labels_list = [["bjets", "cjets", "ujets"]] tracks_name_list = ["tracks"] labels_list = ["R21 Test"] @@ -328,7 +414,7 @@ class JetPlotting_TestCase(unittest.TestCase): datasets_filepaths=filepath_list, datasets_labels=labels_list, datasets_track_names=tracks_name_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -341,14 +427,15 @@ class JetPlotting_TestCase(unittest.TestCase): """Test track variable plots without normalisation.""" plotting_config = self.plot_config["tracks_test_not_normalised"] filepath_list = [self.r21_test_file, self.r22_test_file] + class_labels_list = [["bjets", "cjets", "ujets"], ["bjets", "cjets", "ujets"]] tracks_name_list = ["tracks", "tracks_loose"] labels_list = ["R21 Test", "R22 Test"] plot_input_vars_trks( datasets_filepaths=filepath_list, datasets_labels=labels_list, + datasets_class_labels=class_labels_list, datasets_track_names=tracks_name_list, - class_labels=plotting_config["class_labels"], var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -374,6 +461,7 @@ class JetPlotting_TestCase(unittest.TestCase): """Test plotting track input variables with comparison.""" plotting_config = self.plot_config["Tracks_Test"] filepath_list = [self.r21_test_file, self.r22_test_file] + class_labels_list = [["bjets", "cjets", "ujets"], ["bjets", "cjets", "ujets"]] tracks_name_list = ["tracks", "tracks_loose"] labels_list = ["R21 Test", "R22 Test"] @@ -381,7 +469,7 @@ class JetPlotting_TestCase(unittest.TestCase): datasets_filepaths=filepath_list, datasets_labels=labels_list, datasets_track_names=tracks_name_list, - class_labels=plotting_config["class_labels"], + datasets_class_labels=class_labels_list, var_dict=self.plot_config["Eval_parameters"]["var_dict"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), binning=plotting_config["binning"], @@ -466,13 +554,14 @@ class JetPlotting_TestCase(unittest.TestCase): plotting_config = self.plot_config["nTracks_Test"] filepath_list = [self.r21_test_file, self.r22_test_file] tracks_name_list = ["tracks", "tracks_loose"] + class_labels_list = [["bjets", "cjets", "ujets"], ["bjets", "cjets", "ujets"]] labels_list = ["R21 Test", "R22 Test"] plot_n_tracks_per_jet( datasets_filepaths=filepath_list, datasets_labels=labels_list, + datasets_class_labels=class_labels_list, datasets_track_names=tracks_name_list, - class_labels=plotting_config["class_labels"], n_jets=int(self.plot_config["Eval_parameters"]["n_jets"]), output_directory=f"{self.actual_plots_dir}", # output_directory=f"{self.expected_plots_dir}",