diff --git a/km3io/offline.py b/km3io/offline.py index 50d30d81e9d895326ec3d602bf17b85d0d51e837..6fb9a1576975120ceb6afc4cb579336fdb765cd6 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -12,7 +12,7 @@ EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"]) BranchMapper = namedtuple( "BranchMapper", - ['name', 'key', 'extra', 'exclude', 'update', 'attrparser']) + ['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat']) def _nested_mapper(key): @@ -33,7 +33,8 @@ EVENTS_MAP = BranchMapper(name="events", 'n_tracks': 'trks', 'n_mc_tracks': 'mc_trks' }, - attrparser=lambda a: a) + attrparser=lambda a: a, + flat=True) SUBBRANCH_MAPS = [ BranchMapper( @@ -42,7 +43,8 @@ SUBBRANCH_MAPS = [ extra={}, exclude=['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'], update={}, - attrparser=_nested_mapper), + attrparser=_nested_mapper, + flat=False), BranchMapper(name="mc_tracks", key="mc_trks", extra={}, @@ -51,7 +53,8 @@ SUBBRANCH_MAPS = [ 'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits' ], update={}, - attrparser=_nested_mapper), + attrparser=_nested_mapper, + flat=False), BranchMapper(name="hits", key="hits", extra={}, @@ -60,7 +63,8 @@ SUBBRANCH_MAPS = [ 'hits.pure_a', 'hits.fUniqueID', 'hits.fBits' ], update={}, - attrparser=_nested_mapper), + attrparser=_nested_mapper, + flat=False), BranchMapper(name="mc_hits", key="mc_hits", extra={}, @@ -70,7 +74,8 @@ SUBBRANCH_MAPS = [ 'mc_hits.fUniqueID', 'mc_hits.fBits' ], update={}, - attrparser=_nested_mapper), + attrparser=_nested_mapper, + flat=False), ] @@ -123,40 +128,72 @@ class OfflineReader: class Usr: """Helper class to access AAObject `usr`` stuff""" - def __init__(self, mapper, tree, index=None): - # Here, we assume that every event has the same names in the same order - # to massively increase the performance. This needs triple check if - # it's always the case; the usr-format is simply a very bad design. - self._name = mapper.key + + def __init__(self, mapper, branch, index=None): + self._mapper = mapper + self._name = mapper.name self._index = index + self._branch = branch + self._usr_names = [] + self._usr_idx_lookup = {} + try: - tree[mapper.key + - '.usr'] # This will raise a KeyError in old aanet files + branch['usr'] + # This will raise a KeyError in old aanet files # which has a different strucuter and key (usr_data) - # We do not support those... - self._usr_names = [ - n.decode("utf-8") - for n in tree[mapper.key + '.usr_names'].lazyarray( - basketcache=BASKET_CACHE)[0] - ] - except (KeyError, IndexError): # e.g. old aanet files + # We do not support those + except (KeyError, IndexError): print("The `usr` fields could not be parsed for the '{}' branch.". format(self._name)) - self._usr_names = [] + return + + if mapper.flat: + self._initialise_flat() else: - self._usr_idx_lookup = { - name: index - for index, name in enumerate(self._usr_names) - } - data = tree[mapper.key + - '.usr'].lazyarray(basketcache=BASKET_CACHE) - if index is not None: - data = data[index] - self._usr_data = data - for name in self._usr_names: - setattr(self, name, self[name]) + # self._initialise_nested() + # branch[self._mapper.key + '.usr'] + pass + + def _initialise_flat(self): + # Here, we assume that every event has the same names in the same order + # to massively increase the performance. This needs triple check if + # it's always the case. + self._usr_names = [ + n.decode("utf-8") + for n in self._branch['usr_names'].lazyarray( + basketcache=BASKET_CACHE)[0] + ] + self._usr_idx_lookup = { + name: index + for index, name in enumerate(self._usr_names) + } + + data = self._branch['usr'].lazyarray(basketcache=BASKET_CACHE) + + if self._index is not None: + data = data[self._index] + + self._usr_data = data + + for name in self._usr_names: + setattr(self, name, self[name]) + + def _initialise_nested(self): + self._usr_names = [ + n.decode("utf-8") for n in self.branch['usr_names'].lazyarray( + # TODO this will be fixed soon in uproot, + # see https://github.com/scikit-hep/uproot/issues/465 + uproot.asgenobj(uproot.SimpleArray(uproot.STLVector(uproot.STLString())), self.branch['usr_names']._context, 6), + basketcache=BASKET_CACHE)[0] + ] + self.__getitem__ = self.__getitem_nested__ def __getitem__(self, item): + if self._mapper.flat: + return self.__getitem_flat__(item) + return self.__getitem_nested__(item) + + def __getitem_flat__(self, item): if self._index is not None: return self._usr_data[self._index][:, self._usr_idx_lookup[item]] else: