diff --git a/km3io/offline.py b/km3io/offline.py index ee643c69d44554ad2356c21fe5859d9250b081f6..053c8b07252cce02b6636d3a0cee9d70b5a577e8 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -2,6 +2,8 @@ from collections import namedtuple import uproot import warnings import awkward1 as ak1 +import numba as nb + from .definitions import mc_header, fitparameters from .tools import Branch, BranchMapper, cached_property, _to_num, _unfold_indices @@ -83,6 +85,32 @@ def count_nested(Array, axis=0): return ak1.count(Array, axis=2) +@nb.jit(nopython=True) +def _find(rec_stages, stages, builder): + for s in rec_stages: + builder.begin_list() + for i in s: + num_stages = len(i) + if num_stages == len(stages): + found = 0 + for j in range(num_stages): + if i[j] == stages[j]: + found += 1 + if found == num_stages: + builder.append(1) + else: + builder.append(0) + else: + builder.append(0) + builder.end_list() + + +def mask(rec_stages, stages): + builder = ak1.ArrayBuilder() + _find(rec_stages, ak1.Array(stages), builder) + return builder.snapshot() == 1 + + def best_track(tracks, strategy="first", rec_stages=None): """best track selection based on different strategies