From 0c6f0bb130cb25fabc1a9f236261f1f99f14687b Mon Sep 17 00:00:00 2001
From: Zineb Aly <zaly@km3net.de>
Date: Tue, 6 Oct 2020 12:01:53 +0200
Subject: [PATCH] adapt rec stages mask and best track to one track

---
 km3io/tools.py | 94 ++++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 83 insertions(+), 11 deletions(-)

diff --git a/km3io/tools.py b/km3io/tools.py
index 47d5d1f..0c60f40 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -264,10 +264,10 @@ def _max_lik_track(tracks):
 
 def _best_track(tracks, start=None, end=None, stages=[]):
     if (len(stages) > 0) and (start is None) and (end is None):
-        selected_tracks = tracks[mask(tracks.rec_stages, stages=stages)]
+        selected_tracks = tracks[mask(tracks, stages=stages)]
 
     if (start is not None) and (end is not None) and (len(stages) == 0):
-        selected_tracks = tracks[mask(tracks.rec_stages, start=start, end=end)]
+        selected_tracks = tracks[mask(tracks, start=start, end=end)]
 
     if (start is None) and (end is None) and (len(stages) == 0):
         # this should be modified to a log print and not just a simple print
@@ -378,7 +378,38 @@ def _find_between(rec_stages, start, end, builder):
         builder.end_list()
 
 
-def _mask_rec_stages_between_start_end(rec_stages, start, end):
+@nb.jit(nopython=True)
+def _find_between_single(rec_stages, start, end, builder):
+    """construct an awkward1 array with the same structure as tracks.rec_stages.
+    When stages are between start and end, the Array is filled with value 1, otherwise it is filled
+    with value 0.
+
+    Parameters
+    ----------
+    rec_stages : awkward1 Array
+        tracks.rec_stages .
+    start : int
+        start of reconstruction stages of interest.
+    end : int
+        end of reconstruction stages of interest.
+    builder : awkward1.highlevel.ArrayBuilder
+        awkward1 Array builder.
+    """
+
+    builder.begin_list()
+    for s in rec_stages:
+        num_stages = len(s)
+        if num_stages != 0:
+            if (s[0] == start) and (s[-1] == end):
+                builder.append(1)
+            else:
+                builder.append(0)
+        else:
+            builder.append(0)
+    builder.end_list()
+
+
+def _mask_rec_stages_between_start_end(tracks, start, end):
     """mask tracks where tracks.rec_stages  are between start and end .
 
     Parameters
@@ -397,8 +428,12 @@ def _mask_rec_stages_between_start_end(rec_stages, start, end):
         where stages were found. False otherwise.
     """
     builder = ak1.ArrayBuilder()
-    _find_between(rec_stages, start, end, builder)
-    return builder.snapshot() == 1
+    if tracks.is_single:
+        _find_between_single(tracks.rec_stages, start, end, builder)
+        return (builder.snapshot() == 1)[0]
+    else:
+        _find_between(tracks.rec_stages, start, end, builder)
+        return builder.snapshot() == 1
 
 
 @nb.jit(nopython=True)
@@ -434,7 +469,39 @@ def _find(rec_stages, stages, builder):
         builder.end_list()
 
 
-def _mask_explicit_rec_stages(rec_stages, stages):
+@nb.jit(nopython=True)
+def _find_single(rec_stages, stages, builder):
+    """construct an awkward1 array with the same structure as tracks.rec_stages.
+    When stages are found, the Array is filled with value 1, otherwise it is filled
+    with value 0.
+
+    Parameters
+    ----------
+    rec_stages : awkward1 Array
+        tracks.rec_stages .
+    stages : awkward1 Array
+        reconstruction stages of interest.
+    builder : awkward1.highlevel.ArrayBuilder
+        awkward1 Array builder.
+    """
+    builder.begin_list()
+    for s in rec_stages:
+        num_stages = len(s)
+        if num_stages == len(stages):
+            found = 0
+            for j in range(num_stages):
+                if s[j] == stages[j]:
+                    found += 1
+            if found == num_stages:
+                builder.append(1)
+            else:
+                builder.append(0)
+        else:
+            builder.append(0)
+    builder.end_list()
+
+
+def _mask_explicit_rec_stages(tracks, stages):
     """create a mask on tracks.rec_stages .
 
     Parameters
@@ -450,12 +517,17 @@ def _mask_explicit_rec_stages(rec_stages, stages):
         an awkward1 Array mask where True corresponds to the positions
         where stages were found. False otherwise.
     """
+    # rec_stages = tracks.rec_stages
     builder = ak1.ArrayBuilder()
-    _find(rec_stages, ak1.Array(stages), builder)
-    return builder.snapshot() == 1
+    if tracks.is_single:
+        _find_single(tracks.rec_stages, ak1.Array(stages), builder)
+        return (builder.snapshot() == 1)[0]
+    else:
+        _find(tracks.rec_stages, ak1.Array(stages), builder)
+        return builder.snapshot() == 1
 
 
-def mask(rec_stages, stages=None, start=None, end=None):
+def mask(tracks, stages=None, start=None, end=None):
     """create a mask on tracks.rec_stages .
 
     Parameters
@@ -478,7 +550,7 @@ def mask(rec_stages, stages=None, start=None, end=None):
         raise ValueError("too many inputs are specified")
 
     if (stages is not None) and (start is None) and (end is None):
-        return _mask_explicit_rec_stages(rec_stages, stages)
+        return _mask_explicit_rec_stages(tracks, stages)
 
     if (stages is None) and (start is not None) and (end is not None):
-        return _mask_rec_stages_between_start_end(rec_stages, start, end)
+        return _mask_rec_stages_between_start_end(tracks, start, end)
-- 
GitLab