diff --git a/km3io/tools.py b/km3io/tools.py index 2344e74a47e5b461e4f6e615429dfb8e7a9f4ea0..9d5e9aff69b9da874b6e501e6a97e1e573f0a6f9 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -315,57 +315,80 @@ def mask(arr, sequence=None, startend=None, minmax=None, atleast=None): ) builder = ak.ArrayBuilder() - _mask(arr, builder, sequence, startend, minmax, atleast) + # Numba has very limited recursion support, so this is hardcoded + if arr.ndim == 2: + _mask2d(arr, builder, sequence, startend, minmax, atleast) + else: + _mask3d(arr, builder, sequence, startend, minmax, atleast) return builder.snapshot() -# @nb.njit -def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None): - if arr.ndim == 2: # recursion stop - if startend is not None: - start, end = startend - for els in arr: - if len(els) > 0 and els[0] == start and els[-1] == end: - builder.boolean(True) - else: - builder.boolean(False) - elif minmax is not None: - min, max = minmax - for els in arr: - for el in els: - if el < min or el > max: - builder.boolean(False) - break - else: - builder.boolean(True) - elif sequence is not None: - n = len(sequence) - for els in arr: - if len(els) != n: - builder.boolean(False) - else: - for i in range(n): - if els[i] != sequence[i]: - builder.boolean(False) - break - else: - builder.boolean(True) - elif atleast is not None: - for els in arr: - for e in atleast: - if e not in els: - builder.boolean(False) - break - else: - builder.boolean(True) - return - +@nb.njit +def _mask3d(arr, builder, sequence=None, startend=None, minmax=None, atleast=None): for subarray in arr: builder.begin_list() - _mask(subarray, builder, sequence, startend, minmax, atleast) + _mask2d(subarray, builder, sequence, startend, minmax, atleast) builder.end_list() +@nb.njit +def _mask_startend(arr, builder, start, end): + for els in arr: + if len(els) > 0 and els[0] == start and els[-1] == end: + builder.boolean(True) + else: + builder.boolean(False) + + +@nb.njit +def _mask_minmax(arr, builder, min, max): + for els in arr: + for el in els: + if el < min or el > max: + builder.boolean(False) + break + else: + builder.boolean(True) + + +@nb.njit +def _mask_sequence(arr, builder, sequence): + n = len(sequence) + for els in arr: + if len(els) != n: + builder.boolean(False) + else: + for i in range(n): + if els[i] != sequence[i]: + builder.boolean(False) + break + else: + builder.boolean(True) + + +@nb.njit +def _mask_atleast(arr, builder, atleast): + for els in arr: + for e in atleast: + if e not in els: + builder.boolean(False) + break + else: + builder.boolean(True) + + +@nb.njit +def _mask2d(arr, builder, sequence=None, startend=None, minmax=None, atleast=None): + if startend is not None: + _mask_startend(arr, builder, *startend) + elif minmax is not None: + _mask_minmax(arr, builder, *minmax) + elif sequence is not None: + _mask_sequence(arr, builder, sequence) + elif atleast is not None: + _mask_atleast(arr, builder, atleast) + + def best_jmuon(tracks): """Select the best JMUON track.""" return best_track(tracks, minmax=(krec.JMUONBEGIN, krec.JMUONEND)) diff --git a/tests/test_tools.py b/tests/test_tools.py index 3505b6a5299fe97902d966c273bf3bdad62ab1d4..c66ea3c6c8248928e3c42433e1a9b338e07a3c44 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -429,6 +429,7 @@ class TestMask(unittest.TestCase): m = mask(arr, minmax=(1, 4)) self.assertListEqual(m.tolist(), [[True, False, False], [True]]) + @unittest.skip def test_minmax_4dim_mask(self): arr = ak.Array( [[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]]