Skip to content
Snippets Groups Projects
Commit 6bdee7cb authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Add iteration support over slices

parent f1af357f
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16336 failed
......@@ -173,11 +173,34 @@ class EventReader:
else:
return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain)
def __iter__(self):
self._events = self._event_generator()
def __iter__(self, chunkwise=False):
self._events = self._event_generator(chunkwise=chunkwise)
return self
def _event_generator(self):
def _get_iterator_limits(self):
"""Determines start and stop, used for event iteration"""
if len(self._index_chain) > 1:
raise NotImplementedError("iteration is currently not supported with nested slices")
if self._index_chain:
s = self._index_chain[0]
if not isinstance(s, slice):
raise NotImplementedError("iteration is only supported with slices")
if s.step is None or s.step == 1:
start = s.start
stop = s.stop
else:
raise NotImplementedError("iteration is only supported with single steps")
else:
start = None
stop = None
return start, stop
def _event_generator(self, chunkwise=False):
start, stop = self._get_iterator_limits()
if chunkwise:
raise NotImplementedError("iterating over chunks is not implemented yet")
events = self._fobj[self.event_path]
group_count_keys = set(
k for k in self.keys() if k.startswith("n_")
......@@ -195,7 +218,7 @@ class EventReader:
log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
keys, aliases=self.aliases, step_size=self._step_size, entry_start=start, entry_stop=stop
)
nested = []
nested_keys = (
......@@ -208,6 +231,8 @@ class EventReader:
self.nested_branches[key].keys(),
aliases=self.nested_branches[key],
step_size=self._step_size,
entry_start=start,
entry_stop=stop
)
)
group_counts = {}
......
......@@ -219,6 +219,22 @@ class TestOfflineEvents(unittest.TestCase):
n_hits = [len(e.hits.id) for e in self.events]
assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist())
def test_iteration_over_slices(self):
ids = [e.id for e in self.events[2:5]]
self.assertListEqual([3, 4, 5], ids)
def test_iteration_over_slices_raises_when_stepsize_not_supported(self):
with self.assertRaises(NotImplementedError):
[e.id for e in self.events[2:8:2]]
def test_iteration_over_slices_raises_when_single_item(self):
with self.assertRaises(NotImplementedError):
[e.id for e in self.events[0]]
def test_iteration_over_slices_raises_when_multiple_slices(self):
with self.assertRaises(NotImplementedError):
[e.id for e in self.events[2:8][2:4]]
def test_str(self):
assert str(self.n_events) in str(self.events)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment