From f85a6cff691d0f999dc15a6b9a3bbcc21018e446 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Fri, 6 Mar 2020 17:18:04 +0100
Subject: [PATCH] Fix indexing

---
 km3io/offline.py      | 45 ++++++++++++++++++++++++++++++++-----------
 tests/test_offline.py | 37 ++++++++++++++++++++++++++++++++++-
 2 files changed, 70 insertions(+), 12 deletions(-)

diff --git a/km3io/offline.py b/km3io/offline.py
index 4368ee9..7aa8708 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -594,21 +594,35 @@ class Branch:
         if isinstance(item, slice):
             return self.__class__(self._tree, self._mapper, index=item)
         if isinstance(item, int):
+            # A bit ugly, but whatever works
             if self._mapper.flat:
-                return BranchElement(
-                    self._mapper.name, {
+                if self._index is None:
+                    dct = {
+                        key: self._branch[self._keymap[key]].array()
+                        for key in self.keys()
+                    }
+                else:
+                    dct = {
                         key:
                         self._branch[self._keymap[key]].array()[self._index]
                         for key in self.keys()
-                    })[item]
+                    }
+                return BranchElement(self._mapper.name, dct)[item]
             else:
-                return BranchElement(
-                    self._mapper.name, {
+                if self._index is None:
+                    dct = {
+                        key: self._branch[self._keymap[key]].array()[item]
+                        for key in self.keys()
+                    }
+                else:
+                    dct = {
                         key:
                         self._branch[self._keymap[key]].array()[self._index,
                                                                 item]
                         for key in self.keys()
-                    })
+                    }
+                return BranchElement(self._mapper.name, dct)
+
         if isinstance(item, tuple):
             return self[item[0]][item[1]]
 
@@ -657,16 +671,25 @@ class BranchElement:
         self._name = name
         self._index = index
         self.ItemConstructor = namedtuple(self._name[:-1], dct.keys())
-        for key, values in dct.items():
-            setattr(self, key, values[index])
+        if index is None:
+            for key, values in dct.items():
+                setattr(self, key, values)
+        else:
+            for key, values in dct.items():
+                setattr(self, key, values[index])
 
     def __getitem__(self, item):
         if isinstance(item, slice):
             return self.__class__(self._name, self._dct, index=item)
         if isinstance(item, int):
-            return self.ItemConstructor(
-                **{k: v[self._index][item]
-                   for k, v in self._dct.items()})
+            if self._index is None:
+                return self.ItemConstructor(
+                    **{k: v[item]
+                    for k, v in self._dct.items()})
+            else:
+                return self.ItemConstructor(
+                    **{k: v[self._index][item]
+                    for k, v in self._dct.items()})
 
     def __repr__(self):
         return "<{}[{}]>".format(self.__class__.__name__, self._name)
diff --git a/tests/test_offline.py b/tests/test_offline.py
index fb8ed19..fee41d7 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -182,6 +182,18 @@ class TestOfflineEvents(unittest.TestCase):
         self.assertListEqual(self.t_sec[s], list(s_events.t_sec))
         self.assertListEqual(self.t_ns[s], list(s_events.t_ns))
 
+    def test_slicing_consistency(self):
+        for s in [slice(1, 3), slice(2, 7, 3)]:
+            assert np.allclose(OFFLINE_FILE[s].events.n_hits,
+                               self.events.n_hits[s])
+            assert np.allclose(self.events[s].n_hits, self.events.n_hits[s])
+
+    def test_index_consistency(self):
+        for i in range(self.n_events):
+            assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
+            assert np.allclose(OFFLINE_FILE[i].events.n_hits,
+                               self.events.n_hits[i])
+
     def test_str(self):
         assert str(self.n_events) in str(self.events)
 
@@ -235,9 +247,32 @@ class TestOfflineHits(unittest.TestCase):
         for idx, t in self.t.items():
             assert np.allclose(t, self.hits.t[idx][:len(t)])
 
+    def test_slicing(self):
+        s = slice(2, 8, 2)
+        s_hits = self.hits[s]
+        assert 3 == len(s_hits)
+        for idx, dom_id in self.dom_id.items():
+            self.assertListEqual(dom_id[s], list(self.hits.dom_id[idx][s]))
+        for idx, t in self.t.items():
+            self.assertListEqual(t[s], list(self.hits.t[idx][s]))
+
+    def test_slicing_consistency(self):
+        for s in [slice(1, 3), slice(2, 7, 3)]:
+            for idx in range(3):
+                assert np.allclose(self.hits.dom_id[idx][s],
+                                   self.hits[idx].dom_id[s])
+                assert np.allclose(OFFLINE_FILE[idx].hits.dom_id[s],
+                                   self.hits.dom_id[idx][s])
 
-class TestOfflineTracks(unittest.TestCase):
     @unittest.skip
+    def test_index_consistency(self):
+        for i in range(self.n_events):
+            assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
+            assert np.allclose(OFFLINE_FILE[i].events.n_hits,
+                               self.events.n_hits[i])
+
+
+class TestOfflineTracks(unittest.TestCase):
     def setUp(self):
         self.tracks = OFFLINE_FILE.tracks
         self.r_mc = OFFLINE_NUMUCC
-- 
GitLab