From 1c6883ab65458714114d0f10623c0f60c1013542 Mon Sep 17 00:00:00 2001
From: zineb aly <aly.zineb.az@gmail.com>
Date: Fri, 5 Jun 2020 10:35:05 +0200
Subject: [PATCH] reorganise tools

---
 km3io/offline.py      |   2 +
 km3io/tools.py        | 155 ++++++++++++++++++++++++++++++++++++++++++
 tests/test_offline.py |  93 +------------------------
 tests/test_tools.py   |  99 ++++++++++++++++++++++++++-
 4 files changed, 256 insertions(+), 93 deletions(-)

diff --git a/km3io/offline.py b/km3io/offline.py
index e51f319..9026e68 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -164,8 +164,10 @@ def best_track(tracks, strategy="first", rec_type=None, rec_stages=None):
     """
     if strategy == "first":
         return tracks[:, 0]
+
     if strategy == "rec_stages" and rec_stages is not None:
         return tracks[mask(tracks.rec_stages, rec_stages)]
+
     if strategy == "default" and rec_type is not None:
         return tracks[tracks.rec_type == reconstruction[rec_type]][
             tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0]
diff --git a/km3io/tools.py b/km3io/tools.py
index f664db0..acbb258 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -1,8 +1,11 @@
 #!/usr/bin/env python3
 import numba as nb
 import numpy as np
+import awkward1 as ak1
 import uproot
 
+from .definitions import fitparameters, reconstruction
+
 # 110 MB based on the size of the largest basket found so far in km3net
 BASKET_CACHE_SIZE = 110 * 1024**2
 BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
@@ -85,3 +88,155 @@ def uniquecount(array, dtype=np.int64):
         else:
             out[i] = len(unique(sub_array))
     return out
+
+
+def rec_types():
+    """name of the reconstruction type as defined in the official
+    KM3NeT-Dataformat.
+
+    Returns
+    -------
+    dict_keys
+        reconstruction types.
+    """
+    return reconstruction.keys()
+
+
+def fitinf(fitparam, tracks):
+    """Access fit parameters in tracks.fitinf.
+
+    Parameters
+    ----------
+    fitparam : str
+        the fit parameter name according to fitparameters defined in
+        KM3NeT-Dataformat.
+    tracks : class km3io.offline.OfflineBranch
+        the tracks class. both full tracks branch or a slice of the
+        tracks branch (example tracks[:, 0]) work.
+
+    Returns
+    -------
+    awkward array
+        awkward array of the values of the fit parameter requested.
+    """
+    fit = tracks.fitinf
+    index = fitparameters[fitparam]
+    try:
+        params = fit[count_nested(fit, axis=2) > index]
+        return ak1.Array([i[:, index] for i in params])
+    except ValueError:
+        # This is the case for tracks[:, 0] or any other selection.
+        params = fit[count_nested(fit, axis=1) > index]
+        return params[:, index]
+
+
+def fitparams():
+    """name of the fit parameters as defined in the official
+    KM3NeT-Dataformat.
+
+    Returns
+    -------
+    dict_keys
+        fit parameters keys.
+    """
+    return fitparameters.keys()
+
+
+def count_nested(Array, axis=0):
+    """count elements in a nested awkward Array.
+
+    Parameters
+    ----------
+    Array : Awkward1 Array
+        Array of data. Example tracks.fitinf or tracks.rec_stages.
+    axis : int, optional
+        axis = 0: to count elements in the outmost level of nesting.
+        axis = 1: to count elements in the first level of nesting.
+        axis = 2: to count elements in the second level of nesting.
+
+    Returns
+    -------
+    awkward1 Array or int
+        counts of elements found in a nested awkward1 Array.
+    """
+    if axis == 0:
+        return ak1.num(Array, axis=0)
+    if axis == 1:
+        return ak1.num(Array, axis=1)
+    if axis == 2:
+        return ak1.count(Array, axis=2)
+
+
+@nb.jit(nopython=True)
+def _find(rec_stages, stages, builder):
+    """construct an awkward1 array with the same structure as tracks.rec_stages.
+    When stages are found, the Array is filled with value 1, otherwise it is filled
+    with value 0.
+
+    Parameters
+    ----------
+    rec_stages : awkward1 Array
+        tracks.rec_stages .
+    stages : awkward1 Array
+        reconstruction stages of interest.
+    builder : awkward1.highlevel.ArrayBuilder
+        awkward1 Array builder.
+    """
+    for s in rec_stages:
+        builder.begin_list()
+        for i in s:
+            num_stages = len(i)
+            if num_stages == len(stages):
+                found = 0
+                for j in range(num_stages):
+                    if i[j] == stages[j]:
+                        found += 1
+                if found == num_stages:
+                    builder.append(1)
+                else:
+                    builder.append(0)
+            else:
+                builder.append(0)
+        builder.end_list()
+
+
+def mask(rec_stages, stages):
+    """create a mask on tracks.rec_stages .
+
+    Parameters
+    ----------
+    rec_stages : awkward1 Array
+        tracks.rec_stages .
+    stages : list
+        reconstruction stages of interest.
+
+    Returns
+    -------
+    awkward1 Array
+        an awkward1 Array mask where True corresponds to the positions
+        where stages were found. False otherwise.
+    """
+    builder = ak1.ArrayBuilder()
+    _find(rec_stages, ak1.Array(stages), builder)
+    return builder.snapshot() == 1
+
+
+def best_track(tracks, strategy="first", rec_type=None, rec_stages=None):
+    """best track selection based on different strategies
+
+    Parameters
+    ----------
+    tracks : class km3io.offline.OfflineBranch
+        the tracks branch.
+    strategy : str
+        the trategy desired to select the best tracks.
+    """
+    if strategy == "first":
+        return tracks[:, 0]
+
+    if strategy == "rec_stages" and rec_stages is not None:
+        return tracks[mask(tracks.rec_stages, rec_stages)]
+
+    if strategy == "default" and rec_type is not None:
+        return tracks[tracks.rec_type == reconstruction[rec_type]][
+            tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0]
diff --git a/tests/test_offline.py b/tests/test_offline.py
index 0f4f42b..4fb616b 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -1,10 +1,9 @@
 import unittest
 import numpy as np
-import awkward1 as ak1
 from pathlib import Path
 
 from km3io import OfflineReader
-from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask, best_track, rec_types
+from km3io.offline import _nested_mapper, Header
 
 SAMPLES_DIR = Path(__file__).parent / 'samples'
 OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'km3net_offline.root')
@@ -16,96 +15,6 @@ OFFLINE_MC_TRACK_USR = OfflineReader(
 OFFLINE_NUMUCC = OfflineReader(SAMPLES_DIR / "numucc.root")  # with mc data
 
 
-class TestFitinf(unittest.TestCase):
-    def setUp(self):
-        self.tracks = OFFLINE_FILE.events.tracks
-        self.fit = self.tracks.fitinf
-        self.best = self.tracks[:, 0]
-        self.best_fit = self.best.fitinf
-
-    def test_fitinf(self):
-        beta = fitinf('JGANDALF_BETA0_RAD', self.tracks)
-        best_beta = fitinf('JGANDALF_BETA0_RAD', self.best)
-
-        assert beta[0][0] == self.fit[0][0][0]
-        assert beta[0][1] == self.fit[0][1][0]
-        assert beta[0][2] == self.fit[0][2][0]
-
-        assert best_beta[0] == self.best_fit[0][0]
-        assert best_beta[1] == self.best_fit[1][0]
-        assert best_beta[2] == self.best_fit[2][0]
-
-    def test_fitparams(self):
-        keys = set(fitparams())
-
-        assert "JGANDALF_BETA0_RAD" in keys
-
-
-class TestRecoTypes(unittest.TestCase):
-    def test_reco_types(self):
-        keys = set(rec_types())
-
-        assert "JPP_RECONSTRUCTION_TYPE" in keys
-
-
-class TestBestTrack(unittest.TestCase):
-    def setUp(self):
-        self.tracks = OFFLINE_FILE.events.tracks
-
-    def test_best_tracks(self):
-        first_tracks = best_track(self.tracks, strategy="first")
-        rec_stages_tracks = best_track(self.tracks,
-                                       strategy="rec_stages",
-                                       rec_stages=[1, 3, 5, 4])
-        default_best = best_track(self.tracks,
-                                  strategy="default",
-                                  rec_type="JPP_RECONSTRUCTION_TYPE")
-
-        assert first_tracks.dir_z[0] == self.tracks.dir_z[0][0]
-        assert first_tracks.dir_x[1] == self.tracks.dir_x[1][0]
-
-        assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4]
-        assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4]
-
-        assert default_best.lik[0] == ak1.max(self.tracks.lik[0])
-        assert default_best.lik[1] == ak1.max(self.tracks.lik[1])
-        assert default_best.rec_type[0] == 4000
-
-
-class TestCountNested(unittest.TestCase):
-    def test_count_nested(self):
-        fit = OFFLINE_FILE.events.tracks.fitinf
-
-        assert count_nested(fit, axis=0) == 10
-        assert count_nested(fit, axis=1)[0:4] == ak1.Array([56, 55, 56, 56])
-        assert count_nested(fit, axis=2)[0][0:4] == ak1.Array([17, 11, 8, 8])
-
-
-class TestRecStagesMasks(unittest.TestCase):
-    def setUp(self):
-        self.nested = ak1.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
-                                 [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
-
-    def test_find(self):
-        builder = ak1.ArrayBuilder()
-        _find(self.nested, ak1.Array([1, 2, 3]), builder)
-        labels = builder.snapshot()
-
-        assert labels[0][0] == 1
-        assert labels[0][1] == 1
-        assert labels[0][2] == 0
-        assert labels[1][0] == 0
-
-    def test_mask(self):
-        rec_stages = OFFLINE_FILE.events.tracks.rec_stages
-        stages = [1, 3, 5, 4]
-        masks = mask(rec_stages, stages)
-
-        assert masks[0][0] == all(rec_stages[0][0] == ak1.Array(stages))
-        assert masks[1][0] == all(rec_stages[1][0] == ak1.Array(stages))
-        assert masks[0][1] == False
-
-
 class TestOfflineReader(unittest.TestCase):
     def setUp(self):
         self.r = OFFLINE_FILE
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 5306f42..325e681 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -3,8 +3,105 @@
 import unittest
 import awkward1 as ak
 import numpy as np
+
+from pathlib import Path
+from km3io import OfflineReader
 from km3io.tools import (to_num, cached_property, unfold_indices, unique,
-                         uniquecount)
+                         uniquecount, fitinf, fitparams, count_nested, _find,
+                         mask, best_track, rec_types)
+
+SAMPLES_DIR = Path(__file__).parent / 'samples'
+OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'km3net_offline.root')
+
+
+class TestFitinf(unittest.TestCase):
+    def setUp(self):
+        self.tracks = OFFLINE_FILE.events.tracks
+        self.fit = self.tracks.fitinf
+        self.best = self.tracks[:, 0]
+        self.best_fit = self.best.fitinf
+
+    def test_fitinf(self):
+        beta = fitinf('JGANDALF_BETA0_RAD', self.tracks)
+        best_beta = fitinf('JGANDALF_BETA0_RAD', self.best)
+
+        assert beta[0][0] == self.fit[0][0][0]
+        assert beta[0][1] == self.fit[0][1][0]
+        assert beta[0][2] == self.fit[0][2][0]
+
+        assert best_beta[0] == self.best_fit[0][0]
+        assert best_beta[1] == self.best_fit[1][0]
+        assert best_beta[2] == self.best_fit[2][0]
+
+    def test_fitparams(self):
+        keys = set(fitparams())
+
+        assert "JGANDALF_BETA0_RAD" in keys
+
+
+class TestRecoTypes(unittest.TestCase):
+    def test_reco_types(self):
+        keys = set(rec_types())
+
+        assert "JPP_RECONSTRUCTION_TYPE" in keys
+
+
+class TestBestTrack(unittest.TestCase):
+    def setUp(self):
+        self.tracks = OFFLINE_FILE.events.tracks
+
+    def test_best_tracks(self):
+        first_tracks = best_track(self.tracks, strategy="first")
+        rec_stages_tracks = best_track(self.tracks,
+                                       strategy="rec_stages",
+                                       rec_stages=[1, 3, 5, 4])
+        default_best = best_track(self.tracks,
+                                  strategy="default",
+                                  rec_type="JPP_RECONSTRUCTION_TYPE")
+
+        assert first_tracks.dir_z[0] == self.tracks.dir_z[0][0]
+        assert first_tracks.dir_x[1] == self.tracks.dir_x[1][0]
+
+        assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4]
+        assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4]
+
+        assert default_best.lik[0] == ak.max(self.tracks.lik[0])
+        assert default_best.lik[1] == ak.max(self.tracks.lik[1])
+        assert default_best.rec_type[0] == 4000
+
+
+class TestCountNested(unittest.TestCase):
+    def test_count_nested(self):
+        fit = OFFLINE_FILE.events.tracks.fitinf
+
+        assert count_nested(fit, axis=0) == 10
+        assert count_nested(fit, axis=1)[0:4] == ak.Array([56, 55, 56, 56])
+        assert count_nested(fit, axis=2)[0][0:4] == ak.Array([17, 11, 8, 8])
+
+
+class TestRecStagesMasks(unittest.TestCase):
+    def setUp(self):
+        self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
+                                [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
+
+    def test_find(self):
+        builder = ak.ArrayBuilder()
+        _find(self.nested, ak.Array([1, 2, 3]), builder)
+        labels = builder.snapshot()
+
+        assert labels[0][0] == 1
+        assert labels[0][1] == 1
+        assert labels[0][2] == 0
+        assert labels[1][0] == 0
+
+    def test_mask(self):
+        rec_stages = OFFLINE_FILE.events.tracks.rec_stages
+        stages = [1, 3, 5, 4]
+        masks = mask(rec_stages, stages)
+
+        assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages))
+        assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
+        assert masks[0][1] == False
 
 
 class TestUnique(unittest.TestCase):
-- 
GitLab