diff --git a/km3io/tools.py b/km3io/tools.py
index 0c60f4061bff114ddda520353505234f92270e7d..6472790288240f89ce0ea50d8238993a270b90bb 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -234,7 +234,7 @@ def get_multiplicity(tracks, rec_stages):
     class km3io.offline.OfflineBranch
         tracks branch with the desired reconstruction stages only.
     """
-    return tracks[mask(tracks.rec_stages, rec_stages)]
+    return tracks[mask(tracks, rec_stages)]
 
 
 def _longest_tracks(tracks):
@@ -377,7 +377,6 @@ def _find_between(rec_stages, start, end, builder):
                 builder.append(0)
         builder.end_list()
 
-
 @nb.jit(nopython=True)
 def _find_between_single(rec_stages, start, end, builder):
     """construct an awkward1 array with the same structure as tracks.rec_stages.
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 8168a3bed4ec2b211ec548210b2959530de8257e..50315527de89ea97316f5ca42d34d27d58e39cf7 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -159,25 +159,25 @@ class TestCountNested(unittest.TestCase):
 
 class TestRecStagesMasks(unittest.TestCase):
     def setUp(self):
-        self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
-                                [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
+        # self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
+        #                         [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
 
         self.tracks = OFFLINE_FILE.events.tracks
 
-    def test_find(self):
-        builder = ak.ArrayBuilder()
-        _find(self.nested, ak.Array([1, 2, 3]), builder)
-        labels = builder.snapshot()
+    # def test_find(self):
+    #     builder = ak.ArrayBuilder()
+    #     _find(self.nested, ak.Array([1, 2, 3]), builder)
+    #     labels = builder.snapshot()
 
-        assert labels[0][0] == 1
-        assert labels[0][1] == 1
-        assert labels[0][2] == 0
-        assert labels[1][0] == 0
+    #     assert labels[0][0] == 1
+    #     assert labels[0][1] == 1
+    #     assert labels[0][2] == 0
+    #     assert labels[1][0] == 0
 
     def test_mask_with_explicit_rec_stages(self):
         rec_stages = self.tracks.rec_stages
         stages = [1, 3, 5, 4]
-        masks = mask(rec_stages, stages=stages)
+        masks = mask(self.tracks, stages=stages)
 
         assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages))
         assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
@@ -186,14 +186,13 @@ class TestRecStagesMasks(unittest.TestCase):
     def test_mask_with_start_and_end_of_rec_stages(self):
         rec_stages = self.tracks.rec_stages
         stages = [1, 3, 5, 4]
-        masks = mask(rec_stages, start=1, end=4)
+        masks = mask(self.tracks, start=1, end=4)
 
         assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages))
         assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
         assert masks[0][1] == False
 
 
-
 class TestUnique(unittest.TestCase):
     def run_random_test_with_dtype(self, dtype):
         max_range = 100