Commit f2117df9 authored by Tomke Schroer's avatar Tomke Schroer
Browse files
parents 94efc61a b1d2965c
Pipeline #4137019 passed with stages
in 31 minutes and 1 second
......@@ -20,7 +20,6 @@
retry: 2
# We need to define an empty array for the dependecies to NOT include any artifacts from previous jobs
dependencies: []
needs: []
.requirement_changes: &requirement_changes
changes:
......@@ -59,6 +58,7 @@ build_umamibase_cpu:
BASE: 'BASE_IMAGE=tensorflow/tensorflow:$TFTAG'
DOCKER_FILE: docker/umamibase/Dockerfile
IMAGE_DESTINATION: '${REGISTY_PATH}/umamibase:latest'
needs: []
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
<<: *requirement_changes
......@@ -73,6 +73,7 @@ build_umamibase_gpu:
BASE: 'BASE_IMAGE=tensorflow/tensorflow:$TFTAG-gpu'
DOCKER_FILE: docker/umamibase/Dockerfile
IMAGE_DESTINATION: '${REGISTY_PATH}/umamibase:latest-gpu'
needs: []
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
<<: *requirement_changes
......@@ -177,6 +178,7 @@ build_umamibase_cpu_MR:
BASE: 'BASE_IMAGE=tensorflow/tensorflow:$TFTAG'
DOCKER_FILE: docker/umamibase/Dockerfile
IMAGE_DESTINATION: '${REGISTY_PATH}/temporary_images:${CI_MERGE_REQUEST_IID}-base'
needs: []
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event" && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/umami"
<<: *requirement_changes
......
......@@ -22,7 +22,7 @@
script:
- pip install darglint
- darglint --list-errors
- find . -name "*.py" ! -name *PlottingFunctions.py ! -name *Plotting.py ! -name *conf.py | xargs -n 1 -P 8 -t darglint
- find . -name "*.py" ! -name *conf.py | xargs -n 1 -P 8 -t darglint
.pylint_template: &pylint_template
stage: linting
......
......@@ -36,6 +36,7 @@
- helper_tools
- input_vars_tools
- metrics
- plotting_tools
- preprocessing
- tf_tools
- train_tools
......
include umami/preprocessing_tools/configs/*.yaml
include umami/tools/PyATLASstyle/fonts/*.ttf
include umami/configs/global_config.yaml
\ No newline at end of file
......@@ -4,10 +4,30 @@
### Latest
- Switch to latest puma version (v0.1.3) [!572](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/572)
- Splitting CADS and DIPS Attention [!569](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/569)
- Fixing docker image builds [!571](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/571)
- Fixing uncertainty calculation for the ROC curves [!566](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/566)
### [v0.9](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/tags/0.9)
- Fixing Callback error when LRR is not used [!567](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/567)
- Fixing stacking issue for the jet variables in the PDFSampling [!565](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/565)
- Fixing problem with 4 classes integration test [!564](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/564)
- Rework saliency plots to use puma [!556](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/556)
- Fixing generation of class ids for only one class [!563](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/563)
- Removing hardcoded tmp directories in the integration tests [!562](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/562)
- Fixing x range in metrics plots + correct tagger name in results files [!560](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/560)
- Fixing issue with the PDFSampling shuffling + Fixing small issue with the loaders [!558](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/558)
- Fixing ylabel issue in ROC plots [!555](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/555)
- Adding verbose option to executable scripts [!557](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/557)
- Moving Plotting Files in one folder [!554](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/554)
- Adding classes to global config (light-flavour jets split by quark flavour/gluons, leptonic b-hadron decays) to define extended tagger output [!553](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/553)
- Fixing issues with trained_taggers and taggers_from_file in plotting_epoch_performance.py [!549](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/549)
- Adding plotting API to Contour plots + Updating plotting_umami docs [!537](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/537)
- Adding unit test for prepare_model and minor bug fixes [!546](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/546)
- Adding unit tests for tf generators[!542](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/542)
- Fix epoch bug in continue_training[!543](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/543)
- Fix epoch bug in continue_training[!543](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/543)
- Updating tensorflow to version `2.9.0` and pytorch to `1.11.0-cuda11.3-cudnn8-runtime` [!547](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/547)
- Removing plotting API code and switch to puma [!540](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/540) [!548](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/548)
- Fix epoch bug in continue_training[!543](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/543)
......
# Plotting Input Variables
The input variables for different files can also be plotted using the `plot_input_variables.py` script. Its also steered by a yaml file. An example for such a file can be found [here](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml). The structure is close to the one from `plotting_umami` but still a little bit different.
To start the plotting of the input variables, you need to run the following command
```bash
plot_input_vars.py -c <path/to/config> --tracks
```
or
```bash
plot_input_vars.py -c <path/to/config> --jets
```
which will plot either all plots defined using jet- or track variables. You can also give the `-f` or `--format` option where you can decide on a format for the plots. The default is `pdf`.
### Yaml File
In the following, the possible configration parameters are listed with a brief description.
......@@ -7,17 +20,17 @@ In the following, the possible configration parameters are listed with a brief d
#### Variable dict and number of jets
Here you can define the number of jets that are used and also the variable dict, where all the variables that are available are saved.
??? example "Click to see corresponding code highlighted in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml linenums="1", hl_lines="9-14"
§§§examples/plotting_input_vars.yaml§§§
??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml
§§§examples/plotting_input_vars.yaml:9:14§§§
```
#### Number of Tracks per Jet
The number of tracks per jet can be plotted for all different files. This can be given like this:
??? example "Click to see corresponding code highlighted in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml linenums="1", hl_lines="117-133"
§§§examples/plotting_input_vars.yaml§§§
??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml
§§§examples/plotting_input_vars.yaml:91:108§§§
```
| Options | Data Type | Necessary/Optional | Explanation |
......@@ -36,9 +49,9 @@ The number of tracks per jet can be plotted for all different files. This can be
#### Input Variables Tracks
To plot the track input variables, the following options are used.
??? example "Click to see corresponding code highlighted in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml linenums="1", hl_lines="135-168"
§§§examples/plotting_input_vars.yaml§§§
??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml
§§§examples/plotting_input_vars.yaml:110:144§§§
```
......@@ -58,9 +71,9 @@ To plot the track input variables, the following options are used.
#### Input Variables Jets
To plot the jet input variables, the following options are used.
??? example "Click to see corresponding code highlighted in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml linenums="1", hl_lines="16-115"
§§§examples/plotting_input_vars.yaml§§§
??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml
§§§examples/plotting_input_vars.yaml:16:89§§§
```
| Options | Data Type | Necessary/Optional | Explanation |
......@@ -81,9 +94,9 @@ The `plot_settings` section is similar for all three cases described above.
In order to define some settings you want to apply to all plots, use yaml anchors
as shown here:
??? example "Click to see corresponding code highlighted in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml linenums="1", hl_lines="1-7"
§§§examples/plotting_input_vars.yaml§§§
??? example "Click to see corresponding code in the [example config file](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/plotting_input_vars.yaml)"
```yaml
§§§examples/plotting_input_vars.yaml:1:7§§§
```
Most of the plot settings are valid for all types of input variable plots
......
......@@ -95,7 +95,7 @@ Plotting the ROC Curves of the rejection rates against the b-tagging efficiency.
| `label` | `str` | Necessary | Legend label of the model. |
| `tagger_name` | `str` | Necessary | Name of the tagger which is to be plotted. |
| `rejection_class` | `str` | Necessary | Class which the main flavour is plotted against. |
| `binomialErrors` | `bool` | Optional | Plot binomial errors to plot. |
| `draw_errors` | `bool` | Optional | Plot binomial errors to plot. |
| `xmin` | `float` | Optional | Set the minimum b efficiency in the plot (which is the xmin limit). |
| `ymax` | `float` | Optional | The maximum y axis. |
| `working_points` | `list` | Optional | The specified WPs are calculated and at the calculated b-tagging discriminant there will be a vertical line with a small label on top which prints the WP. |
......@@ -124,10 +124,26 @@ Plot the b efficiency/c-rejection/light-rejection against the pT. For example:
| `flavour` | `str` | Necessary | Flavour class rejection which is to be plotted. |
| `class_labels` | List of class labels that were used in the preprocessing/training. They must be the same in all three files! Order is important! |
| `main_class` | `str` | Class which is to be tagged. |
| `WP` | `float` | Necessary | Float of the working point that will be used. |
| `WP_line` | `float` | Optional | Print a horizontal line at this value efficiency. |
| `working_point` | `float` | Necessary | Float of the working point that will be used. |
| `working_point_line` | `float` | Optional | Print a horizontal line at this value efficiency. |
| `fixed_eff_bin` | `bool` | Optional | Calculate the WP cut on the discriminant per bin. |
#### Saliency Plots
To evaluate the impact of the track variables to the final b-tagging discriminant can't be found using SHAPley. To make the impact visible (for each track of the jet), so-called Saliency maps are used. These maps are calculated when evaluating the model you have trained (if it is activated). A lot of different options can be set. An example is given here:
```yaml
§§§examples/plotting_umami_config_dips.yaml:149:160§§§
```
| Options | Data Type | Necessary/Optional | Explanation |
|---------|-----------|--------------------|-------------|
| `type` | `str` | Necessary | This gives the type of plot function used. Must be `"saliency"` here. |
| `data_set_name` | `str` | Necessary | Name of the dataset that is used. This is the name of the test_file which you want to use. |
| `target_eff` | `float` | Necessary | Efficiency of the target flavour you want to use (Which WP you want to use). The value is given between 0 and 1. |
| `jet_flavour` | `str` | Necessary | Name of flavour you want to plot. |
| `PassBool` | `str` | Necessary | Decide if the jets need to pass the working point discriminant cut or not. `False` would give you, for example, truth b-jets which does not pass the working point discriminant cut and are therefore not tagged a b-jets. |
| `nFixedTrks` | `int` | Necessary | The saliency maps can only be calculated for jets with a fixed number of tracks. This number of tracks can be set with this parameter. For example, if this value is `8`, than only jets which have exactly 8 tracks are used for the saliency maps. This value needs to be set in the train config when you run the evaluation! If you run the evaluation with, for example `5`, you can't plot the saliency map for `8`. |
#### Fraction Contour Plot
Plot two rejections against each other for a given working point with different fraction values.
......
......@@ -16,7 +16,7 @@ After the previous step the ntuples need to be further processed. We can use dif
This processing can be done using the preprocessing capabilities of Umami via the [`preprocessing.py`](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/umami/preprocessing.py) script.
Please refer to the [documentation on preprocessing](preprocessing.md) for additional information.
Please refer to the [documentation on preprocessing](https://umami-docs.web.cern.ch/preprocessing/preprocessing/) for additional information.
For the GNN, we use the `PFlow-Preprocessing-GNN.yaml` config file, found [here](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami-config-tags/-/blob/master/offline/PFlow-Preprocessing-GNN.yaml).
......
......@@ -112,7 +112,7 @@ Here are all important settings defined for the evaluation process (evaluating v
| `frac_step` | All | `float` | Optional | Step size of the fraction value scan. Please keep in mind that the fractions given to the background classes need to add up to one! All combinations that do not add up to one are ignored. If you choose a combination `frac_min`, `frac_max` or `frac_step` where the fractions of the brackground classes do not add up to one, you will get an error while running `evaluate_model.py` |
| `frac_min` | All | `float` | Optional | Minimal fraction value which is set for a background class in the fraction scan. |
| `frac_max` | All | `float` | Optional | Maximal fraction value which is set for a background class in the fraction scan. |
| `Calculate_Saliency` | DIPS | `bool` | Optional | Decide, if the saliency maps are calculated or not. This takes a lot of time and resources! |
| `calculate_saliency` | DIPS | `bool` | Optional | Decide, if the saliency maps are calculated or not. This takes a lot of time and resources! |
| `add_variables_eval` | DL1r, DL1d | `list` | Optional | A list to add available variables to the evaluation files. |
| `shapley` | DL1r, DL1d | `dict` | Optional | `dict` with the options for the feature importance explanation with SHAPley |
| `feature_sets` | DL1r, DL1d | `int` | Optional | Over how many full sets of features it should calculate over. Corresponds to the dots in the beeswarm plot. 200 takes like 10-15 min for DL1r on a 32 core-cpu. |
......
......@@ -163,4 +163,4 @@ Eval_parameters_validation:
WP: 0.77
# Decide, if the Saliency maps are calculated or not.
Calculate_Saliency: False
calculate_saliency: False
......@@ -145,4 +145,4 @@ Eval_parameters_validation:
WP: 0.77
# Decide, if the Saliency maps are calculated or not.
Calculate_Saliency: True
calculate_saliency: True
......@@ -57,7 +57,7 @@ DL1r_light_flavour:
tagger_name: "DL1"
rejection_class: "cjets"
plot_settings: # These settings are given to the umami.evaluation_tools.plotROCRatio() function by unpacking them.
binomialErrors: True
draw_errors: True
xmin: 0.5
ymax: 1000000
figsize: [7, 6] # [width, hight]
......
......@@ -94,7 +94,7 @@ beff_scan_tagger_umami:
tagger_name: "umami"
rejection_class: "cjets"
plot_settings:
binomialErrors: True
draw_errors: True
xmin: 0.5
ymax: 1000000
figsize: [7, 6] # [width, hight]
......@@ -168,7 +168,7 @@ beff_scan_tagger_compare_umami:
tagger_name: "umami"
rejection_class: "cjets"
plot_settings:
binomialErrors: True
draw_errors: True
xmin: 0.5
ymax: 1000000
figsize: [9, 9] # [width, hight]
......
......@@ -87,9 +87,9 @@ Dips_pT_vs_beff:
flavour: "cjets"
class_labels: ["ujets", "cjets", "bjets"]
main_class: "bjets"
WP: 0.77
WP_Line: True
Fixed_WP_Bin: False
working_point: 0.77
working_point_line: True
fixed_eff_bin: False
figsize: [7, 5]
logy: False
use_atlas_tag: True
......@@ -104,9 +104,9 @@ Dips_light_flavour_ttbar:
data_set_name: "ttbar_r21"
label: "DIPS"
tagger_name: "dips"
rejection_class: "cjets"
rejection_class: "ujets"
plot_settings:
binomialErrors: True
draw_errors: True
xmin: 0.5
ymax: 1000000
figsize: [7, 6] # [width, hight]
......@@ -129,7 +129,7 @@ Dips_Comparison_flavour_ttbar:
tagger_name: "dips"
rejection_class: "cjets"
plot_settings:
binomialErrors: True
draw_errors: True
xmin: 0.5
ymax: 1000000
figsize: [9, 9] # [width, hight]
......@@ -149,12 +149,12 @@ confusion_matrix_Dips_ttbar:
Dips_saliency_b_WP77_passed_ttbar:
type: "saliency"
data_set_name: "ttbar_r21"
target_eff: 0.77
jet_flavour: "bjets"
PassBool: True
nFixedTrks: 8
plot_settings:
title: "Saliency map for $b$ jets from \n $t\\bar{t}$ who passed WP = 77% \n with exactly 8 tracks"
target_beff: 0.77
jet_flavour: "cjets"
PassBool: True
FlipAxis: True
use_atlas_tag: True # Enable/Disable atlas_first_tag
atlas_first_tag: "Simulation Internal"
atlas_second_tag: "$\\sqrt{s}=13$ TeV, PFlow jets"
......
......@@ -2,7 +2,7 @@ matplotlib==3.5.1
mlxtend==0.19.0
numpy==1.21.0
pandas==1.3.5
puma-hep==0.1.1
puma-hep==0.1.3
pydash==5.1.0
ruamel.yaml==0.17.21
seaborn==0.11.2
......
......@@ -16,6 +16,7 @@ scripts =
umami/evaluate_model.py
umami/plotting_umami.py
umami/plotting_epoch_performance.py
umami/plot_input_variables.py
[isort]
multi_line_output=3
......
"""Umami framework used in ATLAS FTAG for dataset preparation and tagger training."""
__version__ = "0.8"
__version__ = "0.9"
......@@ -58,6 +58,60 @@ flavour_categories:
colour: ""
legend_label: $cc$-jets
prob_var_name: "pcc"
upjets:
label_var: PartonTruthLabelID
label_value: 1
colour: ""
legend_label: $u$-jets
prob_var_name: "pup"
djets:
label_var: PartonTruthLabelID
label_value: 2
colour: ""
legend_label: $d$-jets
prob_var_name: "pd"
sjets:
label_var: PartonTruthLabelID
label_value: 3
colour: ""
legend_label: $s$-jets
prob_var_name: "ps"
gluonjets:
label_var: PartonTruthLabelID
label_value: 21
colour: ""
legend_label: gluon-jets
prob_var_name: "pg"
lightwogluons:
label_var: PartonTruthLabelID
label_value: [1,2,3]
colour: ""
legend_label: light-jets w/o gluons
prob_var_name: "plwog"
hadrbdecay:
label_var: LeptonDecayLabel
label_value: 0
colour: ""
legend_label: hadronic $b$-hadron decay
prob_var_name: "phadrb"
singleebdecay:
label_var: LeptonDecayLabel
label_value: 1
colour: ""
legend_label: $e$'s in $b$- or $c$-hadron decay
prob_var_name: "pe"
singlemubdecay:
label_var: LeptonDecayLabel
label_value: 2
colour: ""
legend_label: $\\mu$'s in $b$- or $c$-hadron decay
prob_var_name: "pmu"
singletaubdecay:
label_var: LeptonDecayLabel
label_value: 3
colour: ""
legend_label: \u03C4's in $b$- or $c$-hadron decay # \u03C4: unicode for small tau
prob_var_name: "ptau"
# plot style definitions
hist_err_style:
......
......@@ -186,7 +186,8 @@ def LoadJetsFromFile(
)
# Remove all unused jets
jets = jets.drop(indices_toremove)
if len(indices_toremove) != 0:
jets = jets.drop(indices_toremove)
# If not the first file processed, append to the global one
if j == 0 and infile_counter == 0:
......@@ -389,30 +390,33 @@ def LoadTrksFromFile(
)
)
# Remove unused jets from labels
labels = labels.drop(indices_toremove)
Umami_labels = labels["Umami_labels"].values
# Load tracks and delete unused classes
trks = np.delete(
arr=np.asarray(
h5py.File(file, "r")[f"/{tracks_name}"][
infile_counter * chunk_size : (infile_counter + 1) * chunk_size
]
),
obj=indices_toremove,
axis=0,
# Load tracks
trks = np.asarray(
h5py.File(file, "r")[f"/{tracks_name}"][
infile_counter * chunk_size : (infile_counter + 1) * chunk_size
]
)
if len(indices_toremove) != 0:
# Remove unused jets from labels
labels = labels.drop(indices_toremove)
# Delete unused classes and cutted tracks
trks = np.delete(
arr=trks,
obj=indices_toremove,
axis=0,
)
# If not the first file processed, append to the global one
if j == 0 and infile_counter == 0:
all_trks = trks
all_labels = Umami_labels
all_labels = labels["Umami_labels"].values
# if the first file processed, set as global one
else:
all_trks = np.append(all_trks, trks, axis=0)
all_labels = np.append(all_labels, Umami_labels)
all_labels = np.append(all_labels, labels["Umami_labels"].values)
# Adding the loaded jets to counter
nJets_counter += len(trks)
......
#!/usr/bin/env python
"""Execution script for training model evaluations."""
from umami.configuration import global_config, logger # isort:skip
from umami.configuration import global_config, logger, set_log_level # isort:skip
import argparse
import os
import pickle
......@@ -58,6 +58,13 @@ def get_parser():
help="Decide, which tagger was used and is to be evaluated.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Set verbose level to debug for the logger.",
)
parser.add_argument(
"--nJets",
type=int,
......@@ -193,7 +200,8 @@ def evaluate_model(
if "exclude" in train_config.config:
exclude = train_config.config["exclude"]
# Check which test files need to be loaded depending on the CADS version
# Check which test files need to be loaded depending on the umami version
logger.info("Start loading %s test file", data_set_name)
if tagger.casefold() == "umami_cond_att".casefold():
# Load the test jets
x_test, x_test_trk, _ = utt.GetTestFile(
......@@ -289,7 +297,7 @@ def evaluate_model(
)
# Get the discriminant values and probabilities of each tagger for each jet
df_discs_dict = uet.GetScoresProbsDict(
df_discs_dict = uet.get_scores_probs_dict(
jets=jets,
y_true=truth_internal_labels,
tagger_preds=tagger_preds,
......@@ -321,7 +329,7 @@ def evaluate_model(
)
# Get the rejections, discs and effs of the taggers
tagger_rej_dicts = uet.GetRejectionPerEfficiencyDict(
tagger_rej_dicts = uet.get_rej_per_eff_dict(
jets=jets,
y_true=truth_internal_labels,
tagger_preds=tagger_preds,
......@@ -356,10 +364,14 @@ def evaluate_model(
f"results{results_filename_extension}-rej_per_eff-{epoch}.h5",
"a",
) as h5_file:
h5_file.attrs["N_test"] = len(jets)
# Put the number of jets per class in the dict for unc calculation
for flav_counter, flavour in enumerate(class_labels):
h5_file.attrs[f"njets_{flavour}"] = len(
truth_internal_labels[truth_internal_labels == flav_counter]
)
# Get the rejections, discs and f_* values for the taggers
tagger_fraction_rej_dict = uet.GetRejectionPerFractionDict(
tagger_fraction_rej_dict = uet.get_rej_per_frac_dict(
jets=jets,
y_true=truth_internal_labels,
tagger_preds=tagger_preds,
......@@ -390,7 +402,11 @@ def evaluate_model(
f"results{results_filename_extension}-rej_per_fractions-{args.epoch}.h5",
"a",
) as h5_file:
h5_file.attrs["N_test"] = len(jets)
# Put the number of jets per class in the dict for unc calculation
for flav_counter, flavour in enumerate(class_labels):
h5_file.attrs[f"njets_{flavour}"] = len(
truth_internal_labels[truth_internal_labels == flav_counter]
)
def evaluate_model_dips(
......@@ -492,7 +508,8 @@ def evaluate_model_dips(
logger.info(f"Evaluating {model_file}")
# Check which test files need to be loaded depending on the CADS version
if tagger.casefold() == "CADS".casefold():
logger.info("Start loading %s test file", data_set_name)
if tagger.casefold() == "cads":
# Load the test jets
x_test, x_test_trk, y_test = utt.GetTestFile(
input_file=test_file,
......@@ -574,15 +591,15 @@ def evaluate_model_dips(
)
# Get the discriminant values and probabilities of each tagger for each jet
df_discs_dict = uet.GetScoresProbsDict(
df_discs_dict = uet.get_scores_probs_dict(
jets=jets,
y_true=truth_internal_labels,
tagger_preds=[pred_dips],
tagger_names=["dips"],
tagger_names=[tagger.casefold()],
tagger_list=tagger_list,
class_labels=class_labels,
main_class=main_class,
frac_values={"dips": eval_params["frac_values"]},
frac_values={tagger.casefold(): eval_params["frac_values"]},
frac_values_comp=frac_values_comp,
)
......@@ -601,15 +618,15 @@ def evaluate_model_dips(
)
# Get the rejections, discs and effs of the taggers
tagger_rej_dicts = uet.GetRejectionPerEfficiencyDict(
tagger_rej_dicts = uet.get_rej_per_eff_dict(
jets=jets,
y_true=truth_internal_labels,
tagger_preds=[pred_dips],
tagger_names=["dips"],
tagger_names=[tagger.casefold()],
tagger_list=tagger_list,
class_labels=class_labels,
main_class=main_class,
frac_values={"dips": eval_params["frac_values"]},
frac_values={tagger.casefold(): eval_params["frac_values"]},
frac_values_comp=frac_values_comp,
eff_min=0.49 if "eff_min" not in eval_params else eval_params["eff_min"],
eff_max=1.0 if "eff_max" not in eval_params else eval_params["eff_max"],
......@@ -632,14 +649,18 @@ def evaluate_model_dips(
f"results{results_filename_extension}-rej_per_eff-{args.epoch}.h5",
"a",
) as h5_file:
h5_file.attrs["N_test"] = len(jets)
# Put the number of jets per class in the dict for unc calculation
for flav_counter, flavour in enumerate(class_labels):
h5_file.attrs[f"njets_{flavour}"] = len(
truth_internal_labels[truth_internal_labels == flav_counter]
)
# Get the rejections, discs and f_* values for the taggers
tagger_fraction_rej_dict = uet.GetRejectionPerFractionDict(
tagger_fraction_rej_dict = uet.get_rej_per_frac_dict(
jets=jets,
y_true=truth_internal_labels,
tagger_preds=[pred_dips],
tagger_names=["dips"],
tagger_names=[tagger