diff --git a/km3io/offline.py b/km3io/offline.py index bace252d0f8a88bc800f6a360c11faff10028170..b8f860670b7b7c0b013a3257c65b14403b5fe96a 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -26,7 +26,8 @@ BRANCH_MAPS = [ BranchMapper("mc_tracks", "mc_trks", {}, ['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper), BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper), - BranchMapper("mc_hits", "mc_hits", {}, ['mc_hits.usr'], {}, + BranchMapper("mc_hits", "mc_hits", {}, + ['mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id'], {}, _nested_mapper), BranchMapper("events", "Evt", { 't_sec': 't.fSec', @@ -81,9 +82,8 @@ class OfflineReader: self._data = data for mapper in BRANCH_MAPS: - setattr( - self, mapper.name, - BranchElement(self._tree, mapper=mapper, index=self._index)) + setattr(self, mapper.name, + Branch(self._tree, mapper=mapper, index=self._index)) @classmethod def from_index(cls, source, index): @@ -543,8 +543,8 @@ class Header: return "\n".join(lines) -class BranchElement: - """wrapper for offline tracks""" +class Branch: + """Branch accessor class""" def __init__(self, tree, mapper, index=slice(None)): self._tree = tree self._mapper = mapper @@ -581,10 +581,14 @@ class BranchElement: if isinstance(item, slice): return self.__class__(self._tree, self._mapper, index=item) if isinstance(item, int): - return { - key: self._branch[self._keymap[key]].array()[self._index, item] - for key in self.keys() - } + return BranchElement( + self._mapper.name, { + key: self._branch[self._keymap[key]].array()[self._index, + item] + for key in self.keys() + }) + if isinstance(item, tuple): + return self[item[0]][item[1]] return self._branch[self._keymap[item]].lazyarray( basketcache=uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE))[ self._index] @@ -603,3 +607,35 @@ class BranchElement: return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, self._mapper.name, len(self)) + + +class BranchElement: + """Represents a single branch element + + Parameters + ---------- + name: str + The name of the branch + dct: dict (keys=attributes, values=arrays of values) + The data + index: slice + The slice mask to be applied to the sub-arrays + """ + def __init__(self, name, dct, index=slice(None)): + self._dct = dct + self._name = name + self._index = index + self.ItemConstructor = namedtuple(self._name[:-1], dct.keys()) + for key, values in dct.items(): + setattr(self, key, values[index]) + + def __getitem__(self, item): + if isinstance(item, slice): + return self.__class__(self._name, self._dct, index=item) + if isinstance(item, int): + return self.ItemConstructor( + **{k: v[self._index][item] + for k, v in self._dct.items()}) + + def __repr__(self): + return "<{}[{}]>".format(self.__class__.__name__, self._name)