From abda4ddb730571e7d6a8dd9c214f1277101aad5c Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <f.sauerburger@cern.ch> Date: Wed, 3 Jul 2019 15:09:14 +0200 Subject: [PATCH] Move call to idx_array, redirect to old interface --- nnfwtbn/cut.py | 12 +++++++++--- nnfwtbn/tests/test_cut.py | 34 +++++++++++++++++----------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py index 299c4d7..cc29e72 100644 --- a/nnfwtbn/cut.py +++ b/nnfwtbn/cut.py @@ -69,9 +69,15 @@ class Cut: def __call__(self, dataframe): """ - Applies the internally stored cut to the given dataframe, this means - the method returns an index array, specifying which event passed the - event selection. + Applies the internally stored cut to the given dataframe and returns a + new dataframe containing only entries passing the event selection. + """ + return self.idx_array(dataframe) + + def idx_array(self, dataframe): + """ + Applies the internally stored cut to the given dataframe and returns + an index array, specifying which event passed the event selection. """ if self.func is None: return pd.Series(np.ones(len(dataframe), dtype='bool')) diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py index a74512c..5302fcd 100644 --- a/nnfwtbn/tests/test_cut.py +++ b/nnfwtbn/tests/test_cut.py @@ -28,7 +28,7 @@ class CutTestCase(unittest.TestCase): Make sure that the default cut accepts very event in the dataframe. """ default = Cut() - selected = default(self.df) + selected = default.idx_array(self.df) self.assertTrue((selected).all()) @@ -38,7 +38,7 @@ class CutTestCase(unittest.TestCase): uses the lambda to filter the dataframe. """ high_sale = Cut(lambda df: df.sale > 4) - selected = high_sale(self.df) + selected = high_sale.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) @@ -53,7 +53,7 @@ class CutTestCase(unittest.TestCase): old = Cut(lambda df: df.year < 2015) combined = high_sale & old - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, False]) @@ -67,7 +67,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = high_sale & (lambda df: df.year < 2015) - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, False]) @@ -81,7 +81,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = (lambda df: df.year < 2015) & high_sale - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, False]) @@ -95,14 +95,14 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = high_sale & 1 - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014, 2017]) combined = high_sale & 0 - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertFalse((selected).any()) def test_or(self): @@ -113,7 +113,7 @@ class CutTestCase(unittest.TestCase): old = Cut(lambda df: df.year < 2015) combined = high_sale | old - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, True, True, True, False, False, True]) @@ -127,7 +127,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = high_sale | (lambda df: df.year < 2015) - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, True, True, True, False, False, True]) @@ -141,7 +141,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = (lambda df: df.year < 2015) | high_sale - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, True, True, True, False, False, True]) @@ -155,11 +155,11 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = high_sale | 1 - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertTrue(selected.all()) combined = high_sale | 0 - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), @@ -174,7 +174,7 @@ class CutTestCase(unittest.TestCase): old = Cut(lambda df: df.year < 2015) combined = high_sale ^ old - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, False, False, True]) @@ -188,7 +188,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = high_sale ^ (lambda df: df.year < 2015) - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, False, False, True]) @@ -216,14 +216,14 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = high_sale ^ 1 - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, True, True, False]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2015, 2016]) combined = high_sale ^ 0 - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), @@ -236,7 +236,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = ~high_sale - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, True, True, False]) self.assertEqual(list(self.df[selected].year), -- GitLab