From 6bdee7cbe1ca48e547b7b046cf08b3e2ea452c53 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Fri, 11 Dec 2020 09:17:28 +0100 Subject: [PATCH] Add iteration support over slices --- km3io/rootio.py | 33 +++++++++++++++++++++++++++++---- tests/test_offline.py | 16 ++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/km3io/rootio.py b/km3io/rootio.py index f0b9e3f..46c599f 100644 --- a/km3io/rootio.py +++ b/km3io/rootio.py @@ -173,11 +173,34 @@ class EventReader: else: return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain) - def __iter__(self): - self._events = self._event_generator() + def __iter__(self, chunkwise=False): + self._events = self._event_generator(chunkwise=chunkwise) return self - def _event_generator(self): + def _get_iterator_limits(self): + """Determines start and stop, used for event iteration""" + if len(self._index_chain) > 1: + raise NotImplementedError("iteration is currently not supported with nested slices") + if self._index_chain: + s = self._index_chain[0] + if not isinstance(s, slice): + raise NotImplementedError("iteration is only supported with slices") + if s.step is None or s.step == 1: + start = s.start + stop = s.stop + else: + raise NotImplementedError("iteration is only supported with single steps") + else: + start = None + stop = None + return start, stop + + def _event_generator(self, chunkwise=False): + start, stop = self._get_iterator_limits() + + if chunkwise: + raise NotImplementedError("iterating over chunks is not implemented yet") + events = self._fobj[self.event_path] group_count_keys = set( k for k in self.keys() if k.startswith("n_") @@ -195,7 +218,7 @@ class EventReader: log.debug("keys: %s", keys) log.debug("aliases: %s", self.aliases) events_it = events.iterate( - keys, aliases=self.aliases, step_size=self._step_size + keys, aliases=self.aliases, step_size=self._step_size, entry_start=start, entry_stop=stop ) nested = [] nested_keys = ( @@ -208,6 +231,8 @@ class EventReader: self.nested_branches[key].keys(), aliases=self.nested_branches[key], step_size=self._step_size, + entry_start=start, + entry_stop=stop ) ) group_counts = {} diff --git a/tests/test_offline.py b/tests/test_offline.py index 99f6ef6..84d0cb6 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -219,6 +219,22 @@ class TestOfflineEvents(unittest.TestCase): n_hits = [len(e.hits.id) for e in self.events] assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist()) + def test_iteration_over_slices(self): + ids = [e.id for e in self.events[2:5]] + self.assertListEqual([3, 4, 5], ids) + + def test_iteration_over_slices_raises_when_stepsize_not_supported(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[2:8:2]] + + def test_iteration_over_slices_raises_when_single_item(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[0]] + + def test_iteration_over_slices_raises_when_multiple_slices(self): + with self.assertRaises(NotImplementedError): + [e.id for e in self.events[2:8][2:4]] + def test_str(self): assert str(self.n_events) in str(self.events) -- GitLab