From 27e9d8cd70a9816d36fa647e4157cdfe30937bfd Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Fri, 11 Dec 2020 10:59:36 +0100
Subject: [PATCH] Fix fitinf

---
 km3io/tools.py      | 31 +++++-------------
 tests/test_tools.py | 76 +++++++++++++++++++++++++++------------------
 2 files changed, 54 insertions(+), 53 deletions(-)

diff --git a/km3io/tools.py b/km3io/tools.py
index 1246cd1..80b612e 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -134,34 +134,19 @@ def fitinf(fitparam, tracks):
     ----------
     fitparam : int
         the fit parameter key according to fitparameters defined in
-        KM3NeT-Dataformat.
-    tracks : km3io.offline.OfflineBranch
-        the tracks class. both full tracks branch or a slice of the
-        tracks branch (example tracks[:, 0]) work.
+        KM3NeT-Dataformat (see km3io.definitions.fitparameters).
+    tracks : ak.Array or km3io.rootio.Branch
+        reconstructed tracks with .fitinf attribute
 
     Returns
     -------
     awkward1.Array
-        awkward array of the values of the fit parameter requested.
+        awkward array of the values of the fit parameter requested. Missing
+        values are set to NaN.
     """
     fit = tracks.fitinf
-    index = fitparam
-    if tracks.is_single and len(tracks) != 1:
-        params = fit[count_nested(fit, axis=1) > index]
-        out = params[:, index]
-
-    if tracks.is_single and len(tracks) == 1:
-        out = fit[index]
-
-    else:
-        if len(tracks[0]) == 1:  # case of tracks slice with 1 track per event.
-            params = fit[count_nested(fit, axis=1) > index]
-            out = params[:, index]
-        else:
-            params = fit[count_nested(fit, axis=2) > index]
-            out = ak.Array([i[:, index] for i in params])
-
-    return out
+    nonempty = ak.num(fit, axis=-1) > 0
+    return ak.fill_none(fit.mask[nonempty][..., 0], np.nan)
 
 
 def count_nested(arr, axis=0):
@@ -487,7 +472,7 @@ def is_cc(fobj):
             cc_flag = w2list[:, kw2gen.W2LIST_GENHEN_CC]
             out = cc_flag > 0
         else:
-            raise ValueError(f"simulation program {program} is not implemented.")
+            raise NotImplementedError(f"don't know how to determine the CCness of {program} files.")
 
     return out
 
diff --git a/tests/test_tools.py b/tests/test_tools.py
index f7aa9de..6610a4b 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -50,7 +50,6 @@ class TestFitinf(unittest.TestCase):
         self.best = self.tracks[:, 0]
         self.best_fit = self.best.fitinf
 
-    @unittest.skip
     def test_fitinf_from_all_events(self):
         beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks)
 
@@ -58,7 +57,6 @@ class TestFitinf(unittest.TestCase):
         assert beta[0][1] == self.fit[0][1][0]
         assert beta[0][2] == self.fit[0][2][0]
 
-    @unittest.skip
     def test_fitinf_from_one_event(self):
         beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best)
 
@@ -66,12 +64,6 @@ class TestFitinf(unittest.TestCase):
         assert beta[1] == self.best_fit[1][0]
         assert beta[2] == self.best_fit[2][0]
 
-    @unittest.skip
-    def test_fitinf_from_one_event_and_one_track(self):
-        beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks[0][0])
-
-        assert beta == self.tracks[0][0].fitinf[0]
-
 
 class TestBestTrackSelection(unittest.TestCase):
     def setUp(self):
@@ -85,7 +77,6 @@ class TestBestTrackSelection(unittest.TestCase):
 
         assert len(best) == 10
 
-        # TODO: nested items, no idea how to solve this...
         assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
         assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
         assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
@@ -118,7 +109,6 @@ class TestBestTrackSelection(unittest.TestCase):
 
         assert len(best) == 10
 
-        # TODO: nested items, no idea how to solve this...
         assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
         assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
         assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
@@ -138,7 +128,6 @@ class TestBestTrackSelection(unittest.TestCase):
 
         assert len(best) == 10
 
-        # TODO: nested items, no idea how to solve this...
         assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
         assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
         assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
@@ -208,52 +197,81 @@ class TestBestTrackSelection(unittest.TestCase):
         assert best2.lik == ak.max(tracks_slice.lik)
         assert best2.rec_stages.tolist() == [1, 3, 5, 4]
 
-    @unittest.skip
     def test_best_track_on_slices_with_start_end_one_event(self):
         tracks_slice = self.one_event.tracks[0:5]
         best = best_track(tracks_slice, startend=(1, 4))
 
-        assert len(best.lik) == 1
         assert best.lik == ak.max(tracks_slice.lik)
-        assert best.rec_stages[0][0] == 1
-        assert best.rec_stages[0][-1] == 4
+        assert best.rec_stages[0] == 1
+        assert best.rec_stages[-1] == 4
 
-    @unittest.skip
     def test_best_track_on_slices_with_explicit_rec_stages_one_event(self):
         tracks_slice = self.one_event.tracks[0:5]
         best = best_track(tracks_slice, stages=[1, 3, 5, 4])
 
         assert best.lik == ak.max(tracks_slice.lik)
-        assert best.rec_stages[0][0] == 1
-        assert best.rec_stages[0][-1] == 4
+        assert best.rec_stages[0] == 1
+        assert best.rec_stages[-1] == 4
 
-    @unittest.skip
     def test_best_track_on_slices_multiple_events(self):
-        tracks_slice = self.events.tracks[0:5]
+        tracks_slice = self.events[0:5].tracks
 
         # stages in list
         best = best_track(tracks_slice, stages=[1, 3, 5, 4])
 
         assert len(best) == 5
 
-        assert best.lik == ak.max(tracks_slice.lik)
-        assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
+        assert np.allclose(
+            best.lik.tolist(),
+            [
+                294.6407542676734,
+                96.75133289411137,
+                560.2775306614813,
+                278.2872951665753,
+                99.59098153341449,
+            ],
+        )
+        for i in range(len(best)):
+            assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
 
         # stages in set
-        best = best_track(tracks_slice, stages={1, 3, 4, 5})
+        best = best_track(tracks_slice, stages={3, 4, 5})
 
         assert len(best) == 5
 
-        assert best.lik == ak.max(tracks_slice.lik)
-        assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
+        assert np.allclose(
+            best.lik.tolist(),
+            [
+                294.6407542676734,
+                96.75133289411137,
+                560.2775306614813,
+                278.2872951665753,
+                99.59098153341449,
+            ],
+        )
+        for i in range(len(best)):
+            assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
 
         # using start and end
-        best = best_track(tracks_slice, startend=(1, 4))
+        start, end = (1, 4)
+        best = best_track(tracks_slice, startend=(start, end))
 
         assert len(best) == 5
 
-        assert best.lik == ak.max(tracks_slice.lik)
-        assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
+        assert np.allclose(
+            best.lik.tolist(),
+            [
+                294.6407542676734,
+                96.75133289411137,
+                560.2775306614813,
+                278.2872951665753,
+                99.59098153341449,
+            ],
+        )
+        for i in range(len(best)):
+            rs = best.rec_stages[i].tolist()
+            assert rs[0] == start
+            assert rs[-1] == end
 
     def test_best_track_raises_when_unknown_stages(self):
         with self.assertRaises(ValueError):
@@ -265,7 +283,6 @@ class TestBestTrackSelection(unittest.TestCase):
 
 
 class TestBestJmuon(unittest.TestCase):
-    @unittest.skip
     def test_best_jmuon(self):
         best = best_jmuon(OFFLINE_FILE.events.tracks)
 
@@ -537,7 +554,6 @@ class TestUnfoldIndices(unittest.TestCase):
 
 
 class TestIsCC(unittest.TestCase):
-    @unittest.skip
     def test_is_cc(self):
         NC_file = is_cc(GENHEN_OFFLINE_FILE)
         CC_file = is_cc(GSEAGEN_OFFLINE_FILE)
-- 
GitLab