diff --git a/freeforestml/tests/test_model.py b/freeforestml/tests/test_model.py
index d019f4288ce0f8f8c5cb624003a81ff85b9dd028..c596b988e1340719c46e29afb1ba90eb3d82b231 100644
--- a/freeforestml/tests/test_model.py
+++ b/freeforestml/tests/test_model.py
@@ -908,7 +908,7 @@ class HepNetTestCase(unittest.TestCase):
         df["is_sig"] = (df.fpid == 1)
         df["is_ztt"] = (df.fpid == 0)
 
-        net.fit(df.compute(), epochs=5, verbose=0)
+        net.fit(df, epochs=5, verbose=0)
 
         fd, path = tempfile.mkstemp()