diff --git a/nnfwtbn/__init__.py b/nnfwtbn/__init__.py index f09831f6463c91f2f66ab29b2997491138b5e055..9ddd9dbbdac94a30f489ffda1771ce762853f837 100644 --- a/nnfwtbn/__init__.py +++ b/nnfwtbn/__init__.py @@ -9,3 +9,4 @@ from nnfwtbn.model import CrossValidator, ClassicalCV, MixedCV, \ Normalizer, EstimatorNormalizer, \ HepNet from nnfwtbn.stack import Stack, McStack, DataStack +from nnfwtbn.interface import TmvaBdt diff --git a/nnfwtbn/interface.py b/nnfwtbn/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..1a701afa50c085d837c5b9a2e019756bd7d04b38 --- /dev/null +++ b/nnfwtbn/interface.py @@ -0,0 +1,89 @@ +""" +This module provides classes to interface between classifiers from other +frameworks. +""" + +from abc import ABC, abstractmethod + +from lxml import etree +import numpy as np + +class Classifier(ABC): + """ + Abstract classifier train with another framework and loaded into nnfwtbn. + """ + + @abstractmethod + def predict(dataframe): + """ + Returns an array with the predicted values. + """ + +class TmvaBdt(Classifier): + """ + Experimental class to use BDT's from TMVA. The class has the following + limitations:. + - The XML file must contain exactly one classifier. + - The boosting method must be AdaBoost. + - Fisher cuts cannot be used. + """ + + def __init__(self, filename): + """ + Loads the BDT from an XML file. + """ + with open(filename) as xml_file: + xml = etree.parse(xml_file) + + # Checks against unsupported features + boost_type = xml.xpath("//Option[@name='BoostType']")[0].text + if boost_type != "AdaBoost": + raise Exception("Cannot handle boost type %r." % boost_type) + + fisher_cuts = xml.xpath("//Option[@name='UseFisherCuts']")[0].text + if fisher_cuts != "False": + raise Exception("Cannot handle Fisher cuts.") + + self.xml = xml + + def predict(self, dataframe): + """ + Evaluate the BDT on the given dataframe. The method returns an array + with the BDT scores. + """ + # Prepare input variables + variables = {int(_.get("VarIndex")): dataframe[_.get("Expression")] + for _ in self.xml.xpath("//Variable")} + + # Prepare result array + response = np.zeros(len(dataframe)) + sum_weights = 0 + + # Loop over trees + trees = self.xml.xpath("//BinaryTree") + for tree in trees: + tree_weight = float(tree.get("boostWeight")) + sum_weights += tree_weight + + # Loop over terminal notes of tree + leafs = tree.xpath(".//Node[@nType!=0]") + for leaf in leafs: + ancestors = leaf.xpath("ancestor::Node") + mask = np.ones(len(dataframe), dtype='bool') + + # Trace path from root to leaf and record surviving events + for node, next_node in zip(ancestors, ancestors[1:] + [leaf]): + variable = variables[int(node.get("IVar"))] + cut = float(node.get("Cut")) + cut_type = int(node.get("cType")) + next_type = {"l": 0, "r": 1}[next_node.get("pos")] + + # Actual evaluation of node cut + mask &= (next_type == cut_type) ^ (variable < cut) + + leaf_type = int(leaf.get("nType")) + + # Record prediction of tree + response[mask] += tree_weight * leaf_type + + return response / sum_weights diff --git a/requirements.txt b/requirements.txt index ae69b35f1bb2b2e12966a1acf674aef015cdc2b7..0f4a33312e32915e6d5b12d8349ca9b72ca299b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ tables pandas pylorentz https://gitlab.cern.ch/fsauerbu/atlasify/-/archive/v0.2.0/atlasify-v0.2.0.zip +lxml diff --git a/setup.py b/setup.py index 7a33e8db732ff298790c21a65a3a5bafe2b25ea3..1c446f4689695b76a121e6b8369ed5ffd61138d3 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,8 @@ setup(name='nnfwtbn', "tables", "pandas", "pylorentz", - "atlasify>=0.2.0"], + "atlasify>=0.2.0", + "lxml"], test_suite='nnfwtbn.tests', description='Experimental neural network framework to be named.', url="https://gitlab.cern.ch/fsauerbu/nnfwtbn",