From 96d4987633ee7305596af91734a7a644809947bd Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Wed, 1 Apr 2020 15:19:06 +0200 Subject: [PATCH] Use index chain --- km3io/offline.py | 21 +++++++++------------ km3io/tools.py | 34 +++++++++++++++++----------------- tests/test_offline.py | 10 +++++++--- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/km3io/offline.py b/km3io/offline.py index d3f66fb..82cf5f4 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -2,7 +2,7 @@ from collections import namedtuple import uproot import warnings from .definitions import mc_header -from .tools import Branch, BranchMapper, cached_property, _to_num +from .tools import Branch, BranchMapper, cached_property, _to_num, _unfold_indices MAIN_TREE_NAME = "E" EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"] @@ -79,15 +79,15 @@ SUBBRANCH_MAPS = [ class OfflineBranch(Branch): @cached_property def usr(self): - return Usr(self._mapper, self._branch, index=self._index) + return Usr(self._mapper, self._branch, index_chain=self._index_chain) class Usr: """Helper class to access AAObject `usr` stuff""" - def __init__(self, mapper, branch, index=None): + def __init__(self, mapper, branch, index_chain=None): self._mapper = mapper self._name = mapper.name - self._index = index + self._index_chain = [] if index_chain is None else index_chain self._branch = branch self._usr_names = [] self._usr_idx_lookup = {} @@ -125,8 +125,8 @@ class Usr: data = self._branch[self._usr_key].lazyarray() - if self._index is not None: - data = data[self._index] + if self._index_chain: + data = _unfold_indices(data, self._index_chain) self._usr_data = data @@ -150,8 +150,8 @@ class Usr: return self.__getitem_nested__(item) def __getitem_flat__(self, item): - if self._index is not None: - return self._usr_data[self._index][:, self._usr_idx_lookup[item]] + if self._index_chain: + return _unfold_indices(self._usr_data, self._index_chain)[:, self._usr_idx_lookup[item]] else: return self._usr_data[:, self._usr_idx_lookup[item]] @@ -163,10 +163,7 @@ class Usr: uproot.SimpleArray(uproot.STLVector(uproot.STLString())), self._branch[self._usr_key + '_names']._context, 6), basketcache=BASKET_CACHE) - if self._index is None: - return data - else: - return data[self._index] + return _unfold_indices(data, self._index_chain) def keys(self): return self._usr_names diff --git a/km3io/tools.py b/km3io/tools.py index 418aafb..c44af93 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -21,12 +21,14 @@ class cached_property: def _unfold_indices(obj, indices): """Unfolds an index chain and returns the corresponding item""" + original_obj = obj for depth, idx in enumerate(indices): try: obj = obj[idx] except IndexError: - print("IndexError while accessing item '{}' at depth {} ({}) of " - "the index chain {}".format(repr(obj), depth, idx, indices)) + print("IndexError while accessing an item from '{}' at depth {} ({}) " + "using the index chain {}" + .format(repr(original_obj), depth, idx, indices)) raise return obj @@ -41,12 +43,12 @@ class Branch: def __init__(self, tree, mapper, - index=None, + index_chain=None, subbranchmaps=None, keymap=None): self._tree = tree self._mapper = mapper - self._index = index + self._index_chain = [] if index_chain is None else index_chain self._keymap = None self._branch = tree[mapper.key] self._subbranches = [] @@ -61,7 +63,7 @@ class Branch: for mapper in subbranchmaps: subbranch = self.__class__(self._tree, mapper=mapper, - index=self._index) + index_chain=self._index_chain) self._subbranches.append(subbranch) for subbranch in self._subbranches: setattr(self, subbranch._mapper.name, subbranch) @@ -98,39 +100,37 @@ class Branch: def __getkey__(self, key): out = self._branch[self._keymap[key]].lazyarray( basketcache=BASKET_CACHE) - if self._index is not None: - out = out[self._index] - return out + return _unfold_indices(out, self._index_chain) def __getitem__(self, item): """Slicing magic""" - if isinstance(item, (int, slice)): + if isinstance(item, (int, slice, tuple)): return self.__class__(self._tree, self._mapper, - index=item, + index_chain=self._index_chain + [item], keymap=self._keymap, subbranchmaps=self._subbranchmaps) - if isinstance(item, tuple): - return self[item[0]][item[1]] + # if isinstance(item, tuple): + # return self[item[0]][item[1]] if isinstance(item, str): return self.__getkey__(item) return self.__class__(self._tree, self._mapper, - index=np.array(item), + index_chain=self._index_chain + [np.array(item)], keymap=self._keymap, subbranchmaps=self._subbranchmaps) def __len__(self): - if self._index is None: + if not self._index_chain: return len(self._branch) - elif isinstance(self._index, int): + elif isinstance(self._index_chain[-1], int): return 1 else: - return len(self._branch[self._keymap['id']].lazyarray( - basketcache=BASKET_CACHE)[self._index]) + return len(_unfold_indices(self._branch[self._keymap['id']].lazyarray( + basketcache=BASKET_CACHE), self._index_chain)) def __str__(self): return "Number of elements: {}".format(len(self._branch)) diff --git a/tests/test_offline.py b/tests/test_offline.py index 8e1f685..ef5c798 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -152,6 +152,10 @@ class TestOfflineEvents(unittest.TestCase): for i in [0, 2, 5]: assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) + def test_index_chaining(self): + assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5]) + assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]) + def test_str(self): assert str(self.n_events) in str(self.events) @@ -295,9 +299,9 @@ class TestBranchIndexingMagic(unittest.TestCase): assert np.allclose(self.events[3:6].tracks.pos_y[:, 0], self.events.tracks.pos_y[3:6, 0]) - # test slicing with a tuple - assert np.allclose(self.events[0].hits[1].dom_id[0:10], - self.events.hits[(0, 1)].dom_id[0:10]) + # # test slicing with a tuple + # assert np.allclose(self.events[0].hits[1].dom_id[0:10], + # self.events.hits[(0, 1)].dom_id[0:10]) # test selecting with a list self.assertEqual(3, len(self.events[[0, 2, 3]])) -- GitLab