From dc8ae68872bece8552307c8db483e9aff77f685b Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Sat, 7 Mar 2020 13:37:19 +0100 Subject: [PATCH] Lightning fast slicing --- km3io/offline.py | 51 ++++++++++++++++++++++++------------------- tests/test_offline.py | 4 ---- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/km3io/offline.py b/km3io/offline.py index 52e67ac..d4bfee2 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -8,6 +8,7 @@ from .definitions import mc_header MAIN_TREE_NAME = "E" # 110 MB based on the size of the largest basket found so far in km3net BASKET_CACHE_SIZE = 110 * 1024**2 +BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) BranchMapper = namedtuple( "BranchMapper", @@ -81,9 +82,7 @@ class OfflineReader: if file_path is not None: self._fobj = uproot.open(file_path) self._tree = self._fobj[MAIN_TREE_NAME] - self._data = self._tree.lazyarrays( - basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE)) + self._data = self._tree.lazyarrays(basketcache=BASKET_CACHE) else: self._fobj = fobj self._tree = self._fobj[MAIN_TREE_NAME] @@ -121,8 +120,7 @@ class OfflineReader: return len(tree) else: return len( - tree.lazyarrays(basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE))[self.index]) + tree.lazyarrays(basketcache=BASKET_CACHE)[self.index]) @cached_property def header(self): @@ -480,8 +478,6 @@ class Usr: # Here, we assume that every event has the same names in the same order # to massively increase the performance. This needs triple check if it's # always the case; the usr-format is simply a very bad design. - # print("initialising usr for {}".format(name)) - # print("Setting up usr") self._name = name try: tree['usr'] # This will raise a KeyError in old aanet files @@ -493,20 +489,15 @@ class Usr: except (KeyError, IndexError): # e.g. old aanet files self._usr_names = [] else: - # print(" checking usr data") self._usr_idx_lookup = { name: index for index, name in enumerate(self._usr_names) } - data = tree['usr'].lazyarray( - basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE)) + data = tree['usr'].lazyarray(basketcache=BASKET_CACHE) if index is not None: data = data[index] self._usr_data = data - # print(" adding attributes") for name in self._usr_names: - # print(" setting {}".format(name)) setattr(self, name, self[name]) def __getitem__(self, item): @@ -567,12 +558,14 @@ class Header: class Branch: """Branch accessor class""" + # @profile def __init__(self, tree, mapper, index=None, subbranches=None, - subbranchmaps=None): + subbranchmaps=None, + keymap=None): self._tree = tree self._mapper = mapper self._index = index @@ -580,7 +573,10 @@ class Branch: self._branch = tree[mapper.key] self._subbranches = [] - self._initialise_keys() # + if keymap is None: + self._initialise_keys() # + else: + self._keymap = keymap if subbranches is not None: self._subbranches = subbranches @@ -593,6 +589,7 @@ class Branch: for subbranch in self._subbranches: setattr(self, subbranch._mapper.name, subbranch) + # @profile def _initialise_keys(self): """Create the keymap and instance attributes for branch keys""" keys = set(k.decode('utf-8') for k in self._branch.keys()) - set( @@ -607,8 +604,7 @@ class Branch: del self._keymap[k] for key in self._keymap.keys(): - # print("setting", self._mapper.name, key) - setattr(self, key, self[key]) + setattr(self, key, None) def keys(self): return self._keymap.keys() @@ -617,16 +613,29 @@ class Branch: def usr(self): return Usr(self._mapper.name, self._branch, index=self._index) + def __getattribute__(self, attr): + if attr.startswith("_"): # let all private and magic methods pass + return object.__getattribute__(self, attr) + if attr in self._keymap.keys(): # intercept branch key lookups + item = self._keymap[attr] + + out = self._branch[item].lazyarray( + basketcache=BASKET_CACHE) + if self._index is not None: + out = out[self._index] + return out + return object.__getattribute__(self, attr) + + # @profile def __getitem__(self, item): """Slicing magic a la numpy""" - print("Getting item '{}'".format(item)) if isinstance(item, slice): return self.__class__(self._tree, self._mapper, index=item, subbranches=self._subbranches) if isinstance(item, int): - # A bit ugly, but whatever works + # TODO refactor this if self._mapper.flat: if self._index is None: dct = { @@ -665,8 +674,7 @@ class Branch: item = self._keymap[item] out = self._branch[item].lazyarray( - basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE)) + basketcache=BASKET_CACHE) if self._index is not None: out = out[self._index] return out @@ -705,7 +713,6 @@ class BranchElement: The slice mask to be applied to the sub-arrays """ def __init__(self, name, dct, index=None, subbranches=[]): - print("Creating branch element '{}'".format(name)) self._dct = dct self._name = name self._index = index diff --git a/tests/test_offline.py b/tests/test_offline.py index 4d409f4..c5ec14c 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -184,15 +184,11 @@ class TestOfflineEvents(unittest.TestCase): def test_slicing_consistency(self): for s in [slice(1, 3), slice(2, 7, 3)]: - assert np.allclose(OFFLINE_FILE[s].events.n_hits, - self.events.n_hits[s]) assert np.allclose(self.events[s].n_hits, self.events.n_hits[s]) def test_index_consistency(self): for i in [0,2,5]: assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) - assert np.allclose(OFFLINE_FILE[i].events.n_hits, - self.events.n_hits[i]) def test_str(self): assert str(self.n_events) in str(self.events) -- GitLab