From 2a6f1b713ec1d7781ac2f3bfb7e5be4919228481 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Sat, 5 Dec 2020 10:22:06 +0100
Subject: [PATCH] Add iteration

---
 km3io/offline.py      | 53 +++----------------------------------------
 tests/test_offline.py |  4 ----
 2 files changed, 3 insertions(+), 54 deletions(-)

diff --git a/km3io/offline.py b/km3io/offline.py
index 9a5ac79..19cf5aa 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -216,60 +216,13 @@ class OfflineReader:
         return unfold_indices(out, self._index_chain)
 
     def __iter__(self):
-        self._iterator_index = 0
         self._events = self._event_generator()
         return self
 
     def _event_generator(self):
-        events = self._fobj[self.event_path]
-        group_count_keys = set(k for k in self.keys() if k.startswith("n_"))  # special keys to make it easy to count subbranch lengths
-        log.debug("group_count_keys: %s", group_count_keys)
-        keys = set(
-            list(
-                set(self.keys())
-                - set(self.special_branches.keys())
-                - set(self.special_aliases)
-                - group_count_keys
-            )
-            + list(self.aliases.keys())
-        )  # all top-level keys for regular branches
-        log.debug("keys: %s", keys)
-        log.debug("aliases: %s", self.aliases)
-        events_it = events.iterate(
-            keys, aliases=self.aliases, step_size=self._step_size
-        )
-        specials = []
-        special_keys = (
-            self.special_branches.keys()
-        )  # dict-key ordering is an implementation detail
-        log.debug("special_keys: %s", special_keys)
-        for key in special_keys:
-            # print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}")
-
-            specials.append(
-                events[key].iterate(
-                    self.special_branches[key].keys(),
-                    aliases=self.special_branches[key],
-                    step_size=self._step_size,
-                )
-            )
-        group_counts = {}
-        for key in group_count_keys:
-            group_counts[key] = iter(self[key])
-
-        log.debug("group_counts: %s", group_counts)
-        for event_set, *special_sets in zip(events_it, *specials):
-            for _event, *special_items in zip(event_set, *special_sets):
-                data = {}
-                for k in keys:
-                    data[k] = _event[k]
-                for (k, i) in zip(special_keys, special_items):
-                    data[k] = i
-                for tokey, fromkey in self.special_aliases.items():
-                    data[tokey] = data[fromkey]
-                for key in group_counts:
-                    data[key] = next(group_counts[key])
-                yield self._event_ctor(**data)
+        for i in range(len(self)):
+            yield self[i]
+        return
 
     def __next__(self):
         return next(self._events)
diff --git a/tests/test_offline.py b/tests/test_offline.py
index 592b741..eba3cce 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -214,14 +214,12 @@ class TestOfflineEvents(unittest.TestCase):
         assert 8 == len(first_tracks.rec_stages)
         assert 8 == len(first_tracks.lik)
 
-    @unittest.skip
     def test_iteration(self):
         i = 0
         for event in self.events:
             i += 1
         assert 10 == i
 
-    @unittest.skip
     def test_iteration_2(self):
         n_hits = [len(e.hits.id) for e in self.events]
         assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist())
@@ -494,7 +492,6 @@ class TestMcTrackUsr(unittest.TestCase):
     def setUp(self):
         self.f = OFFLINE_MC_TRACK_USR
 
-    @unittest.skip
     def test_usr_names(self):
         n_tracks = len(self.f.events)
         for i in range(3):
@@ -507,7 +504,6 @@ class TestMcTrackUsr(unittest.TestCase):
                 self.f.events.mc_tracks.usr_names[i][1].tolist(),
             )
 
-    @unittest.skip
     def test_usr(self):
         assert np.allclose(
             [0.0487, 0.0588, 3, 2],
-- 
GitLab