diff --git a/changelog.md b/changelog.md index ee0d3e1cce8f5874e6ce08d646e3be9cac1d1af1..1fc1070243481d582e6d6250cde14da63faf1406 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ ### Latest +- Fixing issues with trained_taggers and taggers_from_file in plotting_epoch_performance.py [!549](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/549) - Adding plotting API to Contour plots + Updating plotting_umami docs [!537](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/537) - Adding unit test for prepare_model and minor bug fixes [!546](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/546) - Adding unit tests for tf generators[!542](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/542) diff --git a/docs/plotting/plotting_umami.md b/docs/plotting/plotting_umami.md index 70d789d4745fe1ecd320a748d36f1e6fc5bf10eb..ffbc9a29be175822279475764bef5f5a74a484c9 100644 --- a/docs/plotting/plotting_umami.md +++ b/docs/plotting/plotting_umami.md @@ -95,7 +95,7 @@ Plotting the ROC Curves of the rejection rates against the b-tagging efficiency. | `label` | `str` | Necessary | Legend label of the model. | | `tagger_name` | `str` | Necessary | Name of the tagger which is to be plotted. | | `rejection_class` | `str` | Necessary | Class which the main flavour is plotted against. | -| `binomialErrors` | `bool` | Optional | Plot binomial errors to plot. | +| `draw_errors` | `bool` | Optional | Plot binomial errors to plot. | | `xmin` | `float` | Optional | Set the minimum b efficiency in the plot (which is the xmin limit). | | `ymax` | `float` | Optional | The maximum y axis. | | `working_points` | `list` | Optional | The specified WPs are calculated and at the calculated b-tagging discriminant there will be a vertical line with a small label on top which prints the WP. | @@ -124,8 +124,8 @@ Plot the b efficiency/c-rejection/light-rejection against the pT. For example: | `flavour` | `str` | Necessary | Flavour class rejection which is to be plotted. | | `class_labels` | List of class labels that were used in the preprocessing/training. They must be the same in all three files! Order is important! | | `main_class` | `str` | Class which is to be tagged. | -| `WP` | `float` | Necessary | Float of the working point that will be used. | -| `WP_line` | `float` | Optional | Print a horizontal line at this value efficiency. | +| `working_point` | `float` | Necessary | Float of the working point that will be used. | +| `working_point_line` | `float` | Optional | Print a horizontal line at this value efficiency. | | `fixed_eff_bin` | `bool` | Optional | Calculate the WP cut on the discriminant per bin. | #### Fraction Contour Plot diff --git a/examples/plotting_umami_config_DL1r.yaml b/examples/plotting_umami_config_DL1r.yaml index 15d81da4b8cccd91edd2165581bf1857635cbc0f..bd25af435dc96e0d5ea6039dce4768c8c2f64fa9 100644 --- a/examples/plotting_umami_config_DL1r.yaml +++ b/examples/plotting_umami_config_DL1r.yaml @@ -57,7 +57,7 @@ DL1r_light_flavour: tagger_name: "DL1" rejection_class: "cjets" plot_settings: # These settings are given to the umami.evaluation_tools.plotROCRatio() function by unpacking them. - binomialErrors: True + draw_errors: True xmin: 0.5 ymax: 1000000 figsize: [7, 6] # [width, hight] diff --git a/examples/plotting_umami_config_Umami.yaml b/examples/plotting_umami_config_Umami.yaml index b02a28002bb2b96e8455e9421a201be54839e300..8cea72ee415f3d92f2302316a2a2ad6940a811ab 100644 --- a/examples/plotting_umami_config_Umami.yaml +++ b/examples/plotting_umami_config_Umami.yaml @@ -94,7 +94,7 @@ beff_scan_tagger_umami: tagger_name: "umami" rejection_class: "cjets" plot_settings: - binomialErrors: True + draw_errors: True xmin: 0.5 ymax: 1000000 figsize: [7, 6] # [width, hight] @@ -168,7 +168,7 @@ beff_scan_tagger_compare_umami: tagger_name: "umami" rejection_class: "cjets" plot_settings: - binomialErrors: True + draw_errors: True xmin: 0.5 ymax: 1000000 figsize: [9, 9] # [width, hight] diff --git a/examples/plotting_umami_config_dips.yaml b/examples/plotting_umami_config_dips.yaml index 517ad0de60225cf1fc0742e07b64cc3df21fa3ef..bda566825c8091efe723cc0ad8b66015db3c90c3 100644 --- a/examples/plotting_umami_config_dips.yaml +++ b/examples/plotting_umami_config_dips.yaml @@ -87,9 +87,9 @@ Dips_pT_vs_beff: flavour: "cjets" class_labels: ["ujets", "cjets", "bjets"] main_class: "bjets" - WP: 0.77 - WP_Line: True - Fixed_WP_Bin: False + working_point: 0.77 + working_point_line: True + fixed_eff_bin: False figsize: [7, 5] logy: False use_atlas_tag: True @@ -106,7 +106,7 @@ Dips_light_flavour_ttbar: tagger_name: "dips" rejection_class: "cjets" plot_settings: - binomialErrors: True + draw_errors: True xmin: 0.5 ymax: 1000000 figsize: [7, 6] # [width, hight] @@ -129,7 +129,7 @@ Dips_Comparison_flavour_ttbar: tagger_name: "dips" rejection_class: "cjets" plot_settings: - binomialErrors: True + draw_errors: True xmin: 0.5 ymax: 1000000 figsize: [9, 9] # [width, hight] diff --git a/umami/evaluation_tools/PlottingFunctions.py b/umami/evaluation_tools/PlottingFunctions.py index 0614b58d4fb650278f056a7a9a2a6683646e9872..16ff150f0e5b03137471cbf9aa7fb25db4ecfdaf 100644 --- a/umami/evaluation_tools/PlottingFunctions.py +++ b/umami/evaluation_tools/PlottingFunctions.py @@ -37,7 +37,7 @@ def plot_pt_dependence( disc_cut: float = None, fixed_eff_bin: bool = False, bin_edges: list = None, - wp_line: bool = False, + working_point_line: bool = False, grid: bool = False, colours: list = None, alpha: float = 0.8, @@ -75,7 +75,7 @@ def plot_pt_dependence( bin_edges : list, optional As the name says, the edges of the bins used. Will be set automatically, if None. By default None. - wp_line : bool, optional + working_point_line : bool, optional Print a WP line in the upper plot, by default False. grid : bool, optional Use a grid in the plots, by default False @@ -95,24 +95,6 @@ def plot_pt_dependence( ValueError If deprecated options are given. """ - if "colors" in kwargs: - colours = kwargs["colors"] - kwargs.pop("colors") - if "WP" in kwargs: - working_point = kwargs["WP"] - kwargs.pop("WP") - if "Disc_Cut_Value" in kwargs: - disc_cut = kwargs["Disc_Cut_Value"] - kwargs.pop("Disc_Cut_Value") - if "Fixed_WP_Bin" in kwargs: - fixed_eff_bin = kwargs["Fixed_WP_Bin"] - kwargs.pop("Fixed_WP_Bin") - if "Grid" in kwargs: - grid = kwargs["Grid"] - kwargs.pop("Grid") - if "WP_Line" in kwargs: - wp_line = kwargs["WP_Line"] - kwargs.pop("WP_Line") # Translate the kwargs to new naming scheme kwargs = translate_kwargs(kwargs) @@ -220,12 +202,12 @@ def plot_pt_dependence( if grid is True: plot_pt.set_grid() # Set WP Line - if wp_line is True: + if working_point_line is True: plot_pt.draw_hline(working_point) if main_class != flavour: logger.warning( - "You set `wp_line` to True but you are not looking at the singal " - "efficiency. It will probably not be visible on your plot." + "You set `working_point_line` to True but you are not looking at the" + " singal efficiency. It will probably not be visible on your plot." ) plot_pt.savefig(plot_name, transparent=trans) @@ -309,22 +291,6 @@ def plotROCRatio( # Check for number of provided Rocs n_rocs = len(df_results_list) - # maintain backwards compatibility - if "nTest" in kwargs: - if n_test is None: - n_test = kwargs["nTest"] - kwargs.pop("nTest") - if "colors" in kwargs: - if colours is None: - colours = kwargs["colors"] - # remnant of old implementation passing empty list as default - if kwargs["colors"] == []: - colours = None - kwargs.pop("colors") - if "binomialErrors" in kwargs: - if draw_errors is None: - draw_errors = kwargs["binomialErrors"] - kwargs.pop("binomialErrors") if "ratio_id" in kwargs: if reference_ratio is None and kwargs["ratio_id"] is not None: # if old keyword is used the syntax was also different @@ -409,6 +375,7 @@ def plotROCRatio( elif isinstance(n_test, (int, float)): n_test = [n_test] * len(df_results_list) + elif isinstance(n_test, list): if len(n_test) != len(df_results_list): raise ValueError( diff --git a/umami/plotting_umami.py b/umami/plotting_umami.py index ffed1149d32695198c37c82d163ac13f61aef048..f4714c2b8e4e79f6918bbeafbc4449e66f3f3a29 100644 --- a/umami/plotting_umami.py +++ b/umami/plotting_umami.py @@ -159,14 +159,18 @@ def plot_roc( rej_class_list = [] labels = [] linestyles = [] - colors = [] + colours = [] # Get the epoch which is to be evaluated eval_epoch = int(eval_params["epoch"]) - if "nTest" not in plot_config["plot_settings"].keys(): + if ( + "n_test" not in plot_config["plot_settings"] + or plot_config["plot_settings"]["n_test"] is None + ): n_test_provided = False - plot_config["plot_settings"]["nTest"] = [] + plot_config["plot_settings"]["n_test"] = [] + else: n_test_provided = True @@ -195,25 +199,29 @@ def plot_roc( if "linestyle" in model_config: linestyles.append(model_config["linestyle"]) - if "color" in model_config: - colors.append(model_config["color"]) + if "colour" in model_config: + colours.append(model_config["colour"]) - # nTest is only needed to calculate binomial errors + # n_test is only needed to calculate binomial errors if not n_test_provided and ( - "binomialErrors" in plot_config["plot_settings"] - and plot_config["plot_settings"]["binomialErrors"] + "draw_errors" in plot_config["plot_settings"] + and plot_config["plot_settings"]["draw_errors"] ): with h5py.File( eval_file_dir + f"/results-rej_per_eff-{eval_epoch}.h5", "r" ) as h5_file: - plot_config["plot_settings"]["nTest"].append(h5_file.attrs["N_test"]) + plot_config["plot_settings"]["n_test"].append(h5_file.attrs["N_test"]) + else: - plot_config["plot_settings"]["nTest"] = None + plot_config["plot_settings"]["n_test"] = None # Get the right ratio id for correct ratio calculation ratio_dict = {} ratio_id = [] + if len(colours) == 0: + colours = None + for i, which_a in enumerate(rej_class_list): if which_a not in ratio_dict: ratio_dict.update({which_a: i}) @@ -230,7 +238,7 @@ def plot_roc( plot_name=plot_name, ratio_id=ratio_id, linestyles=linestyles, - colors=colors, + colours=colours, **plot_config["plot_settings"], ) diff --git a/umami/tests/unit/evaluation_tools/test_PlottingFunctions.py b/umami/tests/unit/evaluation_tools/test_PlottingFunctions.py index 0660d11759a600b8ec3c0680366bd9abab8a176d..bb1d17037e5599c774b30242983417d150f12f14 100644 --- a/umami/tests/unit/evaluation_tools/test_PlottingFunctions.py +++ b/umami/tests/unit/evaluation_tools/test_PlottingFunctions.py @@ -144,7 +144,7 @@ class plot_score_TestCase(unittest.TestCase): labels=["RNNIP ttbar", "DIPS ttbar"], # plot_name=self.expected_plots_dir + "ROC_Test.png", plot_name=self.actual_plots_dir + "ROC_Test.png", - nTest=[100000, 100000], + n_test=[100_000, 100_000], working_points=[0.60, 0.70, 0.77, 0.85], main_class="bjets", atlas_second_tag=( @@ -179,7 +179,7 @@ class plot_score_TestCase(unittest.TestCase): labels=["RNNIP", "DIPS", "RNNIP", "DIPS"], # plot_name=self.expected_plots_dir + "ROC_Comparison_Test.png", plot_name=self.actual_plots_dir + "ROC_Comparison_Test.png", - nTest=[100000, 100000, 100000, 100000], + n_test=[100_000, 100_000, 100_000, 100_000], working_points=[0.60, 0.70, 0.77, 0.85], reference_ratio=[True, False, True, False], atlas_second_tag=( diff --git a/umami/tests/unit/train_tools/plots/PlotRejPerEpochComparison.png b/umami/tests/unit/train_tools/plots/PlotRejPerEpochComparison.png index 8b6e3ed48b9d6be4b6e777723f56f5a5f3f3059b..c9429035fa51725d71c39578b3991647e9fe57a2 100644 Binary files a/umami/tests/unit/train_tools/plots/PlotRejPerEpochComparison.png and b/umami/tests/unit/train_tools/plots/PlotRejPerEpochComparison.png differ diff --git a/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_cjets_rejection.png b/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_cjets_rejection.png index 52268937d3fd4f03d0a12a4088b0eff0393af525..c0f166287291570d9f19368d22520049315cc1e5 100644 Binary files a/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_cjets_rejection.png and b/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_cjets_rejection.png differ diff --git a/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_ujets_rejection.png b/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_ujets_rejection.png index fe9f7ad997b86afb069abbd33ef01a7cd017a66a..5a851f5f92a0d4b58559aa9b35b6b6ae2d7cad1a 100644 Binary files a/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_ujets_rejection.png and b/umami/tests/unit/train_tools/plots/PlotRejPerEpoch_ujets_rejection.png differ diff --git a/umami/train_tools/Plotting.py b/umami/train_tools/Plotting.py index fa12b66c8f26bcbf7dcf627c9f9ffb83c7b84467..1798c535005449431651fdd1ff180dbfc71fac76 100644 --- a/umami/train_tools/Plotting.py +++ b/umami/train_tools/Plotting.py @@ -159,7 +159,7 @@ def PlotDiscCutPerEpoch( target_beff: float = 0.77, frac: float = 0.018, plot_datatype: str = "pdf", - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plot the discriminant cut value for a specific working point over all epochs. @@ -218,7 +218,7 @@ def PlotDiscCutPerEpochUmami( val_files: dict = None, target_beff: float = 0.77, plot_datatype: str = "pdf", - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plot the discriminant cut value for a specific working point over all epochs. DIPS and Umami are both shown. @@ -282,11 +282,11 @@ def PlotRejPerEpochComparison( label_extension: str, rej_string: str, taggers_from_file: dict = None, - trained_taggers: list = None, + trained_taggers: dict = None, target_beff: float = 0.77, plot_datatype: str = "pdf", leg_fontsize: int = 10, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plotting the Rejections per Epoch for the trained tagger and the provided comparison taggers. @@ -365,30 +365,11 @@ def PlotRejPerEpochComparison( # Init a linestyle counter counter_models = 0 - # Plot rejection - lines = lines + axes[counter].plot( - df_results["epoch"], - df_results[f"{iter_class}_{rej_string}"], - linestyle=linestyle_list[counter], - color=f"C{counter_models}", - label=tagger_label, - ) - - # Set up the counter - counter_models += 1 - - # Set y label - rej_plot.set_ylabel( - ax_mpl=axes[counter], - label=f'{flav_cat[iter_class]["legend_label"]} rejection', - align_right=False, - ) - if comp_tagger_rej_dict is None: logger.info("No comparison tagger defined. Not plotting those!") else: - for _, comp_tagger in enumerate(comp_tagger_rej_dict): + for comp_tagger in comp_tagger_rej_dict: try: tmp_line = axes[counter].axhline( y=comp_tagger_rej_dict[comp_tagger][ @@ -420,7 +401,7 @@ def PlotRejPerEpochComparison( logger.debug("No local taggers defined. Not plotting those!") else: - for _, tt in enumerate(trained_taggers): + for tt in trained_taggers: try: # Get the needed rejection info from json tt_rej_dict = pd.read_json(trained_taggers[tt]["path"]) @@ -442,6 +423,25 @@ def PlotRejPerEpochComparison( # Set up the counter counter_models += 1 + # Plot rejection + lines = lines + axes[counter].plot( + df_results["epoch"], + df_results[f"{iter_class}_{rej_string}"], + linestyle=linestyle_list[counter], + color=f"C{counter_models}", + label=tagger_label, + ) + + # Set up the counter + counter_models += 1 + + # Set y label + rej_plot.set_ylabel( + ax_mpl=axes[counter], + label=f'{flav_cat[iter_class]["legend_label"]} rejection', + align_right=False, + ) + ax_left.xaxis.set_major_locator(MaxNLocator(integer=True)) # Create the two legends for rejection and model @@ -516,10 +516,10 @@ def PlotRejPerEpoch( label_extension: str, rej_string: str, taggers_from_file: dict = None, - trained_taggers: list = None, + trained_taggers: dict = None, target_beff: float = 0.77, plot_datatype: str = "pdf", - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plotting the Rejections per Epoch for the trained tagger and the provided comparison taggers in separate plots. One per rejection. @@ -581,18 +581,6 @@ def PlotRejPerEpoch( # Init a linestyle counter counter_models = 0 - # Plot rejection - rej_plot.axis_top.plot( - df_results["epoch"], - df_results[f"{iter_class}_{rej_string}"], - linestyle="-", - color=f"C{counter_models}", - label=tagger_label, - ) - - # Set up the counter - counter_models += 1 - if comp_tagger_rej_dict is None: logger.info("No comparison tagger defined. Not plotting those!") @@ -625,7 +613,7 @@ def PlotRejPerEpoch( logger.debug("No local taggers defined. Not plotting those!") else: - for _, tt in enumerate(trained_taggers): + for tt in trained_taggers: try: # Get the needed rejection info from json tt_rej_dict = pd.read_json(trained_taggers[tt]["path"]) @@ -647,6 +635,18 @@ def PlotRejPerEpoch( # Set up the counter counter_models += 1 + # Plot rejection + rej_plot.axis_top.plot( + df_results["epoch"], + df_results[f"{iter_class}_{rej_string}"], + linestyle="-", + color=f"C{counter_models}", + label=tagger_label, + ) + + # Set up the counter + counter_models += 1 + rej_plot.atlas_second_tag += ( f"\nWP={int(target_beff * 100):02d}% {label_extension} sample" ) @@ -663,7 +663,7 @@ def PlotLosses( plot_name: str, val_files: dict = None, plot_datatype: str = "pdf", - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plot the training loss and the validation losses per epoch. @@ -719,7 +719,7 @@ def PlotAccuracies( plot_name: str, val_files: dict = None, plot_datatype: str = "pdf", - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plot the training and validation accuracies per epoch. @@ -773,7 +773,7 @@ def PlotLossesUmami( plot_name: str, val_files: dict = None, plot_datatype: str = "pdf", - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plot the training loss and the validation losses per epoch for Umami model (with DIPS and Umami losses). @@ -844,7 +844,7 @@ def PlotAccuraciesUmami( plot_name: str, val_files: dict = None, plot_datatype: str = "pdf", - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Plot the training and validation accuracies per epoch for Umami model (with DIPS and Umami accuracies). @@ -1066,6 +1066,8 @@ def RunPerformanceCheck( label_extension=val_file_config["label"], rej_string=f"rej_{subtagger}_{val_file_identifier}", target_beff=working_point, + taggers_from_file=Val_settings["taggers_from_file"], + trained_taggers=Val_settings["trained_taggers"], **plot_args, ) @@ -1082,6 +1084,8 @@ def RunPerformanceCheck( label_extension=val_file_config["label"], rej_string=f"rej_{subtagger}_{val_file_identifier}", target_beff=working_point, + taggers_from_file=Val_settings["taggers_from_file"], + trained_taggers=Val_settings["trained_taggers"], **plot_args, ) @@ -1132,8 +1136,11 @@ def RunPerformanceCheck( label_extension=val_file_config["label"], rej_string=f"rej_{val_file_identifier}", target_beff=working_point, + taggers_from_file=Val_settings["taggers_from_file"], + trained_taggers=Val_settings["trained_taggers"], **plot_args, ) + for val_file_identifier, val_file_config in val_files.items(): # Plot rejections in one plot per rejection PlotRejPerEpoch( @@ -1147,6 +1154,8 @@ def RunPerformanceCheck( rej_string=f"rej_{val_file_identifier}", target_beff=working_point, tagger_label=Val_settings["tagger_label"], + taggers_from_file=Val_settings["taggers_from_file"], + trained_taggers=Val_settings["trained_taggers"], **plot_args, ) diff --git a/umami/train_tools/configs/default_train_config.yaml b/umami/train_tools/configs/default_train_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf5b594aa2b65dec1aefd4db5ec7d24c836ae568 --- /dev/null +++ b/umami/train_tools/configs/default_train_config.yaml @@ -0,0 +1,10 @@ +model_name: Null +test_files: Null +NN_structure: Null +Validation_metrics_settings: + trained_taggers: Null + taggers_from_file: Null + val_batch_size: Null +Eval_parameters_validation: + eval_batch_size: Null +tracks_name: Null diff --git a/umami/train_tools/configuration.py b/umami/train_tools/configuration.py index 693d5eb671084a661abb7d335de83ae083927eec..b4b7cf1853e28f77db6fb6b71ab967e28cea3906 100644 --- a/umami/train_tools/configuration.py +++ b/umami/train_tools/configuration.py @@ -1,4 +1,6 @@ """Configuration module for NN trainings.""" +import os + import pydash import yaml @@ -15,6 +17,7 @@ class Configuration: super().__init__() self.yaml_config = yaml_config self.config = {} + self.yaml_default_config = "configs/default_train_config.yaml" self.load_config_file() self.get_configuration() @@ -45,10 +48,27 @@ class Configuration: def load_config_file(self): """Load config file from disk.""" + self.yaml_default_config = os.path.join( + os.path.dirname(__file__), self.yaml_default_config + ) + with open(self.yaml_default_config, "r") as conf: + self.default_config = yaml.load(conf, Loader=yaml_loader) + logger.info(f"Using train config file {self.yaml_config}") with open(self.yaml_config, "r") as conf: self.config = yaml.load(conf, Loader=yaml_loader) + # Check if values in default config are defined in loaded config + # If not, set default values + for elem in self.default_config: + if elem not in self.config or self.config[elem] is None: + self.config[elem] = self.default_config[elem] + + if isinstance(self.default_config[elem], dict): + for item in self.default_config[elem]: + if item not in self.config[elem]: + self.config[elem][item] = self.default_config[elem][item] + def get_configuration(self): """Assigne configuration from file to class variables. @@ -98,12 +118,15 @@ class Configuration: if "evaluate_trained_model" in self.config: if self.config["evaluate_trained_model"] is True: iterate_list = config_train_items + bool_evaluate_trained_model = True elif self.config["evaluate_trained_model"] is False: iterate_list = config_evaluation_items + bool_evaluate_trained_model = False else: iterate_list = config_train_items + bool_evaluate_trained_model = True if "Plotting_settings" in self.config: raise KeyError( @@ -118,61 +141,63 @@ class Configuration: if item == "tracks_name": setattr(self, "tracks_key", f"X_{self.config[item]}_train") - elif item in ( - "Validation_metrics_settings", - "Eval_parameters_validation", - ): - batch_param = ( - "val_batch_size" - if item.startswith("Val") - else "eval_batch_size" - ) - + elif item == "Validation_metrics_settings": try: - if (self.config[item] is not None) and ( - batch_param not in self.config[item] - or self.config[item][batch_param] is None + if ( + self.config["Validation_metrics_settings"]["val_batch_size"] + is None + and self.config["Eval_parameters_validation"][ + "eval_batch_size" + ] + is None + ): + logger.warning( + "Neither eval_batch_size nor " + "val_batch_size was defined. Using " + "training batch_size for " + "validation/evaluation!" + ) + + self.config["Validation_metrics_settings"][ + "val_batch_size" + ] = int(self.config["NN_structure"]["batch_size"]) + self.config["Eval_parameters_validation"][ + "eval_batch_size" + ] = int(self.config["NN_structure"]["batch_size"]) + + elif ( + self.config["Validation_metrics_settings"]["val_batch_size"] + is None + ): + logger.warning( + "No val_batch_size defined. Using training batch size" + " for validation" + ) + + self.config["Validation_metrics_settings"][ + "val_batch_size" + ] = int(self.config["NN_structure"]["batch_size"]) + + elif ( + self.config["Eval_parameters_validation"]["eval_batch_size"] + is None ): - if batch_param == "eval_batch_size": - try: - self.config[item][batch_param] = int( - self.config["Validation_metrics_settings"][ - "val_batch_size" - ] - ) - - logger.warning( - "No eval_batch_size was defined. Using " - "val_batch_size for evaluation!" - ) - - except KeyError: - self.config[item][batch_param] = int( - self.config["NN_structure"]["batch_size"] - ) - - logger.warning( - "Neither eval_batch_size nor " - "val_batch_size was defined. Using " - "training batch_size for " - "validation/evaluation!" - ) - - else: - self.config[item][batch_param] = int( - self.config["NN_structure"]["batch_size"] - ) - - logger.warning( - "No val_batch_size was defined. Using " - "training batch_size for validation!" - ) + logger.warning( + "No eval_batch_size defined. Using validation batch" + " size for evaluation." + ) + + self.config["Eval_parameters_validation"][ + "eval_batch_size" + ] = int( + self.config["Validation_metrics_settings"][ + "val_batch_size" + ] + ) except KeyError as Error: - raise ValueError( - f"Neither batch_size in NN_structure nor {batch_param} " - f"in {item} was given!" - ) from Error + if bool_evaluate_trained_model: + raise ValueError("No batch size given!") from Error setattr(self, item, self.config[item])