From 7358a9635d745b47712cd9850543ebd915afaca5 Mon Sep 17 00:00:00 2001 From: Zineb Aly <zaly@km3net.de> Date: Mon, 5 Oct 2020 18:01:19 +0200 Subject: [PATCH] add first prototype of best_track --- km3io/tools.py | 161 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/km3io/tools.py b/km3io/tools.py index be07110..f856e50 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -353,3 +353,164 @@ def get_multiplicity(tracks, rec_stages): tracks branch with the desired reconstruction stages only. """ return tracks[mask(tracks.rec_stages, rec_stages)] + + +def _longest_tracks(tracks): + if tracks.is_single: + stages_nesting_level = 1 + tracks_nesting_level = 0 + + else: + stages_nesting_level = 2 + tracks_nesting_level = 1 + + len_stages = count_nested(tracks.rec_stages, axis=stages_nesting_level) + longest = tracks[len_stages == ak1.max(len_stages, + axis=tracks_nesting_level)] + + return longest + + +def _max_lik_track(tracks): + if tracks.is_single: + tracks_nesting_level = 0 + else: + tracks_nesting_level = 1 + + return tracks[tracks.lik == ak1.max(tracks.lik, axis=tracks_nesting_level)] + + +def _best_track(tracks, start=None, end=None, stages=[]): + if (len(stages) > 0) and (start is None) and (end is None): + selected_tracks = tracks[mask(tracks.rec_stages, stages)] + + if (start is not None) and (end is not None) and (len(stages) == 0): + selected_tracks = tracks[mask_tracks(tracks.rec_stages, start, end)] + + if (start is None) and (end is None) and (len(stages) == 0): + # this should be modified to a log print and not just a simple print + print( + "No reconstruction stages were specified. The longest reco stages are selected" + ) + + selected_tracks = tracks + + if (len(stages) > 0) and ((start is not None) or (end is not None)): + raise ValueError("too many inputs are specified") + + return _max_lik_track(_longest_tracks(selected_tracks)) + + +def _JShower_stages(): + return set((krec.JSHOWERPREFIT, krec.JSHOWERPOSITIONFIT, + krec.JSHOWERCOMPLETEFIT, krec.JSHOWER_BJORKEN_Y, + krec.JSHOWERENERGYPREFIT, krec.JSHOWERPOINTSIMPLEX, + krec.JSHOWERDIRECTIONPREFIT)) + + +def _JMuon_stages(): + return set((krec.JMUONPREFIT, krec.JMUONSIMPLEX, krec.JMUONGANDALF, + krec.JMUONENERGY, krec.JMUONSTART, krec.JLINEFIT)) + + +def _AAShower_stages(): + return set((krec.AASHOWERFITPREFIT, krec.AASHOWERFITPOSITIONFIT, + krec.AASHOWERFITDIRECTIONENERGYFIT)) + + +def _DUSJShower_stages(): + return set((krec.DUSJSHOWERPREFIT, krec.DUSJSHOWERPOSITIONFIT, + krec.DUSJSHOWERCOMPLETEFIT)) + + +def _reco_stages(reco): + if reco == "JSHOWER": + stages = _JShower_stages() + + if reco == "JMUON": + stages = _JMuon_stages() + + if reco == "AASHOWER": + stages = _AAShower_stages() + + if reco == "DUSJSHOWER": + stages == _DUSJShower_stages() + + else: + raise KeyError( + f"{reco} must be either: 'JSHOWER', 'JMUON', 'AASHOWER', 'DUSJSHOWER'." + ) + + return stages + + +def best_JMuon(tracks, reco, start=None, end=None, stages=[]): + + valid_stages = _reco_stages(reco) + + if (start is not None) and (end is not None): + if (start not in valid_stages) or (end not in valid_stages): + raise KeyError( + f" start and/or end are not in JMuon reconstruction stages") + + if len(stages) > 0: + if not set(stages).issubset(valid_stages): + raise KeyError( + f"one (or all) of the stages in {stages} are not in {reco} reconstruction stages" + ) + + return _best_track(tracks, start=start, end=end, stages=stages) + + +@nb.jit(nopython=True) +def _find_between(rec_stages, start, end, builder): + """construct an awkward1 array with the same structure as tracks.rec_stages. + When stages are between start and end, the Array is filled with value 1, otherwise it is filled + with value 0. + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + start : int + start of reconstruction stages of interest. + end : int + end of reconstruction stages of interest. + builder : awkward1.highlevel.ArrayBuilder + awkward1 Array builder. + """ + for s in rec_stages: + builder.begin_list() + for i in s: + num_stages = len(i) + if num_stages != 0: + if (i[0] == start) and (i[-1] == end): + builder.append(1) + else: + builder.append(0) + else: + builder.append(0) + builder.end_list() + + +def mask_tracks(rec_stages, start, end): + """mask tracks where tracks.rec_stages are between start and end . + + Parameters + ---------- + rec_stages : awkward1 Array + tracks.rec_stages . + start : int + start of reconstruction stages of interest. + end : int + end of reconstruction stages of interest. + + Returns + ------- + awkward1 Array + an awkward1 Array mask where True corresponds to the positions + where stages were found. False otherwise. + """ + builder = ak1.ArrayBuilder() + _find_between(rec_stages, start, end, builder) + return builder.snapshot() == 1 -- GitLab