Commit f095580b authored by Manuel Guth's avatar Manuel Guth
Browse files

Merge branch 'mguth-umami-train-njet-fix' into 'master'

fixing umami train with nJets

See merge request !241
parents b7016672 a70549dc
Pipeline #3226545 passed with stages
in 32 minutes and 18 seconds
......@@ -167,6 +167,9 @@ def TrainLargeFile(args, train_config, preprocess_config):
nJets, nFeatures = f["X_train"].shape
nJets, nDim = f["Y_train"].shape
if NN_structure["nJets_train"] is not None:
nJets = NN_structure["nJets_train"]
# Print how much jets are used
logger.info(f"Number of Jets used for training: {nJets}")
......
......@@ -186,6 +186,9 @@ def Dips(args, train_config, preprocess_config):
nJets, nTrks, nFeatures = f["X_trk_train"].shape
nJets, nDim = f["Y_train"].shape
if NN_structure["nJets_train"] is not None:
nJets = NN_structure["nJets_train"]
# Print how much jets are used
logger.info(f"Number of Jets used for training: {nJets}")
......@@ -285,8 +288,6 @@ def Dips(args, train_config, preprocess_config):
epochs=nEpochs,
validation_data=(X_valid, Y_valid),
callbacks=[dips_mChkPt, reduce_lr, my_callback],
# callbacks=[reduce_lr, my_callback],
# callbacks=[my_callback],
steps_per_epoch=nJets / NN_structure["batch_size"],
use_multiprocessing=True,
workers=8,
......
......@@ -235,6 +235,12 @@ def Umami(args, train_config, preprocess_config):
logger.info(f"nJets: {nJets}, nTrks: {nTrks}")
logger.info(f"nFeatures: {nFeatures}, njet_features: {njet_features}")
if NN_structure["nJets_train"] is not None:
logger.info(
f"Training only on {NN_structure['nJets_train']} jets as specified in the config."
)
nJets = NN_structure["nJets_train"]
umami, _ = Umami_model(
train_config=train_config,
input_shape=(nTrks, nFeatures),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment