Commit 030fb43a authored by Joschka Birk's avatar Joschka Birk
Browse files

Merge branch alfroch-moving-plot-files with refs/heads/master into refs/merge-requests/554/train

parents 573e06c0 dd699538
Pipeline #4066999 passed with stages
in 25 minutes and 51 seconds
......@@ -22,7 +22,7 @@
script:
- pip install darglint
- darglint --list-errors
- find . -name "*.py" ! -name *PlottingFunctions.py ! -name *Plotting.py ! -name *conf.py | xargs -n 1 -P 8 -t darglint
- find . -name "*.py" ! -name *conf.py | xargs -n 1 -P 8 -t darglint
.pylint_template: &pylint_template
stage: linting
......
......@@ -36,6 +36,7 @@
- helper_tools
- input_vars_tools
- metrics
- plotting_tools
- preprocessing
- tf_tools
- train_tools
......
......@@ -3,8 +3,9 @@
### Latest
- Adding classes to global config (light-flavour jets split by quark flavour/gluons, leptonic b-hadron decays) to define extended tagger output [!553](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/553)
- Moving Plotting Files in one folder [!554](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/554)
- Adding classes to global config (light-flavour jets split by quark flavour/gluons, leptonic b-hadron decays) to define extended tagger output [!553](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/553)
- Fixing issues with trained_taggers and taggers_from_file in plotting_epoch_performance.py [!549](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/549)
- Adding plotting API to Contour plots + Updating plotting_umami docs [!537](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/537)
- Adding unit test for prepare_model and minor bug fixes [!546](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/546)
......
......@@ -7,12 +7,3 @@ from umami.evaluation_tools.eval_tools import (
GetScoresProbsDict,
RecomputeScore,
)
from umami.evaluation_tools.PlottingFunctions import (
plot_confusion,
plot_prob,
plot_pt_dependence,
plot_score,
plotFractionContour,
plotROCRatio,
plotSaliency,
)
......@@ -10,7 +10,7 @@ from puma import Histogram, HistogramPlot
import umami.data_tools as udt
from umami.configuration import global_config, logger
from umami.plotting.utils import translate_binning
from umami.plotting_tools.utils import translate_binning
from umami.preprocessing_tools import GetVariableDict
......
......@@ -10,7 +10,7 @@ import yaml
import umami.input_vars_tools as uit
from umami.configuration import logger
from umami.plotting.utils import translate_kwargs
from umami.plotting_tools.utils import translate_kwargs
from umami.tools import yaml_loader
......
"""Plotting functions for umami"""
# TODO: move plotting code in umami/plotting directory and import here
# (e.g. input var plots, evaluation plots, ...)
......@@ -9,8 +9,8 @@ import tensorflow as tf
import umami.train_tools as utt
from umami.helper_tools import get_class_prob_var_names
from umami.plotting_tools import run_validation_check
from umami.preprocessing_tools import Configuration
from umami.train_tools import RunPerformanceCheck
def get_parser():
......@@ -180,7 +180,7 @@ def main(args, train_config, preprocess_config):
)
# Run the Performance check with the values from the dict and plot them
RunPerformanceCheck(
run_validation_check(
train_config=train_config,
tagger=tagger,
tagger_comp_vars={
......
"""Plotting functions for umami"""
# flake8: noqa
# pylint: skip-file
from umami.plotting_tools.eval_plotting_functions import (
plot_confusion_matrix,
plot_fraction_contour,
plot_prob,
plot_pt_dependence,
plot_roc,
plot_saliency,
plot_score,
)
from umami.plotting_tools.preprocessing_plotting_functions import (
plot_resampling_variables,
plot_variable,
preprocessing_plots,
)
from umami.plotting_tools.train_plotting_functions import (
get_comp_tagger_rej_dict,
plot_accuracies,
plot_accuracies_umami,
plot_disc_cut_per_epoch,
plot_disc_cut_per_epoch_umami,
plot_losses,
plot_losses_umami,
plot_rej_per_epoch,
plot_rej_per_epoch_comp,
run_validation_check,
)
......@@ -22,7 +22,7 @@ from puma import (
)
import umami.tools.PyATLASstyle.PyATLASstyle as pas
from umami.plotting.utils import translate_kwargs
from umami.plotting_tools.utils import translate_kwargs
def plot_pt_dependence(
......@@ -212,7 +212,7 @@ def plot_pt_dependence(
plot_pt.savefig(plot_name, transparent=trans)
def plotROCRatio(
def plot_roc(
df_results_list: list,
tagger_list: list,
rej_class_list: list,
......@@ -400,27 +400,27 @@ def plotROCRatio(
):
raise ValueError("Passed lists do not have same length.")
plot_roc = RocPlot(
roc_plot = RocPlot(
n_ratio_panels=n_ratio_panels,
ylabel=ylabel,
xlabel=f'{flav_cat[main_class]["legend_label"]} efficiency',
**kwargs,
)
plot_roc.set_ratio_class(
roc_plot.set_ratio_class(
ratio_panel=1,
rej_class=flav_list[0],
label=f'{flav_cat[flav_list[0]]["legend_label"]} ratio',
)
if n_ratio_panels > 1:
plot_roc.set_ratio_class(
roc_plot.set_ratio_class(
ratio_panel=2,
rej_class=flav_list[1],
label=f'{flav_cat[flav_list[1]]["legend_label"]} ratio',
)
if working_points is not None:
plot_roc.draw_vlines(
roc_plot.draw_vlines(
vlines_xvalues=working_points,
same_height=same_height_WP,
)
......@@ -455,25 +455,25 @@ def plotROCRatio(
colour=colour,
linestyle=linestyle,
)
plot_roc.add_roc(roc_curve, reference=ratio_ref)
roc_plot.add_roc(roc_curve, reference=ratio_ref)
plot_roc.set_leg_rej_labels(
roc_plot.set_leg_rej_labels(
flav_list[0],
label=f'{flav_cat[flav_list[0]]["legend_label"]} rejection',
)
if n_ratio_panels > 1:
plot_roc.set_leg_rej_labels(
roc_plot.set_leg_rej_labels(
flav_list[1],
label=f'{flav_cat[flav_list[1]]["legend_label"]} rejection',
)
# Draw and save the plot
plot_roc.draw(labelpad=labelpad)
plot_roc.savefig(plot_name)
roc_plot.draw(labelpad=labelpad)
roc_plot.savefig(plot_name)
def plotSaliency(
def plot_saliency(
maps_dict: dict,
plot_name: str,
title: str,
......@@ -865,7 +865,7 @@ def plot_prob(
prob_plot.savefig(plot_name, transparent=True)
def plot_confusion(
def plot_confusion_matrix(
df_results: dict,
tagger_name: str,
class_labels: list,
......@@ -934,7 +934,7 @@ def plot_confusion(
plt.close()
def plotFractionContour(
def plot_fraction_contour(
df_results_list: list,
tagger_list: list,
label_list: list,
......
"""Plotting functions for the preprocessing."""
import os
import h5py
import numpy as np
import pandas as pd
from puma import Histogram, HistogramPlot
from umami.configuration import global_config, logger
from umami.plotting_tools.utils import translate_kwargs
def plot_variable(
df,
labels: np.ndarray,
variable: str,
variable_index: int,
var_type: str,
class_labels: list,
output_dir: str,
fileformat: str = "pdf",
**kwargs,
) -> None:
"""
Plot a given variable.
Parameters
----------
df : pd.DataFrame or np.ndarray
DataFrame (for jets) or ndarray (for tracks) with
the jets/tracks inside.
labels : np.ndarray
One hot encoded array with the truth values.
variable : str
Name of the variable which is to be plotted.
variable_index : int
Index of the variable in the final training set. This
is used to identify the variables in the final training
set.
var_type : str
Type of the variable that is used. Either `jets` or
`tracks`.
class_labels : list
List with the flavours used (ORDER IMPORTANT).
output_dir : str
Directory where the plot is saved.
fileformat : str, optional
Fileformat of the plots, by default "pdf"
**kwargs : kwargs
kwargs from `plot_object`
Raises
------
TypeError
If the given variable type is not supported.
"""
# Translate the kwargs
kwargs = translate_kwargs(kwargs)
# Give a debug logger
logger.debug(f"Plotting variable {variable}...")
# Init the histogram plot object
histo_plot = HistogramPlot(**kwargs)
# Set the x-label
if histo_plot.xlabel is None:
histo_plot.xlabel = variable
# Loop over the flavours
for flav_counter, flavour in enumerate(class_labels):
# This is the case if a pandas Dataframe is given
try:
flavour_jets = df[variable][labels[:, flav_counter] == 1].values.flatten()
# This is the case when a numpy ndarray is given
except AttributeError:
flavour_jets = df[variable][labels[:, flav_counter] == 1].flatten()
# This is the case if the training set is already converted to X_train etc.
except IndexError as error:
if var_type.casefold() == "jets":
flavour_jets = df[:, variable_index][
labels[:, flav_counter] == 1
].flatten()
elif var_type.casefold() == "tracks":
flavour_jets = df[:, :, variable_index][
labels[:, flav_counter] == 1
].flatten()
else:
raise TypeError(
f"Variable type {var_type} not supported! Only jets and tracks!"
) from error
# Add the flavour to the histogram
histo_plot.add(
Histogram(
values=np.nan_to_num(flavour_jets),
flavour=flavour,
),
reference=False,
)
# Draw and save the plot
histo_plot.draw()
histo_plot.savefig(
plot_name=os.path.join(
output_dir,
f"{variable}.{fileformat}",
),
**kwargs,
)
def plot_resampling_variables(
concat_samples: dict,
var_positions: list,
variable_names: list,
sample_categories: list,
output_dir: str,
bins_dict: dict,
sample_id_position: int = 3,
fileformat: str = "pdf",
**kwargs,
) -> None:
"""
Plot the variables which are used for resampling before the resampling
starts.
Parameters
----------
concat_samples : dict
Dict with the format given in the Undersampling class by the class object
`concat_samples`.
var_positions : list
The position where the variables are stored in the sub-dict `jets`.
variable_names : list
The name of the 2 variables which will be plotted.
sample_categories : list
List with the names of the sample categories (e.g. ["ttbar", "zprime"]).
output_dir : str
Name of the output directory where the plots will be saved.
bins_dict : dict
Dict with the binning for the resampling variables. First key must be the
variable name with a tuple of 3 int which gives the lower limit, upper limit
and the number of bins to use.
sample_id_position : int, optional
Position in the numpy.ndarray of the concat_samples where the sample
id is stored. By default 3
fileformat : str, optional
Format of the plot file, by default "pdf".
**kwargs : kwargs
kwargs from `plot_object`
Raises
------
ValueError
If unsupported binning is provided.
"""
# Check if output directory exists
os.makedirs(
output_dir,
exist_ok=True,
)
# Defining two linestyles for the resampling variables
linestyles = ["-", "--"]
# Translate the kwargs to new naming scheme
kwargs = translate_kwargs(kwargs)
# Loop over the variables which are used for resampling
for var, varpos in zip(variable_names, var_positions):
if isinstance(bins_dict[var], int):
bins = bins_dict[var]
bins_range = None
elif isinstance(bins_dict[var], (list, tuple)) and len(bins_dict[var]) == 3:
bins = bins_dict[var][2]
bins_range = (
bins_dict[var]["bins_range"][0],
bins_dict[var]["bins_range"][1],
)
else:
raise ValueError(
"Provided binning for plot_resampling_variables is "
"neither a list with three entries nor an int!"
)
# Init a new histogram
histo_plot = HistogramPlot(
bins=bins,
bins_range=bins_range,
**kwargs,
)
# Set the x-label
if histo_plot.xlabel is None:
histo_plot.xlabel = f"{var}"
# Check if the variable is pT (which is in the files in MeV)
# and set the scale value to make it GeV in the plots
if var in ["pT", "pt_btagJes"] or var == global_config.pTvariable:
scale_val = 1e3
histo_plot.xlabel += " [GeV]"
else:
scale_val = 1
# Loop over the different flavours
for flavour in concat_samples:
# Loop over sample ids (ttbar and zprime for example)
for sample_id in np.unique(
concat_samples[flavour]["jets"][:, sample_id_position]
).astype("int"):
# Add the histogram for the flavour
histo_plot.add(
Histogram(
values=concat_samples[flavour]["jets"][:, varpos] / scale_val,
flavour=flavour,
label=sample_categories[sample_id]
if sample_categories
else None,
linestyle=linestyles[sample_id],
),
reference=False,
)
# Draw and save the plot
histo_plot.draw()
histo_plot.savefig(
plot_name=os.path.join(
output_dir,
f"{var}_before_resampling.{fileformat}",
),
**kwargs,
)
def preprocessing_plots(
sample: str,
var_dict: dict,
class_labels: list,
plots_dir: str,
use_random_jets: bool = False,
jet_collection: str = "jets",
track_collection_list: list = None,
nJets: int = 3e4,
seed: int = 42,
**kwargs,
):
"""
Plotting the different track and jet variables after
the preprocessing steps.
Parameters
----------
sample : str
Path to output file of the preprocessing step.
var_dict : dict
Loaded variable dict.
class_labels : list
List with the flavours used (ORDER IMPORTANT).
plots_dir : str
Path to folder where the plots are saved.
use_random_jets : bool, optional
Decide if random jets are drawn from the sample to
ensure correct mixing. Otherwise the first nJets are
used for plotting, by default False
jet_collection : str, optional
Name of the jet collection, by default "jets"
track_collection_list : list, optional
List of str of the track collections which are to be
plotted, by default None
nJets : int, optional
Number of jets to plot, by default int(3e4)
seed : int, optional
Random seed for the selection of the jets, by default 42
**kwargs : kwargs
kwargs from `plot_object`
Raises
------
TypeError
If the provided track collection list is neither a string or
a list.
"""
# Get max number of available jets
with h5py.File(sample, "r") as f:
try:
nJets_infile = len(f["/jets"])
except KeyError:
nJets_infile = len(f["/X_train"])
# Check if random values are used or not
if use_random_jets is True:
# Get a random generator with specified seed
rng = np.random.default_rng(seed=seed)
# Mix the chunks
selected_indicies = sorted(
rng.choice(
np.arange(nJets_infile, dtype=int),
int(nJets),
replace=False,
)
)
else:
# if number of requested jets is larger that what is available,
# plot all available jets.
if nJets > nJets_infile:
logger.warning(
f"You requested {nJets} jets,"
f"but there are only {nJets_infile} jets in the input!"
)
selected_indicies = np.arange(min(nJets, nJets_infile), dtype=int)
# Check if track collection list is valid
if isinstance(track_collection_list, str):
track_collection_list = [track_collection_list]
elif track_collection_list is None:
track_collection_list = []
elif not isinstance(track_collection_list, list):
raise TypeError(
"Track Collection list for variable plotting must be a list or a string!"
)
# Open the file which is to be plotted
with h5py.File(sample, "r") as infile:
# Get the labels of the jets to plot
try:
labels = infile["/labels"][selected_indicies]
except KeyError:
labels = infile["Y_train"][selected_indicies]
# Check if jet collection is given
if jet_collection:
# Check if output directory exists
os.makedirs(
plots_dir,
exist_ok=True,
)
# Extract the correct variables
variables_header = var_dict["train_variables"]
jet_var_list = [i for j in variables_header for i in variables_header[j]]
# Get the jets from file
try:
jets = pd.DataFrame(
infile["/jets"].fields(jet_var_list)[selected_indicies]
)
except KeyError:
jets = np.asarray(infile["X_train"][selected_indicies])
# Loop over variables
for jet_var_counter, jet_var in enumerate(jet_var_list):