diff --git a/km3io/offline.py b/km3io/offline.py index b8f860670b7b7c0b013a3257c65b14403b5fe96a..ec8614d97c793f89143b4cfdc667a29830ab22d9 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -11,7 +11,7 @@ BASKET_CACHE_SIZE = 110 * 1024**2 BranchMapper = namedtuple( "BranchMapper", - ['name', 'key', 'extra', 'exclude', 'update', 'attrparser']) + ['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat']) def _nested_mapper(key): @@ -22,13 +22,14 @@ def _nested_mapper(key): EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"]) BRANCH_MAPS = [ BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {}, - _nested_mapper), + _nested_mapper, False), BranchMapper("mc_tracks", "mc_trks", {}, - ['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper), - BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper), + ['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper, + False), + BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper, False), BranchMapper("mc_hits", "mc_hits", {}, ['mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id'], {}, - _nested_mapper), + _nested_mapper, False), BranchMapper("events", "Evt", { 't_sec': 't.fSec', 't_ns': 't.fNanoSec' @@ -37,7 +38,7 @@ BRANCH_MAPS = [ 'n_mc_hits': 'mc_hits', 'n_tracks': 'trks', 'n_mc_tracks': 'mc_trks' - }, lambda a: a), + }, lambda a: a, True), ] @@ -124,10 +125,6 @@ class OfflineReader: else: warnings.warn("Your file header has an unsupported format") - @cached_property - def usr(self): - return Usr(self._tree) - def get_best_reco(self): """returns the best reconstructed track fit data. The best fit is defined as the track fit with the maximum reconstruction stages. When "nan" is @@ -469,13 +466,14 @@ class OfflineReader: class Usr: """Helper class to access AAObject usr stuff""" - def __init__(self, tree): + def __init__(self, name, tree, index=slice(None)): # Here, we assume that every event has the same names in the same order # to massively increase the performance. This needs triple check if it's # always the case; the usr-format is simply a very bad design. + self._name = name try: self._usr_names = [ - n.decode("utf-8") for n in tree['Evt']['usr_names'].array()[0] + n.decode("utf-8") for n in tree['usr_names'].array()[0] ] except (KeyError, IndexError): # e.g. old aanet files self._usr_names = [] @@ -484,9 +482,9 @@ class Usr: name: index for index, name in enumerate(self._usr_names) } - self._usr_data = tree['Evt']['usr'].lazyarray( + self._usr_data = tree['usr'].lazyarray( basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE)) + BASKET_CACHE_SIZE))[index] for name in self._usr_names: setattr(self, name, self[name]) @@ -502,6 +500,9 @@ class Usr: entries.append("{}: {}".format(name, self[name])) return '\n'.join(entries) + def __repr__(self): + return "<{}[{}]>".format(self.__class__.__name__, self._name) + def _to_num(value): """Convert a value to a numerical one if possible""" @@ -576,17 +577,30 @@ class Branch: def keys(self): return self._keymap.keys() + @cached_property + def usr(self): + return Usr(self._mapper.name, self._branch, index=self._index) + def __getitem__(self, item): """Slicing magic a la numpy""" if isinstance(item, slice): return self.__class__(self._tree, self._mapper, index=item) if isinstance(item, int): - return BranchElement( - self._mapper.name, { - key: self._branch[self._keymap[key]].array()[self._index, - item] - for key in self.keys() - }) + if self._mapper.flat: + return BranchElement( + self._mapper.name, { + key: + self._branch[self._keymap[key]].array()[self._index] + for key in self.keys() + })[item] + else: + return BranchElement( + self._mapper.name, { + key: + self._branch[self._keymap[key]].array()[self._index, + item] + for key in self.keys() + }) if isinstance(item, tuple): return self[item[0]][item[1]] return self._branch[self._keymap[item]].lazyarray(