Skip to content
Snippets Groups Projects
Verified Commit abda4ddb authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Move call to idx_array, redirect to old interface

parent 4f5dabf9
No related branches found
No related tags found
No related merge requests found
...@@ -69,9 +69,15 @@ class Cut: ...@@ -69,9 +69,15 @@ class Cut:
def __call__(self, dataframe): def __call__(self, dataframe):
""" """
Applies the internally stored cut to the given dataframe, this means Applies the internally stored cut to the given dataframe and returns a
the method returns an index array, specifying which event passed the new dataframe containing only entries passing the event selection.
event selection. """
return self.idx_array(dataframe)
def idx_array(self, dataframe):
"""
Applies the internally stored cut to the given dataframe and returns
an index array, specifying which event passed the event selection.
""" """
if self.func is None: if self.func is None:
return pd.Series(np.ones(len(dataframe), dtype='bool')) return pd.Series(np.ones(len(dataframe), dtype='bool'))
......
...@@ -28,7 +28,7 @@ class CutTestCase(unittest.TestCase): ...@@ -28,7 +28,7 @@ class CutTestCase(unittest.TestCase):
Make sure that the default cut accepts very event in the dataframe. Make sure that the default cut accepts very event in the dataframe.
""" """
default = Cut() default = Cut()
selected = default(self.df) selected = default.idx_array(self.df)
self.assertTrue((selected).all()) self.assertTrue((selected).all())
...@@ -38,7 +38,7 @@ class CutTestCase(unittest.TestCase): ...@@ -38,7 +38,7 @@ class CutTestCase(unittest.TestCase):
uses the lambda to filter the dataframe. uses the lambda to filter the dataframe.
""" """
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
selected = high_sale(self.df) selected = high_sale.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[False, False, True, True, True, False, False, True]) [False, False, True, True, True, False, False, True])
...@@ -53,7 +53,7 @@ class CutTestCase(unittest.TestCase): ...@@ -53,7 +53,7 @@ class CutTestCase(unittest.TestCase):
old = Cut(lambda df: df.year < 2015) old = Cut(lambda df: df.year < 2015)
combined = high_sale & old combined = high_sale & old
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[False, False, True, True, True, False, False, False]) [False, False, True, True, True, False, False, False])
...@@ -67,7 +67,7 @@ class CutTestCase(unittest.TestCase): ...@@ -67,7 +67,7 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = high_sale & (lambda df: df.year < 2015) combined = high_sale & (lambda df: df.year < 2015)
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[False, False, True, True, True, False, False, False]) [False, False, True, True, True, False, False, False])
...@@ -81,7 +81,7 @@ class CutTestCase(unittest.TestCase): ...@@ -81,7 +81,7 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = (lambda df: df.year < 2015) & high_sale combined = (lambda df: df.year < 2015) & high_sale
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[False, False, True, True, True, False, False, False]) [False, False, True, True, True, False, False, False])
...@@ -95,14 +95,14 @@ class CutTestCase(unittest.TestCase): ...@@ -95,14 +95,14 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = high_sale & 1 combined = high_sale & 1
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[False, False, True, True, True, False, False, True]) [False, False, True, True, True, False, False, True])
self.assertEqual(list(self.df[selected].year), self.assertEqual(list(self.df[selected].year),
[2012, 2013, 2014, 2017]) [2012, 2013, 2014, 2017])
combined = high_sale & 0 combined = high_sale & 0
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertFalse((selected).any()) self.assertFalse((selected).any())
def test_or(self): def test_or(self):
...@@ -113,7 +113,7 @@ class CutTestCase(unittest.TestCase): ...@@ -113,7 +113,7 @@ class CutTestCase(unittest.TestCase):
old = Cut(lambda df: df.year < 2015) old = Cut(lambda df: df.year < 2015)
combined = high_sale | old combined = high_sale | old
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[True, True, True, True, True, False, False, True]) [True, True, True, True, True, False, False, True])
...@@ -127,7 +127,7 @@ class CutTestCase(unittest.TestCase): ...@@ -127,7 +127,7 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = high_sale | (lambda df: df.year < 2015) combined = high_sale | (lambda df: df.year < 2015)
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[True, True, True, True, True, False, False, True]) [True, True, True, True, True, False, False, True])
...@@ -141,7 +141,7 @@ class CutTestCase(unittest.TestCase): ...@@ -141,7 +141,7 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = (lambda df: df.year < 2015) | high_sale combined = (lambda df: df.year < 2015) | high_sale
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[True, True, True, True, True, False, False, True]) [True, True, True, True, True, False, False, True])
...@@ -155,11 +155,11 @@ class CutTestCase(unittest.TestCase): ...@@ -155,11 +155,11 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = high_sale | 1 combined = high_sale | 1
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertTrue(selected.all()) self.assertTrue(selected.all())
combined = high_sale | 0 combined = high_sale | 0
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[False, False, True, True, True, False, False, True]) [False, False, True, True, True, False, False, True])
self.assertEqual(list(self.df[selected].year), self.assertEqual(list(self.df[selected].year),
...@@ -174,7 +174,7 @@ class CutTestCase(unittest.TestCase): ...@@ -174,7 +174,7 @@ class CutTestCase(unittest.TestCase):
old = Cut(lambda df: df.year < 2015) old = Cut(lambda df: df.year < 2015)
combined = high_sale ^ old combined = high_sale ^ old
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[True, True, False, False, False, False, False, True]) [True, True, False, False, False, False, False, True])
...@@ -188,7 +188,7 @@ class CutTestCase(unittest.TestCase): ...@@ -188,7 +188,7 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = high_sale ^ (lambda df: df.year < 2015) combined = high_sale ^ (lambda df: df.year < 2015)
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[True, True, False, False, False, False, False, True]) [True, True, False, False, False, False, False, True])
...@@ -216,14 +216,14 @@ class CutTestCase(unittest.TestCase): ...@@ -216,14 +216,14 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = high_sale ^ 1 combined = high_sale ^ 1
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[True, True, False, False, False, True, True, False]) [True, True, False, False, False, True, True, False])
self.assertEqual(list(self.df[selected].year), self.assertEqual(list(self.df[selected].year),
[2010, 2011, 2015, 2016]) [2010, 2011, 2015, 2016])
combined = high_sale ^ 0 combined = high_sale ^ 0
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[False, False, True, True, True, False, False, True]) [False, False, True, True, True, False, False, True])
self.assertEqual(list(self.df[selected].year), self.assertEqual(list(self.df[selected].year),
...@@ -236,7 +236,7 @@ class CutTestCase(unittest.TestCase): ...@@ -236,7 +236,7 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4) high_sale = Cut(lambda df: df.sale > 4)
combined = ~high_sale combined = ~high_sale
selected = combined(self.df) selected = combined.idx_array(self.df)
self.assertEqual(list(selected), self.assertEqual(list(selected),
[True, True, False, False, False, True, True, False]) [True, True, False, False, False, True, True, False])
self.assertEqual(list(self.df[selected].year), self.assertEqual(list(self.df[selected].year),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment