From 5e147e2f34c62793b4ae9f44ccc5813162d7ee8b Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Sun, 29 Mar 2020 11:29:55 +0200 Subject: [PATCH] Use subclassed Branch --- km3io/offline.py | 107 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/km3io/offline.py b/km3io/offline.py index 1d970d4..c9a8d51 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -72,6 +72,111 @@ SUBBRANCH_MAPS = [ ] +class OfflineBranch(Branch): + @cached_property + def usr(self): + return Usr(self._mapper, self._branch, index=self._index) + + +class Usr: + """Helper class to access AAObject `usr` stuff""" + def __init__(self, mapper, branch, index=None): + self._mapper = mapper + self._name = mapper.name + self._index = index + self._branch = branch + self._usr_names = [] + self._usr_idx_lookup = {} + + self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr' + + self._initialise() + + def _initialise(self): + try: + self._branch[self._usr_key] + # This will raise a KeyError in old aanet files + # which has a different strucuter and key (usr_data) + # We do not support those (yet) + except (KeyError, IndexError): + print("The `usr` fields could not be parsed for the '{}' branch.". + format(self._name)) + return + + if self._mapper.flat: + self._initialise_flat() + + def _initialise_flat(self): + # Here, we assume that every event has the same names in the same order + # to massively increase the performance. This needs triple check if + # it's always the case. + self._usr_names = [ + n.decode("utf-8") + for n in self._branch[self._usr_key + '_names'].lazyarray()[0] + ] + self._usr_idx_lookup = { + name: index + for index, name in enumerate(self._usr_names) + } + + data = self._branch[self._usr_key].lazyarray() + + if self._index is not None: + data = data[self._index] + + self._usr_data = data + + for name in self._usr_names: + setattr(self, name, self[name]) + + # def _initialise_nested(self): + # self._usr_names = [ + # n.decode("utf-8") for n in self.branch['usr_names'].lazyarray( + # # TODO this will be fixed soon in uproot, + # # see https://github.com/scikit-hep/uproot/issues/465 + # uproot.asgenobj( + # uproot.SimpleArray(uproot.STLVector(uproot.STLString())), + # self.branch['usr_names']._context, 6), + # basketcache=BASKET_CACHE)[0] + # ] + + def __getitem__(self, item): + if self._mapper.flat: + return self.__getitem_flat__(item) + return self.__getitem_nested__(item) + + def __getitem_flat__(self, item): + if self._index is not None: + return self._usr_data[self._index][:, self._usr_idx_lookup[item]] + else: + return self._usr_data[:, self._usr_idx_lookup[item]] + + def __getitem_nested__(self, item): + data = self._branch[self._usr_key + '_names'].lazyarray( + # TODO this will be fixed soon in uproot, + # see https://github.com/scikit-hep/uproot/issues/465 + uproot.asgenobj( + uproot.SimpleArray(uproot.STLVector(uproot.STLString())), + self._branch[self._usr_key + '_names']._context, 6), + basketcache=BASKET_CACHE) + if self._index is None: + return data + else: + return data[self._index] + + def keys(self): + return self._usr_names + + def __str__(self): + entries = [] + for name in self.keys(): + entries.append("{}: {}".format(name, self[name])) + return '\n'.join(entries) + + def __repr__(self): + return "<{}[{}]>".format(self.__class__.__name__, self._name) + + class OfflineReader: """reader for offline ROOT files""" def __init__(self, file_path=None): @@ -90,7 +195,7 @@ class OfflineReader: @cached_property def events(self): """The `E` branch, containing all offline events.""" - return Branch(self._tree, + return OfflineBranch(self._tree, mapper=EVENTS_MAP, subbranchmaps=SUBBRANCH_MAPS) -- GitLab