From 6f08c38b2acae02731e873468563c906b1fedda5 Mon Sep 17 00:00:00 2001
From: Zineb Aly <zaly@km3net.de>
Date: Wed, 21 Oct 2020 22:53:51 +0200
Subject: [PATCH] debug fitinf + tests

---
 km3io/tools.py      | 20 ++++++++++++++------
 tests/test_tools.py | 17 ++++++++++++-----
 2 files changed, 26 insertions(+), 11 deletions(-)

diff --git a/km3io/tools.py b/km3io/tools.py
index a422264..c4c157a 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -145,14 +145,22 @@ def fitinf(fitparam, tracks):
     """
     fit = tracks.fitinf
     index = fitparam
-    try:
-        params = fit[count_nested(fit, axis=2) > index]
-        return ak1.Array([i[:, index] for i in params])
-    except ValueError:
-        # This is the case for tracks[:, 0] or any other selection.
+    if tracks.is_single and len(tracks) != 1:
         params = fit[count_nested(fit, axis=1) > index]
-        return params[:, index]
+        out = params[:, index]
 
+    if tracks.is_single and len(tracks) == 1:
+        out = fit[index]
+
+    else:
+        if len(tracks[0]) == 1:  # case of tracks slice with 1 track per event.
+            params = fit[count_nested(fit, axis=1) > index]
+            out = params[:, index]
+        else:
+            params = fit[count_nested(fit, axis=2) > index]
+            out = ak1.Array([i[:, index] for i in params])
+
+    return out
 
 def count_nested(arr, axis=0):
     """Count elements in a nested awkward Array.
diff --git a/tests/test_tools.py b/tests/test_tools.py
index d5cfd00..d576c67 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -39,17 +39,24 @@ class TestFitinf(unittest.TestCase):
         self.best = self.tracks[:, 0]
         self.best_fit = self.best.fitinf
 
-    def test_fitinf(self):
+    def test_fitinf_from_all_events(self):
         beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks)
-        best_beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best)
 
         assert beta[0][0] == self.fit[0][0][0]
         assert beta[0][1] == self.fit[0][1][0]
         assert beta[0][2] == self.fit[0][2][0]
 
-        assert best_beta[0] == self.best_fit[0][0]
-        assert best_beta[1] == self.best_fit[1][0]
-        assert best_beta[2] == self.best_fit[2][0]
+    def test_fitinf_from_one_event(self):
+        beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best)
+
+        assert beta[0] == self.best_fit[0][0]
+        assert beta[1] == self.best_fit[1][0]
+        assert beta[2] == self.best_fit[2][0]
+
+    def test_fitinf_from_one_event_and_one_track(self):
+        beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks[0][0])
+
+        assert beta == self.tracks[0][0].fitinf[0]
 
 
 class TestBestTrackSelection(unittest.TestCase):
-- 
GitLab