diff --git a/km3io/tools.py b/km3io/tools.py
index b66866ef32f12b0f43c8cbacf4e645a7106aae48..27fa1d4e215aa63102ebd4427e5d9e69d29d6e71 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -224,37 +224,38 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
 
     Parameters
     ----------
-    tracks : km3io.offline.OfflineBranch
-        Array of tracks or jagged array of tracks (multiple events).
+    tracks : awkward.Array
+      A list of tracks or doubly nested tracks, usually from
+      OfflineReader.events.tracks or subarrays of that, containing recunstructed
+      tracks.
     startend: tuple(int, int), optional
-        The required first and last stage in tracks.rec_stages.
+      The required first and last stage in tracks.rec_stages.
     minmax: tuple(int, int), optional
-        The range (minimum and maximum) value of rec_stages to take into account.
+      The range (minimum and maximum) value of rec_stages to take into account.
     stages : list or set, optional
-        - list: the order of the rec_stages is respected.
-        - set: a subset of required stages; the order is irrelevant.
+      - list: the order of the rec_stages is respected.
+      - set: a subset of required stages; the order is irrelevant.
 
     Returns
     -------
-    km3io.offline.OfflineBranch
-        The best tracks based on the selection.
+    awkward.Array or namedtuple
+      Be aware that the dimensions are kept, which means that the final
+      track attributes are nested when multiple events are passed in.
+      If a single event (just a list of tracks) is provided, a named tuple
+      with a single track and flat attributes is created.
 
     Raises
     ------
     ValueError
-        - too many inputs specified.
-        - no inputs are specified.
+      When invalid inputs are specified.
 
     """
     inputs = (stages, startend, minmax)
 
-    if all(v is None for v in inputs):
+    if sum(v is not None for v in inputs) != 1:
         raise ValueError("either stages, startend or minmax must be specified.")
 
-    if stages is not None and (startend is not None or minmax is not None):
-        raise ValueError("Please specify either a range or a set of rec stages.")
-
-    if stages is not None and startend is None and minmax is None:
+    if stages is not None:
         if isinstance(stages, list):
             m1 = mask(tracks.rec_stages, sequence=stages)
         elif isinstance(stages, set):
@@ -262,10 +263,10 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
         else:
             raise ValueError("stages must be a list or a set of integers")
 
-    if startend is not None and minmax is None and stages is None:
+    if startend is not None:
         m1 = mask(tracks.rec_stages, startend=startend)
 
-    if minmax is not None and startend is None and stages is None:
+    if minmax is not None:
         m1 = mask(tracks.rec_stages, minmax=minmax)
 
     try:
@@ -291,7 +292,7 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
 
 
 def mask(arr, sequence=None, startend=None, minmax=None, atleast=None):
-    """Return a boolean mask which check each nested sub-array for a condition.
+    """Return a boolean mask which mask each nested sub-array for a condition.
 
     Parameters
     ----------
@@ -306,93 +307,119 @@ def mask(arr, sequence=None, startend=None, minmax=None, atleast=None):
         order)
     atleast : list(int), optional
         True for entries where at least the provided elements are present.
+
+    An extensive discussion about this implementation can be found at:
+    https://github.com/scikit-hep/awkward-1.0/issues/580
+    Many thanks for Jim for the fruitful discussion and the final implementation.
     """
     inputs = (sequence, startend, minmax, atleast)
 
-    if all(v is None for v in inputs):
+    if sum(v is not None for v in inputs) != 1:
         raise ValueError(
             "either sequence, startend, minmax or atleast must be specified."
         )
 
-    builder = ak.ArrayBuilder()
-    # Numba has very limited recursion support, so this is hardcoded
-    if arr.ndim == 2:
-        _mask2d(arr, builder, sequence, startend, minmax, atleast)
-    else:
-        _mask3d(arr, builder, sequence, startend, minmax, atleast)
-    return builder.snapshot()
-
+    def recurse(layout):
+        if layout.purelist_depth == 2:
+            if startend is not None:
+                np_array = _mask_startend(ak.Array(layout), *startend)
+            elif minmax is not None:
+                np_array = _mask_minmax(ak.Array(layout), *minmax)
+            elif sequence is not None:
+                np_array = _mask_sequence(ak.Array(layout), np.array(sequence))
+            elif atleast is not None:
+                np_array = _mask_atleast(ak.Array(layout), np.array(atleast))
+
+            return ak.layout.NumpyArray(np_array)
+
+        elif isinstance(
+            layout,
+            (
+                ak.layout.ListArray32,
+                ak.layout.ListArrayU32,
+                ak.layout.ListArray64,
+            ),
+        ):
+            if len(layout.stops) == 0:
+                content = recurse(layout.content)
+            else:
+                content = recurse(layout.content[: np.max(layout.stops)])
+            return type(layout)(layout.starts, layout.stops, content)
+
+        elif isinstance(
+            layout,
+            (
+                ak.layout.ListOffsetArray32,
+                ak.layout.ListOffsetArrayU32,
+                ak.layout.ListOffsetArray64,
+            ),
+        ):
+            content = recurse(layout.content[: layout.offsets[-1]])
+            return type(layout)(layout.offsets, content)
+
+        elif isinstance(layout, ak.layout.RegularArray):
+            content = recurse(layout.content)
+            return ak.layout.RegularArray(content, layout.size)
 
