From 7358a9635d745b47712cd9850543ebd915afaca5 Mon Sep 17 00:00:00 2001
From: Zineb Aly <zaly@km3net.de>
Date: Mon, 5 Oct 2020 18:01:19 +0200
Subject: [PATCH] add first prototype of best_track

---
 km3io/tools.py | 161 +++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 161 insertions(+)

diff --git a/km3io/tools.py b/km3io/tools.py
index be07110..f856e50 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -353,3 +353,164 @@ def get_multiplicity(tracks, rec_stages):
         tracks branch with the desired reconstruction stages only.
     """
     return tracks[mask(tracks.rec_stages, rec_stages)]
+
+
+def _longest_tracks(tracks):
+    if tracks.is_single:
+        stages_nesting_level = 1
+        tracks_nesting_level = 0
+
+    else:
+        stages_nesting_level = 2
+        tracks_nesting_level = 1
+
+    len_stages = count_nested(tracks.rec_stages, axis=stages_nesting_level)
+    longest = tracks[len_stages == ak1.max(len_stages,
+                                           axis=tracks_nesting_level)]
+
+    return longest
+
+
+def _max_lik_track(tracks):
+    if tracks.is_single:
+        tracks_nesting_level = 0
+    else:
+        tracks_nesting_level = 1
+
+    return tracks[tracks.lik == ak1.max(tracks.lik, axis=tracks_nesting_level)]
+
+
+def _best_track(tracks, start=None, end=None, stages=[]):
+    if (len(stages) > 0) and (start is None) and (end is None):
+        selected_tracks = tracks[mask(tracks.rec_stages, stages)]
+
+    if (start is not None) and (end is not None) and (len(stages) == 0):
+        selected_tracks = tracks[mask_tracks(tracks.rec_stages, start, end)]
+
+    if (start is None) and (end is None) and (len(stages) == 0):
+        # this should be modified to a log print and not just a simple print
+        print(
+            "No reconstruction stages were specified. The longest reco stages are selected"
+        )
+
+        selected_tracks = tracks
+
+    if (len(stages) > 0) and ((start is not None) or (end is not None)):
+        raise ValueError("too many inputs are specified")
+
+    return _max_lik_track(_longest_tracks(selected_tracks))
+
+
+def _JShower_stages():
+    return set((krec.JSHOWERPREFIT, krec.JSHOWERPOSITIONFIT,
+               krec.JSHOWERCOMPLETEFIT, krec.JSHOWER_BJORKEN_Y,
+               krec.JSHOWERENERGYPREFIT, krec.JSHOWERPOINTSIMPLEX,
+               krec.JSHOWERDIRECTIONPREFIT))
+
+
+def _JMuon_stages():
+    return set((krec.JMUONPREFIT, krec.JMUONSIMPLEX, krec.JMUONGANDALF,
+               krec.JMUONENERGY, krec.JMUONSTART, krec.JLINEFIT))
+
+
+def _AAShower_stages():
+    return set((krec.AASHOWERFITPREFIT, krec.AASHOWERFITPOSITIONFIT,
+               krec.AASHOWERFITDIRECTIONENERGYFIT))
+
+
+def _DUSJShower_stages():
+    return set((krec.DUSJSHOWERPREFIT, krec.DUSJSHOWERPOSITIONFIT,
+               krec.DUSJSHOWERCOMPLETEFIT))
+
+
+def _reco_stages(reco):
+    if reco == "JSHOWER":
+        stages = _JShower_stages()
+
+    if reco == "JMUON":
+        stages = _JMuon_stages()
+
+    if reco == "AASHOWER":
+        stages = _AAShower_stages()
+
+    if reco == "DUSJSHOWER":
+        stages == _DUSJShower_stages()
+
+    else:
+        raise KeyError(
+            f"{reco} must be either: 'JSHOWER', 'JMUON', 'AASHOWER', 'DUSJSHOWER'."
+        )
+
+    return stages
+
+
+def best_JMuon(tracks, reco, start=None, end=None, stages=[]):
+
+    valid_stages = _reco_stages(reco)
+
+    if (start is not None) and (end is not None):
+        if (start not in valid_stages) or (end not in valid_stages):
+            raise KeyError(
+                f" start and/or end are not in JMuon reconstruction stages")
+
+    if len(stages) > 0:
+        if not set(stages).issubset(valid_stages):
+            raise KeyError(
+                f"one (or all) of the stages in {stages} are not in {reco} reconstruction stages"
+            )
+
+    return _best_track(tracks, start=start, end=end, stages=stages)
+
+
+@nb.jit(nopython=True)
+def _find_between(rec_stages, start, end, builder):
+    """construct an awkward1 array with the same structure as tracks.rec_stages.
+    When stages are between start and end, the Array is filled with value 1, otherwise it is filled
+    with value 0.
+
+    Parameters
+    ----------
+    rec_stages : awkward1 Array
+        tracks.rec_stages .
+    start : int
+        start of reconstruction stages of interest.
+    end : int
+        end of reconstruction stages of interest.
+    builder : awkward1.highlevel.ArrayBuilder
+        awkward1 Array builder.
+    """
+    for s in rec_stages:
+        builder.begin_list()
+        for i in s:
+            num_stages = len(i)
+            if num_stages != 0:
+                if (i[0] == start) and (i[-1] == end):
+                    builder.append(1)
+                else:
+                    builder.append(0)
+            else:
+                builder.append(0)
+        builder.end_list()
+
+
+def mask_tracks(rec_stages, start, end):
+    """mask tracks where tracks.rec_stages  are between start and end .
+
+    Parameters
+    ----------
+    rec_stages : awkward1 Array
+        tracks.rec_stages .
+    start : int
+        start of reconstruction stages of interest.
+    end : int
+        end of reconstruction stages of interest.
+
+    Returns
+    -------
+    awkward1 Array
+        an awkward1 Array mask where True corresponds to the positions
+        where stages were found. False otherwise.
+    """
+    builder = ak1.ArrayBuilder()
+    _find_between(rec_stages, start, end, builder)
+    return builder.snapshot() == 1
-- 
GitLab