diff --git a/km3io/tools.py b/km3io/tools.py index 6472790288240f89ce0ea50d8238993a270b90bb..d4ce035901bfe4bdae886a5a68b72818f6694aae 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -262,88 +262,37 @@ def _max_lik_track(tracks): return tracks[tracks.lik == ak1.max(tracks.lik, axis=tracks_nesting_level)] -def _best_track(tracks, start=None, end=None, stages=[]): - if (len(stages) > 0) and (start is None) and (end is None): +def best_track(tracks, start=None, end=None, stages=None): + if (stages is None) and (start is None) and (end is None): selected_tracks = tracks[mask(tracks, stages=stages)] - if (start is not None) and (end is not None) and (len(stages) == 0): + if (start is not None) and (end is not None) and (stages is None): selected_tracks = tracks[mask(tracks, start=start, end=end)] - if (start is None) and (end is None) and (len(stages) == 0): + if (start is None) and (end is None) and (stages is None): # this should be modified to a log print and not just a simple print - print( - "No reconstruction stages were specified. The longest reco stages are selected" - ) - - selected_tracks = tracks + raise ValueError("No reconstruction stages were specified") - if (len(stages) > 0) and ((start is not None) or (end is not None)): + if (stages is not None) and ((start is not None) or (end is not None)): raise ValueError("too many inputs are specified") return _max_lik_track(_longest_tracks(selected_tracks)) def _JShower_stages(): - return set( - (krec.JSHOWERPREFIT, krec.JSHOWERPOSITIONFIT, krec.JSHOWERCOMPLETEFIT, - krec.JSHOWER_BJORKEN_Y, krec.JSHOWERENERGYPREFIT, - krec.JSHOWERPOINTSIMPLEX, krec.JSHOWERDIRECTIONPREFIT)) + return set(range(krec.JSHOWERBEGIN, krec.JSHOWEREND)) def _JMuon_stages(): - return set((krec.JMUONPREFIT, krec.JMUONSIMPLEX, krec.JMUONGANDALF, - krec.JMUONENERGY, krec.JMUONSTART, krec.JLINEFIT)) + return set(range(krec.JMUONBEGIN, krec.JMUONEND)) def _AAShower_stages(): - return set((krec.AASHOWERFITPREFIT, krec.AASHOWERFITPOSITIONFIT, - krec.AASHOWERFITDIRECTIONENERGYFIT)) + return set(range(krec.AASHOWERBEGIN, krec.AASHOWEREND)) def _DUSJShower_stages(): - return set((krec.DUSJSHOWERPREFIT, krec.DUSJSHOWERPOSITIONFIT, - krec.DUSJSHOWERCOMPLETEFIT)) - - -def _reco_stages(reco): - valid_recos = set(("JSHOWER", "JMUON", "AASHOWER", "DUSJSHOWER")) - - if reco == "JSHOWER": - stages = _JShower_stages() - - if reco == "JMUON": - stages = _JMuon_stages() - - if reco == "AASHOWER": - stages = _AAShower_stages() - - if reco == "DUSJSHOWER": - stages == _DUSJShower_stages() - - if reco not in valid_recos: - raise KeyError( - f"{reco} must be either: 'JSHOWER', 'JMUON', 'AASHOWER', 'DUSJSHOWER'." - ) - - return stages - - -def best_track(tracks, reco, start=None, end=None, stages=[]): - - valid_stages = _reco_stages(reco) - - if (start is not None) and (end is not None): - if (start not in valid_stages) or (end not in valid_stages): - raise ValueError( - f" start and/or end are not in {reco} reconstruction stages") - - if len(stages) > 0: - if not set(stages).issubset(valid_stages): - raise KeyError( - f"one (or all) of the stages in {stages} are not in {reco} reconstruction stages" - ) - - return _best_track(tracks, start=start, end=end, stages=stages) + return set(range(krec.DUSJSHOWERBEGIN, krec.DUSJSHOWEREND)) @nb.jit(nopython=True) @@ -377,6 +326,7 @@ def _find_between(rec_stages, start, end, builder): builder.append(0) builder.end_list() + @nb.jit(nopython=True) def _find_between_single(rec_stages, start, end, builder): """construct an awkward1 array with the same structure as tracks.rec_stages. @@ -553,3 +503,130 @@ def mask(tracks, stages=None, start=None, end=None): if (stages is None) and (start is not None) and (end is not None): return _mask_rec_stages_between_start_end(tracks, start, end) + + +@nb.jit(nopython=True) +def _find_in_range(rec_stages, start, end, builder): + """construct an awkward1 array with the same structure as tracks.rec_stages. + When stages are between start and end, the Array is filled with value 1, otherwise it is filled + with value 0. + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + start : int + start of reconstruction stages of interest. + end : int + end of reconstruction stages of interest. + builder : awkward1.highlevel.ArrayBuilder + awkward1 Array builder. + """ + valid_stages = set(range(start, end + 1)) + for s in rec_stages: + builder.begin_list() + for i in s: + num_stages = len(i) + if num_stages != 0: + found = 0 + for j in i: + if j in valid_stages: + found += 1 + if found == num_stages: + builder.append(1) + else: + builder.append(0) + else: + builder.append(0) + builder.end_list() + + +@nb.jit(nopython=True) +def _find_in_range_single(rec_stages, start, end, builder): + """construct an awkward1 array with the same structure as tracks.rec_stages. + When stages are between start and end, the Array is filled with value 1, otherwise it is filled + with value 0. + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + start : int + start of reconstruction stages of interest. + end : int + end of reconstruction stages of interest. + builder : awkward1.highlevel.ArrayBuilder + awkward1 Array builder. + """ + + builder.begin_list() + valid_stages = set(range(start, end + 1)) + for s in rec_stages: + num_stages = len(s) + if num_stages != 0: + found = 0 + for i in s: + if i in valid_stages: + found += 1 + if found == num_stages: + builder.append(1) + else: + builder.append(0) + else: + builder.append(0) + builder.end_list() + + +def _mask_rec_stages_in_range_start_end(tracks, start, end): + """mask tracks where tracks.rec_stages are between start and end . + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + start : int + start of reconstruction stages of interest. + end : int + end of reconstruction stages of interest. + + Returns + ------- + awkward1 Array + an awkward1 Array mask where True corresponds to the positions + where stages were found. False otherwise. + """ + builder = ak1.ArrayBuilder() + if tracks.is_single: + _find_in_range_single(tracks.rec_stages, start, end, builder) + return (builder.snapshot() == 1)[0] + else: + _find_in_range(tracks.rec_stages, start, end, builder) + return builder.snapshot() == 1 + + +def best_jmuon(tracks): + mask = _mask_rec_stages_in_range_start_end(tracks, krec.JMUONBEGIN, + krec.JMUONEND) + + return _max_lik_track(_longest_tracks(tracks[mask])) + + +def best_jshower(tracks): + mask = _mask_rec_stages_in_range_start_end(tracks, krec.JSHOWERBEGIN, + krec.JSHOWEREND) + + return _max_lik_track(_longest_tracks(tracks[mask])) + + +def best_aashower(tracks): + mask = _mask_rec_stages_in_range_start_end(tracks, krec.AASHOWERBEGIN, + krec.AASHOWEREND) + + return _max_lik_track(_longest_tracks(tracks[mask])) + + +def best_dusjshower(tracks): + mask = _mask_rec_stages_in_range_start_end(tracks, krec.DUSJSHOWERBEGIN, + krec.DUSJSHOWEREND) + + return _max_lik_track(_longest_tracks(tracks[mask]))