Commit b68ba2a7 authored by Joschka Birk's avatar Joschka Birk
Browse files

Merge branch alfroch-add-preprocessing-plots with refs/heads/master into...

Merge branch alfroch-add-preprocessing-plots with refs/heads/master into refs/merge-requests/440/train
parents 67b5fcfc 4c77aea8
Pipeline #3636252 passed with stages
in 25 minutes and 56 seconds
......@@ -307,6 +307,10 @@ For an explanation of the resampling function specific `options`, have a look in
# this stores the indices per sample into an intermediate file
intermediate_index_file: *intermediate_index_file
# How many jets you want to use for the plotting of the results
# Give null (the yaml None) if you don't want to plot them
njets_to_plot: 3e4
```
| Setting | Type | Explanation |
......@@ -316,6 +320,7 @@ For an explanation of the resampling function specific `options`, have a look in
| `save_tracks` | `bool` | Define if tracks are processed or not. These are not needed to train DL1r/DL1d |
| `tracks_names` | `list` of `str` | Name of the tracks (in the .h5 files coming from the dumper) which are processed. Multiple tracks datasets can be preprocessed simultaneously when two `str` are given in the list. |
| `intermediate_index_file` | `str` | For the resampling, the indicies of the jets to use are saved in an intermediate indicies `.h5` file. You can define a name and path in the [Preprocessing-parameters.yaml](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/Preprocessing-parameters.yaml). |
| `njets_to_plot` | `int` | Number of jets which are used for plotting the variables of the jets/tracks after each preprocessing step (resampling, scaling, shuffling/writing). If `null` is given, the plotting is skipped. |
**Note**: `nJets` are the number of jets you want to have in your final training file for the `count` and `weighting` method. For the `pdf` method, this is the number of jets per flavour in the training file!
......
......@@ -266,6 +266,10 @@ sampling:
# If you want to attach weights to the final files
bool_attach_sample_weights: False
# How many jets you want to use for the plotting of the results
# Give null (the yaml None) if you don't want to plot them
njets_to_plot: 3e4
# Name of the output file from the preprocessing
outfile_name: *outfile_name
plot_name: PFlow_ext-hybrid
......
......@@ -258,6 +258,10 @@ sampling: &sampling
# If you want to attach weights to the final files
bool_attach_sample_weights: False
# How many jets you want to use for the plotting of the results
# Give null (the yaml None) if you don't want to plot them
njets_to_plot: 3e4
# Name of the output file from the preprocessing
outfile_name: *outfile_name
plot_name: PFlow_ext-hybrid
......
......@@ -9,7 +9,7 @@ import pandas as pd
from umami.configuration import logger
from .utils import GetVariableDict
from .utils import GetVariableDict, preprocessing_plots
def Gen_default_dict(scale_dict: dict) -> dict:
......@@ -163,7 +163,7 @@ class Scaling:
and can apply it.
"""
def __init__(self, config: object, compression: str = "gzip") -> None:
def __init__(self, config: object) -> None:
"""
Init the needed configs and variables
......@@ -171,15 +171,13 @@ class Scaling:
----------
config : object
Loaded config file for the preprocessing.
compression : str, optional
Type of compression which should be used., by default "gzip"
"""
self.config = config
self.scale_dict_path = config.dict_file
self.bool_use_tracks = config.sampling["options"]["save_tracks"]
self.tracks_names = self.config.sampling["options"]["tracks_names"]
self.compression = compression
self.compression = self.config.compression
logger.info(f"Using variable dict at {config.var_file}")
self.variable_config = GetVariableDict(config.var_file)
......@@ -1087,3 +1085,24 @@ class Scaling:
break
chunk_counter += 1
# Plot the variables from the output file of the resampling process
if (
"njets_to_plot" in self.config.sampling["options"]
and self.config.sampling["options"]["njets_to_plot"]
):
preprocessing_plots(
sample=self.config.GetFileName(option="resampled_scaled"),
var_dict=self.variable_config,
class_labels=self.config.sampling["class_labels"],
plots_dir=os.path.join(
self.config.config["parameters"]["file_path"],
"plots/scaling/",
),
track_collection_list=self.config.sampling["options"]["tracks_names"]
if "tracks_names" in self.config.sampling["options"]
and "save_tracks" in self.config.sampling["options"]
and self.config.sampling["options"]["save_tracks"] is True
else None,
nJets=self.config.sampling["options"]["njets_to_plot"],
)
"""Module handling training file writing to disk."""
import json
import os
import pickle
import h5py
......@@ -8,7 +9,7 @@ from numpy.lib.recfunctions import repack_fields, structured_to_unstructured
from scipy.stats import binned_statistic_2d
from umami.configuration import logger
from umami.preprocessing_tools import GetVariableDict
from umami.preprocessing_tools import GetVariableDict, preprocessing_plots
class TrainSampleWriter:
......@@ -357,6 +358,27 @@ class TrainSampleWriter:
chunk_counter += 1
jet_idx = jet_idx_end
# Plot the variables from the output file of the resampling process
if (
"njets_to_plot" in self.config.sampling["options"]
and self.config.sampling["options"]["njets_to_plot"]
):
preprocessing_plots(
sample=self.config.GetFileName(option="resampled_scaled_shuffled"),
var_dict=self.variable_config,
class_labels=self.config.sampling["class_labels"],
plots_dir=os.path.join(
self.config.config["parameters"]["file_path"],
"plots/resampling_scaled_shuffled/",
),
track_collection_list=self.config.sampling["options"]["tracks_names"]
if "tracks_names" in self.config.sampling["options"]
and "save_tracks" in self.config.sampling["options"]
and self.config.sampling["options"]["save_tracks"] is True
else None,
nJets=self.config.sampling["options"]["njets_to_plot"],
)
def calculateWeights(
self,
weights_dict: dict,
......
......@@ -40,5 +40,7 @@ from umami.preprocessing_tools.utils import (
GetVariableDict,
ResamplingPlots,
generate_process_tag,
plot_variable,
preprocessing_plots,
)
from umami.preprocessing_tools.Writing_Train_File import TrainSampleWriter
......@@ -10,7 +10,12 @@ from umami.preprocessing_tools.resampling.resampling_base import (
CorrectFractions,
ResamplingTools,
)
from umami.preprocessing_tools.utils import ResamplingPlots, generate_process_tag
from umami.preprocessing_tools.utils import (
GetVariableDict,
ResamplingPlots,
generate_process_tag,
preprocessing_plots,
)
class UnderSampling(ResamplingTools):
......@@ -226,6 +231,24 @@ class UnderSampling(ResamplingTools):
# Write file to disk
self.WriteFile(self.indices_to_keep)
# Plot the variables from the output file of the resampling process
if "njets_to_plot" in self.options and self.options["njets_to_plot"]:
preprocessing_plots(
sample=self.config.GetFileName(option="resampled"),
var_dict=GetVariableDict(self.config.var_file),
class_labels=self.config.sampling["class_labels"],
plots_dir=os.path.join(
self.resampled_path,
"plots/resampling/",
),
track_collection_list=self.options["tracks_names"]
if "tracks_names" in self.options
and "save_tracks" in self.options
and self.options["save_tracks"] is True
else None,
nJets=self.options["njets_to_plot"],
)
class ProbabilityRatioUnderSampling(UnderSampling):
"""
......
......@@ -20,7 +20,12 @@ from umami.preprocessing_tools.resampling.resampling_base import (
SamplingGenerator,
read_dataframe_repetition,
)
from umami.preprocessing_tools.utils import ResamplingPlots, generate_process_tag
from umami.preprocessing_tools.utils import (
GetVariableDict,
ResamplingPlots,
generate_process_tag,
preprocessing_plots,
)
class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
......@@ -1274,7 +1279,7 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
f.create_dataset(
"jets",
data=selected_indices,
compression="gzip",
compression=self.config.compression,
chunks=True,
maxshape=(None,),
)
......@@ -1399,14 +1404,14 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
out_file.create_dataset(
"jets",
data=jets,
compression="gzip",
compression=self.config.compression,
chunks=True,
maxshape=(None,),
)
out_file.create_dataset(
"labels",
data=labels,
compression="gzip",
compression=self.config.compression,
chunks=True,
maxshape=(None, labels.shape[1]),
)
......@@ -1415,7 +1420,7 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
out_file.create_dataset(
tracks_name,
data=tracks[i],
compression="gzip",
compression=self.config.compression,
chunks=True,
maxshape=(None, tracks[i].shape[1]),
)
......@@ -1597,14 +1602,14 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
out_file.create_dataset(
"jets",
data=jets,
compression="gzip",
compression=self.config.compression,
chunks=True,
maxshape=(None,),
)
out_file.create_dataset(
"labels",
data=labels,
compression="gzip",
compression=self.config.compression,
chunks=True,
maxshape=(None, labels.shape[1]),
)
......@@ -1613,7 +1618,7 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
out_file.create_dataset(
tracks_name,
data=tracks[i],
compression="gzip",
compression=self.config.compression,
chunks=True,
maxshape=(None, tracks[i].shape[1]),
)
......@@ -2012,7 +2017,7 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
f.create_dataset(
"jets",
data=selected_indices,
compression="gzip",
compression=self.config.compression,
)
# Return the selected indicies
......@@ -2517,6 +2522,25 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
# single large file
if self.do_combination:
self.Combine_Flavours()
# Plot the variables from the output file of the resampling process
if "njets_to_plot" in self.options and self.options["njets_to_plot"]:
preprocessing_plots(
sample=self.config.GetFileName(option="resampled"),
var_dict=GetVariableDict(self.config.var_file),
class_labels=self.config.sampling["class_labels"],
plots_dir=os.path.join(
self.resampled_path,
"plots/resampling/",
),
track_collection_list=self.options["tracks_names"]
if "tracks_names" in self.options
and "save_tracks" in self.options
and self.options["save_tracks"] is True
else None,
nJets=self.options["njets_to_plot"],
)
else:
logger.warning("Skipping combining step (not in list to execute).")
......
......@@ -8,7 +8,12 @@ import numpy as np
from umami.configuration import logger
from umami.preprocessing_tools.resampling.resampling_base import ResamplingTools
from umami.preprocessing_tools.utils import ResamplingPlots, generate_process_tag
from umami.preprocessing_tools.utils import (
GetVariableDict,
ResamplingPlots,
generate_process_tag,
preprocessing_plots,
)
class Weighting(ResamplingTools):
......@@ -146,3 +151,21 @@ class Weighting(ResamplingTools):
# write out indices.h5 to use preprocessing chain
self.GetIndices()
self.WriteFile(self.indices_to_keep)
# Plot the variables from the output file of the resampling process
if "njets_to_plot" in self.options and self.options["njets_to_plot"]:
preprocessing_plots(
sample=self.config.GetFileName(option="resampled"),
var_dict=GetVariableDict(self.config.var_file),
class_labels=self.config.sampling["class_labels"],
plots_dir=os.path.join(
self.resampled_path,
"plots/resampling/",
),
track_collection_list=self.options["tracks_names"]
if "tracks_names" in self.options
and "save_tracks" in self.options
and self.options["save_tracks"] is True
else None,
nJets=self.options["njets_to_plot"],
)
"""Collection of utility functions for preprocessing tools."""
import os
import h5py
import matplotlib as mtp
import matplotlib.pyplot as plt
import numpy as np
......@@ -9,6 +10,7 @@ import yaml
from sklearn.preprocessing import LabelBinarizer
from umami.configuration import global_config, logger
from umami.helper_tools import hist_w_unc
from umami.tools import applyATLASstyle, makeATLAStag, yaml_loader
......@@ -78,6 +80,444 @@ def GetBinaryLabels(
return lb.fit_transform(labels)
def plot_variable(
df,
labels: np.ndarray,
variable: str,
variable_index: int,
var_type: str,
class_labels: list,
output_dir: str,
binning: dict = None,
figsize: list = None,
normed: bool = True,
fileformat: str = "pdf",
UseAtlasTag: bool = True,
AtlasTag: str = "Internal Simulation",
SecondTag: str = "$\\sqrt{s}=13$ TeV, PFlow Jets",
y_scale: float = 1.3,
yAxisAtlasTag: float = 0.9,
leg_loc: str = "upper right",
label_fontsize: int = 12,
leg_fontsize: int = 10,
leg_ncol: int = 1,
logy: bool = True,
**kwargs, # pylint: disable=unused-argument
):
"""
Plot a given variable.
Parameters
----------
df : pd.DataFrame or np.ndarray
DataFrame (for jets) or ndarray (for tracks) with
the jets/tracks inside.
labels : np.ndarray
One hot encoded array with the truth values.
variable : str
Name of the variable which is to be plotted.
variable_index : int
Index of the variable in the final training set. This
is used to identify the variables in the final training
set.
var_type : str
Type of the variable that is used. Either `jets` or
`tracks`.
class_labels : list
List with the flavours used (ORDER IMPORTANT).
output_dir : str
Directory where the plot is saved.
binning : dict, optional
Dict with the variables as keys and binning as item,
by default None
figsize : list, optional
List with the size of the figure, by default None
normed : bool, optional
Normalise the flavours, by default True
fileformat : str, optional
Fileformat of the plots, by default "pdf"
UseAtlasTag : bool, optional
Use a ATLAS tag, by default True
AtlasTag : str, optional
First line of ATLAS tag, by default "Internal Simulation"
SecondTag : str, optional
Second line of ATLAS tag, by default "$sqrt{s}=13$ TeV, PFlow Jets"
y_scale : float, optional
Increase the y-axis to fit the ATALS tag in, by default 1.3
yAxisAtlasTag : float, optional
Relative y axis position of the ATLAS Tag, by default 0.9
leg_loc : str, optional
Position of the legend in the plot, by default "upper right"
label_fontsize : int, optional
Fontsize of the axis labels, by default 12
leg_fontsize : int, optional
Fontsize of the legend, by default 10
leg_ncol : int, optional
Number of columns in the legend, by default 1
logy : bool, optional
Plot a logarithmic y-axis, by default True
**kwargs : kwargs
kwargs from `plot_object`
Raises
------
TypeError
If the given variable type is not supported.
"""
# Check if binning is given. If not, init an empty dict
if not binning:
binning = {}
# Check if figsize is given. If not, init default size
if not figsize:
figsize = [11.69 * 0.8, 8.27 * 0.8]
# Set ATLAS plot style
applyATLASstyle(mtp)
# Give a debug logger
logger.debug(f"Plotting variable {variable}...")
# Get the binning
try:
_, bins = np.histogram(
a=np.nan_to_num(df[variable]),
bins=binning[variable]
if variable in binning and binning is not None
else 50,
)
except IndexError as Error:
if var_type.casefold() == "jets":
array = np.nan_to_num(df[:, variable_index])
elif var_type.casefold() == "tracks":
array = np.nan_to_num(df[:, :, variable_index])
else:
raise TypeError(
f"Variable type {var_type} not supported! Only jets and tracks!"
) from Error
_, bins = np.histogram(
a=array,
bins=binning[variable]
if variable in binning and binning is not None
else 50,
)
# Init a new figure
fig = plt.figure(figsize=(figsize[0], figsize[1]))
ax = fig.subplots()
# Loop over the flavours
for flav_counter, flavour in enumerate(class_labels):
# Get all jets with the correct flavour
try:
flavour_jets = df[variable][labels[:, flav_counter] == 1].values
except AttributeError:
flavour_jets = df[variable][labels[:, flav_counter] == 1]
except IndexError as Error:
if var_type.casefold() == "jets":
flavour_jets = df[:, variable_index][
labels[:, flav_counter] == 1
].flatten()
elif var_type.casefold() == "tracks":
flavour_jets = df[:, :, variable_index][labels[:, flav_counter] == 1]
else:
raise TypeError(
f"Variable type {var_type} not supported! Only jets and tracks!"
) from Error
# Calculate bins
hist_bins, weights, unc, band = hist_w_unc(
a=flavour_jets,
bins=bins,
normed=normed,
)
# Plot the bins
ax.hist(
x=hist_bins[:-1],
bins=hist_bins,
weights=weights,
histtype="step",
linewidth=1.0,
color=global_config.flavour_categories[flavour]["colour"],
stacked=False,
fill=False,
label=global_config.flavour_categories[flavour]["legend_label"],
)
# Plot uncertainty
ax.hist(
x=hist_bins[:-1],
bins=hist_bins,
bottom=band,
weights=unc * 2,
label="stat. unc." if flavour == class_labels[-1] else None,
**global_config.hist_err_style,
)
# Set xlabel
ax.set_xlabel(
variable,
fontsize=label_fontsize,
horizontalalignment="right",
x=1.0,
)
if normed:
ax.set_ylabel(
"Normalised Number of Jets",
fontsize=label_fontsize,
horizontalalignment="right",
y=1.0,
)
else:
ax.set_ylabel(
"Number of Jets",
fontsize=label_fontsize,
horizontalalignment="right",
y=1.0,
)
# Set logscale for y axis
if logy is True:
ax.set_yscale("log")
# Increase ymax so atlas tag don't cut plot
ymin, ymax = ax.get_ylim()
ax.set_ylim(
ymin,
ymax * np.log(ymax / ymin) * 10 * y_scale,
)
else:
# Increase ymax so atlas tag don't cut plot
ymin, ymax = ax.get_ylim()
ax.set_ylim(bottom=ymin, top=y_scale * ymax)
# ATLAS tag
if UseAtlasTag is True:
makeATLAStag(
ax=ax,
fig=fig,
first_tag=AtlasTag,
second_tag=SecondTag,
ymax=yAxisAtlasTag,
)
# Set legend
ax.legend(
loc=leg_loc,