From 8f2d835755be4c18b1e338ea739747187cb489b9 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Wed, 9 Dec 2020 18:15:04 +0100
Subject: [PATCH] Fix doubly nested best tracks

---
 km3io/tools.py      |  2 +-
 tests/test_tools.py | 43 +++++++++++++++++++++----------------------
 2 files changed, 22 insertions(+), 23 deletions(-)

diff --git a/km3io/tools.py b/km3io/tools.py
index 8db9a5b..2e00758 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -288,7 +288,7 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
         return namedtuple("BestTrack", out.fields)(
             *[getattr(out, a)[0] for a in out.fields]
         )
-    return out
+    return out[:, 0]
 
 
 def mask(arr, sequence=None, startend=None, minmax=None, atleast=None):
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 69bcdca..9ea78fa 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -86,20 +86,20 @@ class TestBestTrackSelection(unittest.TestCase):
         assert len(best) == 10
 
         # TODO: nested items, no idea how to solve this...
-        assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]]
+        assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[3].tolist() == [1, 3, 5, 4]
 
         # test with a shorter set of rec_stages
         best2 = best_track(self.events.tracks, stages=[1, 3])
 
         assert len(best2) == 10
 
-        assert best2.rec_stages[0].tolist() == [[1, 3]]
-        assert best2.rec_stages[1].tolist() == [[1, 3]]
-        assert best2.rec_stages[2].tolist() == [[1, 3]]
-        assert best2.rec_stages[3].tolist() == [[1, 3]]
+        assert best2.rec_stages[0].tolist() == [1, 3]
+        assert best2.rec_stages[1].tolist() == [1, 3]
+        assert best2.rec_stages[2].tolist() == [1, 3]
+        assert best2.rec_stages[3].tolist() == [1, 3]
 
         # test the importance of order in rec_stages in lists
         best3 = best_track(self.events.tracks, stages=[3, 1])
@@ -119,10 +119,10 @@ class TestBestTrackSelection(unittest.TestCase):
         assert len(best) == 10
 
         # TODO: nested items, no idea how to solve this...
-        assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]]
+        assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[3].tolist() == [1, 3, 5, 4]
 
         # test with a shorter set of rec_stages
         best2 = best_track(self.events.tracks, stages={1, 3})
@@ -130,9 +130,8 @@ class TestBestTrackSelection(unittest.TestCase):
         assert len(best2) == 10
 
         for rec_stages in best2.rec_stages:
-            rs = rec_stages[0]  # nested
             for stage in {1, 3}:
-                assert stage in rs
+                assert stage in rec_stages
 
     def test_best_track_selection_from_multiple_events_with_start_end(self):
         best = best_track(self.events.tracks, startend=(1, 4))
@@ -140,20 +139,20 @@ class TestBestTrackSelection(unittest.TestCase):
         assert len(best) == 10
 
         # TODO: nested items, no idea how to solve this...
-        assert best.rec_stages[0].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[1].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[2].tolist() == [[1, 3, 5, 4]]
-        assert best.rec_stages[3].tolist() == [[1, 3, 5, 4]]
+        assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
+        assert best.rec_stages[3].tolist() == [1, 3, 5, 4]
 
         # test with shorter stages
         best2 = best_track(self.events.tracks, startend=(1, 3))
 
         assert len(best2) == 10
 
-        assert best2.rec_stages[0].tolist() == [[1, 3]]
-        assert best2.rec_stages[1].tolist() == [[1, 3]]
-        assert best2.rec_stages[2].tolist() == [[1, 3]]
-        assert best2.rec_stages[3].tolist() == [[1, 3]]
+        assert best2.rec_stages[0].tolist() == [1, 3]
+        assert best2.rec_stages[1].tolist() == [1, 3]
+        assert best2.rec_stages[2].tolist() == [1, 3]
+        assert best2.rec_stages[3].tolist() == [1, 3]
 
         # test the importance of start as a real start of rec_stages
         best3 = best_track(self.events.tracks, startend=(0, 3))
-- 
GitLab