diff --git a/nnfwtbn/process.py b/nnfwtbn/process.py index 535eaed00a0c70b7a5a252df1a3f037c6cb5bcb3..a8dc22079bfed70deebae216bc0f896615d74dc5 100644 --- a/nnfwtbn/process.py +++ b/nnfwtbn/process.py @@ -35,6 +35,9 @@ class Process: If the range_var argument is omitted, the value of Process.DEFAULT_RANGE_VAR is used, this defaults to 'fpid'. + + A process behaves like a cut in many ways. For example, the call() and + idx_array methods are identical. """ ####################################################### # Selection @@ -72,6 +75,19 @@ class Process: # Label self.label = label + + def __call__(self, dataframe): + """ + Returns a dataframe containing only the events of this process. + """ + return self.selection(dataframe) + + def idx_array(self, dataframe): + """ + Returns the index array of the given dataframe which selects all + events of this process. + """ + return self.selection.idx_array(dataframe) def __repr__(self): """ diff --git a/nnfwtbn/tests/test_process.py b/nnfwtbn/tests/test_process.py index 96aa2ba8babf03e321ca398fe966b4a88b26d989..eea9ebebff3cd37445c58a19d3a4fb03fb2b634d 100644 --- a/nnfwtbn/tests/test_process.py +++ b/nnfwtbn/tests/test_process.py @@ -1,5 +1,8 @@ import unittest + +import pandas as pd + from nnfwtbn import Process, Cut class ProcessTestCase(unittest.TestCase): @@ -102,3 +105,33 @@ class ProcessTestCase(unittest.TestCase): self.assertRaises(ValueError, Process, "Top", range=(5, 10, 20)) self.assertRaises(ValueError, Process, "Top", range=(5, )) self.assertRaises(ValueError, Process, "Top", range=3) + + def generate_df(self): + """ + Generate a toy dataframe. + """ + + return pd.DataFrame({'process_id': [1, 2, 1], + 'momentum_t': [100, 90, 110]}) + + def test_call(self): + """ + Check that calling a process returns a dataframe with selected events. + """ + process = Process("Signal", range=(1, 1), range_var="process_id") + + df = self.generate_df() + process_df = process(df) + + self.assertEqual(list(process_df.momentum_t), [100, 110]) + + def test_idx_array(self): + """ + Check that calling idx_array() returns an index array with selected events. + """ + process = Process("Signal", range=(1, 1), range_var="process_id") + + df = self.generate_df() + selected = process.idx_array(df) + + self.assertEqual(list(selected), [True, False, True])