-def mask_alt(arr, start, end):
-    nonempty = ak.num(arr, axis=-1) > 0
-    mask = ((arr.mask[nonempty][..., 0] == start) & (arr.mask[nonempty][..., -1] == end))
-    return ak.fill_none(mask, False)
+        else:
+            raise NotImplementedError(repr(arr))
 
-
-@nb.njit
-def _mask3d(arr, builder, sequence=None, startend=None, minmax=None, atleast=None):
-    for subarray in arr:
-        builder.begin_list()
-        _mask2d(subarray, builder, sequence, startend, minmax, atleast)
-        builder.end_list()
+    layout = ak.to_layout(arr, allow_record=True, allow_other=False)
+    return ak.Array(recurse(layout))
 
 
 @nb.njit
-def _mask_startend(arr, builder, start, end):
-    for els in arr:
-        if len(els) > 0 and els[0] == start and els[-1] == end:
-            builder.boolean(True)
-        else:
-            builder.boolean(False)
+def _mask_startend(arr, start, end):
+    out = np.empty(len(arr), np.bool_)
+    for i, subarr in enumerate(arr):
+        out[i] = len(subarr) > 0 and subarr[0] == start and subarr[-1] == end
+    return out
 
 
 @nb.njit
-def _mask_minmax(arr, builder, min, max):
-    for els in arr:
-        for el in els:
-            if el < min or el > max:
-                builder.boolean(False)
-                break
+def _mask_minmax(arr, min, max):
+    out = np.empty(len(arr), np.bool_)
+    for i, subarr in enumerate(arr):
+        if len(subarr) == 0:
+            out[i] = False
         else:
-            builder.boolean(True)
+            for el in subarr:
+                if el < min or el > max:
+                    out[i] = False
+                    break
+            else:
+                out[i] = True
+    return out
 
 
 @nb.njit
-def _mask_sequence(arr, builder, sequence):
+def _mask_sequence(arr, sequence):
+    out = np.empty(len(arr), np.bool_)
     n = len(sequence)
-    for els in arr:
-        if len(els) != n:
-            builder.boolean(False)
+    for i, subarr in enumerate(arr):
+        if len(subarr) != n:
+            out[i] = False
         else:
-            for i in range(n):
-                if els[i] != sequence[i]:
-                    builder.boolean(False)
+            for j in range(n):
+                if subarr[j] != sequence[j]:
+                    out[i] = False
                     break
             else:
-                builder.boolean(True)
+                out[i] = True
+    return out
 
 
 @nb.njit
-def _mask_atleast(arr, builder, atleast):
-    for els in arr:
-        for e in atleast:
-            if e not in els:
-                builder.boolean(False)
+def _mask_atleast(arr, atleast):
+    out = np.empty(len(arr), np.bool_)
+    for i, subarr in enumerate(arr):
+        for req_el in atleast:
+            if req_el not in subarr:
+                out[i] = False
                 break
         else:
-            builder.boolean(True)
-
-
-@nb.njit
-def _mask2d(arr, builder, sequence=None, startend=None, minmax=None, atleast=None):
-    if startend is not None:
-        _mask_startend(arr, builder, *startend)
-    elif minmax is not None:
-        _mask_minmax(arr, builder, *minmax)
-    elif sequence is not None:
-        _mask_sequence(arr, builder, sequence)
-    elif atleast is not None:
-        _mask_atleast(arr, builder, atleast)
+            out[i] = True
+    return out
 
 
 def best_jmuon(tracks):
diff --git a/tests/test_tools.py b/tests/test_tools.py
index c66ea3c6c8248928e3c42433e1a9b338e07a3c44..3505b6a5299fe97902d966c273bf3bdad62ab1d4 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -429,7 +429,6 @@ class TestMask(unittest.TestCase):
         m = mask(arr, minmax=(1, 4))
         self.assertListEqual(m.tolist(), [[True, False, False], [True]])
 
-    @unittest.skip
     def test_minmax_4dim_mask(self):
         arr = ak.Array(
             [[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]]