diff --git a/tests/test_tools.py b/tests/test_tools.py
index 50315527de89ea97316f5ca42d34d27d58e39cf7..53518f91adda1f9cfd40c47e6775a0035aad5d6f 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -159,22 +159,22 @@ class TestCountNested(unittest.TestCase):
 
 class TestRecStagesMasks(unittest.TestCase):
     def setUp(self):
-        # self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
-        #                         [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
+        self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
+                                [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
 
         self.tracks = OFFLINE_FILE.events.tracks
 
-    # def test_find(self):
-    #     builder = ak.ArrayBuilder()
-    #     _find(self.nested, ak.Array([1, 2, 3]), builder)
-    #     labels = builder.snapshot()
+    def test_find(self):
+        builder = ak.ArrayBuilder()
+        _find(self.nested, ak.Array([1, 2, 3]), builder)
+        labels = builder.snapshot()
 
-    #     assert labels[0][0] == 1
-    #     assert labels[0][1] == 1
-    #     assert labels[0][2] == 0
-    #     assert labels[1][0] == 0
+        assert labels[0][0] == 1
+        assert labels[0][1] == 1
+        assert labels[0][2] == 0
+        assert labels[1][0] == 0
 
-    def test_mask_with_explicit_rec_stages(self):
+    def test_mask_with_explicit_rec_stages_with_multiple_events(self):
         rec_stages = self.tracks.rec_stages
         stages = [1, 3, 5, 4]
         masks = mask(self.tracks, stages=stages)
@@ -183,7 +183,7 @@ class TestRecStagesMasks(unittest.TestCase):
         assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
         assert masks[0][1] == False
 
-    def test_mask_with_start_and_end_of_rec_stages(self):
+    def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self):
         rec_stages = self.tracks.rec_stages
         stages = [1, 3, 5, 4]
         masks = mask(self.tracks, start=1, end=4)
@@ -192,6 +192,24 @@ class TestRecStagesMasks(unittest.TestCase):
         assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
         assert masks[0][1] == False
 
+    def test_mask_with_start_and_end_of_rec_stages_signle_event(self):
+        rec_stages = self.tracks.rec_stages[0][0]
+        stages = [1, 3, 5, 4]
+        track = self.tracks[0]
+        masks = mask(track, start=1, end=4)
+
+        assert track[masks].rec_stages[0][0] == 1
+        assert track[masks].rec_stages[0][-1] == 4
+
+    def test_mask_with_explicit_rec_stages_with_single_event(self):
+        rec_stages = self.tracks.rec_stages[0][0]
+        stages = [1, 3]
+        track = self.tracks[0]
+        masks = mask(track, stages=stages)
+
+        assert track[masks].rec_stages[0][0] == stages[0]
+        assert track[masks].rec_stages[0][1] == stages[1]
+
 
 class TestUnique(unittest.TestCase):
     def run_random_test_with_dtype(self, dtype):