Commit ae88caf8 authored by Alexander Froch's avatar Alexander Froch
Browse files

Merge branch birk-transition-to-puma with refs/heads/master into refs/merge-requests/548/train

parents 26ca0dd5 a32d20fc
Pipeline #3999416 passed with stages
in 13 minutes and 9 seconds
......@@ -2,10 +2,9 @@
These .md files can then be used in the documentation."""
import re
import puma
from npdoc_to_md import render_md_from_obj_docstring # pylint: disable=import-error
import umami.plotting
def generate_parameters_table(
obj: object,
......@@ -79,15 +78,22 @@ def main():
"""Main function which is called when the script is executed"""
# define here the objects of which you want the parameters as markdown table
objects_to_render = {
"umami.plotting.plot_object": {
"obj": umami.plotting.plot_object,
"filename": "docstring_input_var_plots_umami.plotting.plot_object.md",
"exclude": ["logy", "plotting_done", "n_ratio_panels"],
"puma.PlotObject": {
"obj": puma.PlotObject,
"filename": "docstring_puma_PlotObject.md",
"exclude": ["logy", "plotting_done"],
# Excluded because:
# logy -> has different default (False) in Histogram plot
# plotting_done -> attribute that should not be modified by the user
},
"umami.plotting.histogram_plot": {
"obj": umami.plotting.histogram_plot.__init__,
"filename": "docstring_input_var_plots_umami.plotting.histogram_plot.md",
"puma.HistogramPlot": {
"obj": puma.HistogramPlot.__init__,
"filename": "docstring_puma_HistogramPlot.md",
"exclude": ["bins", "bins_range", "**kwargs"],
# Excluded because:
# bins -> are handled differently (defined in the input var plot config)
# bins_range -> same here
# **kwargs -> we specifically put the **kwargs from PlotObject in the docs
},
}
for name, config in objects_to_render.items():
......
......@@ -8,7 +8,7 @@
- Adding unit tests for tf generators[!542](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/542)
- Fix epoch bug in continue_training[!543](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/543)
- Updating tensorflow to version `2.9.0` and pytorch to `1.11.0-cuda11.3-cudnn8-runtime` [!547](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/547)
- Removing plotting API code and switch to puma [!540](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/540)
- Removing plotting API code and switch to puma [!540](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/540) [!548](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/548)
- Fix epoch bug in continue_training[!543](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/543)
- Remove IPxD from default configs [!544](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/544)
......
......@@ -105,11 +105,15 @@ You can use the following parameters. Note that some parameters are not supporte
| `n_leading` | Track variables | `list` | Optional | `list` of the x leading tracks. If `None`, all tracks will be plotted. If `0` the leading tracks sorted after `sorting variable` will be plotted. You can add like `None`, `0` and `1` for example and it will plot all 3 of them, each in their own folders with according labeling. This must be a `list`! Even if there is only one option given. |
| `track_origins` | Track variables and n_tracks plot | `list` | Optional | `list` that gives the desired track origins when plotting. |
All remaining plot settings are parameters which are handed to the plotting API,
more specifically the `histogram_plot` class.
Therefore, all parameters supported by the `histogram_plot` class can be specified there.
All remaining plot settings are parameters which are handed to `puma` (Plotting
UMami API) more specifically the `HistogramPlot` class.
Therefore, all parameters supported by the `HistogramPlot` class can be specified there.
### List of plotting API parameters
[`puma` documentation](https://umami-hep.github.io/puma/)
§§§docs/ci_assets/docstring_input_var_plots_umami.plotting.histogram_plot.md§§§
§§§docs/ci_assets/docstring_input_var_plots_umami.plotting.plot_object.md:3:§§§
### List of `puma` parameters
§§§docs/ci_assets/docstring_puma_HistogramPlot.md§§§
<!-- in the docstring for the PlotObject class, start at line 3, since we don't want
the header to be included (column names of the md-table -->
§§§docs/ci_assets/docstring_puma_PlotObject.md:3:§§§
......@@ -14,7 +14,7 @@ import numpy as np
import pandas as pd
from umami.metrics import calc_rej
from puma.metrics import calc_rej
```
???+ example "Reading `.h5` file"
......
......@@ -11,16 +11,9 @@ import numpy as np
from mlxtend.evaluate import confusion_matrix
from mlxtend.plotting import plot_confusion_matrix as mlxtend_plot_cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from puma import Histogram, HistogramPlot, Roc, RocPlot, VarVsEff, VarVsEffPlot
import umami.tools.PyATLASstyle.PyATLASstyle as pas
from umami.plotting import (
histogram,
histogram_plot,
roc,
roc_plot,
var_vs_eff,
var_vs_eff_plot,
)
from umami.plotting.utils import translate_kwargs
from umami.tools import applyATLASstyle
......@@ -88,7 +81,7 @@ def plot_pt_dependence(
linewidth : float, optional
Define the linewidth of the plotted lines, by default 1.6
**kwargs : kwargs
kwargs for `var_vs_eff_plot` function
kwargs for `VarVsEffPlot` function
Raises
------
......@@ -168,7 +161,7 @@ def plot_pt_dependence(
mode = "sig_eff"
y_label = f'{flav_cat[flavour]["legend_label"]} efficiency'
plot_pt = var_vs_eff_plot(
plot_pt = VarVsEffPlot(
mode=mode,
ylabel=y_label,
n_ratio_panels=1,
......@@ -198,7 +191,7 @@ def plot_pt_dependence(
)
disc = df_results[f"disc_{tagger}"]
plot_pt.add(
var_vs_eff(
VarVsEff(
x_var_sig=jetPts[is_signal],
disc_sig=disc[is_signal],
x_var_bkg=jetPts[is_bkg] if mode == "bkg_rej" else None,
......@@ -296,7 +289,7 @@ def plotROCRatio(
List of bools indicating which roc used as reference for ratio calculation,
by default None
**kwargs : kwargs
kwargs passed to roc_plot
kwargs passed to RocPlot
Raises
------
......@@ -306,7 +299,7 @@ def plotROCRatio(
if lists don't have the same length
"""
# Check for number of provided rocs
# Check for number of provided Rocs
n_rocs = len(df_results_list)
# maintain backwards compatibility
......@@ -433,7 +426,7 @@ def plotROCRatio(
):
raise ValueError("Passed lists do not have same length.")
plot_roc = roc_plot(
plot_roc = RocPlot(
n_ratio_panels=n_ratio_panels,
ylabel=ylabel,
xlabel=f'{flav_cat[main_class]["legend_label"]} efficiency',
......@@ -478,7 +471,7 @@ def plotROCRatio(
n_test,
reference_ratio,
):
roc_curve = roc(
roc_curve = Roc(
df_results[df_eff_key],
df_results[f"{tagger}_{rej_class}_rej"],
n_test=nte,
......@@ -724,7 +717,7 @@ def plot_score(
Decide, if all working points lines have the same height or
not, by default True
**kwargs : kwargs
kwargs for `histogram_plot` function
kwargs for `HistogramPlot` function
"""
# Set number of ratio panels if not specified
......@@ -744,8 +737,8 @@ def plot_score(
# Get index dict
index_dict = {f"{flavour}": i for i, flavour in enumerate(class_labels_list[0])}
# Init the histogram plot object
score_plot = histogram_plot(**kwargs)
# Init the Histogram plot object
score_plot = HistogramPlot(**kwargs)
# Set the xlabel
if score_plot.xlabel is None:
......@@ -790,7 +783,7 @@ def plot_score(
for iter_flavour in class_labels:
score_plot.add(
histogram(
Histogram(
values=df_results.query(f"labels=={index_dict[iter_flavour]}")[
f"disc_{tagger}"
],
......@@ -835,7 +828,7 @@ def plot_prob(
plot_name : str
Path, Name and format of the resulting plot file.
**kwargs : kwargs
kwargs for `var_vs_eff_plot` function
kwargs for `VarVsEffPlot` function
"""
# Set number of ratio panels if not specified
......@@ -856,7 +849,7 @@ def plot_prob(
index_dict = {f"{iter_flav}": i for i, iter_flav in enumerate(class_labels_list[0])}
# Init the histogram plot object
prob_plot = histogram_plot(**kwargs)
prob_plot = HistogramPlot(**kwargs)
# Set the xlabel
if prob_plot.xlabel is None:
......@@ -882,7 +875,7 @@ def plot_prob(
for iter_flavour in class_labels:
prob_plot.add(
histogram(
Histogram(
values=df_results.query(f"labels=={index_dict[iter_flavour]}")[
f'{tagger}_{flav_cat[flavour]["prob_var_name"]}'
],
......
......@@ -6,10 +6,10 @@ import os
import numpy as np
from pandas import DataFrame
from puma import Histogram, HistogramPlot
import umami.data_tools as udt
from umami.configuration import global_config, logger
from umami.plotting import histogram, histogram_plot
from umami.plotting.utils import translate_binning
from umami.preprocessing_tools import GetVariableDict
......@@ -93,7 +93,7 @@ def plot_n_tracks_per_jet(
Track set that is to be used for plotting, by default "All"
**kwargs: dict
Keyword arguments passed to the plot. You can use all arguments that are
supported by the `histogram_plot` class in the plotting API.
supported by the `HistogramPlot` class in the plotting API.
"""
kwargs = check_kwargs_for_ylabel_and_n_ratio_panel(
......@@ -133,7 +133,7 @@ def plot_n_tracks_per_jet(
logger.info(f"Track origin: {track_origin}\n")
# Initialise plot
n_tracks_plot = histogram_plot(**kwargs)
n_tracks_plot = HistogramPlot(**kwargs)
# Set xlabel
n_tracks_plot.xlabel = (
"Number of tracks per jet"
......@@ -166,7 +166,7 @@ def plot_n_tracks_per_jet(
n_tracks_means[label].update({flavour: n_tracks_flavour.mean()})
n_tracks_plot.add(
histogram(
Histogram(
values=n_tracks_flavour,
flavour=flavour,
label=label,
......@@ -234,7 +234,7 @@ def plot_input_vars_trks(
Track set that is to be used for plotting, by default "All"
**kwargs: dict
Keyword arguments passed to the plot. You can use all arguments that are
supported by the `histogram_plot` class in the plotting API.
supported by the `HistogramPlot` class in the plotting API.
"""
......@@ -352,7 +352,7 @@ def plot_input_vars_trks(
logger.info(f"Plotting {var}...")
# Initialise plot for this variable
var_plot = histogram_plot(bins=bins_dict[var], **kwargs)
var_plot = HistogramPlot(bins=bins_dict[var], **kwargs)
if n_lead is None:
var_plot.xlabel = (
......@@ -410,7 +410,7 @@ def plot_input_vars_trks(
# Add histogram to plot
var_plot.add(
histogram(
Histogram(
values=track_values,
flavour=flavour,
label=label,
......@@ -469,7 +469,7 @@ def plot_input_vars_jets(
Option to make the background of the plot transparent, by default True
**kwargs: dict
Keyword arguments passed to the plot. You can use all arguments that are
supported by the `histogram_plot` class in the plotting API.
supported by the `HistogramPlot` class in the plotting API.
"""
kwargs = check_kwargs_for_ylabel_and_n_ratio_panel(
......@@ -523,7 +523,7 @@ def plot_input_vars_jets(
if var in bins_dict:
# Initialise plot for this variable
var_plot = histogram_plot(bins=bins_dict[var], xlabel=var, **kwargs)
var_plot = HistogramPlot(bins=bins_dict[var], xlabel=var, **kwargs)
# setting range based on value from config file
if special_param_jets is not None and var in special_param_jets:
if (
......@@ -553,7 +553,7 @@ def plot_input_vars_jets(
# Add histogram to plot
var_plot.add(
histogram(
Histogram(
values=jets_flavour,
flavour=flavour,
label=label,
......
......@@ -2,11 +2,8 @@
# pylint: skip-file
from umami.metrics.metrics import (
calc_disc_values,
calc_eff,
calc_rej,
discriminant_output_shape,
get_gradients,
get_rejection,
get_score,
)
from umami.metrics.tools import eff_err, rej_err
......@@ -9,7 +9,6 @@ import copy
import numpy as np
from umami.helper_tools import save_divide
from umami.tools import check_main_class_input
# Try to import keras from tensorflow
......@@ -545,98 +544,3 @@ def get_rejection(
) from error
return rej_dict, cutvalue
def calc_eff(
sig_disc: np.ndarray,
bkg_disc: np.ndarray,
target_eff,
return_cuts: bool = False,
):
"""Calculate efficiency
Parameters
----------
sig_disc : np.ndarray
signal discriminant
bkg_disc : np.ndarray
background discriminant
target_eff : float or list
WP which is used for discriminant calculation
return_cuts : bool
Specifies if the cut values corresponding to the provided WPs are returned.
If target_eff is a float, only one cut value will be returned. If target_eff
is an array, target_eff is an array as well.
Returns
-------
float or np.ndarray
efficiency
if target_eff is a float, a float is returned if it's a list a np.ndarray
float or np.ndarray
cutvalue if return_cuts is True
if target_eff is a float, a float is returned if it's a list a np.ndarray
"""
# TODO: with python 3.10 using type union operator
# float | np.ndarray for both target_eff and the returned values
if isinstance(target_eff, float):
cutvalue = np.percentile(sig_disc, 100.0 * (1.0 - target_eff))
eff = save_divide(len(bkg_disc[bkg_disc > cutvalue]), len(bkg_disc), 0)
if return_cuts:
return eff, cutvalue
return eff
eff = np.zeros(len(target_eff))
cutvalue = np.zeros(len(target_eff))
for i, t_eff in enumerate(target_eff):
cutvalue[i] = np.percentile(sig_disc, 100.0 * (1.0 - t_eff))
eff[i] = save_divide(len(bkg_disc[bkg_disc > cutvalue[i]]), len(bkg_disc), 0)
if return_cuts:
return eff, cutvalue
return eff
def calc_rej(
sig_disc: np.ndarray,
bkg_disc: np.ndarray,
target_eff,
return_cuts: bool = False,
):
"""Calculate efficiency
Parameters
----------
sig_disc : np.ndarray
signal discriminant
bkg_disc : np.ndarray
background discriminant
target_eff : float or list
WP which is used for discriminant calculation
return_cuts : bool
Specifies if the cut values corresponding to the provided WPs are returned.
If target_eff is a float, only one cut value will be returned. If target_eff
is an array, target_eff is an array as well.
Returns
-------
float or np.ndarray
rejection
if target_eff is a float, a float is returned if it's a list a np.ndarray
float or np.ndarray
cutvalue if return_cuts is True
if target_eff is a float, a float is returned if it's a list a np.ndarray
"""
# TODO: with python 3.10 using type union operator
# float | np.ndarray for both target_eff and the returned values
eff = calc_eff(
sig_disc=sig_disc,
bkg_disc=bkg_disc,
target_eff=target_eff,
return_cuts=return_cuts,
)
rej = save_divide(1, eff[0] if return_cuts else eff, np.inf)
if return_cuts:
return rej, eff[1]
return rej
"""Tools for metrics module."""
import numpy as np
from umami.configuration import logger
def eff_err(
arr: np.ndarray,
n_counts: int,
suppress_zero_divison_error: bool = False,
norm: bool = False,
) -> np.ndarray:
"""Calculate statistical efficiency uncertainty.
Parameters
----------
arr : numpy.array
efficiency values
n_counts : int
number of used statistics to calculate efficiency
suppress_zero_divison_error : bool
not raising Error for zero division
norm : bool, optional
if True, normed (relative) error is being calculated, by default False
Returns
-------
numpy.array
efficiency uncertainties
Raises
------
ValueError
if n_counts <=0
Notes
-----
This method uses binomial errors as described in section 2.2 of
https://inspirehep.net/files/57287ac8e45a976ab423f3dd456af694
"""
logger.debug("Calculating efficiency error.")
logger.debug(f"arr: {arr}")
logger.debug(f"n_counts: {n_counts}")
logger.debug(f"suppress_zero_divison_error: {suppress_zero_divison_error}")
logger.debug(f"norm: {norm}")
# TODO: suppress_zero_divison_error should not be necessary, but functions calling
# eff_err seem to need this functionality - should be deprecated though.
if np.any(n_counts <= 0) and not suppress_zero_divison_error:
raise ValueError(
f"You passed as argument `N` {n_counts} but it has to be larger 0."
)
if norm:
return np.sqrt(arr * (1 - arr) / n_counts) / arr
return np.sqrt(arr * (1 - arr) / n_counts)
def rej_err(
arr: np.ndarray,
n_counts: int,
norm: bool = False,
) -> np.ndarray:
"""Calculate the rejection uncertainties.
Parameters
----------
arr : numpy.array
rejection values
n_counts : int
number of used statistics to calculate rejection
norm : bool, optional
if True, normed (relative) error is being calculated, by default False
Returns
-------
numpy.array
rejection uncertainties
Raises
------
ValueError
if n_counts <=0
ValueError
if any rejection value is 0
Notes
-----
special case of `eff_err()`
"""
logger.debug("Calculating rejection error.")
logger.debug(f"arr: {arr}")
logger.debug(f"n_counts: {n_counts}")
logger.debug(f"norm: {norm}")
if np.any(n_counts <= 0):
raise ValueError(
f"You passed as argument `n_counts` {n_counts} but it has to be larger 0."
)
if np.any(arr == 0):
raise ValueError("One rejection value is 0, cannot calculate error.")
if norm:
return np.power(arr, 2) * eff_err(1 / arr, n_counts) / arr
return np.power(arr, 2) * eff_err(1 / arr, n_counts)
"""Plotting module."""
# flake8: noqa
# pylint: skip-file
"""Plotting functions for umami"""
# This implementation is just temporary to get puma das a replacement for the plotting
# API within umami
# later everything needs to be changed in umami to the new naming
from puma.histogram import Histogram as histogram
from puma.histogram import HistogramPlot as histogram_plot
from puma.plot_base import PlotBase as plot_base
from puma.plot_base import PlotLineObject as plot_line_object
from puma.plot_base import PlotObject as plot_object
from puma.roc import Roc as roc
from puma.roc import RocPlot as roc_plot
from puma.var_vs_eff import VarVsEff as var_vs_eff
from puma.var_vs_eff import VarVsEffPlot as var_vs_eff_plot
# from umami.plotting.histogram import histogram, histogram_plot
# from umami.plotting.plot_base import plot_base, plot_line_object, plot_object
# from umami.plotting.roc import roc, roc_plot
# from umami.plotting.var_vs_eff import var_vs_eff, var_vs_eff_plot
# TODO: move plotting code in umami/plotting directory and import here
# (e.g. input var plots, evaluation plots, ...)
"""Helper functions for the plotting API"""
"""Helper functions for plotting"""
import numpy as np
import pandas as pd
from scipy.special import softmax
from umami.configuration import logger # isort:skip
......@@ -119,185 +117,3 @@ def translate_binning(
raise ValueError(f"Type {type(binning)} is not supported!")
return bins
def set_xaxis_ticklabels_invisible(ax):
"""Helper function to set the ticklabels of the xaxis invisible
Parameters
----------
ax : matplotlib.axes.Axes
Axis you want to modify
"""
for label in ax.get_xticklabels():
label.set_visible(False)
def get_good_pie_colours(colour_scheme=None):
"""Helper function to get good colours for a pie chart. You can
choose between a specific colour scheme or use the default colours
for a pie chart
Parameters
----------
colour_scheme : string, optional
colour scheme for the pie chart. Can be None to use default colours
or blue, red, green or yellow to use a specific colour scheme
Returns
-------
list
returns a list of colours in the specified colour scheme
Raises
------
KeyError
If colour_scheme is not in ["blue", "red", "green", "yellow", None]
"""
# TODO change in python 3.10 -> case syntax
if colour_scheme is None:
return [
"#1F77B4",
"#FF7F0E",
"#2CA02C",
"#D62728",
"#9467BD",
"#8C564B",
"#E377C2",