diff --git a/km3io/tools.py b/km3io/tools.py index 2261ad375effd448105c023c5c6b1602c1c27f64..f9409a5c9bfb4717a95fe7e1b8018c8bc44d0dcd 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -218,124 +218,6 @@ def count_nested(Array, axis=0): return ak1.count(Array, axis=2) -@nb.jit(nopython=True) -def _find(rec_stages, stages, builder): - """construct an awkward1 array with the same structure as tracks.rec_stages. - When stages are found, the Array is filled with value 1, otherwise it is filled - with value 0. - - Parameters - ---------- - rec_stages : awkward1 Array - tracks.rec_stages . - stages : awkward1 Array - reconstruction stages of interest. - builder : awkward1.highlevel.ArrayBuilder - awkward1 Array builder. - """ - for s in rec_stages: - builder.begin_list() - for i in s: - num_stages = len(i) - if num_stages == len(stages): - found = 0 - for j in range(num_stages): - if i[j] == stages[j]: - found += 1 - if found == num_stages: - builder.append(1) - else: - builder.append(0) - else: - builder.append(0) - builder.end_list() - - -def mask(rec_stages, stages): - """create a mask on tracks.rec_stages . - - Parameters - ---------- - rec_stages : awkward1 Array - tracks.rec_stages . - stages : list - reconstruction stages of interest. - - Returns - ------- - awkward1 Array - an awkward1 Array mask where True corresponds to the positions - where stages were found. False otherwise. - """ - builder = ak1.ArrayBuilder() - _find(rec_stages, ak1.Array(stages), builder) - return builder.snapshot() == 1 - - -# def best_track(tracks, strategy="default", rec_type=None): -# """best track selection based on different strategies - -# Parameters -# ---------- -# tracks : class km3io.offline.OfflineBranch -# a subset of reconstructed tracks where `events.n_tracks > 0` is always true. -# strategy : str -# the trategy desired to select the best tracks. It is either: -# - "first" : to select the first tracks. -# - "default": to select the best tracks (the first ones) corresponding to a specific -# reconstruction algorithm (JGandalf, Jshowerfit, etc). This requires rec_type input. -# Example: best_track(my_tracks, strategy="default", rec_type="JPP_RECONSTRUCTION_TYPE"). -# rec_type : str, optional -# reconstruction type as defined in the official KM3NeT-Dataformat. - -# Returns -# ------- -# class km3io.offline.OfflineBranch -# tracks class with the desired "best tracks" selection. - -# Raises -# ------ -# ValueError -# ValueError raised when: -# - an invalid strategy is requested. -# - a subset of events with empty tracks is used. -# """ -# options = ['first', 'default'] -# if strategy not in options: -# raise ValueError("{} not in {}".format(strategy, options)) - -# n_events = 1 if tracks.is_single else len(tracks) - -# if n_events > 1 and any(count_nested(tracks.lik, axis=1) == 0): -# raise ValueError( -# "'events' should not contain empty tracks. Consider applying the mask: events.n_tracks>0" -# ) - -# if strategy == "first": -# if n_events == 1: -# out = tracks[0] -# else: -# out = tracks[:, 0] - -# if strategy == "default" and rec_type is None: -# raise ValueError( -# "rec_type must be provided when the default strategy is used.") - -# if strategy == "default" and rec_type is not None: -# if n_events == 1: -# rec_types = tracks[tracks.rec_type == krec[rec_type]] -# len_stages = count_nested(rec_types.rec_stages, axis=1) -# longest = rec_types[len_stages == ak1.max(len_stages, axis=0)] -# out = longest[longest.lik == ak1.max(longest.lik, axis=0)] -# else: -# rec_types = tracks[tracks.rec_type == krec[rec_type]] -# len_stages = count_nested(rec_types.rec_stages, axis=2) -# longest = rec_types[len_stages == ak1.max(len_stages, axis=1)] -# out = longest[longest.lik == ak1.max(longest.lik, axis=1)] - -# return out - - def get_multiplicity(tracks, rec_stages): """tracks selection based on specific reconstruction stages (for multiplicity calculations). @@ -382,10 +264,10 @@ def _max_lik_track(tracks): def _best_track(tracks, start=None, end=None, stages=[]): if (len(stages) > 0) and (start is None) and (end is None): - selected_tracks = tracks[mask(tracks.rec_stages, stages)] + selected_tracks = tracks[mask(tracks.rec_stages, stages=stages)] if (start is not None) and (end is not None) and (len(stages) == 0): - selected_tracks = tracks[mask_tracks(tracks.rec_stages, start, end)] + selected_tracks = tracks[mask(tracks.rec_stages, start=start, end=end)] if (start is None) and (end is None) and (len(stages) == 0): # this should be modified to a log print and not just a simple print @@ -424,6 +306,8 @@ def _DUSJShower_stages(): def _reco_stages(reco): + valid_recos = set(("JSHOWER", "JMUON", "AASHOWER", "DUSJSHOWER")) + if reco == "JSHOWER": stages = _JShower_stages() @@ -436,7 +320,7 @@ def _reco_stages(reco): if reco == "DUSJSHOWER": stages == _DUSJShower_stages() - else: + if reco not in valid_recos: raise KeyError( f"{reco} must be either: 'JSHOWER', 'JMUON', 'AASHOWER', 'DUSJSHOWER'." ) @@ -451,7 +335,7 @@ def best_track(tracks, reco, start=None, end=None, stages=[]): if (start is not None) and (end is not None): if (start not in valid_stages) or (end not in valid_stages): raise KeyError( - f" start and/or end are not in JMuon reconstruction stages") + f" start and/or end are not in {reco} reconstruction stages") if len(stages) > 0: if not set(stages).issubset(valid_stages): @@ -479,6 +363,7 @@ def _find_between(rec_stages, start, end, builder): builder : awkward1.highlevel.ArrayBuilder awkward1 Array builder. """ + for s in rec_stages: builder.begin_list() for i in s: @@ -493,7 +378,7 @@ def _find_between(rec_stages, start, end, builder): builder.end_list() -def mask_tracks(rec_stages, start, end): +def _mask_rec_stages_between_start_end(rec_stages, start, end): """mask tracks where tracks.rec_stages are between start and end . Parameters @@ -514,3 +399,86 @@ def mask_tracks(rec_stages, start, end): builder = ak1.ArrayBuilder() _find_between(rec_stages, start, end, builder) return builder.snapshot() == 1 + + +@nb.jit(nopython=True) +def _find(rec_stages, stages, builder): + """construct an awkward1 array with the same structure as tracks.rec_stages. + When stages are found, the Array is filled with value 1, otherwise it is filled + with value 0. + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + stages : awkward1 Array + reconstruction stages of interest. + builder : awkward1.highlevel.ArrayBuilder + awkward1 Array builder. + """ + for s in rec_stages: + builder.begin_list() + for i in s: + num_stages = len(i) + if num_stages == len(stages): + found = 0 + for j in range(num_stages): + if i[j] == stages[j]: + found += 1 + if found == num_stages: + builder.append(1) + else: + builder.append(0) + else: + builder.append(0) + builder.end_list() + + +def _mask_explicit_rec_stages(rec_stages, stages): + """create a mask on tracks.rec_stages . + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + stages : list + reconstruction stages of interest. + + Returns + ------- + awkward1 Array + an awkward1 Array mask where True corresponds to the positions + where stages were found. False otherwise. + """ + builder = ak1.ArrayBuilder() + _find(rec_stages, ak1.Array(stages), builder) + return builder.snapshot() == 1 + + +def mask(rec_stages, stages=None, start=None, end=None): + """create a mask on tracks.rec_stages . + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + stages : list + reconstruction stages of interest. + + Returns + ------- + awkward1 Array + an awkward1 Array mask where True corresponds to the positions + where stages were found. False otherwise. + """ + if (stages is None) and (start is None) and (end is None): + raise KeyError("either stages or (start and end) must be specified") + + if (stages is not None) and (start is not None) and (end is not None): + raise ValueError("too many inputs are specified") + + if (stages is not None) and (start is None) and (end is None): + return _mask_explicit_rec_stages(rec_stages, stages) + + if (stages is None) and (start is not None) and (end is not None): + return _mask_rec_stages_between_start_end(rec_stages, start, end) \ No newline at end of file