Skip to content
Snippets Groups Projects
Commit dc8ae688 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Lightning fast slicing

parent 37b13c8a
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
...@@ -8,6 +8,7 @@ from .definitions import mc_header ...@@ -8,6 +8,7 @@ from .definitions import mc_header
MAIN_TREE_NAME = "E" MAIN_TREE_NAME = "E"
# 110 MB based on the size of the largest basket found so far in km3net # 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2 BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
BranchMapper = namedtuple( BranchMapper = namedtuple(
"BranchMapper", "BranchMapper",
...@@ -81,9 +82,7 @@ class OfflineReader: ...@@ -81,9 +82,7 @@ class OfflineReader:
if file_path is not None: if file_path is not None:
self._fobj = uproot.open(file_path) self._fobj = uproot.open(file_path)
self._tree = self._fobj[MAIN_TREE_NAME] self._tree = self._fobj[MAIN_TREE_NAME]
self._data = self._tree.lazyarrays( self._data = self._tree.lazyarrays(basketcache=BASKET_CACHE)
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
else: else:
self._fobj = fobj self._fobj = fobj
self._tree = self._fobj[MAIN_TREE_NAME] self._tree = self._fobj[MAIN_TREE_NAME]
...@@ -121,8 +120,7 @@ class OfflineReader: ...@@ -121,8 +120,7 @@ class OfflineReader:
return len(tree) return len(tree)
else: else:
return len( return len(
tree.lazyarrays(basketcache=uproot.cache.ThreadSafeArrayCache( tree.lazyarrays(basketcache=BASKET_CACHE)[self.index])
BASKET_CACHE_SIZE))[self.index])
@cached_property @cached_property
def header(self): def header(self):
...@@ -480,8 +478,6 @@ class Usr: ...@@ -480,8 +478,6 @@ class Usr:
# Here, we assume that every event has the same names in the same order # Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if it's # to massively increase the performance. This needs triple check if it's
# always the case; the usr-format is simply a very bad design. # always the case; the usr-format is simply a very bad design.
# print("initialising usr for {}".format(name))
# print("Setting up usr")
self._name = name self._name = name
try: try:
tree['usr'] # This will raise a KeyError in old aanet files tree['usr'] # This will raise a KeyError in old aanet files
...@@ -493,20 +489,15 @@ class Usr: ...@@ -493,20 +489,15 @@ class Usr:
except (KeyError, IndexError): # e.g. old aanet files except (KeyError, IndexError): # e.g. old aanet files
self._usr_names = [] self._usr_names = []
else: else:
# print(" checking usr data")
self._usr_idx_lookup = { self._usr_idx_lookup = {
name: index name: index
for index, name in enumerate(self._usr_names) for index, name in enumerate(self._usr_names)
} }
data = tree['usr'].lazyarray( data = tree['usr'].lazyarray(basketcache=BASKET_CACHE)
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
if index is not None: if index is not None:
data = data[index] data = data[index]
self._usr_data = data self._usr_data = data
# print(" adding attributes")
for name in self._usr_names: for name in self._usr_names:
# print(" setting {}".format(name))
setattr(self, name, self[name]) setattr(self, name, self[name])
def __getitem__(self, item): def __getitem__(self, item):
...@@ -567,12 +558,14 @@ class Header: ...@@ -567,12 +558,14 @@ class Header:
class Branch: class Branch:
"""Branch accessor class""" """Branch accessor class"""
# @profile
def __init__(self, def __init__(self,
tree, tree,
mapper, mapper,
index=None, index=None,
subbranches=None, subbranches=None,
subbranchmaps=None): subbranchmaps=None,
keymap=None):
self._tree = tree self._tree = tree
self._mapper = mapper self._mapper = mapper
self._index = index self._index = index
...@@ -580,7 +573,10 @@ class Branch: ...@@ -580,7 +573,10 @@ class Branch:
self._branch = tree[mapper.key] self._branch = tree[mapper.key]
self._subbranches = [] self._subbranches = []
self._initialise_keys() # if keymap is None:
self._initialise_keys() #
else:
self._keymap = keymap
if subbranches is not None: if subbranches is not None:
self._subbranches = subbranches self._subbranches = subbranches
...@@ -593,6 +589,7 @@ class Branch: ...@@ -593,6 +589,7 @@ class Branch:
for subbranch in self._subbranches: for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch) setattr(self, subbranch._mapper.name, subbranch)
# @profile
def _initialise_keys(self): def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys""" """Create the keymap and instance attributes for branch keys"""
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set( keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
...@@ -607,8 +604,7 @@ class Branch: ...@@ -607,8 +604,7 @@ class Branch:
del self._keymap[k] del self._keymap[k]
for key in self._keymap.keys(): for key in self._keymap.keys():
# print("setting", self._mapper.name, key) setattr(self, key, None)
setattr(self, key, self[key])
def keys(self): def keys(self):
return self._keymap.keys() return self._keymap.keys()
...@@ -617,16 +613,29 @@ class Branch: ...@@ -617,16 +613,29 @@ class Branch:
def usr(self): def usr(self):
return Usr(self._mapper.name, self._branch, index=self._index) return Usr(self._mapper.name, self._branch, index=self._index)
def __getattribute__(self, attr):
if attr.startswith("_"): # let all private and magic methods pass
return object.__getattribute__(self, attr)
if attr in self._keymap.keys(): # intercept branch key lookups
item = self._keymap[attr]
out = self._branch[item].lazyarray(
basketcache=BASKET_CACHE)
if self._index is not None:
out = out[self._index]
return out
return object.__getattribute__(self, attr)
# @profile
def __getitem__(self, item): def __getitem__(self, item):
"""Slicing magic a la numpy""" """Slicing magic a la numpy"""
print("Getting item '{}'".format(item))
if isinstance(item, slice): if isinstance(item, slice):
return self.__class__(self._tree, return self.__class__(self._tree,
self._mapper, self._mapper,
index=item, index=item,
subbranches=self._subbranches) subbranches=self._subbranches)
if isinstance(item, int): if isinstance(item, int):
# A bit ugly, but whatever works # TODO refactor this
if self._mapper.flat: if self._mapper.flat:
if self._index is None: if self._index is None:
dct = { dct = {
...@@ -665,8 +674,7 @@ class Branch: ...@@ -665,8 +674,7 @@ class Branch:
item = self._keymap[item] item = self._keymap[item]
out = self._branch[item].lazyarray( out = self._branch[item].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache( basketcache=BASKET_CACHE)
BASKET_CACHE_SIZE))
if self._index is not None: if self._index is not None:
out = out[self._index] out = out[self._index]
return out return out
...@@ -705,7 +713,6 @@ class BranchElement: ...@@ -705,7 +713,6 @@ class BranchElement:
The slice mask to be applied to the sub-arrays The slice mask to be applied to the sub-arrays
""" """
def __init__(self, name, dct, index=None, subbranches=[]): def __init__(self, name, dct, index=None, subbranches=[]):
print("Creating branch element '{}'".format(name))
self._dct = dct self._dct = dct
self._name = name self._name = name
self._index = index self._index = index
......
...@@ -184,15 +184,11 @@ class TestOfflineEvents(unittest.TestCase): ...@@ -184,15 +184,11 @@ class TestOfflineEvents(unittest.TestCase):
def test_slicing_consistency(self): def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]: for s in [slice(1, 3), slice(2, 7, 3)]:
assert np.allclose(OFFLINE_FILE[s].events.n_hits,
self.events.n_hits[s])
assert np.allclose(self.events[s].n_hits, self.events.n_hits[s]) assert np.allclose(self.events[s].n_hits, self.events.n_hits[s])
def test_index_consistency(self): def test_index_consistency(self):
for i in [0,2,5]: for i in [0,2,5]:
assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
assert np.allclose(OFFLINE_FILE[i].events.n_hits,
self.events.n_hits[i])
def test_str(self): def test_str(self):
assert str(self.n_events) in str(self.events) assert str(self.n_events) in str(self.events)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment