train_Dips.py 10.5 KB
Newer Older
Manuel Guth's avatar
Manuel Guth committed
1
from umami.configuration import logger  # isort:skip
Manuel Guth's avatar
Manuel Guth committed
2
import argparse
3
import json
Manuel Guth's avatar
Manuel Guth committed
4

Janik Von Ahnen's avatar
Janik Von Ahnen committed
5
import h5py
Manuel Guth's avatar
Manuel Guth committed
6
import tensorflow as tf
7
from tensorflow.keras import activations
Alexander Froch's avatar
Alexander Froch committed
8
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
9
10
11
12
13
14
15
16
17
18
19
from tensorflow.keras.layers import (
    Activation,
    BatchNormalization,
    Dense,
    Dropout,
    Input,
    Masking,
    TimeDistributed,
)
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
Manuel Guth's avatar
Manuel Guth committed
20

21
import umami.tf_tools as utf
Manuel Guth's avatar
Manuel Guth committed
22
import umami.train_tools as utt
23
from umami.institutes.utils import is_qsub_available, submit_zeuthen
Manuel Guth's avatar
Manuel Guth committed
24
from umami.preprocessing_tools import Configuration
Manuel Guth's avatar
Manuel Guth committed
25
26
27
28


def GetParser():
    """Argument parser for Preprocessing script."""
29
30
31
32
33
    parser = argparse.ArgumentParser(
        description="Preprocessing command line options."
    )

    parser.add_argument(
Janik Von Ahnen's avatar
Janik Von Ahnen committed
34
35
        "-c",
        "--config_file",
36
37
        type=str,
        required=True,
Janik Von Ahnen's avatar
Janik Von Ahnen committed
38
        help="Name of the training config file",
39
40
41
    )

    parser.add_argument(
Alexander Froch's avatar
Alexander Froch committed
42
        "-e", "--epochs", type=int, help="Number of training epochs."
43
44
    )

45
46
47
48
49
50
51
    parser.add_argument(
        "-z",
        "--zeuthen",
        action="store_true",
        help="Train on Zeuthen with GPU support",
    )

Manuel Guth's avatar
Manuel Guth committed
52
    # TODO: implementng vr_overlap
53
    parser.add_argument(
Janik Von Ahnen's avatar
Janik Von Ahnen committed
54
55
56
        "--vr_overlap",
        action="store_true",
        help="Option to enable vr overlap removall for validation sets.",
57
58
    )

59
60
61
62
63
64
65
    parser.add_argument(
        "-o",
        "--overwrite_config",
        action="store_true",
        help="Overwrite the configs files saved in metadata folder",
    )

Manuel Guth's avatar
Manuel Guth committed
66
67
68
69
    args = parser.parse_args()
    return args


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def Dips_model(train_config=None, input_shape=None):
    # Load NN Structure and training parameter from file
    NN_structure = train_config.NN_structure

    # Set NN options
    batch_norm = NN_structure["Batch_Normalisation"]
    dropout = NN_structure["dropout"]
    class_labels = NN_structure["class_labels"]

    if train_config.model_file is not None:
        # Load DIPS model from file
        logger.info(f"Loading model from: {train_config.model_file}")
        dips = load_model(
            train_config.model_file, {"Sum": utf.Sum}, compile=False
        )

    else:
        logger.info("No modelfile provided! Initialize a new one!")

        # Set the track input
        trk_inputs = Input(shape=input_shape)

        # Masking the missing tracks
        masked_inputs = Masking(mask_value=0)(trk_inputs)
        tdd = masked_inputs

        # Define the TimeDistributed layers for the different tracks
        for i, phi_nodes in enumerate(NN_structure["ppm_sizes"]):

            tdd = TimeDistributed(
                Dense(phi_nodes, activation="linear"), name=f"Phi{i}_Dense"
            )(tdd)

            if batch_norm:
                tdd = TimeDistributed(
                    BatchNormalization(), name=f"Phi{i}_BatchNormalization"
                )(tdd)

            if dropout != 0:
                tdd = TimeDistributed(
                    Dropout(rate=dropout), name=f"Phi{i}_Dropout"
                )(tdd)

            tdd = TimeDistributed(
                Activation(activations.relu), name=f"Phi{i}_ReLU"
            )(tdd)

        # This is where the magic happens... sum up the track features!
        F = utf.Sum(name="Sum")(tdd)

        # Define the main dips structure
        for j, (F_nodes, p) in enumerate(
            zip(
                NN_structure["dense_sizes"],
                [dropout] * len(NN_structure["dense_sizes"][:-1]) + [0],
            )
        ):

            F = Dense(F_nodes, activation="linear", name=f"F{j}_Dense")(F)
            if batch_norm:
                F = BatchNormalization(name=f"F{j}_BatchNormalization")(F)
            if dropout != 0:
                F = Dropout(rate=p, name=f"F{j}_Dropout")(F)
            F = Activation(activations.relu, name=f"F{j}_ReLU")(F)

        # Set output and activation function
        output = Dense(
            len(class_labels), activation="softmax", name="Jet_class"
        )(F)
        dips = Model(inputs=trk_inputs, outputs=output)

    # Print Dips model summary when log level lower or equal INFO level
    if logger.level <= 20:
        dips.summary()

    # Set optimier and loss
    model_optimizer = Adam(learning_rate=NN_structure["lr"])
    dips.compile(
        loss="categorical_crossentropy",
        optimizer=model_optimizer,
        metrics=["accuracy"],
    )
    return dips, NN_structure["epochs"]


