diff --git a/km3io/offline.py b/km3io/offline.py
index 19cf5aa55d83f46af5790127e9148f332eef6ad7..de606c6306e1407a103da996365a52f3cbf3bd0f 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -79,7 +79,15 @@ class OfflineReader:
         "mc_tracks": "mc_trks",
     }
 
-    def __init__(self, f, index_chain=None, step_size=2000, keys=None, aliases=None, event_ctor=None):
+    def __init__(
+        self,
+        f,
+        index_chain=None,
+        step_size=2000,
+        keys=None,
+        aliases=None,
+        event_ctor=None,
+    ):
         """OfflineReader class is an offline ROOT file wrapper
 
         Parameters
@@ -187,10 +195,12 @@ class OfflineReader:
                 step_size=self._step_size,
                 aliases=self.aliases,
                 keys=self.keys(),
-                event_ctor=self._event_ctor
+                event_ctor=self._event_ctor,
             )
 
-        if isinstance(key, str) and key.startswith("n_"):  # group counts, for e.g. n_events, n_hits etc.
+        if isinstance(key, str) and key.startswith(
+            "n_"
+        ):  # group counts, for e.g. n_events, n_hits etc.
             key = self._keyfor(key.split("n_")[1])
             arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
             return unfold_indices(arr, self._index_chain)
@@ -207,9 +217,7 @@ class OfflineReader:
                 if from_field in branch[key].keys():
                     fields.append(to_field)
             log.debug(fields)
-            out = branch[key].arrays(
-                fields, aliases=self.special_branches[key]
-            )
+            out = branch[key].arrays(fields, aliases=self.special_branches[key])
         else:
             out = branch[self.aliases.get(key, key)].array()
 
@@ -220,9 +228,57 @@ class OfflineReader:
         return self
 
     def _event_generator(self):
-        for i in range(len(self)):
-            yield self[i]
-        return
+        events = self._fobj[self.event_path]
+        group_count_keys = set(
+            k for k in self.keys() if k.startswith("n_")
+        )  # special keys to make it easy to count subbranch lengths
+        log.debug("group_count_keys: %s", group_count_keys)
+        keys = set(
+            list(
+                set(self.keys())
+                - set(self.special_branches.keys())
+                - set(self.special_aliases)
+                - group_count_keys
+            )
+            + list(self.aliases.keys())
+        )  # all top-level keys for regular branches
+        log.debug("keys: %s", keys)
+        log.debug("aliases: %s", self.aliases)
+        events_it = events.iterate(
+            keys, aliases=self.aliases, step_size=self._step_size
+        )
+        specials = []
+        special_keys = (
+            self.special_branches.keys()
+        )  # dict-key ordering is an implementation detail
+        log.debug("special_keys: %s", special_keys)
+        for key in special_keys:
+            # print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}")
+
+            specials.append(
+                events[key].iterate(
+                    self.special_branches[key].keys(),
+                    aliases=self.special_branches[key],
+                    step_size=self._step_size,
+                )
+            )
+        group_counts = {}
+        for key in group_count_keys:
+            group_counts[key] = iter(self[key])
+
+        log.debug("group_counts: %s", group_counts)
+        for event_set, *special_sets in zip(events_it, *specials):
+            for _event, *special_items in zip(event_set, *special_sets):
+                data = {}
+                for k in keys:
+                    data[k] = _event[k]
+                for (k, i) in zip(special_keys, special_items):
+                    data[k] = i
+                for tokey, fromkey in self.special_aliases.items():
+                    data[tokey] = data[fromkey]
+                for key in group_counts:
+                    data[key] = next(group_counts[key])
+                yield self._event_ctor(**data)
 
     def __next__(self):
         return next(self._events)
@@ -246,7 +302,6 @@ class OfflineReader:
         """The raw number of events without any indexing/slicing magic"""
         return len(self._fobj[self.event_path]["id"].array())
 
-
     def __repr__(self):
         length = len(self)
         actual_length = self.__actual_len__()
diff --git a/km3io/tools.py b/km3io/tools.py
index d75ce95fed15d8cae5f2790ad5a660d1624daff4..2344e74a47e5b461e4f6e615429dfb8e7a9f4ea0 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -275,7 +275,7 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
 
     tracks = tracks[m1]
 
-    rec_stage_lengths = ak.num(tracks.rec_stages, axis=axis+1)
+    rec_stage_lengths = ak.num(tracks.rec_stages, axis=axis + 1)
     max_rec_stage_length = ak.max(rec_stage_lengths, axis=axis)
     m2 = rec_stage_lengths == max_rec_stage_length
     tracks = tracks[m2]
@@ -284,7 +284,9 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
 
     out = tracks[m3]
     if isinstance(out, ak.highlevel.Record):
-        return namedtuple("BestTrack", out.fields)(*[getattr(out, a)[0] for a in out.fields])
+        return namedtuple("BestTrack", out.fields)(
+            *[getattr(out, a)[0] for a in out.fields]
+        )
     return out
 
 
@@ -308,20 +310,22 @@ def mask(arr, sequence=None, startend=None, minmax=None, atleast=None):
     inputs = (sequence, startend, minmax, atleast)
 
     if all(v is None for v in inputs):
-        raise ValueError("either sequence, startend, minmax or atleast must be specified.")
+        raise ValueError(
+            "either sequence, startend, minmax or atleast must be specified."
+        )
 
     builder = ak.ArrayBuilder()
     _mask(arr, builder, sequence, startend, minmax, atleast)
     return builder.snapshot()
 
-#nb.njit  # TODO: not supported in awkward yet
-#                 see https://github.com/scikit-hep/awkward-1.0/issues/572
+
+# @nb.njit
 def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None):
     if arr.ndim == 2:  # recursion stop
         if startend is not None:
             start, end = startend
             for els in arr:
-                if ak.count(els) > 0 and els[0] == start and els[-1] == end:
+                if len(els) > 0 and els[0] == start and els[-1] == end:
                     builder.boolean(True)
                 else:
                     builder.boolean(False)
@@ -362,7 +366,6 @@ def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None)
         builder.end_list()
 
 
-
 def best_jmuon(tracks):
     """Select the best JMUON track."""
     return best_track(tracks, minmax=(krec.JMUONBEGIN, krec.JMUONEND))
diff --git a/tests/test_offline.py b/tests/test_offline.py
index eba3ccecb4902f6b6317dffdccee05f9787298a8..8c6d3cdde7fef3286526f980a9a8a603f6654aeb 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -183,17 +183,13 @@ class TestOfflineEvents(unittest.TestCase):
 
     def test_index_consistency(self):
         for i in [0, 2, 5]:
-            assert np.allclose(
-                self.events[i].n_hits, self.events.n_hits[i]
-            )
+            assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
 
     def test_index_chaining(self):
         assert np.allclose(
             self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist()
         )
-        assert np.allclose(
-            self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]
-        )
+        assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0])
 
     @unittest.skip
     def test_index_chaining_on_nested_branches_aka_records(self):
@@ -361,7 +357,22 @@ class TestOfflineTracks(unittest.TestCase):
         self.n_events = 10
 
     def test_fields(self):
-        for field in ['id', 'pos_x', 'pos_y', 'pos_z', 'dir_x', 'dir_y', 'dir_z', 't', 'E', 'len', 'lik', 'rec_type', 'rec_stages', 'fitinf']:
+        for field in [
+            "id",
+            "pos_x",
+            "pos_y",
+            "pos_z",
+            "dir_x",
+            "dir_y",
+            "dir_z",
+            "t",
+            "E",
+            "len",
+            "lik",
+            "rec_type",
+            "rec_stages",
+            "fitinf",
+        ]:
             getattr(self.tracks, field)
 
     def test_item_selection(self):
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 8ae5a4f9266c8c4fa8ed23ebb9a929c02f4ccb78..3505b6a5299fe97902d966c273bf3bdad62ab1d4 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -230,7 +230,6 @@ class TestBestTrackSelection(unittest.TestCase):
 
         assert len(best) == 5
 
-        import pdb; pdb.set_trace()
         assert best.lik == ak.max(tracks_slice.lik)
         assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
 
@@ -419,6 +418,27 @@ class TestRecStagesMasks(unittest.TestCase):
             mask(self.tracks)
 
 
+class TestMask(unittest.TestCase):
+    def test_minmax_2dim_mask(self):
+        arr = ak.Array([[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]])
+        m = mask(arr, minmax=(1, 4))
+        self.assertListEqual(m.tolist(), [True, False, False])
+
+    def test_minmax_3dim_mask(self):
+        arr = ak.Array([[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]])
+        m = mask(arr, minmax=(1, 4))
+        self.assertListEqual(m.tolist(), [[True, False, False], [True]])
+
+    def test_minmax_4dim_mask(self):
+        arr = ak.Array(
+            [[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]]
+        )
+        m = mask(arr, minmax=(1, 4))
+        self.assertListEqual(
+            m.tolist(), [[[True, False, False], [True]], [[False, True]]]
+        )
+
+
 class TestUnique(unittest.TestCase):
     def run_random_test_with_dtype(self, dtype):
         max_range = 100