diff --git a/km3io/rootio.py b/km3io/rootio.py index f0b9e3f8aa9a9a74239c0a67d89c161d83a23e5e..46c599f4ae86aa2277b8b567c39dba71b829d5d9 100644 --- a/km3io/rootio.py +++ b/km3io/rootio.py @@ -173,11 +173,34 @@ class EventReader: else: return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain) - def __iter__(self): - self._events = self._event_generator() + def __iter__(self, chunkwise=False): + self._events = self._event_generator(chunkwise=chunkwise) return self - def _event_generator(self): + def _get_iterator_limits(self): + """Determines start and stop, used for event iteration""" + if len(self._index_chain) > 1: + raise NotImplementedError("iteration is currently not supported with nested slices") + if self._index_chain: + s = self._index_chain[0] + if not isinstance(s, slice): + raise NotImplementedError("iteration is only supported with slices") + if s.step is None or s.step == 1: + start = s.start + stop = s.stop + else: + raise NotImplementedError("iteration is only supported with single steps") + else: + start = None + stop = None + return start, stop + + def _event_generator(self, chunkwise=False): + start, stop = self._get_iterator_limits() + + if chunkwise: + raise NotImplementedError("iterating over chunks is not implemented yet") + events = self._fobj[self.event_path] group_count_keys = set( k for k in self.keys() if k.startswith("n_") @@ -195,7 +218,7 @@ class EventReader: log.debug("keys: %s", keys) log.debug("aliases: %s", self.aliases) events_it = events.iterate( - keys, aliases=self.aliases, step_size=self._step_size + keys, aliases=self.aliases, step_size=self._step_size, entry_start=start, entry_stop=stop ) nested = [] nested_keys = ( @@ -208,6 +231,8 @@ class EventReader: self.nested_branches[key].keys(), aliases=self.nested_branches[key], step_size=self._step_size, + entry_start=start, + entry_stop=stop ) ) group_counts = {} diff --git a/tests/test_offline.py b/tests/test_offline.py index 99f6ef64823e8a0aa8e9d2eb2a85732b658d386a..84d0cb6338d70b75e5d6e5126c35ea3ad4ab052c 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -219,6 +219,22 @@ class TestOfflineEvents(unittest.TestCase): n_hits = [len(e.hits.id) for e in self.events] assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist()) + def test_iteration_over_slices(self): + ids = [e.id for e in self.events[2:5]] + self.assertListEqual([3, 4, 5], ids) + + def test_iteration_over_slices_raises_when_stepsize_not_supported(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[2:8:2]] + + def test_iteration_over_slices_raises_when_single_item(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[0]] + + def test_iteration_over_slices_raises_when_multiple_slices(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[2:8][2:4]] + def test_str(self): assert str(self.n_events) in str(self.events)