diff --git a/km3io/offline.py b/km3io/offline.py index 82cf5f47b03194c27d0cbbcfc699653754a89f1f..14043bc88432082b4e7dcf3dfe741ae71e617f96 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -151,7 +151,9 @@ class Usr: def __getitem_flat__(self, item): if self._index_chain: - return _unfold_indices(self._usr_data, self._index_chain)[:, self._usr_idx_lookup[item]] + return _unfold_indices( + self._usr_data, self._index_chain)[:, + self._usr_idx_lookup[item]] else: return self._usr_data[:, self._usr_idx_lookup[item]] @@ -197,8 +199,8 @@ class OfflineReader: def events(self): """The `E` branch, containing all offline events.""" return OfflineBranch(self._tree, - mapper=EVENTS_MAP, - subbranchmaps=SUBBRANCH_MAPS) + mapper=EVENTS_MAP, + subbranchmaps=SUBBRANCH_MAPS) @cached_property def header(self): diff --git a/km3io/tools.py b/km3io/tools.py index c44af93de281861cdf5a6f532225225f0959d496..6775e1ee8ebc24e44191196e113ab1478fe467fd 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -26,9 +26,10 @@ def _unfold_indices(obj, indices): try: obj = obj[idx] except IndexError: - print("IndexError while accessing an item from '{}' at depth {} ({}) " - "using the index chain {}" - .format(repr(original_obj), depth, idx, indices)) + print( + "IndexError while accessing an item from '{}' at depth {} ({}) " + "using the index chain {}".format(repr(original_obj), depth, + idx, indices)) raise return obj @@ -62,7 +63,7 @@ class Branch: if subbranchmaps is not None: for mapper in subbranchmaps: subbranch = self.__class__(self._tree, - mapper=mapper, + mapper=mapper, index_chain=self._index_chain) self._subbranches.append(subbranch) for subbranch in self._subbranches: @@ -129,8 +130,10 @@ class Branch: elif isinstance(self._index_chain[-1], int): return 1 else: - return len(_unfold_indices(self._branch[self._keymap['id']].lazyarray( - basketcache=BASKET_CACHE), self._index_chain)) + return len( + _unfold_indices( + self._branch[self._keymap['id']].lazyarray( + basketcache=BASKET_CACHE), self._index_chain)) def __str__(self): return "Number of elements: {}".format(len(self._branch)) diff --git a/tests/test_offline.py b/tests/test_offline.py index ef5c798dd754181915729b8ece8e42412fe60127..4f47a4917e3cd77104dfaed1f4f2ebc058702cc1 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -154,7 +154,8 @@ class TestOfflineEvents(unittest.TestCase): def test_index_chaining(self): assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5]) - assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]) + assert np.allclose(self.events[3:5][0].n_hits, + self.events.n_hits[3:5][0]) def test_str(self): assert str(self.n_events) in str(self.events) @@ -344,7 +345,6 @@ class TestUsr(unittest.TestCase): self.f.events.usr.DeltaPosZ) - class TestNestedMapper(unittest.TestCase): def test_nested_mapper(self): self.assertEqual('pos_x', _nested_mapper("trks.pos.x")) diff --git a/tests/test_tools.py b/tests/test_tools.py index 3c4b6559e0158af6414111b94ca3cce717f2b6bb..857f9c84a7939f0676560afc29cb74bc41d0c85a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -30,7 +30,8 @@ class TestUnfoldIndices(unittest.TestCase): assert data[indices[0]][indices[1]] == _unfold_indices(data, indices) indices = [slice(1, 9, 2), slice(1, 4), 2] - assert data[indices[0]][indices[1]][indices[2]] == _unfold_indices(data, indices) + assert data[indices[0]][indices[1]][indices[2]] == _unfold_indices( + data, indices) def test_unfold_indices_raises_index_error(self): data = range(10)