Commit f72b3ebb authored by Maxence Draguet's avatar Maxence Draguet
Browse files

Merge branch 'master' into maxence-adding-loading-DL1

parents 6efdb4ca 5431cce0
......@@ -45,7 +45,8 @@ test_coverage:
image: python:3.7-slim
script:
- pip install --upgrade pip setuptools wheel
- pip install -r requirements.txt
- pip install pytest==6.2.4
- pip install pytest-cov==2.12.0
- cd ./coverage_files/
- coverage combine
- coverage report
......
......@@ -11,6 +11,7 @@ RUN apt-get update && \
echo "krb5-config krb5-config/add_servers_realm string CERN.CH" | debconf-set-selections && \
echo "krb5-config krb5-config/default_realm string CERN.CH" | debconf-set-selections && \
apt-get install -y krb5-user && \
apt-get install -y build-essential && \
apt-get install -y vim nano emacs less screen graphviz python3-tk wget
COPY requirements.txt .
......
# Explaining the importance of features with SHAPley
[SHAPley](https://github.com/slundberg/shap) is a framework that helps you understand how your training of a machine learning model is affected by the input variables, or in other words from which variables your model possibly learns the most. You just need to add a `--shapley` flag to `evaluate_model.py --dl1` as e.g.
```bash
python umami/evaluate_model.py -c examples/DL1r-PFlow-Training-config.yaml -e 230 --dl1 --shapley
```
and it will output a beeswarm plot into `modelname/plots/`. Each dot in this plot is for one whole set of features (or one jet). They are stacked vertically once there is no space horizontally anymore to indicate density. The colormap tells you what the actual value was that entered the model. The Shap value is basically calculated by removing features, letting the model make a prediction and then observe what would happen if you introduce features again to your prediction. If you do this over all possible combinations you get estimates of a features impact to your model. This is what the x-axis (SHAP value) tells you: the on average(!) contribution of a variable to an output node you are interested in (default is the output node for b-jets). In practice, large magnitudes (which is also what these plots are ordered by default in umami) are great, as they give the model a better possibility to discriminate. Features with large negative shap values therefore will help the model to better reject, whereas features with large positive shap values helps the model to learn that these are most probably jets from the category of interest. If you want to know more about shapley values, here is a [talk](https://indico.cern.ch/event/1071129/#4-shapely-for-nn-input-ranking) from our alorithms meeting.
You have some options to play with in the `Eval_parameters_validation` section in the [DL1r-PFlow-Training-config.yaml](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/blob/master/examples/DL1r-PFlow-Training-config.yaml)
```bash
Eval_parameters_validation:
(...other stuff...)
shapley:
# Over how many full sets of features it should calculate over.
# Corresponds to the dots in the beeswarm plot.
# 200 takes like 10-15 min for DL1r on a 32 core-cpu
feature_sets: 200
# defines which of the model outputs (flavor) you want to explain
# [tau,b,c,u] := [3, 2, 1, 0]
model_output: 2
# You can also choose if you want to plot the magnitude of feature
# importance for all output nodes (flavors) in another plot. This
# will give you a bar plot of the mean SHAP value magnitudes.
bool_all_flavor_plot: False
# as this takes much longer you can average the feature_sets to a
# smaller set, 50 is a good choice for DL1r
averaged_sets: 50
# [11,11] works well for dl1r
plot_size: [11, 11]
```
MC Samples
==============
The FTAG1 derivations and the most recent ntuples for PFlow with the new RNNIP, SMT and the latest DL1* recommendations inside are shown in the following table. DIPS Default, DIPS Loose, DL1d Default and DL1d Loose are added as DL2 in the h5 ntuples.
The FTAG1 derivations and the most recent ntuples for PFlow with the new RNNIP, SMT and the latest DL1* recommendations inside are shown in the following table. DIPS Default, DIPS Loose, DL1d Default, DL1d Loose and UMAMI are added as DL2 in the h5 ntuples.
## Default FTAG Samples (ttbar and Z')
| Sample | h5 ntuples | h5 ntuples (looser track selection) | FTAG1 derivations | AOD |
| ------------- | ---------------- | ---------------- | ---------------- | ---------------- |
|MC16a - ttbar | user.alfroch.410470.btagTraining.e6337_s3126_r9364_p3985.EMPFlow.2021-07-28-T130145-R11969_output.h5 | user.alfroch.410470.btagTraining.e6337_s3126_r9364_p3985.EMPFlow_loose.2021-07-30-T132351-R32377_output.h5 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.deriv.DAOD_FTAG1.e6337_s3126_r9364_p3985 | |
|MC16a - Z' | user.alfroch.427080.btagTraining.e5362_s3126_r9364_p3985.EMPFlow.2021-07-28-T130145-R11969_output.h5 | user.alfroch.427080.btagTraining.e5362_s3126_r9364_p3985.EMPFlow_loose.2021-07-30-T132351-R32377_output.h5 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.deriv.DAOD_FTAG1.e5362_s3126_r9364_p3985 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.recon.AOD.e5362_s3126_r9364 |
|MC16d - ttbar | user.alfroch.410470.btagTraining.e6337_s3126_r10201_p3985.EMPFlow.2021-07-28-T130145-R11969_output.h5 | user.alfroch.410470.btagTraining.e6337_s3126_r10201_p3985.EMPFlow_loose.2021-07-30-T132351-R32377_output.h5 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.deriv.DAOD_FTAG1.e6337_s3126_r10201_p3985 | |
|MC16d - Z' | user.alfroch.427080.btagTraining.e5362_s3126_r10201_p3985.EMPFlow.2021-07-28-T130145-R11969_output.h5 | user.alfroch.427080.btagTraining.e5362_s3126_r10201_p3985.EMPFlow_loose.2021-07-30-T132351-R32377_output.h5 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.deriv.DAOD_FTAG1.e5362_s3126_r10201_p3985 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.recon.AOD.e5362_s3126_r10201 |
|MC16d - Z' extended | user.alfroch.427081.btagTraining.e6928_e5984_s3126_r10201_r10210_p3985.EMPFlow.2021-07-28-T130145-R11969_output.h5 | user.alfroch.427081.btagTraining.e6928_e5984_s3126_r10201_r10210_p3985.EMPFlow_loose.2021-07-30-T132351-R32377_output.h5 | mc16_13TeV.427081.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime_Extended.deriv.DAOD_FTAG1.e6928_e5984_s3126_r10201_r10210_p3985 | mc16_13TeV.427081.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime_Extended.recon.AOD.e6928_s3126_r10201 |
|MC16d - Z' extended (QSP on)| user.mguth.800030.btagTraining.e7954_e7400_s3663_r10201_p4207.EMPFlow-DL1d.2021-06-21-T181646-R9170_output.h5 | user.alfroch.800030.btagTraining.e7954_e7400_s3663_r10201_p4207.EMPFlow_loose.2021-07-30-T140220-R6689 | mc16_13TeV.800030.Py8EG_A14NNPDF23LO_flatpT_Zprime_Extended.deriv.DAOD_FTAG1.e7954_e7400_s3663_r10201_p4207 ||
|MC16e - ttbar | user.alfroch.410470.btagTraining.e6337_s3126_r10724_p3985.EMPFlow.2021-07-28-T130145-R11969_output.h5 | user.alfroch.410470.btagTraining.e6337_s3126_r10724_p3985.EMPFlow_loose.2021-07-30-T132351-R32377_output.h5 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.deriv.DAOD_FTAG1.e6337_s3126_r10724_p3985 | |
|MC16e - Z' | user.alfroch.427080.btagTraining.e5362_s3126_r10724_p3985.EMPFlow.2021-07-28-T130145-R11969_output.h5 | user.alfroch.427080.btagTraining.e5362_s3126_r10724_p3985.EMPFlow_loose.2021-07-30-T132351-R32377_output.h5 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.deriv.DAOD_FTAG1.e5362_s3126_r10724_p3985 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.recon.AOD.e5362_s3126_r10724 |
|MC16a - ttbar | user.alfroch.410470.btagTraining.e6337_s3126_r9364_p3985.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.410470.btagTraining.e6337_s3126_r9364_p3985.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.deriv.DAOD_FTAG1.e6337_s3126_r9364_p3985 | |
|MC16a - Z' | user.alfroch.427080.btagTraining.e5362_s3126_r9364_p3985.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.427080.btagTraining.e5362_s3126_r9364_p3985.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.deriv.DAOD_FTAG1.e5362_s3126_r9364_p3985 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.recon.AOD.e5362_s3126_r9364 |
|MC16d - ttbar | user.alfroch.410470.btagTraining.e6337_s3126_r10201_p3985.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.410470.btagTraining.e6337_s3126_r10201_p3985.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.deriv.DAOD_FTAG1.e6337_s3126_r10201_p3985 | |
|MC16d - Z' | user.alfroch.427080.btagTraining.e5362_s3126_r10201_p3985.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.427080.btagTraining.e5362_s3126_r10201_p3985.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.deriv.DAOD_FTAG1.e5362_s3126_r10201_p3985 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.recon.AOD.e5362_s3126_r10201 |
|MC16d - Z' extended | user.alfroch.427081.btagTraining.e6928_e5984_s3126_r10201_r10210_p3985.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.427081.btagTraining.e6928_e5984_s3126_r10201_r10210_p3985.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.427081.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime_Extended.deriv.DAOD_FTAG1.e6928_e5984_s3126_r10201_r10210_p3985 | mc16_13TeV.427081.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime_Extended.recon.AOD.e6928_s3126_r10201 |
|MC16d - Z' extended (QSP on)| user.alfroch.800030.btagTraining.e7954_e7400_s3663_r10201_p4207.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.800030.btagTraining.e7954_e7400_s3663_r10201_p4207.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.800030.Py8EG_A14NNPDF23LO_flatpT_Zprime_Extended.deriv.DAOD_FTAG1.e7954_e7400_s3663_r10201_p4207 ||
|MC16e - ttbar | user.alfroch.410470.btagTraining.e6337_s3126_r10724_p3985.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.410470.btagTraining.e6337_s3126_r10724_p3985.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.deriv.DAOD_FTAG1.e6337_s3126_r10724_p3985 | |
|MC16e - Z' | user.alfroch.427080.btagTraining.e5362_s3126_r10724_p3985.EMPFlow.2021-09-07-T122808-R14883_output.h5 | user.alfroch.427080.btagTraining.e5362_s3126_r10724_p3985.EMPFlow_loose.2021-09-07-T122950-R13989_output.h5 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.deriv.DAOD_FTAG1.e5362_s3126_r10724_p3985 | mc16_13TeV.427080.Pythia8EvtGen_A14NNPDF23LO_flatpT_Zprime.recon.AOD.e5362_s3126_r10724 |
## Release 22 Samples with Muons
The round 2 release 22 samples with RNNIP, DL1* and DIPS.
| Sample | h5 ntuples | h5 ntuples (looser track selection) | DAOD_PHYSVAL derivations| AOD |
| ------------- | ---------------- | ---------------- | ---------------- | ---------------- |
| ttbar | user.alfroch.410470.btagTraining.e6337_e5984_s3126_r12629_p4724.EMPFlow.2021-09-20-T161046-R30966_output.h5 | user.alfroch.410470.btagTraining.e6337_e5984_s3126_r12629_p4724.EMPFlow_loose.2021-09-20-T165329-R29738_output.h5 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.deriv.DAOD_PHYSVAL.e6337_e5984_s3126_r12629_p4724 | mc16_13TeV.410470.PhPy8EG_A14_ttbar_hdamp258p75_nonallhad.recon.AOD.e6337_e5984_s3126_r12629
| Z' Extended (With QSP, Yes shower weights) | user.alfroch.800030.btagTraining.e7954_s3672_r12629_r12636_p4724.EMPFlow.2021-09-20-T161046-R30966_output.h5 | user.alfroch.800030.btagTraining.e7954_s3672_r12629_r12636_p4724.EMPFlow_loose.2021-09-20-T165329-R29738_output.h5 | mc16_13TeV.800030.Py8EG_A14NNPDF23LO_flatpT_Zprime_Extended.deriv.DAOD_PHYSVAL.e7954_s3672_r12629_r12636_p4724 | |
| Z' | user.alfroch.500567.btagTraining.e7954_e7400_s3672_r12629_r12636_p4724.EMPFlow.2021-09-20-T161046-R30966_output.h5 | user.alfroch.500567.btagTraining.e7954_e7400_s3672_r12629_r12636_p4724.EMPFlow_loose.2021-09-20-T165329-R29738_output.h5 | mc16_13TeV.500567.MGH7EG_NNPDF23ME_Zprime.deriv.DAOD_PHYSVAL.e7954_e7400_s3672_r12629_r12636_p4724 | mc16_13TeV.500567.MGH7EG_NNPDF23ME_Zprime.merge.AOD.e7954_e7400_s3672_r12629_r12636 |
## Release 22 Samples
The Round 2 release 22 samples with RNNIP, DL1* and DIPS.
......
......@@ -37,7 +37,7 @@ For more information, consider the [FTAG TWiki about flavour labelling](https://
| -------------------------- | ---------------- |
| 0 | light jets |
| 4 | c-jets |
| 5 | single b-jets |
| 5 | b-jets |
| 15 | tau-jets |
| HadronConeExclExtendedTruthLabelID | Category |
......
......@@ -106,3 +106,26 @@ Eval_parameters_validation:
# Set the datatype of the plots
plot_datatype: "pdf"
# some properties for the feature importance explanation with SHAPley
shapley:
# Over how many full sets of features it should calculate over.
# Corresponds to the dots in the beeswarm plot.
# 200 takes like 10-15 min for DL1r on a 32 core-cpu
feature_sets: 200
# defines which of the model outputs (flavor) you want to explain
# [tau,b,c,u] := [3, 2, 1, 0]
model_output: 2
# You can also choose if you want to plot the magnitude of feature
# importance for all output nodes (flavors) in another plot. This
# will give you a bar plot of the mean SHAP value magnitudes.
bool_all_flavor_plot: False
# as this takes much longer you can average the feature_sets to a
# smaller set, 50 is a good choice for DL1r
averaged_sets: 50
# [11,11] works well for dl1r
plot_size: [11, 11]
......@@ -31,6 +31,7 @@ nav:
- LWTNN Conversion: LWTNN-conversion.md
- Evaluate Taggers in Samples: WO_trained_model.md
- Plotting evaluated Results: plotting_umami.md
- Feature Importance: Feature_Importance.md
plugins:
- search
......@@ -32,3 +32,4 @@ flake8==3.9.2
black==21.5b1
pre-commit==2.12.1
yamllint==1.26.2
shap==0.39.0
\ No newline at end of file
File mode changed from 100644 to 100755
......@@ -21,6 +21,7 @@ from umami.evaluation_tools.PlottingFunctions import (
getDiscriminant,
)
from umami.preprocessing_tools import Configuration
from umami.evaluation_tools import FeatureImportance
# from plottingFunctions import sigBkgEff
tf.compat.v1.disable_eager_execution()
......@@ -87,6 +88,12 @@ def GetParser():
affected by this! They are 0.018 / 0.08 for DL1r / RNNIP.""",
)
parser.add_argument(
"--shapley",
action="store_true",
help="Calculates feature importance for DL1",
)
args = parser.parse_args()
return args
......@@ -674,7 +681,12 @@ def EvaluateModelDips(
def EvaluateModelDL1(
args, train_config, preprocess_config, test_file, data_set_name
args,
train_config,
preprocess_config,
test_file,
data_set_name,
test_file_entry,
):
# Check if epochs are set or not
if args.epoch is None:
......@@ -1086,6 +1098,29 @@ def EvaluateModelDL1(
f.attrs["N_test"] = len(df)
f.close()
if args.shapley:
logger.info("Explaining feature importance with SHAPley")
FeatureImportance.ShapleyOneFlavor(
model=model,
test_data=X_test,
model_output=Eval_params["shapley"]["model_output"],
feature_sets=Eval_params["shapley"]["feature_sets"],
plot_size=Eval_params["shapley"]["plot_size"],
plot_path=f"{train_config.model_name}/",
plot_name=test_file_entry + "_shapley_b-jets",
)
if Eval_params["shapley"]["bool_all_flavor_plot"]:
FeatureImportance.ShapleyAllFlavors(
model=model,
test_data=X_test,
feature_sets=Eval_params["shapley"]["feature_sets"],
averaged_sets=Eval_params["shapley"]["averaged_sets"],
plot_size=Eval_params["shapley"]["plot_size"],
plot_path=f"{train_config.model_name}/",
plot_name=test_file_entry + "_shapley_all_flavors",
)
if not bool_use_taus:
return
......@@ -1183,6 +1218,7 @@ if __name__ == "__main__":
train_config.ttbar_test_files[ttbar_models][
"data_set_name"
],
ttbar_models,
)
if train_config.zpext_test_files is not None:
......@@ -1196,6 +1232,7 @@ if __name__ == "__main__":
train_config.zpext_test_files[zpext_models][
"data_set_name"
],
zpext_models,
)
elif args.dips:
......@@ -1227,7 +1264,7 @@ if __name__ == "__main__":
else:
if train_config.zpext_test_files is not None:
if train_config.ttbar_test_files is not None:
logger.info("Start evaluating UMAMI with ttbar test files...")
for ttbar_models in train_config.ttbar_test_files:
EvaluateModel(
......
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import shap
import os
def ShapleyOneFlavor(
model,
test_data,
model_output=2,
feature_sets=200,
plot_size=(11, 11),
plot_path=None,
plot_name="shapley_b-jets",
):
"""
https://github.com/slundberg/shap
Calculates shap values from shap package and plots results as beeswarm plot
(Explainers are chosen automatically by shap depending on the feature size)
model_output: is the output node of the model like:
tau_index, b_index, c_index, u_index = 3, 2, 1, 0
feature_sets: how many whole sets of features to be calculated over.
Corresponds to the number of dots per feature in the beeswarm plot
plot_size: (11,11) works well for DL1r
"""
explainer = shap.Explainer(model.predict, masker=test_data.values)
shap_values = explainer(test_data.values[:feature_sets, :])
# From ReDevVerse comments https://github.com/slundberg/shap/issues/1460
# model_output in np.take takes the according flavor
# max_display defines how many features will be shown in the plot
shap.summary_plot(
shap_values=np.take(shap_values.values, model_output, axis=-1),
features=test_data.values[:feature_sets, :],
feature_names=list(test_data.keys()),
plot_size=plot_size,
max_display=100,
)
plt.tight_layout()
if not os.path.exists(os.path.abspath(plot_path + "/plots")):
os.makedirs(os.path.abspath(plot_path + "/plots"))
plt.savefig(plot_path + "plots/" + plot_name + ".pdf")
plt.close("all")
def ShapleyAllFlavors(
model,
test_data,
feature_sets=200,
averaged_sets=50,
plot_size=(11, 11),
plot_path=None,
plot_name="shapley_all_flavors",
):
"""
Makes a bar plot for the influence of features for all flavour outputs as
categories in one plot.
averaged_sets: let's you average over input features before
they are handed to the shap framework to decrease runtime.
"""
# it is just calculating mean values, not an actual kmeans algorithm
averaged_data = shap.kmeans(
test_data.values[:feature_sets, :], averaged_sets
)
explainer = shap.KernelExplainer(model.predict, data=averaged_data)
shap_values = explainer.shap_values(test_data.values[:feature_sets, :])
# b: "#1f77b4"
# c: "#ff7f0e"
# u: "#2ca02c"
# make colors for flavor outputs
jet_cmap = colors.ListedColormap(["#2ca02c", "#ff7f0e", "#1f77b4"])
# class_inds="original" gives you the right label order
# max_display: defines how many features will be shown in the plot
# class_names: plot labels
shap.summary_plot(
shap_values=shap_values,
features=test_data,
feature_names=list(test_data.keys()),
class_names=["u-jets", "c-jets", "b-jets"],
class_inds="original",
plot_type="bar",
color=jet_cmap,
plot_size=plot_size,
max_display=100,
)
plt.tight_layout()
plt.savefig(plot_path + "/plots/" + plot_name + ".pdf")
plt.close("all")
......@@ -77,6 +77,31 @@ def runTrainingDL1(config):
if isSuccess is True:
run_evaluate_model_DL1
logger.info(
"Test: running evaluate_model.py for DL1 with shapley option..."
)
run_evaluate_model_DL1_with_shapley = run(
[
"evaluate_model.py",
"-c",
f"{config}",
"-e",
"1",
"--dl1",
"--shapley",
]
)
try:
run_evaluate_model_DL1_with_shapley.check_returncode()
except CalledProcessError:
logger.info(
"Test failed: evaluate_model.py for DL1 with shapley option."
)
isSuccess = False
if isSuccess is True:
run_evaluate_model_DL1_with_shapley
return isSuccess
......
......@@ -53,3 +53,24 @@ Eval_parameters_validation:
WP_b: 0.77
# fc_value and WP_b are autmoatically added to the plot label
SecondTag: "\n$\\sqrt{s}=13$ TeV, PFlow jets"
# some properties for the feature importance explanation with SHAPley
shapley:
# Over how many full sets of features it should calculate over
# 200 takes like 10-15 min for DL1r on 32 core-cpu
feature_sets: 4
# defines which of the model outputs (flavor) you want to explain
# [tau,b,c,u] := [3, 2, 1, 0]
model_output: 2
# You can also choose if you want to view the feature importance of all
# flavors in one plot (usually we are just interested in the b-output)
bool_all_flavor_plot: True
# as this takes much longer you can average the shapley_feature_sets to a
# smaller set, 50 is a good choice for DL1r
averaged_sets: 2
# [11,11] works well for dl1r
plot_size: [11, 11]
......@@ -296,12 +296,22 @@ def Dips(args, train_config, preprocess_config):
min_lr=0.000001,
)
# Convert numpy arrays to tensors to avoid memory leak in callbacks
X_valid_tensor = tf.convert_to_tensor(X_valid, dtype=tf.float64)
Y_valid_tensor = tf.convert_to_tensor(Y_valid, dtype=tf.int64)
if train_config.add_validation_file is not None:
X_valid_add_tensor = tf.convert_to_tensor(X_valid_add, dtype=tf.float64)
Y_valid_add_tensor = tf.convert_to_tensor(Y_valid_add, dtype=tf.int64)
else:
X_valid_add_tensor = None
Y_valid_add_tensor = None
# Forming a dict for Callback
val_data_dict = {
"X_valid": X_valid,
"Y_valid": Y_valid,
"X_valid_add": X_valid_add,
"Y_valid_add": Y_valid_add,
"X_valid": X_valid_tensor,
"Y_valid": Y_valid_tensor,
"X_valid_add": X_valid_add_tensor,
"Y_valid_add": Y_valid_add_tensor,
}
# Set my_callback as callback. Writes history information
......
......@@ -727,6 +727,9 @@ def GetTestSample(
all_jets = all_jets.append(jets, ignore_index=True)
all_labels = all_labels.append(labels, ignore_index=True)
# Add the number of loaded jets to counter
nJets_counter += len(all_jets)
# Stop loading if enough jets are loaded
if nJets_counter >= nJets:
break
......@@ -818,6 +821,9 @@ def GetTestSampleTrks(
np.append(all_trks, np.stack(var_arr_list, axis=-1))
np.append(all_labels, labels)
# Add the number of jets to counter
nJets_counter += len(all_trks)
# Stop loading if enough jets are loaded
if nJets_counter >= nJets:
break
......@@ -875,37 +881,26 @@ def load_validation_data(train_config, preprocess_config, nJets: int):
def load_validation_data_dips(train_config, preprocess_config, nJets: int):
exclude = None
if "exclude" in train_config.config:
exclude = train_config.config["exclude"]
val_data_dict = {}
(_, val_data_dict["X_valid"], val_data_dict["Y_valid"],) = GetTestFile(
(val_data_dict["X_valid"], val_data_dict["Y_valid"],) = GetTestSampleTrks(
train_config.validation_file,
train_config.var_dict,
preprocess_config,
nJets=nJets,
exclude=exclude,
)
(
val_data_dict["X_valid_add"],
val_data_dict["Y_valid_add"],
val_data_dict["X_valid_trk_add"],
) = (None, None, None)
) = (None, None)
if train_config.add_validation_file is not None:
(
_,
val_data_dict["X_valid_add"],
val_data_dict["Y_valid_add"],
) = GetTestFile(
) = GetTestSampleTrks(
train_config.add_validation_file,
train_config.var_dict,
preprocess_config,
nJets=nJets,
exclude=exclude,
)
assert (
val_data_dict["X_valid"].shape[1]
== val_data_dict["X_valid_add"].shape[1]
)
return val_data_dict
......@@ -1063,7 +1058,8 @@ def evaluate_model_dips(model, data_dict, target_beff=0.77, cfrac=0.018):
accuracy_add,
c_rej_add,
u_rej_add,
) = (None, None, None, None)
disc_cut_add,
) = (None, None, None, None, None)
if data_dict["X_valid_add"] is not None:
loss_add, accuracy_add = model.evaluate(
......@@ -1119,16 +1115,22 @@ def calc_validation_metrics(
for f in os.listdir(train_config.model_name)
if "model" in f
]
with open(
get_validation_dict_name(
WP_b=Eval_parameters["WP_b"],
fc_value=Eval_parameters["fc_value"],
n_jets=Eval_parameters["n_jets"],
dir_name=train_config.model_name,
),
"r",
) as training_out_json:
training_output_list = json.load(training_out_json)
try:
with open(
get_validation_dict_name(
WP_b=Eval_parameters["WP_b"],
fc_value=Eval_parameters["fc_value"],
n_jets=Eval_parameters["n_jets"],
dir_name=train_config.model_name,
),
"r",
) as training_out_json:
training_output_list = json.load(training_out_json)
except FileNotFoundError:
training_output_list = [
{"epoch": n} for n in range(train_config.NN_structure["epochs"])
]
results = []
for n, model_file in enumerate(training_output):
......@@ -1177,16 +1179,22 @@ def calc_validation_metrics_dips(
for f in os.listdir(train_config.model_name)
if "model_epoch" in f
]
with open(
get_validation_dict_name(
WP_b=Eval_parameters["WP_b"],
fc_value=Eval_parameters["fc_value"],
n_jets=Eval_parameters["n_jets"],
dir_name=train_config.model_name,
),
"r",
) as training_out_json:
training_output_list = json.load(training_out_json)
try:
with open(
get_validation_dict_name(
WP_b=Eval_parameters["WP_b"],
fc_value=Eval_parameters["fc_value"],
n_jets=Eval_parameters["n_jets"],
dir_name=train_config.model_name,
),
"r",
) as training_out_json:
training_output_list = json.load(training_out_json)
except FileNotFoundError:
training_output_list = [
{"epoch": n} for n in range(train_config.NN_structure["epochs"])
]
results = []
for n, model_file in enumerate(sorted(training_output, key=natural_keys)):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment