Skip to content
Snippets Groups Projects
Verified Commit 24550d78 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Add Cut copy constructor

parent e38b6ed8
No related branches found
No related tags found
No related merge requests found
......@@ -74,8 +74,12 @@ class Cut:
the optional function is omitted, Every row in the dataframe is
accepted by this cut.
"""
self.func = func
self.label = label
if isinstance(func, Cut):
self.func = func.func
self.label = label or func.label
else:
self.func = func
self.label = label
def __call__(self, dataframe):
"""
......
......@@ -261,3 +261,30 @@ class CutTestCase(unittest.TestCase):
high_sale = Cut(lambda df: df.sale > 4, label="High sales volume")
self.assertEqual(high_sale.label, "High sales volume")
def test_init_cut(self):
"""
Check that a cut can be passed to the constructor.
"""
high_sale = Cut(lambda df: df.sale > 4)
high_sale2 = Cut(high_sale)
self.assertEqual(len(high_sale2(self.df)), 4)
self.assertEqual(len(high_sale2.idx_array(self.df)), 8)
def test_init_cut_name_inherit(self):
"""
Check that the name of a cut passed to the constructor is inherited.
"""
high_sale = Cut(lambda df: df.sale > 4, label="High sales volume")
high_sale2 = Cut(high_sale)
self.assertEqual(high_sale2.label, "High sales volume")
def test_init_cut_name_inherit_precedence(self):
"""
Check that the name argument has precedence over the given cut.
"""
high_sale = Cut(lambda df: df.sale > 4, label="High sales volume")
high_sale2 = Cut(high_sale, label="Other label")
self.assertEqual(high_sale2.label, "Other label")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment