diff --git a/km3io/tools.py b/km3io/tools.py index 47d5d1f0d7463d5a5b1a29cdb2d186478b3b8c74..0c60f4061bff114ddda520353505234f92270e7d 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -264,10 +264,10 @@ def _max_lik_track(tracks): def _best_track(tracks, start=None, end=None, stages=[]): if (len(stages) > 0) and (start is None) and (end is None): - selected_tracks = tracks[mask(tracks.rec_stages, stages=stages)] + selected_tracks = tracks[mask(tracks, stages=stages)] if (start is not None) and (end is not None) and (len(stages) == 0): - selected_tracks = tracks[mask(tracks.rec_stages, start=start, end=end)] + selected_tracks = tracks[mask(tracks, start=start, end=end)] if (start is None) and (end is None) and (len(stages) == 0): # this should be modified to a log print and not just a simple print @@ -378,7 +378,38 @@ def _find_between(rec_stages, start, end, builder): builder.end_list() -def _mask_rec_stages_between_start_end(rec_stages, start, end): +@nb.jit(nopython=True) +def _find_between_single(rec_stages, start, end, builder): + """construct an awkward1 array with the same structure as tracks.rec_stages. + When stages are between start and end, the Array is filled with value 1, otherwise it is filled + with value 0. + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + start : int + start of reconstruction stages of interest. + end : int + end of reconstruction stages of interest. + builder : awkward1.highlevel.ArrayBuilder + awkward1 Array builder. + """ + + builder.begin_list() + for s in rec_stages: + num_stages = len(s) + if num_stages != 0: + if (s[0] == start) and (s[-1] == end): + builder.append(1) + else: + builder.append(0) + else: + builder.append(0) + builder.end_list() + + +def _mask_rec_stages_between_start_end(tracks, start, end): """mask tracks where tracks.rec_stages are between start and end . Parameters @@ -397,8 +428,12 @@ def _mask_rec_stages_between_start_end(rec_stages, start, end): where stages were found. False otherwise. """ builder = ak1.ArrayBuilder() - _find_between(rec_stages, start, end, builder) - return builder.snapshot() == 1 + if tracks.is_single: + _find_between_single(tracks.rec_stages, start, end, builder) + return (builder.snapshot() == 1)[0] + else: + _find_between(tracks.rec_stages, start, end, builder) + return builder.snapshot() == 1 @nb.jit(nopython=True) @@ -434,7 +469,39 @@ def _find(rec_stages, stages, builder): builder.end_list() -def _mask_explicit_rec_stages(rec_stages, stages): +@nb.jit(nopython=True) +def _find_single(rec_stages, stages, builder): + """construct an awkward1 array with the same structure as tracks.rec_stages. + When stages are found, the Array is filled with value 1, otherwise it is filled + with value 0. + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + stages : awkward1 Array + reconstruction stages of interest. + builder : awkward1.highlevel.ArrayBuilder + awkward1 Array builder. + """ + builder.begin_list() + for s in rec_stages: + num_stages = len(s) + if num_stages == len(stages): + found = 0 + for j in range(num_stages): + if s[j] == stages[j]: + found += 1 + if found == num_stages: + builder.append(1) + else: + builder.append(0) + else: + builder.append(0) + builder.end_list() + + +def _mask_explicit_rec_stages(tracks, stages): """create a mask on tracks.rec_stages . Parameters @@ -450,12 +517,17 @@ def _mask_explicit_rec_stages(rec_stages, stages): an awkward1 Array mask where True corresponds to the positions where stages were found. False otherwise. """ + # rec_stages = tracks.rec_stages builder = ak1.ArrayBuilder() - _find(rec_stages, ak1.Array(stages), builder) - return builder.snapshot() == 1 + if tracks.is_single: + _find_single(tracks.rec_stages, ak1.Array(stages), builder) + return (builder.snapshot() == 1)[0] + else: + _find(tracks.rec_stages, ak1.Array(stages), builder) + return builder.snapshot() == 1 -def mask(rec_stages, stages=None, start=None, end=None): +def mask(tracks, stages=None, start=None, end=None): """create a mask on tracks.rec_stages . Parameters @@ -478,7 +550,7 @@ def mask(rec_stages, stages=None, start=None, end=None): raise ValueError("too many inputs are specified") if (stages is not None) and (start is None) and (end is None): - return _mask_explicit_rec_stages(rec_stages, stages) + return _mask_explicit_rec_stages(tracks, stages) if (stages is None) and (start is not None) and (end is not None): - return _mask_rec_stages_between_start_end(rec_stages, start, end) + return _mask_rec_stages_between_start_end(tracks, start, end)