diff --git a/.gitignore b/.gitignore index 21fa5fcb65758e1c1396738c9a796d91c3f6d8e4..3b50feb0a0bc53e9ed146ca9a7f26518adba703d 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,3 @@ plots/* env/* .vscode/* !umami/tests/unit/**/*.png -python_install/ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ad2eb5af93310ea501509ea98e582d954462b625..819431cc43942eef8b624077f584f3afd2331c6f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,8 +1,7 @@ # Tag of TensorFlow base image # https://pypi.org/project/tensorflow/#history variables: - TFTAG: 2.6.0 - TORCHTAG: 1.9.0-cuda11.1-cudnn8-runtime + TFTAG: 2.5.0 stages: - linting @@ -31,20 +30,11 @@ linter: - if: $CI_COMMIT_BRANCH != '' - if: $CI_PIPELINE_SOURCE == "merge_request_event" -yaml_linter: - stage: linting - image: sdesbure/yamllint - script: - - 'yamllint -d "{extends: relaxed, rules: {line-length: disable}}" .' - rules: - - if: $CI_COMMIT_BRANCH != '' - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - test_coverage: stage: coverage_test_stage image: python:3.7-slim script: - - pip install --upgrade pip setuptools wheel + - pip install --upgrade pip - pip install -r requirements.txt - cd ./coverage_files/ - coverage combine @@ -52,14 +42,13 @@ test_coverage: - coverage xml rules: - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - - if: $CI_PIPELINE_SOURCE == "merge_request_event" && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/umami" + - if: $CI_PIPELINE_SOURCE == "merge_request_event" && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/umami" artifacts: when: always paths: - coverage_files/ reports: cobertura: coverage_files/coverage.xml - retry: 2 include: - 'pipelines/.unit_test-gitlab-ci.yaml' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99fecb49b3d0579a02b6fd2916ae46609a69417c..a49e88dbc35b8f77c89d9bf57038ca68b7af9910 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,4 +19,4 @@ repos: language: system entry: flake8 types: [python] - exclude: setup.py + exclude: setup.py \ No newline at end of file diff --git a/docker/umamibase/Dockerfile b/docker/umamibase/Dockerfile index bf6503c44057c6173d583da951fa01069b02cddf..d55268e441ca307c0d88244f0fa35549368466e4 100644 --- a/docker/umamibase/Dockerfile +++ b/docker/umamibase/Dockerfile @@ -11,7 +11,7 @@ RUN apt-get update && \ echo "krb5-config krb5-config/add_servers_realm string CERN.CH" | debconf-set-selections && \ echo "krb5-config krb5-config/default_realm string CERN.CH" | debconf-set-selections && \ apt-get install -y krb5-user && \ - apt-get install -y vim nano emacs less screen graphviz python3-tk wget + apt-get install -y vim nano emacs less screen graphviz python3-tk wget COPY requirements.txt . diff --git a/examples/DL1r-PFlow-Training-config.yaml b/examples/DL1r-PFlow-Training-config.yaml index 2079f2ea55562c5605555c3253871361c8b4e30e..fc21bc0f500e37fc72e101091936011ad786448a 100644 --- a/examples/DL1r-PFlow-Training-config.yaml +++ b/examples/DL1r-PFlow-Training-config.yaml @@ -26,7 +26,7 @@ ttbar_test_files: data_set_name: "ttbar_comparison" zpext_test_files: - zpext_r21: + zpext_r21: Path: /work/ws/nemo/fr_af1100-Training-Simulations-0/hybrids/MC16d_hybrid-ext_odd_0_PFlow-no_pTcuts-file_1.h5 data_set_name: "zpext" @@ -45,11 +45,11 @@ bool_use_taus: False exclude: [] -NN_structure: +NN_structure: lr: 0.01 batch_size: 15000 activations: ["relu", "relu", "relu", "relu", "relu", "relu", "relu", "relu"] - units: [256, 128, 60, 48, 36, 24, 12, 6] + units: [256, 128, 60, 48, 36, 24, 12, 6] # Eval parameters for validation evaluation while training Eval_parameters_validation: @@ -106,3 +106,4 @@ Eval_parameters_validation: # Set the datatype of the plots plot_datatype: "pdf" + diff --git a/examples/Dips-PFlow-Training-config.yaml b/examples/Dips-PFlow-Training-config.yaml index 8d9d67d2b58fc92473c59fe3c267a689a197641d..add66aeb2c1d92d623921ffb2bcb1ff1ff5d514c 100755 --- a/examples/Dips-PFlow-Training-config.yaml +++ b/examples/Dips-PFlow-Training-config.yaml @@ -26,7 +26,7 @@ ttbar_test_files: data_set_name: "ttbar_comparison" zpext_test_files: - zpext_r21: + zpext_r21: Path: /work/ws/nemo/fr_af1100-Training-Simulations-0/hybrids/MC16d_hybrid-ext_odd_0_PFlow-no_pTcuts-file_1.h5 data_set_name: "zpext" @@ -100,4 +100,4 @@ Eval_parameters_validation: SecondTag: "\n$\\sqrt{s}=13$ TeV, PFlow jets" # Set the datatype of the plots - plot_datatype: "pdf" + plot_datatype: "pdf" \ No newline at end of file diff --git a/examples/PFlow-Preprocessing-DESY.yaml b/examples/PFlow-Preprocessing-DESY.yaml index fd0e3dc0dde768799aa38de3d7fc32b21c093f92..66c04e433de84ac107f0bd14a52048b6c80e6675 100644 --- a/examples/PFlow-Preprocessing-DESY.yaml +++ b/examples/PFlow-Preprocessing-DESY.yaml @@ -42,7 +42,7 @@ preparation: training_ttbar_cjets: type: ttbar category: cjets - # Number of c jets available in MC16d + # Number of c jets available in MC16d n_jets: 12745953 n_split: 13 cuts: @@ -170,7 +170,7 @@ bool_process_taus: False # set to true if extended flavour labelling scheme is used in preprocessing bool_extended_labelling: False -# Define undersampling method used. Valid are "count", "weight", +# Define undersampling method used. Valid are "count", "weight", # "count_bcl_weight_tau", "template_b" and "template_b_count" # count_bcl_weight_tau is a hybrid of count and weight to deal with taus. # template_b uses the b as the target distribution, but does not guarantee diff --git a/examples/PFlow-Preprocessing.yaml b/examples/PFlow-Preprocessing.yaml index 0b43aaf085786dbb239d0a501341877a787beae5..229afc186b52ee748b440cf09cbe42c2f05f7409 100755 --- a/examples/PFlow-Preprocessing.yaml +++ b/examples/PFlow-Preprocessing.yaml @@ -42,7 +42,7 @@ preparation: training_ttbar_cjets: type: ttbar category: cjets - # Number of c jets available in MC16d + # Number of c jets available in MC16d n_jets: 12745953 n_split: 13 cuts: @@ -170,7 +170,7 @@ bool_process_taus: False # set to true if extended flavour labelling scheme is used in preprocessing bool_extended_labelling: False -# Define undersampling method used. Valid are "count", "weight", +# Define undersampling method used. Valid are "count", "weight", # "count_bcl_weight_tau", "template_b" and "template_b_count" # count_bcl_weight_tau is a hybrid of count and weight to deal with taus. # template_b uses the b as the target distribution, but does not guarantee diff --git a/examples/custom-pdf.py b/examples/custom-pdf.py deleted file mode 100644 index 52b6f42c1ef5a1b031b76604a5c41d69c955cebc..0000000000000000000000000000000000000000 --- a/examples/custom-pdf.py +++ /dev/null @@ -1,20 +0,0 @@ -import numpy as np - -from umami.preprocessing import PDFSampling - -# create some dummy data -x = np.random.default_rng().normal(size=1000) -y = np.random.default_rng().normal(1, 2, size=1000) - -# get 2d histogram of our dummy data -h_original, x_bins, y_bins = np.histogram2d(x, y, [4, 5]) - -# calculate a custom function -pt = np.cos(x ** 2) + np.sin(x + y) + np.exp(x) -eta = 20 - y ** 2 - -h_target, _, _ = np.histogram2d(pt, eta, bins=[x_bins, y_bins]) - -ps = PDFSampling() -ps.CalculatePDFRatio(h_target, h_original, x_bins, y_bins) -ps.save("custom-pdf.pkl") diff --git a/examples/dumper-evalute-config.yaml b/examples/dumper-evalute-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a6eecc1e0dd5e4798ebfcf577c74e70833d03cd --- /dev/null +++ b/examples/dumper-evalute-config.yaml @@ -0,0 +1,14 @@ +# Add train_config filepath +train_config: /work/ws/nemo/fr_af1100-Training-Simulations-0/b-Tagging/Submission_Scripts/Train_Dips_Loose/configs/Dips-PFlow-Training-config.yaml + +# Add preprocess_config with which the used model is trained +preprocess_config: /work/ws/nemo/fr_af1100-Training-Simulations-0/b-Tagging/Submission_Scripts/Train_Dips_Loose/configs/PFlow-Preprocessing.yaml + +# Path to Variable dict used in preprocessing +var_dict: /work/ws/nemo/fr_af1100-Training-Simulations-0/b-Tagging/Submission_Scripts/Train_Dips_Loose/configs/Dips_Variables.yaml + +# File to test the dips integration in the dumper +test_file: /work/ws/nemo/fr_af1100-Training-Simulations-0/DAOD_PHYSVAL.25394913._000001.pool.root.1_loose.h5 + +# Path to model dirs +model_file: /work/ws/nemo/fr_af1100-Training-Simulations-0/b-Tagging/packages/umami/umami/dips_Loose_lr_0.001_bs_15000_epoch_200_nTrainJets_Full/dips_model_59.h5 \ No newline at end of file diff --git a/examples/evalute_comp_taggers.yaml b/examples/evalute_comp_taggers.yaml index f57855de87a9534ff80f3eaa7b02f2603388546f..d10f92fbfb4f3bab63b3c1a56eac6c8c7b981992 100644 --- a/examples/evalute_comp_taggers.yaml +++ b/examples/evalute_comp_taggers.yaml @@ -17,7 +17,7 @@ ttbar_test_files: data_set_name: "ttbar_comparison" zpext_test_files: - zpext_r21: + zpext_r21: Path: /work/ws/nemo/fr_af1100-Training-Simulations-0/hybrids/MC16d_hybrid-ext_odd_0_PFlow-no_pTcuts-file_1.h5 data_set_name: "zpext" @@ -42,4 +42,4 @@ Eval_parameters_validation: fc_values_comp: { "rnnip": 0.08, "DL1r": 0.018, - } + } \ No newline at end of file diff --git a/examples/plotting_input_vars.yaml b/examples/plotting_input_vars.yaml index db6597beb95bbea7479c898e149c9d5d6de32710..330eb13102946da7ae4fb404840ce94540de0ca7 100644 --- a/examples/plotting_input_vars.yaml +++ b/examples/plotting_input_vars.yaml @@ -27,31 +27,31 @@ jets_input_vars: special_param_jets: IP2D_cu: lim_left: -30 - lim_right: 30 - IP2D_bu: + lim_right: 30 + IP2D_bu: lim_left: -30 - lim_right: 30 - IP2D_bc: + lim_right: 30 + IP2D_bc: lim_left: -30 - lim_right: 30 - IP3D_cu: + lim_right: 30 + IP3D_cu: lim_left: -30 - lim_right: 30 - IP3D_bu: + lim_right: 30 + IP3D_bu: lim_left: -30 - lim_right: 30 - IP3D_bc: + lim_right: 30 + IP3D_bc: lim_left: -30 lim_right: 30 SV1_NGTinSvx: lim_left: 0 - lim_right: 19 + lim_right: 19 JetFitterSecondaryVertex_nTracks: lim_left: 0 - lim_right: 17 + lim_right: 17 JetFitter_nTracksAtVtx: lim_left: 0 - lim_right: 19 + lim_right: 19 JetFitter_nSingleTracks: lim_left: 0 lim_right: 18 @@ -62,50 +62,50 @@ jets_input_vars: lim_left: 0 lim_right: 200 binning: - IP2D_cu: 100 - IP2D_bu: 100 - IP2D_bc: 100 - IP2D_isDefaults: 2 - IP3D_cu: 100 - IP3D_bu: 100 - IP3D_bc: 100 - IP3D_isDefaults: 2 - JetFitter_mass: 100 - JetFitter_energyFraction: 100 - JetFitter_significance3d: 100 - JetFitter_deltaR: 100 - JetFitter_nVTX: 7 - JetFitter_nSingleTracks: 19 - JetFitter_nTracksAtVtx: 20 - JetFitter_N2Tpair: 201 - JetFitter_isDefaults: 2 + IP2D_cu : 100 + IP2D_bu : 100 + IP2D_bc : 100 + IP2D_isDefaults : 2 + IP3D_cu : 100 + IP3D_bu : 100 + IP3D_bc : 100 + IP3D_isDefaults : 2 + JetFitter_mass : 100 + JetFitter_energyFraction : 100 + JetFitter_significance3d : 100 + JetFitter_deltaR : 100 + JetFitter_nVTX : 7 + JetFitter_nSingleTracks : 19 + JetFitter_nTracksAtVtx : 20 + JetFitter_N2Tpair : 201 + JetFitter_isDefaults : 2 JetFitterSecondaryVertex_minimumTrackRelativeEta: 11 JetFitterSecondaryVertex_averageTrackRelativeEta: 11 JetFitterSecondaryVertex_maximumTrackRelativeEta: 11 - JetFitterSecondaryVertex_maximumAllJetTrackRelativeEta: 11 - JetFitterSecondaryVertex_minimumAllJetTrackRelativeEta: 11 - JetFitterSecondaryVertex_averageAllJetTrackRelativeEta: 11 - JetFitterSecondaryVertex_displacement2d: 100 - JetFitterSecondaryVertex_displacement3d: 100 - JetFitterSecondaryVertex_mass: 100 - JetFitterSecondaryVertex_energy: 100 - JetFitterSecondaryVertex_energyFraction: 100 - JetFitterSecondaryVertex_isDefaults: 2 - JetFitterSecondaryVertex_nTracks: 18 - pt_btagJes: 100 - absEta_btagJes: 100 - SV1_Lxy: 100 - SV1_N2Tpair: 8 - SV1_NGTinSvx: 20 - SV1_masssvx: 100 - SV1_efracsvx: 100 - SV1_significance3d: 100 - SV1_deltaR: 10 - SV1_L3d: 100 - SV1_isDefaults: 2 - rnnip_pb: 50 - rnnip_pc: 50 - rnnip_pu: 50 + JetFitterSecondaryVertex_maximumAllJetTrackRelativeEta : 11 + JetFitterSecondaryVertex_minimumAllJetTrackRelativeEta : 11 + JetFitterSecondaryVertex_averageAllJetTrackRelativeEta : 11 + JetFitterSecondaryVertex_displacement2d : 100 + JetFitterSecondaryVertex_displacement3d : 100 + JetFitterSecondaryVertex_mass : 100 + JetFitterSecondaryVertex_energy : 100 + JetFitterSecondaryVertex_energyFraction : 100 + JetFitterSecondaryVertex_isDefaults : 2 + JetFitterSecondaryVertex_nTracks : 18 + pt_btagJes : 100 + absEta_btagJes : 100 + SV1_Lxy : 100 + SV1_N2Tpair : 8 + SV1_NGTinSvx : 20 + SV1_masssvx : 100 + SV1_efracsvx : 100 + SV1_significance3d : 100 + SV1_deltaR : 10 + SV1_L3d : 100 + SV1_isDefaults : 2 + rnnip_pb : 50 + rnnip_pc : 50 + rnnip_pu : 50 flavours: b: 5 c: 4 @@ -178,4 +178,4 @@ tracks_input_vars: flavours: b: 5 c: 4 - u: 0 + u: 0 \ No newline at end of file diff --git a/examples/plotting_umami_config_DL1r.yaml b/examples/plotting_umami_config_DL1r.yaml index bb64912fbc0d53c0b136a8f12490d17852c28a9b..46510896cf5eb75c1b517b624e2db8ef91edcde1 100644 --- a/examples/plotting_umami_config_DL1r.yaml +++ b/examples/plotting_umami_config_DL1r.yaml @@ -9,7 +9,7 @@ scores_DL1r: # Each item on this level defines one plot. The name of this key is type: "scores" data_set_name: "ttbar" # data set to use. This chooses either the test_file ('ttbar') or the add_test_file ('zpext') # To include taus, add "ptau" as the last entry - prediction_labels: ["DL1_pb", "DL1_pc", "DL1_pu"] # For umami use umami_pX or dips_pX. + prediction_labels: ["DL1_pb", "DL1_pc", "DL1_pu"] # For umami use umami_pX or dips_pX. plot_settings: # All options of the score plot can be changed here UseAtlasTag: True # Enable/Disable AtlasTag AtlasTag: "Internal Simulation" @@ -58,9 +58,9 @@ DL1r_c_flavour: binomialErrors: true SecondTag: "\n$\\sqrt{s}=13$ TeV, PFlow jets,\n$t\\bar{t}$ test sample, fc=0.018" -# To do a DL1r_t_flavour ROC plot: +# To do a DL1r_t_flavour ROC plot: # Same as above with DL1_trej and DL1r_trej (replace "c" flavour by "t"). - + # Example of a pt vs efficiency plot in a small pT region eff_vs_pt_small: type: "ROCvsVar" @@ -69,8 +69,8 @@ eff_vs_pt_small: flat_eff: True # bool whether to plot a flat b-efficiency as a function of var efficiency: 70 # the targeted efficiency fc: 0.018 - prediction_labels: ["DL1_pb", "DL1_pc", "DL1_pu"] # the prediction label to use - variable: pt # which variable to plot the efficiency as a function of. + prediction_labels: ["DL1_pb", "DL1_pc", "DL1_pu"] # the prediction label to use + variable: pt # which variable to plot the efficiency as a function of. max_variable: 1500000 #maximum value of the range of variable. min_variable: 10000 #minimum value of the range of variable. nbin: 100 #number of bin to use @@ -93,8 +93,8 @@ eff_vs_pt_large: data_set_name: "ttbar" flat_eff: True #bool whether to plot a flat b-efficiency as a function of var efficiency: 70 #the targeted efficiency - prediction_labels: ["DL1_pb", "DL1_pc", "DL1_pu"] # the prediction label to use - variable: pt #which variable to plot the efficiency as a function of. + prediction_labels: ["DL1_pb", "DL1_pc", "DL1_pu"] # the prediction label to use + variable: pt #which variable to plot the efficiency as a function of. max_variable: 5000000 #maximum value of the range of variable. min_variable: 200000 #minimum value of the range of variable. nbin: 15 #number of bin to use diff --git a/examples/plotting_umami_config_Umami.yaml b/examples/plotting_umami_config_Umami.yaml index 64f1e60292c6f8dd7b9300d5aa10f4e4339caea6..76098a50f3dfea1835a55a0966a779d48dc5886f 100644 --- a/examples/plotting_umami_config_Umami.yaml +++ b/examples/plotting_umami_config_Umami.yaml @@ -48,7 +48,7 @@ confusion_matrix_Umami_ttbar: prediction_labels: ["umami_pb", "umami_pc", "umami_pu"] # For umami use umami_pX or dips_pX. The order matters! # Scanning b-eff, comparing Umami and DL1r, ttbar -beff_scan_tagger_umami: +beff_scan_tagger_compare_umami: type: "ROC" models_to_plot: umami_ttbar_urej: @@ -111,7 +111,6 @@ Umami_prob_comparison_pb: SecondTag: "\n$\\sqrt{s}=13$ TeV, PFlow Jets" yAxisAtlasTag: 0.9 Ratio_Cut: [0.5, 1.5] - # Scanning b-eff, comparing Umami and DL1r, ttbar beff_scan_tagger_compare_umami: type: "ROC_Comparison" diff --git a/examples/umami-PFlow-Training-config.yaml b/examples/umami-PFlow-Training-config.yaml index 109b325a56c631dc9a08f6f63827638c6b137b8f..6c29a8f6af5ec02711b837c26040fef233bf7b31 100644 --- a/examples/umami-PFlow-Training-config.yaml +++ b/examples/umami-PFlow-Training-config.yaml @@ -22,7 +22,7 @@ ttbar_test_files: data_set_name: "ttbar" zpext_test_files: - zpext_r21: + zpext_r21: Path: /nfs/dust/atlas/user/ahnenjan/phd/umami/run/samples/standard_mc16d_ttbar_Zext__2M/hybridLargeFiles/MC16d_hybrid-ext_odd_0_PFlow-no_pTcuts-file_1.h5 data_set_name: "zpext" @@ -33,7 +33,7 @@ bool_use_taus: False exclude: [] -NN_structure: +NN_structure: lr: 0.01 batch_size: 5000 epochs: 200 diff --git a/mkdocs.yml b/mkdocs.yml index 5e6f21b4a1390acf77a51c83c7f87ebb1f4889ed..13d500f600b6bdf605ff3649bc47fbfa442c0f56 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,7 +28,7 @@ nav: - Running DL1r: DL1r-instructions.md - Running Dips: Dips-instructions.md - Running Umami: Umami-instructions.md - - LWTNN Conversion: LWTNN-conversion.md + - LWTNN Conversion: LWTNN-conversion.md - Evaluate Taggers in Samples: WO_trained_model.md - Plotting evaluated Results: plotting_umami.md diff --git a/pipelines/.docker-gitlab-ci.yaml b/pipelines/.docker-gitlab-ci.yaml index 67a9edee670b848e5e3f9ca290e8e1487b52e09c..3af3ad9385bc72c98cba0a8d6382cf6cec5f0fb7 100644 --- a/pipelines/.docker-gitlab-ci.yaml +++ b/pipelines/.docker-gitlab-ci.yaml @@ -1,5 +1,6 @@ # ---------------------------------------------------------------------------- -# Umami base + Umami images: only get built on master and tags +# Umami base + Umami images: only get built on master +# (see below for tags) # ---------------------------------------------------------------------------- .image_build_template: &image_build_template @@ -8,7 +9,6 @@ - ignore tags: - docker-image-build - retry: 2 build_umamibase_cpu: @@ -37,17 +37,6 @@ build_umamibase_gpu: variables: TO: '${CI_REGISTRY}/${CI_PROJECT_NAMESPACE}/umami/umamibase:$CI_COMMIT_REF_SLUG-gpu' -build_umamibase_gpu_pytorch: - <<: *image_build_template - stage: builds - variables: - DOCKER_FILE: docker/umamibase/Dockerfile - FROM: pytorch/pytorch:$TORCHTAG - TO: '${CI_REGISTRY}/${CI_PROJECT_NAMESPACE}/umami/umamibase:latest-pytorch-gpu' - rules: - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH - - # Umami images: use base image as a foundation to speed up build process build_umami_cpu: <<: *image_build_template @@ -98,18 +87,7 @@ build_umamibase_cpu_MR: rules: - if: $CI_PIPELINE_SOURCE == "merge_request_event" && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/umami" -# possibility to trigger also the GPU image in a MR - but only manually -build_umamibase_gpu_MR: - <<: *image_build_template - stage: image_build_umamibase - variables: - DOCKER_FILE: docker/umamibase/Dockerfile - FROM: tensorflow/tensorflow:$TFTAG-gpu - TO: '${CI_REGISTRY}/${CI_PROJECT_NAMESPACE}/umami/temporary_images:${CI_MERGE_REQUEST_IID}-gpu-base' - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/umami" - when: manual - allow_failure: true + # ---------------------------------------------------------------------------- # Publishing: # copies of the images built in gitlab CI/CD will be deployed to Docker Hub diff --git a/pipelines/.docs-gitlab-ci.yaml b/pipelines/.docs-gitlab-ci.yaml index 4f7bdfac5ba8dc941ed3196877d4961e02e82275..78849082641c748f78760fb420fb10d85661e118 100644 --- a/pipelines/.docs-gitlab-ci.yaml +++ b/pipelines/.docs-gitlab-ci.yaml @@ -1,15 +1,16 @@ + build_docs: stage: builds image: gitlab-registry.cern.ch/authoring/documentation/mkdocs:stable script: - - mkdocs build --strict --clean --site-dir www - - echo -e "AddDefaultCharset UTF-8\nSSLRequireSSL\n" > www/.htaccess + - mkdocs build --strict --clean --site-dir www + - echo -e "AddDefaultCharset UTF-8\nSSLRequireSSL\n" > www/.htaccess before_script: - - "" # overwrite default, do nothing + - "" # overwrite default, do nothing artifacts: - paths: - - www - expire_in: 1 hour + paths: + - www + expire_in: 1 hour only: - master diff --git a/pipelines/.integration_test-gitlab-ci.yaml b/pipelines/.integration_test-gitlab-ci.yaml index 0c8a92f43651d417c9b3c217a6dec85a3e892843..117b09cd4510c1ef6be42256cb3cbdaefb317431 100644 --- a/pipelines/.integration_test-gitlab-ci.yaml +++ b/pipelines/.integration_test-gitlab-ci.yaml @@ -9,7 +9,6 @@ IMAGE_TYPE: "umamibase:latest" - if: $CI_PIPELINE_SOURCE == "merge_request_event" && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/umami" image: '${CI_REGISTRY}/${CI_PROJECT_NAMESPACE}/umami/$IMAGE_TYPE' - retry: 2 .artifact_template: &artifact_template name: "$CI_JOB_NAME" diff --git a/pipelines/.unit_test-gitlab-ci.yaml b/pipelines/.unit_test-gitlab-ci.yaml index 49ceb356c9767ba0ffc39c0703c9fb05a77bf3da..580cebe12259ce007b4602f9f1ea2aa1de45c159 100644 --- a/pipelines/.unit_test-gitlab-ci.yaml +++ b/pipelines/.unit_test-gitlab-ci.yaml @@ -4,7 +4,7 @@ unittest: script: - pip install -r requirements.txt - apt-get update - - apt-get install -y wget + - apt-get install -y wget - python setup.py develop - pytest ./umami/tests/unit/preprocessing -v - pytest ./umami/tests/unit/evaluation_tools -v @@ -33,7 +33,6 @@ unittest: - coverage_files/ reports: junit: report.xml - retry: 2 unittest_preprocessing: <<: *unittest_template @@ -57,4 +56,4 @@ unittest_input_vars_tools: <<: *unittest_template script: - pytest --cov=./ --cov-report= ./umami/tests/unit/input_vars_tools/ -v --junitxml=report.xml - - cp .coverage coverage_files/.coverage.unittest_input_vars_tools + - cp .coverage coverage_files/.coverage.unittest_input_vars_tools \ No newline at end of file diff --git a/umami/configs/DL1r_Variables.yaml b/umami/configs/DL1r_Variables.yaml index ec52e3339e5c139b3ff07a7d92ee0ff77246635f..f01332f6975e88c5c6af24a2fba9c915faa01be4 100644 --- a/umami/configs/DL1r_Variables.yaml +++ b/umami/configs/DL1r_Variables.yaml @@ -47,11 +47,11 @@ train_variables: - IP3D_bu - IP3D_bc - IP3D_cu - RNNIP: + RNNIP: - rnnip_pb - rnnip_pc - rnnip_pu - + # useful variables which might want to be kept but being used for training spectator_variables: - DL1r_pb @@ -68,4 +68,4 @@ custom_defaults_vars: SV1_NGTinSvx: -1 SV1_efracsvx: 0 JetFitterSecondaryVertex_nTracks: 0 - JetFitterSecondaryVertex_energyFraction: 0 + JetFitterSecondaryVertex_energyFraction: 0 \ No newline at end of file diff --git a/umami/configs/DL1r_Variables_R22.yaml b/umami/configs/DL1r_Variables_R22.yaml index b90fb5345b9fed590d6bbb4a0bd277685bf3e108..da29bc7e9c9a987f9cbfe7abc9171058ea7f4ff6 100644 --- a/umami/configs/DL1r_Variables_R22.yaml +++ b/umami/configs/DL1r_Variables_R22.yaml @@ -48,7 +48,7 @@ train_variables: - IP3D_bu - IP3D_bc - IP3D_cu - RNNIP: + RNNIP: - rnnip_pb - rnnip_pc - rnnip_pu diff --git a/umami/configs/Dips_Variables.yaml b/umami/configs/Dips_Variables.yaml index b1f4118f6a45b6b29b0b6182bb9c20f23d248bdb..83bed05c29ebbeed2adfdfed48d3e8ccd730c941 100644 --- a/umami/configs/Dips_Variables.yaml +++ b/umami/configs/Dips_Variables.yaml @@ -47,7 +47,7 @@ train_variables: - IP3D_bu - IP3D_bc - IP3D_cu - + # useful variables which might want to be kept but being used for training spectator_variables: - DL1r_pb @@ -67,7 +67,7 @@ custom_defaults_vars: JetFitterSecondaryVertex_energyFraction: 0 track_train_variables: - noNormVars: + noNormVars: - IP3D_signed_d0_significance - IP3D_signed_z0_significance - numberOfInnermostPixelLayerHits @@ -84,4 +84,4 @@ track_train_variables: - numberOfPixelHits - numberOfSCTHits - btagIp_d0 - - btagIp_z0SinTheta + - btagIp_z0SinTheta \ No newline at end of file diff --git a/umami/configs/Umami_Variables.yaml b/umami/configs/Umami_Variables.yaml index b1f4118f6a45b6b29b0b6182bb9c20f23d248bdb..80373ac3481ce9f217e1e8d1f04769d0324285c1 100644 --- a/umami/configs/Umami_Variables.yaml +++ b/umami/configs/Umami_Variables.yaml @@ -67,7 +67,7 @@ custom_defaults_vars: JetFitterSecondaryVertex_energyFraction: 0 track_train_variables: - noNormVars: + noNormVars: - IP3D_signed_d0_significance - IP3D_signed_z0_significance - numberOfInnermostPixelLayerHits diff --git a/umami/configs/global_config.yaml b/umami/configs/global_config.yaml index a8773fae57ddb4d706ba80f19c128aba77d9ca22..fb741a85020d94b575748249d038b4dfc6a7d533 100644 --- a/umami/configs/global_config.yaml +++ b/umami/configs/global_config.yaml @@ -22,4 +22,4 @@ hist_err_style: fill: False linewidth: 0 hatch: "/////" - edgecolor: "#666666" + edgecolor: "#666666" \ No newline at end of file diff --git a/umami/preprocessing_tools/PDF_Sampling.py b/umami/preprocessing_tools/PDF_Sampling.py deleted file mode 100644 index 15ba2c58e0d79c943e9d0d64371e87da8800c698..0000000000000000000000000000000000000000 --- a/umami/preprocessing_tools/PDF_Sampling.py +++ /dev/null @@ -1,257 +0,0 @@ -import os -import pickle - -import numpy as np -from scipy.interpolate import RectBivariateSpline - -from umami.configuration import logger - - -class PDFSampling(object): - """ - Sampling method using ratios between distributions to sample training - file. - An importance sampling approach - """ - - def __init__(self): - self.inter_func = None - self._ratio = None - - def CalculatePDF( - self, - x_y_target, - x_y_original, - bins=[100, 9], - ratio_max: float = 1, - ): - """ - Calculates the histograms of the input data and uses them to - calculate the PDF Ratio. - CalculatePDFRatio is invoked here. - - Inputs: - x_y_target: A 2D tuple of the target datapoints of x and y. - x_y_original: A 2D tuple of the to resample datapoints of x and y. - bins: This can be all possible binning inputs as for numpy - histogram2d. - ratio_max: Maximum Ratio difference which is used for upsampling - the inputs. - - Output: - Provides the PDF interpolation function which is used for sampling. - This can be returned with Inter_Func. It is a property of the class. - """ - - # Calculate the corresponding histograms - h_target, self._x_bin_edges, self._y_bin_edges = np.histogram2d( - x_y_target[0], x_y_target[1], bins - ) - - h_original, _, _ = np.histogram2d( - x_y_original[0], - x_y_original[1], - [self._x_bin_edges, self._y_bin_edges], - ) - - # Calculate the PDF Ratio - self.CalculatePDFRatio( - h_target, - h_original, - self._x_bin_edges, - self._y_bin_edges, - ratio_max=ratio_max, - ) - - def CalculatePDFRatio( - self, - h_target, - h_original, - x_bin_edges, - y_bin_edges, - ratio_max: float = 1, - ): - """ - Receives the histograms of the target and original data, the bins - and a max ratio value. Latter is optional. - - Inputs: - h_target: Output of numpy histogram2D for the target datapoints - h_original: Output of numpy histogram2D for the original datapoints - bins: The bin edges of the binning used for the numpy histogram2D. - This is also returned from numpy histgram2D - ratio_max: Maximum Ratio difference which is used for upsampling - the inputs. - - Output: - Provides the PDF interpolation function which is used for sampling. - This can be returned with Inter_Func. It is a property of the class. - """ - - # Normalise the histograms to unity - h_target = h_target / np.sum(h_target) - h_original = h_original / np.sum(h_original) - - # Transform bin edges to bin centres - self.x_bins = (x_bin_edges[:-1] + x_bin_edges[1:]) / 2 - self.y_bins = (y_bin_edges[:-1] + y_bin_edges[1:]) / 2 - - # Calculating the ratio of the reference distribution w.r.t. the target distribution - ratio = np.divide( - h_target, - h_original, - out=np.zeros( - h_original.shape, - dtype=float, - ), - where=(h_original != 0), - ) - - # Setting max ratio value - self._ratio = ratio / ratio.max() * ratio_max - - # Calculate interpolation function - logger.info("Retrieve interpolation function") - self.inter_func = RectBivariateSpline( - self.x_bins, self.y_bins, self._ratio - ) - - def save(self, file_name: str, overwrite: bool = False): - """ - Save the interpolation function to file - - Input: - file_name: Path where the pickle file is saved. - - Output: - Pickle file of the PDF interpolation function. - """ - - if self.inter_func is not None: - if os.path.isfile(file_name) is True: - if overwrite is True: - logger.warning( - "File already exists at given path! Overwrite existing file!" - ) - - # Dump function into pickle file - with open(file_name, "wb") as file: - pickle.dump(self.inter_func, file) - - else: - logger.warning( - "File already exists at given path! PDF interpolation function not saved!" - ) - - else: - # Dump function into pickle file - with open(file_name, "wb") as file: - pickle.dump(self.inter_func, file) - - else: - raise ValueError("Interpolation function not calculated/given!") - - def load(self, file_name: str): - """ - Load the interpolation function from file. - - Input: - file_name: Path where the pickle file is saved. - - Output: - PDF interpolation function of the pickle file is added as property - to the class. - """ - - with open(file_name, "rb") as file: - self.inter_func = pickle.load(file) - - # the resampling is so far only working for a batch which is being normalised - # TODO: rename - def inMemoryResample(self, x_values, y_values, size): - """ - Resample all of the datapoints at once. Requirement for that - is that all datapoints fit in the RAM. - - Input: - x_values: x values of the datapoints which are to be resampled from (i.e pT) - y_values: y values of the datapoints which are to be resampled from (i.e eta) - size: Number of jets which are resampled. - - Output: - Resampled jets - """ - - if type(x_values) == float or type(x_values) == int: - x_values = np.asarray([x_values]) - - if type(y_values) == float or type(y_values) == int: - y_values = np.asarray([y_values]) - - # Check for sizes of x_values and y_values - if len(y_values) != len(x_values): - raise ValueError("x_values and y_values need to have same size!") - - # Evaluate the datapoints with the PDF function - r_resamp = self.inter_func.ev(x_values, y_values) - - # Discard all datapoints where the ratio is 0 or less - indices = np.where(r_resamp >= 0)[0] - r_resamp = r_resamp[indices] - - # Normalise the datapoints for sampling - r_resamp = r_resamp / np.sum(r_resamp) - - # Resample the datapoints based on their PDF Ratio value - sampled_indices = np.random.default_rng().choice( - indices, p=r_resamp, size=size - ) - - # Return the resampled datapoints - return x_values[sampled_indices], y_values[sampled_indices] - - # TODO: rename - def Resample(self, x_values, y_values): - """ - Resample a batch of datapoints at once. This function is used - if multiple files need to resampled and also if the datapoints - does not fit in the RAM. - - Input: - x_values: x values of the datapoints which are to be resampled from (i.e pT) - y_values: y values of the datapoints which are to be resampled from (i.e eta) - - Output: - Resampled jets - """ - - if type(x_values) == float or type(x_values) == int: - x_values = np.asarray([x_values]) - - if type(y_values) == float or type(y_values) == int: - y_values = np.asarray([y_values]) - - # Check for sizes of x_values and y_values - if len(y_values) != len(x_values): - raise ValueError("x_values and y_values need to have same size!") - - # Evaluate the datapoints with the PDF function - r_resamp = self.inter_func.ev(x_values, y_values) - - # Get random numbers from generator - rnd_numbers = np.random.default_rng().uniform(0, 1, len(r_resamp)) - - # Decide, based on the PDF values for the datapoints and the random - # numbers which datapoints are sampled - sampled_indices = np.where(rnd_numbers < r_resamp) - - # Return sampled datapoints - return x_values[sampled_indices], y_values[sampled_indices] - - @property - def ratio(self): - return self._ratio - - @property - def Inter_Func(self): - return self.inter_func diff --git a/umami/preprocessing_tools/__init__.py b/umami/preprocessing_tools/__init__.py index 8db17a09af729d9bd9a2ba6dace76b1dfce8babc..acfb90ba3b796dc245301ecd46876f1f8d93f8a5 100644 --- a/umami/preprocessing_tools/__init__.py +++ b/umami/preprocessing_tools/__init__.py @@ -9,20 +9,19 @@ from umami.preprocessing_tools.Merging import ( create_datasets, get_size, ) -from umami.preprocessing_tools.PDF_Sampling import PDFSampling from umami.preprocessing_tools.Preparation import get_jets from umami.preprocessing_tools.Resampling import ( - EnforceFraction, Gen_default_dict, GetNJetsPerIteration, GetScales, - RunSampling, - RunStatSamples, UnderSampling, UnderSamplingProp, UnderSamplingTemplate, Weighting2D, dict_in, + EnforceFraction, + RunStatSamples, + RunSampling, ) from umami.preprocessing_tools.utils import ( GetBinaryLabels, diff --git a/umami/preprocessing_tools/configs/preprocessing_default_config.yaml b/umami/preprocessing_tools/configs/preprocessing_default_config.yaml index 2c2e1b4b89c966274ec23cb8d1c62b8d9c5be499..6912415129936d61c2af7f62a7c9ffbf749ee73b 100644 --- a/umami/preprocessing_tools/configs/preprocessing_default_config.yaml +++ b/umami/preprocessing_tools/configs/preprocessing_default_config.yaml @@ -16,7 +16,7 @@ pT_max: False bool_process_taus: False # set to true if extended flavour labelling is used in preprocessing bool_extended_labelling: False -sampling_method: count +sampling_method: count f_z: path: Null file: Null diff --git a/umami/tests/integration/fixtures/testSetup.yaml b/umami/tests/integration/fixtures/testSetup.yaml index ea413c08be166ee2c8b7bdd85ccece47ca3cd094..9221cfa8c252469e2ffcb54fb3c2e4732625234d 100644 --- a/umami/tests/integration/fixtures/testSetup.yaml +++ b/umami/tests/integration/fixtures/testSetup.yaml @@ -51,4 +51,4 @@ test_input_vars_plot: testdir: /tmp/umami/plot_input_vars/ files: - plot_input_vars_r21_check.h5 - - plot_input_vars_r22_check.h5 + - plot_input_vars_r22_check.h5 \ No newline at end of file diff --git a/umami/tests/unit/input_vars_tools/fixtures/plot_input_variables.yaml b/umami/tests/unit/input_vars_tools/fixtures/plot_input_variables.yaml index 835cda8f79e18409cff5421cec6381b0e56676a6..ec8dcf8c240ed862ec74f92d508f3250bd55311d 100644 --- a/umami/tests/unit/input_vars_tools/fixtures/plot_input_variables.yaml +++ b/umami/tests/unit/input_vars_tools/fixtures/plot_input_variables.yaml @@ -71,6 +71,21 @@ Tracks_Test: R22: files: /tmp/umami/plot_input_vars/plot_input_vars_r22_check.h5 label: "R22 Test" + plot_settings: + Log: True + UseAtlasTag: True + AtlasTag: "Internal Simulation" + SecondTag: "$\\sqrt{s}$ = 13 TeV, $t\\bar{t}$ PFlow Jets \n3000 Jets" + yAxisAtlasTag: 0.925 + yAxisIncrease: 2 + figsize: [7, 5] + Ratio_Cut: [0.5, 2] + bool_use_taus: False + flavours: + b: 5 + c: 4 + u: 0 + plot_settings: sorting_variable: "ptfrac" n_Leading: [None, 0] diff --git a/umami/tests/unit/preprocessing/fixtures/PDF_interpolation_function.pkl b/umami/tests/unit/preprocessing/fixtures/PDF_interpolation_function.pkl deleted file mode 100644 index 27cf3e71b7e9da2fbc46518b990bd5aa1acb47df..0000000000000000000000000000000000000000 Binary files a/umami/tests/unit/preprocessing/fixtures/PDF_interpolation_function.pkl and /dev/null differ diff --git a/umami/tests/unit/preprocessing/test_preprocessing_tools_PDF_Sampling.py b/umami/tests/unit/preprocessing/test_preprocessing_tools_PDF_Sampling.py deleted file mode 100644 index e6771440993a640f13ecc61b4713af2e56540dd0..0000000000000000000000000000000000000000 --- a/umami/tests/unit/preprocessing/test_preprocessing_tools_PDF_Sampling.py +++ /dev/null @@ -1,200 +0,0 @@ -import os -import tempfile -import unittest - -import numpy as np - -from umami.preprocessing_tools import PDFSampling - - -class PDFSampling_TestCase(unittest.TestCase): - """ - Unit test the PDFSampling class - """ - - def setUp(self): - """ - Set-Up a few testarrays for PDF Sampling - """ - - self.tmp_dir = tempfile.TemporaryDirectory() - self.tmp_func_dir = f"{self.tmp_dir.name}/" - self.func_dir = os.path.join(os.path.dirname(__file__), "fixtures/") - self.x_y_target = ( - np.random.default_rng().uniform(-1, 1, 10000), - np.random.default_rng().uniform(-1, 1, 10000), - ) - - self.x_y_original = ( - np.random.default_rng().normal(0, 1, size=10000), - np.random.default_rng().normal(0, 1, size=10000), - ) - - self.ratio_max = 1 - self.bins = [50, 50] - - def test_Init_Properties(self): - # Init new Sampler - Sampler = PDFSampling() - - # Check if Interpolation function is None - self.assertTrue(Sampler.Inter_Func is None) - self.assertTrue(Sampler.ratio is None) - - def test_CalculatePDF(self): - # Init new Sampler - Sampler = PDFSampling() - - # Calculate Interpolation function - Sampler.CalculatePDF( - x_y_target=self.x_y_target, - x_y_original=self.x_y_original, - bins=self.bins, - ) - - self.assertTrue((self.bins[0], self.bins[1]) == Sampler.ratio.shape) - self.assertTrue(Sampler.Inter_Func is not None) - - def test_CalculatePDFRatio(self): - # Init new Sampler - Sampler = PDFSampling() - - h_target, x_bin_edges, y_bin_edges = np.histogram2d( - x=self.x_y_target[0], - y=self.x_y_target[1], - bins=self.bins, - ) - - h_original, _, _ = np.histogram2d( - x=self.x_y_original[0], - y=self.x_y_original[1], - bins=[x_bin_edges, y_bin_edges], - ) - - Sampler.CalculatePDFRatio( - h_target=h_target, - h_original=h_original, - x_bin_edges=x_bin_edges, - y_bin_edges=y_bin_edges, - ratio_max=self.ratio_max, - ) - - self.assertTrue((self.bins[0], self.bins[1]) == Sampler.ratio.shape) - self.assertTrue(Sampler.Inter_Func is not None) - - def test_inMemoryResample(self): - # Init new Sampler - Sampler = PDFSampling() - - # Calculate Interpolation function - Sampler.CalculatePDF( - x_y_target=self.x_y_target, - x_y_original=self.x_y_original, - bins=self.bins, - ) - - x_values, y_values = Sampler.inMemoryResample( - x_values=self.x_y_original[0], - y_values=self.x_y_original[1], - size=1000, - ) - - self.assertEqual( - len(x_values), - 1000, - ) - - self.assertEqual( - len(y_values), - 1000, - ) - - def test_Resample_Array(self): - # Init new Sampler - Sampler = PDFSampling() - - # Calculate Interpolation function - Sampler.CalculatePDF( - x_y_target=self.x_y_target, - x_y_original=self.x_y_original, - bins=self.bins, - ) - - x_values, y_values = Sampler.Resample( - x_values=self.x_y_original[0], - y_values=self.x_y_original[1], - ) - - self.assertEqual( - len(x_values), - len(y_values), - ) - - def test_Resample_Float(self): - # Init new Sampler - Sampler = PDFSampling() - - # Calculate Interpolation function - Sampler.CalculatePDF( - x_y_target=self.x_y_target, - x_y_original=self.x_y_original, - bins=self.bins, - ) - - x_values, y_values = Sampler.Resample( - y_values=2, - x_values=2, - ) - - self.assertEqual( - len(x_values), - len(y_values), - ) - - def test_save(self): - # Init new Sampler - Sampler = PDFSampling() - - # Calculate Interpolation function - Sampler.CalculatePDF( - x_y_target=self.x_y_target, - x_y_original=self.x_y_original, - bins=self.bins, - ) - - # Save function to pickle file - Sampler.save( - os.path.join(self.tmp_func_dir, "PDF_interpolation_function.pkl") - ) - - self.assertTrue( - os.path.isfile( - os.path.join( - self.tmp_func_dir, "PDF_interpolation_function.pkl" - ) - ) - ) - - def test_load(self): - # Init new Sampler - Sampler = PDFSampling() - - Sampler.load( - os.path.join(self.func_dir, "PDF_interpolation_function.pkl") - ) - - x_values, y_values = Sampler.inMemoryResample( - x_values=self.x_y_original[0], - y_values=self.x_y_original[1], - size=1000, - ) - - self.assertEqual( - len(x_values), - 1000, - ) - - self.assertEqual( - len(y_values), - 1000, - ) diff --git a/umami/tests/unit/train_tools/fixtures/test_train_config.yaml b/umami/tests/unit/train_tools/fixtures/test_train_config.yaml index 52047aa78cb086ee8d03cba50ee19d68c7d24a9d..2ab63a38656c7bc434e5bd32ca016aafa7694861 100644 --- a/umami/tests/unit/train_tools/fixtures/test_train_config.yaml +++ b/umami/tests/unit/train_tools/fixtures/test_train_config.yaml @@ -21,7 +21,7 @@ ttbar_test_files: data_set_name: "ttbar_unit_test" zpext_test_files: - zpext_r21: + zpext_r21: Path: dummy.h5 data_set_name: "zpext_unit_test" @@ -32,7 +32,7 @@ bool_use_taus: False exclude: [] -NN_structure: +NN_structure: lr: 0.01 batch_size: 100 epoch: 5 @@ -52,4 +52,4 @@ Eval_parameters_validation: fc_value: 0.018 WP_b: 0.77 # fc_value and WP_b are autmoatically added to the plot label - SecondTag: "\n$\\sqrt{s}=13$ TeV, PFlow jets" + SecondTag: "\n$\\sqrt{s}=13$ TeV, PFlow jets" \ No newline at end of file diff --git a/umami/tests/unit/train_tools/fixtures/var_dict_test.yaml b/umami/tests/unit/train_tools/fixtures/var_dict_test.yaml index 34f1abd48e231363b8d69459c9816c9c6586e8b4..36c83742f60144894c6663ec621f465ec657d525 100644 --- a/umami/tests/unit/train_tools/fixtures/var_dict_test.yaml +++ b/umami/tests/unit/train_tools/fixtures/var_dict_test.yaml @@ -8,7 +8,7 @@ train_variables: - JetFitter_mass track_train_variables: - noNormVars: + noNormVars: - IP3D_signed_d0_significance logNormVars: @@ -16,4 +16,4 @@ track_train_variables: - dr jointNormVars: - numberOfPixelHits - - numberOfSCTHits + - numberOfSCTHits \ No newline at end of file diff --git a/umami/train_DL1.py b/umami/train_DL1.py index f85b5f29724af03f4f241848d7a347b6f1ad6186..7965c41257c163ccab28b962585b187b380751ef 100644 --- a/umami/train_DL1.py +++ b/umami/train_DL1.py @@ -14,7 +14,7 @@ from tensorflow.keras.layers import ( Dropout, Input, ) -from tensorflow.keras.models import Model, load_model +from tensorflow.keras.models import Model from tensorflow.keras.optimizers import Adam import umami.train_tools as utt @@ -80,28 +80,25 @@ def NN_model(train_config, input_shape): bool_use_taus = train_config.bool_use_taus n_units_end = 4 if bool_use_taus else 3 NN_config = train_config.NN_structure - if train_config.model_file is not None: - logger.info(f"Loading model from: {train_config.model_file}") - model = load_model(train_config.model_file, compile=False) - else: - inputs = Input(shape=input_shape) - x = inputs - for i, unit in enumerate(NN_config["units"]): - x = Dense( - units=unit, - activation="linear", - kernel_initializer="glorot_uniform", - )(x) - x = BatchNormalization()(x) - x = Activation(NN_config["activations"][i])(x) - if "dropout_rate" in NN_config: - x = Dropout(NN_config["dropout_rate"][i])(x) - predictions = Dense( - units=n_units_end, - activation="softmax", + inputs = Input(shape=input_shape) + x = inputs + for i, unit in enumerate(NN_config["units"]): + x = Dense( + units=unit, + activation="linear", kernel_initializer="glorot_uniform", )(x) - model = Model(inputs=inputs, outputs=predictions) + x = BatchNormalization()(x) + x = Activation(NN_config["activations"][i])(x) + if "dropout_rate" in NN_config: + x = Dropout(NN_config["dropout_rate"][i])(x) + predictions = Dense( + units=n_units_end, + activation="softmax", + kernel_initializer="glorot_uniform", + )(x) + + model = Model(inputs=inputs, outputs=predictions) # Print DL1 model summary when log level lower or equal INFO level if logger.level <= 20: model.summary()