diff --git a/km3io/tools.py b/km3io/tools.py index 0cf10f9370ea1ff6663b74ccb312348dafd3f0ce..e443357aa1ec72278b7e44bc9f14ccc4175d0728 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -55,6 +55,8 @@ class Branch: self._subbranches = [] self._subbranchmaps = subbranchmaps + self._iterator_index = 0 + if keymap is None: self._initialise_keys() # else: @@ -125,8 +127,22 @@ class Branch: self._branch[self._keymap['id']].lazyarray( basketcache=BASKET_CACHE), self._index_chain)) + def __iter__(self): + self._iterator_index = 0 + return self + + def __next__(self): + idx = self._iterator_index + self._iterator_index += 1 + if idx >= len(self): + raise StopIteration + return self[idx] + def __str__(self): - return "Number of elements: {}".format(len(self._branch)) + length = len(self) + return "{} ({}) with {} element{}".format(self.__class__.__name__, + self._mapper.name, length, + 's' if length > 1 else '') def __repr__(self): length = len(self) diff --git a/tests/test_offline.py b/tests/test_offline.py index dfd8e7f2382dc3889ec707b054bc13554fb606f7..2dd1ee7fe2f4e0a07fb30ae94919ea4fcb53e8e7 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -161,6 +161,16 @@ class TestOfflineEvents(unittest.TestCase): assert np.allclose(self.events.hits[3:5][1][4].dom_id, self.events[3:5][1][4].hits.dom_id) + def test_iteration(self): + i = 0 + for event in self.events: + i += 1 + assert 10 == i + + def test_iteration_2(self): + n_hits = [e.n_hits for e in self.events] + assert np.allclose(n_hits, self.events.n_hits) + def test_str(self): assert str(self.n_events) in str(self.events)