From 494195bd602f4bf2fde8f618282b23f2635e142c Mon Sep 17 00:00:00 2001
From: Ruggero Turra <ruggero.turra@cern.ch>
Date: Tue, 17 Nov 2020 14:09:32 +0000
Subject: [PATCH] Fix bug in txt to ROOT file conversion for lightGBM training
 for MVAUtils

---
 .../MVAUtils/util/convertLGBMToRootTree.py    | 21 ++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py b/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py
index 5373876caf7..962a363f69f 100755
--- a/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py
+++ b/Reconstruction/MVAUtils/util/convertLGBMToRootTree.py
@@ -221,16 +221,23 @@ def test(model_file, tree_file,
 
     mva_utils = ROOT.MVAUtils.BDT(tree)
 
-    objective = booster.dump_model()['objective']
-    if 'multiclass' in objective:
-        logging.info("testing multiclass")
+    objective = booster.dump_model()['objective'].strip()
+    # binary and xentropy are not the exact same thing when training but the output value is the same
+    # (https://lightgbm.readthedocs.io/en/latest/Parameters.html)
+    binary_aliases = ('binary', 'cross_entropy', 'xentropy')
+    regression_aliases = ('regression_l2', 'l2', 'mean_squared_error', 'mse', 'l2_root', 'root_mean_squared_error', 'rmse')
+    multiclass_aliases = ('multiclass', 'softmax')
+    if objective in multiclass_aliases:
+        logging.info("assuming multiclass, testing")
         return test_multiclass(booster, mva_utils, ntests, test_file)
-    elif 'binary' in objective:
-        logging.info("testing binary")
+    elif objective in binary_aliases:
+        logging.info("assuming binary classification, testing")
         return test_binary(booster, mva_utils, ntests, test_file)
-    else:
-        logging.info("testing regression")
+    elif objective in regression_aliases:
+        logging.info("assuming regression, testing")
         return test_regression(booster, mva_utils, ntests, test_file)
+    else:
+        print("cannot understand objective '%s'" % objective)
 
 
 def test_regression(booster, mva_utils, ntests=10000, test_file=None):
-- 
GitLab