From 976019b84c2718e4d9b018cd5d08bf23af87ade3 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Wed, 9 Dec 2020 10:35:16 +0100 Subject: [PATCH] Final implementation of the best track masking --- km3io/tools.py | 181 +++++++++++++++++++++++++------------------- tests/test_tools.py | 1 - 2 files changed, 104 insertions(+), 78 deletions(-) diff --git a/km3io/tools.py b/km3io/tools.py index b66866e..27fa1d4 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -224,37 +224,38 @@ def best_track(tracks, startend=None, minmax=None, stages=None): Parameters ---------- - tracks : km3io.offline.OfflineBranch - Array of tracks or jagged array of tracks (multiple events). + tracks : awkward.Array + A list of tracks or doubly nested tracks, usually from + OfflineReader.events.tracks or subarrays of that, containing recunstructed + tracks. startend: tuple(int, int), optional - The required first and last stage in tracks.rec_stages. + The required first and last stage in tracks.rec_stages. minmax: tuple(int, int), optional - The range (minimum and maximum) value of rec_stages to take into account. + The range (minimum and maximum) value of rec_stages to take into account. stages : list or set, optional - - list: the order of the rec_stages is respected. - - set: a subset of required stages; the order is irrelevant. + - list: the order of the rec_stages is respected. + - set: a subset of required stages; the order is irrelevant. Returns ------- - km3io.offline.OfflineBranch - The best tracks based on the selection. + awkward.Array or namedtuple + Be aware that the dimensions are kept, which means that the final + track attributes are nested when multiple events are passed in. + If a single event (just a list of tracks) is provided, a named tuple + with a single track and flat attributes is created. Raises ------ ValueError - - too many inputs specified. - - no inputs are specified. + When invalid inputs are specified. """ inputs = (stages, startend, minmax) - if all(v is None for v in inputs): + if sum(v is not None for v in inputs) != 1: raise ValueError("either stages, startend or minmax must be specified.") - if stages is not None and (startend is not None or minmax is not None): - raise ValueError("Please specify either a range or a set of rec stages.") - - if stages is not None and startend is None and minmax is None: + if stages is not None: if isinstance(stages, list): m1 = mask(tracks.rec_stages, sequence=stages) elif isinstance(stages, set): @@ -262,10 +263,10 @@ def best_track(tracks, startend=None, minmax=None, stages=None): else: raise ValueError("stages must be a list or a set of integers") - if startend is not None and minmax is None and stages is None: + if startend is not None: m1 = mask(tracks.rec_stages, startend=startend) - if minmax is not None and startend is None and stages is None: + if minmax is not None: m1 = mask(tracks.rec_stages, minmax=minmax) try: @@ -291,7 +292,7 @@ def best_track(tracks, startend=None, minmax=None, stages=None): def mask(arr, sequence=None, startend=None, minmax=None, atleast=None): - """Return a boolean mask which check each nested sub-array for a condition. + """Return a boolean mask which mask each nested sub-array for a condition. Parameters ---------- @@ -306,93 +307,119 @@ def mask(arr, sequence=None, startend=None, minmax=None, atleast=None): order) atleast : list(int), optional True for entries where at least the provided elements are present. + + An extensive discussion about this implementation can be found at: + https://github.com/scikit-hep/awkward-1.0/issues/580 + Many thanks for Jim for the fruitful discussion and the final implementation. """ inputs = (sequence, startend, minmax, atleast) - if all(v is None for v in inputs): + if sum(v is not None for v in inputs) != 1: raise ValueError( "either sequence, startend, minmax or atleast must be specified." ) - builder = ak.ArrayBuilder() - # Numba has very limited recursion support, so this is hardcoded - if arr.ndim == 2: - _mask2d(arr, builder, sequence, startend, minmax, atleast) - else: - _mask3d(arr, builder, sequence, startend, minmax, atleast) - return builder.snapshot() - + def recurse(layout): + if layout.purelist_depth == 2: + if startend is not None: + np_array = _mask_startend(ak.Array(layout), *startend) + elif minmax is not None: + np_array = _mask_minmax(ak.Array(layout), *minmax) + elif sequence is not None: + np_array = _mask_sequence(ak.Array(layout), np.array(sequence)) + elif atleast is not None: + np_array = _mask_atleast(ak.Array(layout), np.array(atleast)) + + return ak.layout.NumpyArray(np_array) + + elif isinstance( + layout, + ( + ak.layout.ListArray32, + ak.layout.ListArrayU32, + ak.layout.ListArray64, + ), + ): + if len(layout.stops) == 0: + content = recurse(layout.content) + else: + content = recurse(layout.content[: np.max(layout.stops)]) + return type(layout)(layout.starts, layout.stops, content) + + elif isinstance( + layout, + ( + ak.layout.ListOffsetArray32, + ak.layout.ListOffsetArrayU32, + ak.layout.ListOffsetArray64, + ), + ): + content = recurse(layout.content[: layout.offsets[-1]]) + return type(layout)(layout.offsets, content) + + elif isinstance(layout, ak.layout.RegularArray): + content = recurse(layout.content) + return ak.layout.RegularArray(content, layout.size) -def mask_alt(arr, start, end): - nonempty = ak.num(arr, axis=-1) > 0 - mask = ((arr.mask[nonempty][..., 0] == start) & (arr.mask[nonempty][..., -1] == end)) - return ak.fill_none(mask, False) + else: + raise NotImplementedError(repr(arr)) - -@nb.njit -def _mask3d(arr, builder, sequence=None, startend=None, minmax=None, atleast=None): - for subarray in arr: - builder.begin_list() - _mask2d(subarray, builder, sequence, startend, minmax, atleast) - builder.end_list() + layout = ak.to_layout(arr, allow_record=True, allow_other=False) + return ak.Array(recurse(layout)) @nb.njit -def _mask_startend(arr, builder, start, end): - for els in arr: - if len(els) > 0 and els[0] == start and els[-1] == end: - builder.boolean(True) - else: - builder.boolean(False) +def _mask_startend(arr, start, end): + out = np.empty(len(arr), np.bool_) + for i, subarr in enumerate(arr): + out[i] = len(subarr) > 0 and subarr[0] == start and subarr[-1] == end + return out @nb.njit -def _mask_minmax(arr, builder, min, max): - for els in arr: - for el in els: - if el < min or el > max: - builder.boolean(False) - break +def _mask_minmax(arr, min, max): + out = np.empty(len(arr), np.bool_) + for i, subarr in enumerate(arr): + if len(subarr) == 0: + out[i] = False else: - builder.boolean(True) + for el in subarr: + if el < min or el > max: + out[i] = False + break + else: + out[i] = True + return out @nb.njit -def _mask_sequence(arr, builder, sequence): +def _mask_sequence(arr, sequence): + out = np.empty(len(arr), np.bool_) n = len(sequence) - for els in arr: - if len(els) != n: - builder.boolean(False) + for i, subarr in enumerate(arr): + if len(subarr) != n: + out[i] = False else: - for i in range(n): - if els[i] != sequence[i]: - builder.boolean(False) + for j in range(n): + if subarr[j] != sequence[j]: + out[i] = False break else: - builder.boolean(True) + out[i] = True + return out @nb.njit -def _mask_atleast(arr, builder, atleast): - for els in arr: - for e in atleast: - if e not in els: - builder.boolean(False) +def _mask_atleast(arr, atleast): + out = np.empty(len(arr), np.bool_) + for i, subarr in enumerate(arr): + for req_el in atleast: + if req_el not in subarr: + out[i] = False break else: - builder.boolean(True) - - -@nb.njit -def _mask2d(arr, builder, sequence=None, startend=None, minmax=None, atleast=None): - if startend is not None: - _mask_startend(arr, builder, *startend) - elif minmax is not None: - _mask_minmax(arr, builder, *minmax) - elif sequence is not None: - _mask_sequence(arr, builder, sequence) - elif atleast is not None: - _mask_atleast(arr, builder, atleast) + out[i] = True + return out def best_jmuon(tracks): diff --git a/tests/test_tools.py b/tests/test_tools.py index c66ea3c..3505b6a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -429,7 +429,6 @@ class TestMask(unittest.TestCase): m = mask(arr, minmax=(1, 4)) self.assertListEqual(m.tolist(), [[True, False, False], [True]]) - @unittest.skip def test_minmax_4dim_mask(self): arr = ak.Array( [[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]] -- GitLab