Manuel Guth's avatar
Manuel Guth committed
155
def Dips(args, train_config, preprocess_config):
156
157
    # Load NN Structure and training parameter from file
    NN_structure = train_config.NN_structure
158
    Val_params = train_config.Eval_parameters_validation
159

160
    # Load the validation tracks
Manuel Guth's avatar
Manuel Guth committed
161
162
163
    X_valid, Y_valid = utt.GetTestSampleTrks(
        input_file=train_config.validation_file,
        var_dict=train_config.var_dict,
164
        preprocess_config=preprocess_config,
165
        class_labels=NN_structure["class_labels"],
166
        nJets=int(Val_params["n_jets"]),
Alexander Froch's avatar
Alexander Froch committed
167
    )
Manuel Guth's avatar
Manuel Guth committed
168

169
170
171
172
173
174
    # Load the extra validation tracks if defined.
    # If not, set it to none
    if train_config.add_validation_file is not None:
        X_valid_add, Y_valid_add = utt.GetTestSampleTrks(
            input_file=train_config.add_validation_file,
            var_dict=train_config.var_dict,
175
            preprocess_config=preprocess_config,
176
            class_labels=NN_structure["class_labels"],
177
            nJets=int(Val_params["n_jets"]),
178
179
180
181
182
183
        )

    else:
        X_valid_add = None
        Y_valid_add = None

184
    # Get the shapes for training
185
    with h5py.File(train_config.train_file, "r") as f:
186
187
        nJets, nTrks, nFeatures = f["X_trk_train"].shape
        nJets, nDim = f["Y_train"].shape
Manuel Guth's avatar
Manuel Guth committed
188

189
190
191
    if NN_structure["nJets_train"] is not None:
        nJets = NN_structure["nJets_train"]

192
    # Print how much jets are used
Manuel Guth's avatar
Manuel Guth committed
193
    logger.info(f"Number of Jets used for training: {nJets}")
194

195
    # Init dips model
196
    dips, epochs = Dips_model(
197
198
        train_config=train_config, input_shape=(nTrks, nFeatures)
    )
Manuel Guth's avatar
Manuel Guth committed
199

Alexander Froch's avatar
Alexander Froch committed
200
    # Get training set from generator
Janik Von Ahnen's avatar
Janik Von Ahnen committed
201
202
    train_dataset = (
        tf.data.Dataset.from_generator(
203
            utf.dips_generator(
204
                train_file_path=train_config.train_file,
205
                X_trk_Name="X_trk_train",
206
207
208
209
                Y_Name="Y_train",
                n_jets=NN_structure["nJets_train"],
                batch_size=NN_structure["batch_size"],
            ),
Janik Von Ahnen's avatar
Janik Von Ahnen committed
210
211
212
213
214
215
216
217
218
            (tf.float32, tf.float32),
            (
                tf.TensorShape([None, nTrks, nFeatures]),
                tf.TensorShape([None, nDim]),
            ),
        )
        .repeat()
        .prefetch(3)
    )
Manuel Guth's avatar
Manuel Guth committed
219

Alexander Froch's avatar
Alexander Froch committed
220
    # Check if epochs is set via argparser or not
221
222
    if args.epochs is None:
        nEpochs = epochs
Manuel Guth's avatar
Manuel Guth committed
223

Alexander Froch's avatar
Alexander Froch committed
224
    # If not, use epochs from config file
225
226
227
    else:
        nEpochs = args.epochs

Alexander Froch's avatar
Alexander Froch committed
228
    # Set ModelCheckpoint as callback
229
    dips_mChkPt = ModelCheckpoint(
230
        f"{train_config.model_name}" + "/dips_model_{epoch:03d}.h5",
Janik Von Ahnen's avatar
Janik Von Ahnen committed
231
        monitor="val_loss",
232
233
        verbose=True,
        save_best_only=False,
Alexander Froch's avatar
Alexander Froch committed
234
        validation_batch_size=NN_structure["batch_size"],
Janik Von Ahnen's avatar
Janik Von Ahnen committed
235
        save_weights_only=False,
236
237
    )

Alexander Froch's avatar
Alexander Froch committed
238
    # Set ReduceLROnPlateau as callback
