From 20819a7de2045e94d198202f69997a65d5d45dee Mon Sep 17 00:00:00 2001
From: Tamas Gal <himself@tamasgal.com>
Date: Thu, 10 Feb 2022 10:45:56 +0100
Subject: [PATCH] Add getitem to SummarysliceReader

---
 km3io/online.py      | 25 ++++++++++++++++++++++++-
 tests/test_online.py | 40 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/km3io/online.py b/km3io/online.py
index 438e845..ad080f9 100644
--- a/km3io/online.py
+++ b/km3io/online.py
@@ -96,6 +96,22 @@ class SummarysliceReader:
                 *[getattr(chunk, bc.branch_address) for bc in self._subbranches]
             )
 
+    def __getitem__(self, idx):
+        if idx >= len(self) or idx < -len(self):
+            raise IndexError("Chunk index out of range")
+
+        s = self._step_size
+
+        if idx < 0:
+            idx = len(self) + idx
+
+        chunk = self._branch.arrays(
+            dict(self._subbranches), entry_start=idx * s, entry_stop=(idx + 1) * s
+        )
+        return self.ChunksConstructor(
+            *[getattr(chunk, bc.branch_address) for bc in self._subbranches]
+        )
+
     def __iter__(self):
         self._chunks = self._chunks_generator()
         return self
@@ -107,7 +123,14 @@ class SummarysliceReader:
         return int(np.ceil(self._branch.num_entries / self._step_size))
 
     def __repr__(self):
-        return f"<{self.__class__.__name__} {self._branch.num_entries} items, step_size={self._step_size} ({len(self)} chunks)>"
+        step_size = self._step_size
+        n_items = self._branch.num_entries
+        cls_name = self.__class__.__name__
+        n_chunks = len(self)
+        return (
+            f"<{cls_name} {n_items} items, step_size={step_size} "
+            f"({n_chunks} chunk{'' if n_chunks == 1 else 's'})>"
+        )
 
 
 @nb.vectorize(
diff --git a/tests/test_online.py b/tests/test_online.py
index 09fb9ac..e7efc6e 100644
--- a/tests/test_online.py
+++ b/tests/test_online.py
@@ -749,6 +749,46 @@ class TestSummarysliceReader(unittest.TestCase):
         sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=3)
         assert 1 == len(sr)
 
+    def test_getitem_raises_when_out_of_range(self):
+        sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=1)
+        with self.assertRaises(IndexError):
+            sr[123]
+        with self.assertRaises(IndexError):
+            sr[-123]
+        with self.assertRaises(IndexError):
+            sr[3]
+        sr[-3]  # this should still work, gives the first element in this case
+        with self.assertRaises(IndexError):
+            sr[-4]
+
+    def test_getitem(self):
+        sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=1)
+        for idx in range(len(sr)):
+            assert len(sr[idx].headers) == 1
+            assert len(sr[idx].slices) == 1
+
+        first_frame_index = sr[0].headers.frame_index  # 126
+        last_frame_index = sr[2].headers.frame_index  # 128
+
+        assert 126 == first_frame_index
+        assert 128 == last_frame_index
+
+        sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=2)
+        assert len(sr[0].headers) == 2
+        assert len(sr[0].slices) == 2
+        assert len(sr[1].headers) == 1
+        assert len(sr[1].slices) == 1
+        with self.assertRaises(IndexError):
+            assert len(sr[2].headers) == 0
+            assert len(sr[2].slices) == 0
+
+        assert first_frame_index == sr[0].headers[0].frame_index
+        assert last_frame_index == sr[1].headers[0].frame_index
+
+
+        assert last_frame_index == sr[-1].headers[0].frame_index
+        assert first_frame_index == sr[-2].headers[0].frame_index
+
     def test_iterate_with_step_size_one(self):
         sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=1)
         i = 0
-- 
GitLab