diff --git a/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py b/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py index 5373876caf74988ad1bd720a09ee58087265445b..962a363f69fa82cad923b21e1d8478fcbd2bc328 100755 --- a/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py +++ b/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py @@ -221,16 +221,23 @@ def test(model_file, tree_file, mva_utils = ROOT.MVAUtils.BDT(tree) - objective = booster.dump_model()['objective'] - if 'multiclass' in objective: - logging.info("testing multiclass") + objective = booster.dump_model()['objective'].strip() + # binary and xentropy are not the exact same thing when training but the output value is the same + # (https://lightgbm.readthedocs.io/en/latest/Parameters.html) + binary_aliases = ('binary', 'cross_entropy', 'xentropy') + regression_aliases = ('regression_l2', 'l2', 'mean_squared_error', 'mse', 'l2_root', 'root_mean_squared_error', 'rmse') + multiclass_aliases = ('multiclass', 'softmax') + if objective in multiclass_aliases: + logging.info("assuming multiclass, testing") return test_multiclass(booster, mva_utils, ntests, test_file) - elif 'binary' in objective: - logging.info("testing binary") + elif objective in binary_aliases: + logging.info("assuming binary classification, testing") return test_binary(booster, mva_utils, ntests, test_file) - else: - logging.info("testing regression") + elif objective in regression_aliases: + logging.info("assuming regression, testing") return test_regression(booster, mva_utils, ntests, test_file) + else: + print("cannot understand objective '%s'" % objective) def test_regression(booster, mva_utils, ntests=10000, test_file=None):