From 6bdee7cbe1ca48e547b7b046cf08b3e2ea452c53 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Fri, 11 Dec 2020 09:17:28 +0100
Subject: [PATCH] Add iteration support over slices

---
 km3io/rootio.py       | 33 +++++++++++++++++++++++++++++----
 tests/test_offline.py | 16 ++++++++++++++++
 2 files changed, 45 insertions(+), 4 deletions(-)

diff --git a/km3io/rootio.py b/km3io/rootio.py
index f0b9e3f..46c599f 100644
--- a/km3io/rootio.py
+++ b/km3io/rootio.py
@@ -173,11 +173,34 @@ class EventReader:
         else:
             return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain)
 
-    def __iter__(self):
-        self._events = self._event_generator()
+    def __iter__(self, chunkwise=False):
+        self._events = self._event_generator(chunkwise=chunkwise)
         return self
 
-    def _event_generator(self):
+    def _get_iterator_limits(self):
+        """Determines start and stop, used for event iteration"""
+        if len(self._index_chain) > 1:
+            raise NotImplementedError("iteration is currently not supported with nested slices")
+        if self._index_chain:
+            s = self._index_chain[0]
+            if not isinstance(s, slice):
+                raise NotImplementedError("iteration is only supported with slices")
+            if s.step is None or s.step == 1:
+                start = s.start
+                stop = s.stop
+            else:
+                raise NotImplementedError("iteration is only supported with single steps")
+        else:
+            start = None
+            stop = None
+        return start, stop
+
+    def _event_generator(self, chunkwise=False):
+        start, stop = self._get_iterator_limits()
+
+        if chunkwise:
+            raise NotImplementedError("iterating over chunks is not implemented yet")
+
         events = self._fobj[self.event_path]
         group_count_keys = set(
             k for k in self.keys() if k.startswith("n_")
@@ -195,7 +218,7 @@ class EventReader:
         log.debug("keys: %s", keys)
         log.debug("aliases: %s", self.aliases)
         events_it = events.iterate(
-            keys, aliases=self.aliases, step_size=self._step_size
+            keys, aliases=self.aliases, step_size=self._step_size, entry_start=start, entry_stop=stop
         )
         nested = []
         nested_keys = (
@@ -208,6 +231,8 @@ class EventReader:
                     self.nested_branches[key].keys(),
                     aliases=self.nested_branches[key],
                     step_size=self._step_size,
+                    entry_start=start,
+                    entry_stop=stop
                 )
             )
         group_counts = {}
diff --git a/tests/test_offline.py b/tests/test_offline.py
index 99f6ef6..84d0cb6 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -219,6 +219,22 @@ class TestOfflineEvents(unittest.TestCase):
         n_hits = [len(e.hits.id) for e in self.events]
         assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist())
 
+    def test_iteration_over_slices(self):
+        ids = [e.id for e in self.events[2:5]]
+        self.assertListEqual([3, 4, 5], ids)
+
+    def test_iteration_over_slices_raises_when_stepsize_not_supported(self):
+        with self.assertRaises(NotImplementedError):
+            [e.id for e in self.events[2:8:2]]
+
+    def test_iteration_over_slices_raises_when_single_item(self):
+        with self.assertRaises(NotImplementedError):
+            [e.id for e in self.events[0]]
+
+    def test_iteration_over_slices_raises_when_multiple_slices(self):
+        with self.assertRaises(NotImplementedError):
+            [e.id for e in self.events[2:8][2:4]]
+
     def test_str(self):
         assert str(self.n_events) in str(self.events)
 
-- 
GitLab