Skip to content
Snippets Groups Projects
Commit 96d49876 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Use index chain

parent 34346f84
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
......@@ -2,7 +2,7 @@ from collections import namedtuple
import uproot
import warnings
from .definitions import mc_header
from .tools import Branch, BranchMapper, cached_property, _to_num
from .tools import Branch, BranchMapper, cached_property, _to_num, _unfold_indices
MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
......@@ -79,15 +79,15 @@ SUBBRANCH_MAPS = [
class OfflineBranch(Branch):
@cached_property
def usr(self):
return Usr(self._mapper, self._branch, index=self._index)
return Usr(self._mapper, self._branch, index_chain=self._index_chain)
class Usr:
"""Helper class to access AAObject `usr` stuff"""
def __init__(self, mapper, branch, index=None):
def __init__(self, mapper, branch, index_chain=None):
self._mapper = mapper
self._name = mapper.name
self._index = index
self._index_chain = [] if index_chain is None else index_chain
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
......@@ -125,8 +125,8 @@ class Usr:
data = self._branch[self._usr_key].lazyarray()
if self._index is not None:
data = data[self._index]
if self._index_chain:
data = _unfold_indices(data, self._index_chain)
self._usr_data = data
......@@ -150,8 +150,8 @@ class Usr:
return self.__getitem_nested__(item)
def __getitem_flat__(self, item):
if self._index is not None:
return self._usr_data[self._index][:, self._usr_idx_lookup[item]]
if self._index_chain:
return _unfold_indices(self._usr_data, self._index_chain)[:, self._usr_idx_lookup[item]]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
......@@ -163,10 +163,7 @@ class Usr:
uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self._branch[self._usr_key + '_names']._context, 6),
basketcache=BASKET_CACHE)
if self._index is None:
return data
else:
return data[self._index]
return _unfold_indices(data, self._index_chain)
def keys(self):
return self._usr_names
......
......@@ -21,12 +21,14 @@ class cached_property:
def _unfold_indices(obj, indices):
"""Unfolds an index chain and returns the corresponding item"""
original_obj = obj
for depth, idx in enumerate(indices):
try:
obj = obj[idx]
except IndexError:
print("IndexError while accessing item '{}' at depth {} ({}) of "
"the index chain {}".format(repr(obj), depth, idx, indices))
print("IndexError while accessing an item from '{}' at depth {} ({}) "
"using the index chain {}"
.format(repr(original_obj), depth, idx, indices))
raise
return obj
......@@ -41,12 +43,12 @@ class Branch:
def __init__(self,
tree,
mapper,
index=None,
index_chain=None,
subbranchmaps=None,
keymap=None):
self._tree = tree
self._mapper = mapper
self._index = index
self._index_chain = [] if index_chain is None else index_chain
self._keymap = None
self._branch = tree[mapper.key]
self._subbranches = []
......@@ -61,7 +63,7 @@ class Branch:
for mapper in subbranchmaps:
subbranch = self.__class__(self._tree,
mapper=mapper,
index=self._index)
index_chain=self._index_chain)
self._subbranches.append(subbranch)
for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch)
......@@ -98,39 +100,37 @@ class Branch:
def __getkey__(self, key):
out = self._branch[self._keymap[key]].lazyarray(
basketcache=BASKET_CACHE)
if self._index is not None:
out = out[self._index]
return out
return _unfold_indices(out, self._index_chain)
def __getitem__(self, item):
"""Slicing magic"""
if isinstance(item, (int, slice)):
if isinstance(item, (int, slice, tuple)):
return self.__class__(self._tree,
self._mapper,
index=item,
index_chain=self._index_chain + [item],
keymap=self._keymap,
subbranchmaps=self._subbranchmaps)
if isinstance(item, tuple):
return self[item[0]][item[1]]
# if isinstance(item, tuple):
# return self[item[0]][item[1]]
if isinstance(item, str):
return self.__getkey__(item)
return self.__class__(self._tree,
self._mapper,
index=np.array(item),
index_chain=self._index_chain + [np.array(item)],
keymap=self._keymap,
subbranchmaps=self._subbranchmaps)
def __len__(self):
if self._index is None:
if not self._index_chain:
return len(self._branch)
elif isinstance(self._index, int):
elif isinstance(self._index_chain[-1], int):
return 1
else:
return len(self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE)[self._index])
return len(_unfold_indices(self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE), self._index_chain))
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
......
......@@ -152,6 +152,10 @@ class TestOfflineEvents(unittest.TestCase):
for i in [0, 2, 5]:
assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
def test_index_chaining(self):
assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5])
assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0])
def test_str(self):
assert str(self.n_events) in str(self.events)
......@@ -295,9 +299,9 @@ class TestBranchIndexingMagic(unittest.TestCase):
assert np.allclose(self.events[3:6].tracks.pos_y[:, 0],
self.events.tracks.pos_y[3:6, 0])
# test slicing with a tuple
assert np.allclose(self.events[0].hits[1].dom_id[0:10],
self.events.hits[(0, 1)].dom_id[0:10])
# # test slicing with a tuple
# assert np.allclose(self.events[0].hits[1].dom_id[0:10],
# self.events.hits[(0, 1)].dom_id[0:10])
# test selecting with a list
self.assertEqual(3, len(self.events[[0, 2, 3]]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment