train_umami.py 11.5 KB
Newer Older
Janik Von Ahnen's avatar
Janik Von Ahnen committed
1
#!/usr/bin/env python
Manuel Guth's avatar
Manuel Guth committed
2
from umami.configuration import logger  # isort:skip
Manuel Guth's avatar
Manuel Guth committed
3
import argparse
4
import json
Manuel Guth's avatar
Manuel Guth committed
5

Janik Von Ahnen's avatar
Janik Von Ahnen committed
6
import h5py
Manuel Guth's avatar
Manuel Guth committed
7
import tensorflow as tf
8
import yaml
9
from tensorflow.keras import activations
10
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
11
12
13
14
15
16
17
18
19
20
21
22
from tensorflow.keras.layers import (
    Activation,
    BatchNormalization,
    Concatenate,
    Dense,
    Dropout,
    Input,
    Masking,
    TimeDistributed,
)
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
Manuel Guth's avatar
Manuel Guth committed
23

24
import umami.tf_tools as utf
Manuel Guth's avatar
Manuel Guth committed
25
import umami.train_tools as utt
26
from umami.institutes.utils import is_qsub_available, submit_zeuthen
27
from umami.preprocessing_tools import Configuration
28
from umami.tools import yaml_loader
29

Manuel Guth's avatar
Manuel Guth committed
30
31
32

def GetParser():
    """Argument parser for Preprocessing script."""
Janik Von Ahnen's avatar
Janik Von Ahnen committed
33
34
35
36
37
38
39
40
41
42
43
    parser = argparse.ArgumentParser(
        description="Preprocessing command line" "options."
    )

    parser.add_argument(
        "-c",
        "--config_file",
        type=str,
        required=True,
        help="Name of the training config file",
    )
44

Janik Von Ahnen's avatar
Janik Von Ahnen committed
45
46
47
48
49
50
51
    parser.add_argument(
        "-e",
        "--epochs",
        type=int,
        help="Number\
        of training epochs.",
    )
52
53
54
55
56
57
58
59

    parser.add_argument(
        "-z",
        "--zeuthen",
        action="store_true",
        help="Train on Zeuthen with GPU support",
    )

Manuel Guth's avatar
Manuel Guth committed
60
    # TODO: implementng vr_overlap
Janik Von Ahnen's avatar
Janik Von Ahnen committed
61
62
63
64
65
66
    parser.add_argument(
        "--vr_overlap",
        action="store_true",
        help="""Option to
                        enable vr overlap removall for validation sets.""",
    )
67
68
69
70
71
72
    parser.add_argument(
        "-o",
        "--overwrite_config",
        action="store_true",
        help="Overwrite the configs files saved in metadata folder",
    )
Manuel Guth's avatar
Manuel Guth committed
73
74
75
76
    args = parser.parse_args()
    return args


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def Umami_model(train_config=None, input_shape=None, njet_features=None):
    # Load NN Structure and training parameter from file
    NN_structure = train_config.NN_structure

    # Set NN options
    batch_norm = NN_structure["Batch_Normalisation"]
    dropout = NN_structure["dropout"]
    class_labels = NN_structure["class_labels"]

    if train_config.model_file is not None:
        # Load DIPS model from file
        logger.info(f"Loading model from: {train_config.model_file}")
        umami = load_model(
            train_config.model_file, {"Sum": utf.Sum}, compile=False
        )

    else:
        logger.info("No modelfile provided! Initialize a new one!")

        # Set the track input
        trk_inputs = Input(shape=input_shape)

        # Masking the missing tracks
        masked_inputs = Masking(mask_value=0)(trk_inputs)
        tdd = masked_inputs

        # Define the TimeDistributed layers for the different tracks
        for i, phi_nodes in enumerate(NN_structure["DIPS_ppm_units"]):

            tdd = TimeDistributed(
                Dense(phi_nodes, activation="linear"), name=f"Phi{i}_Dense"
            )(tdd)

            if batch_norm:
                tdd = TimeDistributed(
                    BatchNormalization(), name=f"Phi{i}_BatchNormalization"
                )(tdd)

            if dropout != 0:
                tdd = TimeDistributed(
                    Dropout(rate=dropout), name=f"Phi{i}_Dropout"
                )(tdd)

            tdd = TimeDistributed(
                Activation(activations.relu), name=f"Phi{i}_ReLU"
            )(tdd)

        # This is where the magic happens... sum up the track features!
        F = utf.Sum(name="Sum")(tdd)

        for j, (F_nodes, p) in enumerate(
            zip(
                NN_structure["DIPS_dense_units"],
                [dropout] * len(NN_structure["DIPS_dense_units"][:-1]) + [0],
            )
        ):

            F = Dense(F_nodes, activation="linear", name=f"F{j}_Dense")(F)
            if batch_norm:
                F = BatchNormalization(name=f"F{j}_BatchNormalization")(F)
            if dropout != 0:
                F = Dropout(rate=p, name=f"F{j}_Dropout")(F)
            F = Activation(activations.relu, name=f"F{j}_ReLU")(F)

        dips_output = Dense(
            len(class_labels), activation="softmax", name="dips"
        )(F)

        # Input layer
        jet_inputs = Input(shape=(njet_features,))

        # Adding the intermediate dense layers for DL1
        x = jet_inputs
        for unit in NN_structure["intermediate_units"]:
            x = Dense(
                units=unit,
                activation="linear",
                kernel_initializer="glorot_uniform",
            )(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)

        # Concatenate the inputs
        x = Concatenate()([F, x])

        # Loop to initialise the hidden layers
        for unit in NN_structure["DL1_units"]:
            x = Dense(
                units=unit,
                activation="linear",
                kernel_initializer="glorot_uniform",
            )(x)
            x = BatchNormalization()(x)
            x = Activation("relu")(x)

        jet_output = Dense(
            units=len(class_labels),
            activation="softmax",
            kernel_initializer="glorot_uniform",
            name="umami",
        )(x)

        umami = Model(
            inputs=[trk_inputs, jet_inputs], outputs=[dips_output, jet_output]
        )

    # Print Umami model summary when log level lower or equal INFO level
    if logger.level <= 20:
        umami.summary()

    # Set optimier and loss
    model_optimizer = Adam(learning_rate=NN_structure["lr"])
    umami.compile(
        loss="categorical_crossentropy",
        loss_weights={"dips": NN_structure["dips_loss_weight"], "umami": 1},
        optimizer=model_optimizer,
        metrics=["accuracy"],
    )

    return umami, NN_structure["epochs"]