239
    reduce_lr = ReduceLROnPlateau(
Janik Von Ahnen's avatar
Janik Von Ahnen committed
240
241
        monitor="loss",
        factor=0.8,
242
        patience=3,
Janik Von Ahnen's avatar
Janik Von Ahnen committed
243
244
245
246
        verbose=1,
        mode="auto",
        cooldown=5,
        min_lr=0.000001,
247
    )
Manuel Guth's avatar
Manuel Guth committed
248

Sebastien Rettie's avatar
Sebastien Rettie committed
249
250
251
252
    # Convert numpy arrays to tensors to avoid memory leak in callbacks
    X_valid_tensor = tf.convert_to_tensor(X_valid, dtype=tf.float64)
    Y_valid_tensor = tf.convert_to_tensor(Y_valid, dtype=tf.int64)
    if train_config.add_validation_file is not None:
253
254
255
        X_valid_add_tensor = tf.convert_to_tensor(
            X_valid_add, dtype=tf.float64
        )
Sebastien Rettie's avatar
Sebastien Rettie committed
256
257
258
259
260
        Y_valid_add_tensor = tf.convert_to_tensor(Y_valid_add, dtype=tf.int64)
    else:
        X_valid_add_tensor = None
        Y_valid_add_tensor = None

261
262
    # Forming a dict for Callback
    val_data_dict = {
Sebastien Rettie's avatar
Sebastien Rettie committed
263
264
265
266
        "X_valid": X_valid_tensor,
        "Y_valid": Y_valid_tensor,
        "X_valid_add": X_valid_add_tensor,
        "Y_valid_add": Y_valid_add_tensor,
267
268
    }

269
270
    # Set my_callback as callback. Writes history information
    # to json file.
271
    my_callback = utt.MyCallback(
272
        model_name=train_config.model_name,
273
274
        class_labels=train_config.NN_structure["class_labels"],
        main_class=train_config.NN_structure["main_class"],
275
        val_data_dict=val_data_dict,
276
        target_beff=train_config.Eval_parameters_validation["WP"],
277
        frac_dict=train_config.Eval_parameters_validation["frac_values"],
278
        dict_file_name=utt.get_validation_dict_name(
279
            WP=train_config.Eval_parameters_validation["WP"],
280
            n_jets=train_config.Eval_parameters_validation["n_jets"],
281
282
            dir_name=train_config.model_name,
        ),
Alexander Froch's avatar
Alexander Froch committed
283
    )
Manuel Guth's avatar
Manuel Guth committed
284

Manuel Guth's avatar
Manuel Guth committed
285
    logger.info("Start training")
286
    history = dips.fit(
287
288
289
        train_dataset,
        epochs=nEpochs,
        validation_data=(X_valid, Y_valid),
290
        callbacks=[dips_mChkPt, reduce_lr, my_callback],
Alexander Froch's avatar
Alexander Froch committed
291
        steps_per_epoch=nJets / NN_structure["batch_size"],
292
        use_multiprocessing=True,
Janik Von Ahnen's avatar
Janik Von Ahnen committed
293
        workers=8,
294
    )
Manuel Guth's avatar
Manuel Guth committed
295

296
297
298
299
    # Dump dict into json
    logger.info(
        f"Dumping history file to {train_config.model_name}/history.json"
    )
300
301
302
303
304

    # Make the history dict the same shape as the dict from the callbacks
    hist_dict = utt.prepare_history_dict(history.history)

    # Dump history file to json
305
    with open(f"{train_config.model_name}/history.json", "w") as outfile:
306
        json.dump(hist_dict, outfile, indent=4)
307

Alexander Froch's avatar
Linting    
Alexander Froch committed
308

309
310
311
312
313
314
315
def DipsZeuthen(args, train_config, preprocess_config):
    if is_qsub_available():
        args.model_name = train_config.model_name
        args.dips = True
        submit_zeuthen(args)
    else:
        logger.warning(
316
317
            "No Zeuthen batch system found, training locally instead."
        )
318
319
320
        Dips(args, train_config, preprocess_config)


Janik Von Ahnen's avatar
Janik Von Ahnen committed
321
if __name__ == "__main__":
Manuel Guth's avatar
Manuel Guth committed
322
    args = GetParser()
323
324
325
326
327

    gpus = tf.config.experimental.list_physical_devices("GPU")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

Manuel Guth's avatar
Manuel Guth committed
328
329
    train_config = utt.Configuration(args.config_file)
    preprocess_config = Configuration(train_config.preprocess_config)
330
331
332

    utt.create_metadata_folder(
        train_config_path=args.config_file,
333
        var_dict_path=train_config.var_dict,
334
        model_name=train_config.model_name,
335
        preprocess_config_path=train_config.preprocess_config,
336
337
338
339
        overwrite_config=True if args.overwrite_config else False,
    )

    # Start the real training
340
341
    if args.zeuthen:
        DipsZeuthen(args, train_config, preprocess_config)
342

343
344
    else:
        Dips(args, train_config, preprocess_config)