From 2a6f1b713ec1d7781ac2f3bfb7e5be4919228481 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Sat, 5 Dec 2020 10:22:06 +0100 Subject: [PATCH] Add iteration --- km3io/offline.py | 53 +++---------------------------------------- tests/test_offline.py | 4 ---- 2 files changed, 3 insertions(+), 54 deletions(-) diff --git a/km3io/offline.py b/km3io/offline.py index 9a5ac79..19cf5aa 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -216,60 +216,13 @@ class OfflineReader: return unfold_indices(out, self._index_chain) def __iter__(self): - self._iterator_index = 0 self._events = self._event_generator() return self def _event_generator(self): - events = self._fobj[self.event_path] - group_count_keys = set(k for k in self.keys() if k.startswith("n_")) # special keys to make it easy to count subbranch lengths - log.debug("group_count_keys: %s", group_count_keys) - keys = set( - list( - set(self.keys()) - - set(self.special_branches.keys()) - - set(self.special_aliases) - - group_count_keys - ) - + list(self.aliases.keys()) - ) # all top-level keys for regular branches - log.debug("keys: %s", keys) - log.debug("aliases: %s", self.aliases) - events_it = events.iterate( - keys, aliases=self.aliases, step_size=self._step_size - ) - specials = [] - special_keys = ( - self.special_branches.keys() - ) # dict-key ordering is an implementation detail - log.debug("special_keys: %s", special_keys) - for key in special_keys: - # print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}") - - specials.append( - events[key].iterate( - self.special_branches[key].keys(), - aliases=self.special_branches[key], - step_size=self._step_size, - ) - ) - group_counts = {} - for key in group_count_keys: - group_counts[key] = iter(self[key]) - - log.debug("group_counts: %s", group_counts) - for event_set, *special_sets in zip(events_it, *specials): - for _event, *special_items in zip(event_set, *special_sets): - data = {} - for k in keys: - data[k] = _event[k] - for (k, i) in zip(special_keys, special_items): - data[k] = i - for tokey, fromkey in self.special_aliases.items(): - data[tokey] = data[fromkey] - for key in group_counts: - data[key] = next(group_counts[key]) - yield self._event_ctor(**data) + for i in range(len(self)): + yield self[i] + return def __next__(self): return next(self._events) diff --git a/tests/test_offline.py b/tests/test_offline.py index 592b741..eba3cce 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -214,14 +214,12 @@ class TestOfflineEvents(unittest.TestCase): assert 8 == len(first_tracks.rec_stages) assert 8 == len(first_tracks.lik) - @unittest.skip def test_iteration(self): i = 0 for event in self.events: i += 1 assert 10 == i - @unittest.skip def test_iteration_2(self): n_hits = [len(e.hits.id) for e in self.events] assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist()) @@ -494,7 +492,6 @@ class TestMcTrackUsr(unittest.TestCase): def setUp(self): self.f = OFFLINE_MC_TRACK_USR - @unittest.skip def test_usr_names(self): n_tracks = len(self.f.events) for i in range(3): @@ -507,7 +504,6 @@ class TestMcTrackUsr(unittest.TestCase): self.f.events.mc_tracks.usr_names[i][1].tolist(), ) - @unittest.skip def test_usr(self): assert np.allclose( [0.0487, 0.0588, 3, 2], -- GitLab