Manuel Guth's avatar
Manuel Guth committed
199
def Umami(args, train_config, preprocess_config):
200
201
202
    # Load NN Structure and training parameter from file
    NN_structure = train_config.NN_structure

203
204
    val_data_dict = None
    if train_config.Eval_parameters_validation["n_jets"] > 0:
205
        val_data_dict = utt.load_validation_data_umami(
Janik Von Ahnen's avatar
Janik Von Ahnen committed
206
207
208
            train_config,
            preprocess_config,
            train_config.Eval_parameters_validation["n_jets"],
209
        )
Manuel Guth's avatar
Manuel Guth committed
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    # Load the excluded variables from train_config
    if "exclude" in train_config.config:
        exclude = train_config.config["exclude"]

    else:
        exclude = None

    # Load variable config
    with open(train_config.var_dict, "r") as conf:
        variable_config = yaml.load(conf, Loader=yaml_loader)

    # Get excluded variables
    _, _, excluded_var = utt.get_jet_feature_indices(
        variable_config["train_variables"], exclude
    )
Alexander Froch's avatar
Alexander Froch committed
226
227
228

    # Use the number of jets set in the config file for training
    NN_structure = train_config.NN_structure
229

230
    with h5py.File(train_config.train_file, "r") as f:
231
232
233
234
        nJets, nTrks, nFeatures = f["X_trk_train"].shape
        nJets, nDim = f["Y_train"].shape
        nJets, njet_features = f["X_train"].shape

Manuel Guth's avatar
Manuel Guth committed
235
236
    logger.info(f"nJets: {nJets}, nTrks: {nTrks}")
    logger.info(f"nFeatures: {nFeatures}, njet_features: {njet_features}")
237

Manuel Guth's avatar
Manuel Guth committed
238
239
240
241
242
243
    if NN_structure["nJets_train"] is not None:
        logger.info(
            f"Training only on {NN_structure['nJets_train']} jets as specified in the config."
        )
        nJets = NN_structure["nJets_train"]

244
    umami, _ = Umami_model(
245
246
247
248
        train_config=train_config,
        input_shape=(nTrks, nFeatures),
        njet_features=njet_features,
    )
Janik Von Ahnen's avatar
Janik Von Ahnen committed
249
250
251

    train_dataset = (
        tf.data.Dataset.from_generator(
252
            utf.umami_generator(
253
254
255
256
257
258
259
                train_file_path=train_config.train_file,
                X_Name="X_train",
                X_trk_Name="X_trk_train",
                Y_Name="Y_train",
                n_jets=NN_structure["nJets_train"],
                batch_size=train_config.NN_structure["batch_size"],
                excluded_var=excluded_var,
Janik Von Ahnen's avatar
Janik Von Ahnen committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            ),
            output_types=(
                {"input_1": tf.float32, "input_2": tf.float32},
                tf.float32,
            ),
            output_shapes=(
                {
                    "input_1": tf.TensorShape([None, nTrks, nFeatures]),
                    "input_2": tf.TensorShape([None, njet_features]),
                },
                tf.TensorShape([None, nDim]),
            ),
        )
        .repeat()
        .prefetch(3)
    )
Manuel Guth's avatar
Manuel Guth committed
276

Alexander Froch's avatar
Alexander Froch committed
277
278
279
280
281
282
283
    # Check if epochs is set via argparser or not
    if args.epochs is None:
        nEpochs = NN_structure["epochs"]

    # If not, use epochs from config file
    else:
        nEpochs = args.epochs
Manuel Guth's avatar
Manuel Guth committed
284

285
    # Define LearningRate Reducer as Callback
Janik Von Ahnen's avatar
Janik Von Ahnen committed
286
287
288
289
290
291
292
    reduce_lr = ReduceLROnPlateau(
        monitor="loss",
        factor=0.8,
        patience=3,
        verbose=1,
        mode="auto",
        cooldown=5,
Alexander Froch's avatar
Alexander Froch committed
293
        min_learning_rate=0.000001,
Janik Von Ahnen's avatar
Janik Von Ahnen committed
294
    )
295
296
297
298
299
300
301
302
303
304
305
306

    # Set ModelCheckpoint as callback
    umami_mChkPt = ModelCheckpoint(
        f"{train_config.model_name}" + "/umami_model_{epoch:03d}.h5",
        monitor="val_loss",
        verbose=True,
        save_best_only=False,
        validation_batch_size=NN_structure["batch_size"],
        save_weights_only=False,
    )

    # Init the Umami callback
Manuel Guth's avatar
Manuel Guth committed
307
    my_callback = utt.MyCallbackUmami(
308
        model_name=train_config.model_name,
309
310
        class_labels=train_config.NN_structure["class_labels"],
        main_class=train_config.NN_structure["main_class"],
311
        val_data_dict=val_data_dict,
312
        target_beff=train_config.Eval_parameters_validation["WP"],
313
        frac_dict=train_config.Eval_parameters_validation["frac_values"],
314
        dict_file_name=utt.get_validation_dict_name(
315
            WP=train_config.Eval_parameters_validation["WP"],
316
            n_jets=train_config.Eval_parameters_validation["n_jets"],
Janik Von Ahnen's avatar
Janik Von Ahnen committed
317
318
            dir_name=train_config.model_name,
        ),
Manuel Guth's avatar
Manuel Guth committed
319
320
    )

Janik Von Ahnen's avatar
Janik Von Ahnen committed
321
    # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
Manuel Guth's avatar
Manuel Guth committed
322
    logger.info("Start training")
323
    history = umami.fit(
Janik Von Ahnen's avatar
Janik Von Ahnen committed
324
325
        train_dataset,
        epochs=nEpochs,
326
        callbacks=[umami_mChkPt, reduce_lr, my_callback],
Janik Von Ahnen's avatar
Janik Von Ahnen committed
327
328
329
330
        steps_per_epoch=nJets / train_config.NN_structure["batch_size"],
        use_multiprocessing=True,
        workers=8,
    )
Manuel Guth's avatar
Manuel Guth committed
331

332
333
334
335
    # Dump dict into json
    logger.info(
        f"Dumping history file to {train_config.model_name}/history.json"
    )
336
337
338
339
340

    # Make the history dict the same shape as the dict from the callbacks
    hist_dict = utt.prepare_history_dict(history.history)

    # Dump history file to json
341
    with open(f"{train_config.model_name}/history.json", "w") as outfile:
342
        json.dump(hist_dict, outfile, indent=4)
343

Manuel Guth's avatar
Manuel Guth committed
344

345
346
347
348
349
350
351
def UmamiZeuthen(args, train_config, preprocess_config):
    if is_qsub_available():
        args.model_name = train_config.model_name
        args.umami = True
        submit_zeuthen(args)
    else:
        logger.warning(
352
353
            "No Zeuthen batch system found, training locally instead."
        )
354
355
356
        Umami(args, train_config, preprocess_config)


Janik Von Ahnen's avatar
Janik Von Ahnen committed
357
if __name__ == "__main__":
Manuel Guth's avatar
Manuel Guth committed
358
    args = GetParser()
Janik Von Ahnen's avatar
Janik Von Ahnen committed
359

Janik Von Ahnen's avatar
Janik Von Ahnen committed
360
    gpus = tf.config.experimental.list_physical_devices("GPU")
Janik Von Ahnen's avatar
Janik Von Ahnen committed
361
362
363
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

Manuel Guth's avatar
Manuel Guth committed
364
365
    train_config = utt.Configuration(args.config_file)
    preprocess_config = Configuration(train_config.preprocess_config)
366
367
368

    utt.create_metadata_folder(
        train_config_path=args.config_file,
369
        var_dict_path=train_config.var_dict,
370
        model_name=train_config.model_name,
371
        preprocess_config_path=train_config.preprocess_config,
372
373
374
        overwrite_config=True if args.overwrite_config else False,
    )

375
376
377
378
    if args.zeuthen:
        UmamiZeuthen(args, train_config, preprocess_config)
    else:
        Umami(args, train_config, preprocess_config)