diff --git a/km3io/offline.py b/km3io/offline.py index bc1d0e3a318134f5df1074b54cdab9525517feb3..5b3f06c884249ad3d63f5c2020c7164eeed04c87 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -56,11 +56,7 @@ class cached_property: class OfflineReader: """reader for offline ROOT files""" - def __init__(self, - file_path=None, - fobj=None, - data=None, - index=slice(None)): + def __init__(self, file_path=None, fobj=None, data=None, index=None): """ OfflineReader class is an offline ROOT file wrapper Parameters @@ -83,6 +79,7 @@ class OfflineReader: self._data = data for mapper in BRANCH_MAPS: + # print("setting mapper {}".format(mapper.name)) setattr(self, mapper.name, Branch(self._tree, mapper=mapper, index=self._index)) @@ -107,7 +104,7 @@ class OfflineReader: def __len__(self): tree = self._fobj[MAIN_TREE_NAME] - if self._index == slice(None): + if self._index is None: return len(tree) else: return len( @@ -466,29 +463,37 @@ class OfflineReader: class Usr: """Helper class to access AAObject usr stuff""" - def __init__(self, name, tree, index=slice(None)): + def __init__(self, name, tree, index=None): # Here, we assume that every event has the same names in the same order # to massively increase the performance. This needs triple check if it's # always the case; the usr-format is simply a very bad design. + # print("initialising usr for {}".format(name)) + # print("Setting up usr") self._name = name try: tree['usr'] # This will raise a KeyError in old aanet files - # which has a different strucuter and key (usr_data) - # We do not support those... + # which has a different strucuter and key (usr_data) + # We do not support those... self._usr_names = [ - n.decode("utf-8") for n in tree['usr_names'].array()[0] + n.decode("utf-8") for n in tree['usr_names'].lazyarray()[0] ] except (KeyError, IndexError): # e.g. old aanet files self._usr_names = [] else: + # print(" checking usr data") self._usr_idx_lookup = { name: index for index, name in enumerate(self._usr_names) } - self._usr_data = tree['usr'].lazyarray( + data = tree['usr'].lazyarray( basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE))[index] + BASKET_CACHE_SIZE)) + if index is not None: + data = data[index] + self._usr_data = data + # print(" adding attributes") for name in self._usr_names: + # print(" setting {}".format(name)) setattr(self, name, self[name]) def __getitem__(self, item): @@ -549,7 +554,7 @@ class Header: class Branch: """Branch accessor class""" - def __init__(self, tree, mapper, index=slice(None)): + def __init__(self, tree, mapper, index=None): self._tree = tree self._mapper = mapper self._index = index @@ -606,13 +611,21 @@ class Branch: }) if isinstance(item, tuple): return self[item[0]][item[1]] - out = self._branch[self._keymap[item]].lazyarray( - basketcache=uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)) - if self._index != slice(None): - out[self._index] + + if isinstance(item, str): + item = self._keymap[item] + + out = self._branch[item].lazyarray( + basketcache=uproot.cache.ThreadSafeArrayCache( + BASKET_CACHE_SIZE)) + if self._index is not None: + out = out[self._index] + return out + + return self.__class__(self._tree, self._mapper, index=np.array(item)) def __len__(self): - if self._index == slice(None): + if self._index is None: return len(self._branch) else: return len( @@ -639,7 +652,7 @@ class BranchElement: index: slice The slice mask to be applied to the sub-arrays """ - def __init__(self, name, dct, index=slice(None)): + def __init__(self, name, dct, index=None): self._dct = dct self._name = name self._index = index