Skip to content
Snippets Groups Projects
Commit 2a6f1b71 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Add iteration

parent ecf86b18
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16219 failed
......@@ -216,60 +216,13 @@ class OfflineReader:
return unfold_indices(out, self._index_chain)
def __iter__(self):
self._iterator_index = 0
self._events = self._event_generator()
return self
def _event_generator(self):
events = self._fobj[self.event_path]
group_count_keys = set(k for k in self.keys() if k.startswith("n_")) # special keys to make it easy to count subbranch lengths
log.debug("group_count_keys: %s", group_count_keys)
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
) # all top-level keys for regular branches
log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
) # dict-key ordering is an implementation detail
log.debug("special_keys: %s", special_keys)
for key in special_keys:
# print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}")
specials.append(
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self._step_size,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
log.debug("group_counts: %s", group_counts)
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
for i in range(len(self)):
yield self[i]
return
def __next__(self):
return next(self._events)
......
......@@ -214,14 +214,12 @@ class TestOfflineEvents(unittest.TestCase):
assert 8 == len(first_tracks.rec_stages)
assert 8 == len(first_tracks.lik)
@unittest.skip
def test_iteration(self):
i = 0
for event in self.events:
i += 1
assert 10 == i
@unittest.skip
def test_iteration_2(self):
n_hits = [len(e.hits.id) for e in self.events]
assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist())
......@@ -494,7 +492,6 @@ class TestMcTrackUsr(unittest.TestCase):
def setUp(self):
self.f = OFFLINE_MC_TRACK_USR
@unittest.skip
def test_usr_names(self):
n_tracks = len(self.f.events)
for i in range(3):
......@@ -507,7 +504,6 @@ class TestMcTrackUsr(unittest.TestCase):
self.f.events.mc_tracks.usr_names[i][1].tolist(),
)
@unittest.skip
def test_usr(self):
assert np.allclose(
[0.0487, 0.0588, 3, 2],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment