Commit 80b89173 authored by Joschka Birk's avatar Joschka Birk
Browse files

Merge branch alfroch-adding-unit-test with refs/heads/master into refs/merge-requests/546/train

parents 61598d34 a7192d58
Pipeline #3996729 passed with stages
in 12 minutes and 14 seconds
......@@ -22,7 +22,7 @@
script:
- pip install darglint
- darglint --list-errors
- find . -name "*.py" ! -name *PlottingFunctions.py ! -name *Plotting.py | xargs -n 1 -P 8 darglint -s numpy -z full --log-level INFO
- find . -name "*.py" ! -name *PlottingFunctions.py ! -name *Plotting.py ! -name *conf.py | xargs -n 1 -P 8 -t darglint
.pylint_template: &pylint_template
stage: linting
......
......@@ -4,6 +4,7 @@
### Latest
- Adding unit test for prepare_model and minor bug fixes [!546](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/546)
- Adding unit tests for tf generators[!542](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/542)
- Fix epoch bug in continue_training[!543](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/543)
- Updating tensorflow to version `2.9.0` and pytorch to `1.11.0-cuda11.3-cudnn8-runtime` [!547](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/umami/-/merge_requests/547)
......
import os
import tempfile
import unittest
from subprocess import run
import numpy as np
import tensorflow as tf
from umami.configuration import logger, set_log_level
from umami.tf_tools import Attention, DeepSet, DenseNet
from umami.tf_tools import Attention, DeepSet, DenseNet, prepare_model
set_log_level(logger, "DEBUG")
......@@ -326,3 +331,117 @@ class test_DenseNet(tf.test.TestCase):
# Test output
np.testing.assert_almost_equal(expected_output, out)
class TestPrepareModel(unittest.TestCase):
"""Test the prepare_model function."""
def setUp(self) -> None:
self.tmp_dir = tempfile.TemporaryDirectory()
self.tmp_test_dir = f"{self.tmp_dir.name}/"
self.model_name = self.tmp_test_dir + "Test_prepare_model"
self.model_file = None
self.NN_structure = {"load_optimiser": False}
os.makedirs(
os.path.join(
self.tmp_test_dir,
self.model_name,
"model_files",
)
)
run(
[
"wget",
"https://umami-ci-provider.web.cern.ch/umami/test_model_file.h5",
"--directory-prefix",
self.tmp_test_dir,
]
)
run(
[
"cp",
os.path.join(self.tmp_test_dir, "test_model_file.h5"),
os.path.join(
self.model_name,
"model_files",
"model_epoch001.h5",
),
]
)
def test_init_fresh_model(self):
"""Test fresh model init."""
model, init_epoch, load_optimiser = prepare_model(train_config=self)
with self.subTest("Check Model"):
self.assertIsNone(model)
with self.subTest("Check init_epoch"):
self.assertEqual(init_epoch, 0)
with self.subTest("Check load_optimiser"):
self.assertFalse(load_optimiser)
def test_init_fresh_model_no_load_optimiser_given(self):
"""Test fresh model init with no load_optimiser given."""
self.NN_structure = {}
model, init_epoch, load_optimiser = prepare_model(train_config=self)
with self.subTest("Check Model"):
self.assertIsNone(model)
with self.subTest("Check init_epoch"):
self.assertEqual(init_epoch, 0)
with self.subTest("Check load_optimiser"):
self.assertFalse(load_optimiser)
def test_load_optimiser_ValueError(self):
"""Test load optimiser error."""
self.NN_structure = {"load_optimiser": True}
with self.assertRaises(ValueError):
_, _, _ = prepare_model(train_config=self)
def test_load_model_without_continue_training(self):
"""Test loading of a model without continuation."""
self.model_file = os.path.join(
self.tmp_test_dir,
"test_model_file.h5",
)
model, init_epoch, load_optimiser = prepare_model(train_config=self)
with self.subTest("Check Model"):
self.assertTrue(isinstance(model, object))
with self.subTest("Check init_epoch"):
self.assertEqual(init_epoch, 0)
with self.subTest("Check load_optimiser"):
self.assertFalse(load_optimiser)
def test_load_model_with_continue_training(self):
"""Test loading of a model without continuation."""
model, init_epoch, load_optimiser = prepare_model(
train_config=self,
continue_training=True,
)
with self.subTest("Check Model"):
self.assertTrue(isinstance(model, object))
with self.subTest("Check init_epoch"):
# The init_epoch value of keras is 0. If you start a new training
# the new epoch will be init_epoch + 1. If you already have a training
# the init_epoch must be the value of the last epoch saved, which is
# in this test case the epoch 1.
self.assertEqual(init_epoch, 1)
with self.subTest("Check load_optimiser"):
self.assertTrue(load_optimiser)
......@@ -942,7 +942,6 @@ def RunPerformanceCheck(
Eval_parameters = train_config.Eval_parameters_validation
Val_settings = train_config.Validation_metrics_settings
plot_args = train_config.plot_args
logger.warning(f"plot_args = {plot_args}")
frac_dict = Eval_parameters["frac_values"]
class_labels = train_config.NN_structure["class_labels"]
main_class = train_config.NN_structure["main_class"]
......@@ -951,6 +950,9 @@ def RunPerformanceCheck(
# labels. These are used several times in this function
val_files = train_config.validation_files
# Printing the given plot args for debugging
logger.debug(f"plot_args = {plot_args}")
# Check the main class input and transform it into a set
main_class = check_main_class_input(main_class)
......
......@@ -32,6 +32,7 @@ class Configuration:
"val_batch_size",
"eval_batch_size",
"taggers_from_file",
"trained_taggers",
"tagger_label",
"WP",
]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment