diff --git a/km3io/offline.py b/km3io/offline.py index dac9f8159522f52d53ac8ec381eff7fb32d7b295..981de8316462793eb1d9f5639b49ba41c8186d86 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -1,3 +1,4 @@ +from collections import namedtuple import uproot import numpy as np import warnings @@ -10,6 +11,22 @@ MAIN_TREE_NAME = "E" BASKET_CACHE_SIZE = 110 * 1024**2 +BranchMapper = namedtuple("BranchMapper", ['name', 'key', 'extra_keys', 'attrparser']) + +def _nested_mapper(key): + """Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)""" + return '_'.join(key.split('.')[1:]) + + +BRANCH_MAPS = [ + BranchMapper("tracks", "trks", {}, _nested_mapper), + BranchMapper("mc_tracks", "mc_trks", {}, _nested_mapper), + BranchMapper("hits", "mc_hits", {}, _nested_mapper), + BranchMapper("mc_hits", "mc_hits", {}, _nested_mapper), + BranchMapper("events", "Evt", {'t_sec': 't.fSec', 't_ns': 't.fNanoSec'}, lambda a: a), +] + + class cached_property: """A simple cache decorator for properties.""" def __init__(self, function): @@ -22,189 +39,9 @@ class cached_property: return prop -def _get_keys(tree, fake_branches=None): - """Get tree keys except those in fake_branches - - Parameters - ---------- - tree : uproot.Tree - The tree to look for keys - fake_branches : list of str or None - The fake branches to ignore - - Returns - ------- - list of str - The keys of the tree. - """ - keys = [] - for key in tree.keys(): - key = key.decode('utf-8') - if fake_branches is not None and key in fake_branches: - continue - keys.append(key) - return keys - - -class OfflineKeys: - """wrapper for offline keys""" - def __init__(self, tree): - """OfflineKeys is a class that reads all the available keys in an offline - file and adapts the keys format to Python format. - - Parameters - ---------- - tree : uproot.TTree - The main ROOT tree. - """ - self._tree = tree - - def __str__(self): - return '\n'.join([ - "Events keys are:\n\t" + "\n\t".join(self.events_keys), - "Hits keys are:\n\t" + '\n\t'.join(self.hits_keys), - "Tracks keys are:\n\t" + '\n\t'.join(self.tracks_keys), - "Mc hits keys are:\n\t" + '\n\t'.join(self.mc_hits_keys), - "Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys) - ]) - - def __repr__(self): - return "<{}>".format(self.__class__.__name__) - - @cached_property - def events_keys(self): - """reads events keys from an offline file. - - Returns - ------- - list of str - list of all events keys found in an offline file, - except those found in fake branches. - """ - fake_branches = ['Evt', 'AAObject', 'TObject', 't'] - t_baskets = ['t.fSec', 't.fNanoSec'] - tree = self._tree['Evt'] - return _get_keys(self._tree['Evt'], fake_branches) + t_baskets - - @cached_property - def hits_keys(self): - """reads hits keys from an offline file. - - Returns - ------- - list of str - list of all hits keys found in an offline file, - except those found in fake branches. - """ - fake_branches = ['hits.usr', 'hits.usr_names'] - return _get_keys(self._tree['hits'], fake_branches) - - @cached_property - def tracks_keys(self): - """reads tracks keys from an offline file. - - Returns - ------- - list of str - list of all tracks keys found in an offline file, - except those found in fake branches. - """ - # a solution can be tree['trks.usr_data'].array( - # uproot.asdtype(">i4")) - fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names'] - return _get_keys(self._tree['Evt']['trks'], fake_branches) - - @cached_property - def mc_hits_keys(self): - """reads mc hits keys from an offline file. - - Returns - ------- - list of str - list of all mc hits keys found in an offline file, - except those found in fake branches. - """ - fake_branches = ['mc_hits.usr', 'mc_hits.usr_names'] - return _get_keys(self._tree['Evt']['mc_hits'], fake_branches) - - @cached_property - def mc_tracks_keys(self): - """reads mc tracks keys from an offline file. - - Returns - ------- - list of str - list of all mc tracks keys found in an offline file, - except those found in fake branches. - """ - fake_branches = [ - 'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names' - ] - return _get_keys(self._tree['Evt']['mc_trks'], fake_branches) - - @cached_property - def valid_keys(self): - """constructs a list of all valid keys to be read from an offline event file. - Returns - ------- - list of str - list of all valid keys. - """ - return (self.events_keys + self.hits_keys + self.tracks_keys + - self.mc_tracks_keys + self.mc_hits_keys) - - @cached_property - def fit_keys(self): - """constructs a list of fit parameters, not yet outsourced in an offline file. - - Returns - ------- - list of str - list of all "trks.fitinf" keys. - """ - return sorted(km3io.definitions.fitparameters.data, - key=km3io.definitions.fitparameters.data.get, - reverse=False) - - @cached_property - def cut_hits_keys(self): - """adapts hits keys for instance variables format in a Python class. - - Returns - ------- - list of str - list of adapted hits keys. - """ - return [k.split('hits.')[1].replace('.', '_') for k in self.hits_keys] - - @cached_property - def cut_tracks_keys(self): - """adapts tracks keys for instance variables format in a Python class. - - Returns - ------- - list of str - list of adapted tracks keys. - """ - return [ - k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys - ] - - @cached_property - def cut_events_keys(self): - """adapts events keys for instance variables format in a Python class. - - Returns - ------- - list of str - list of adapted events keys. - """ - return [k.replace('.', '_') for k in self.events_keys] - - class OfflineReader: """reader for offline ROOT files""" - def __init__(self, file_path=None, fobj=None, data=None): + def __init__(self, file_path=None, fobj=None, data=None, index=slice(-1)): """ OfflineReader class is an offline ROOT file wrapper Parameters @@ -214,6 +51,7 @@ class OfflineReader: path-like object that points to the file. """ + self._index = index if file_path is not None: self._fobj = uproot.open(file_path) self._tree = self._fobj[MAIN_TREE_NAME] @@ -225,6 +63,9 @@ class OfflineReader: self._tree = self._fobj[MAIN_TREE_NAME] self._data = data + for mapper in BRANCH_MAPS: + setattr(self, mapper.name, BranchElement(self._tree, mapper=mapper, index=self._index)) + @classmethod def from_index(cls, source, index): """Create an instance with a subtree of a given index @@ -232,18 +73,24 @@ class OfflineReader: Parameters ---------- source: ROOTDirectory - The source file. + The source file object. index: index or slice The index or slice to create the subtree. """ - instance = cls(fobj=source._fobj, data=source._data[index]) + instance = cls(fobj=source._fobj, data=source._data[index], index=index) return instance def __getitem__(self, index): return OfflineReader.from_index(source=self, index=index) def __len__(self): - return len(self._data) + tree = self._fobj[MAIN_TREE_NAME] + if self._index == slice(-1): + return len(tree) + else: + return len(tree.lazyarrays( + basketcache=uproot.cache.ThreadSafeArrayCache( + BASKET_CACHE_SIZE))[self.index]) @cached_property def header(self): @@ -256,76 +103,6 @@ class OfflineReader: else: warnings.warn("Your file header has an unsupported format") - @cached_property - def keys(self): - """wrapper for all keys in an offline file. - - Returns - ------- - Class - OfflineKeys. - """ - return OfflineKeys(self._tree) - - @cached_property - def events(self): - """wrapper for offline events. - - Returns - ------- - Class - OfflineEvents. - """ - return OfflineEvents( - self.keys.cut_events_keys, - [self._data[key] for key in self.keys.events_keys]) - - @cached_property - def hits(self): - """wrapper for offline hits. - - Returns - ------- - Class - OfflineHits. - """ - return OfflineHits(self.keys.cut_hits_keys, - [self._data[key] for key in self.keys.hits_keys]) - - @cached_property - def tracks(self): - """wrapper for offline tracks. - - Returns - ------- - Class - OfflineTracks. - """ - return OfflineTracks(self._tree['trks']) - - @cached_property - def mc_hits(self): - """wrapper for offline mc hits. - - Returns - ------- - Class - OfflineHits. - """ - return OfflineHits(self.keys.cut_hits_keys, - [self._data[key] for key in self.keys.mc_hits_keys]) - - @cached_property - def mc_tracks(self): - """wrapper for offline mc tracks. - - Returns - ------- - Class - OfflineTracks. - """ - return OfflineTracks(self._tree['mc_trks']) - @cached_property def usr(self): return Usr(self._tree) @@ -705,137 +482,23 @@ class Usr: return '\n'.join(entries) -class OfflineEvents: - """wrapper for offline events""" - def __init__(self, keys, values): - """wrapper for offline events. - - Parameters - ---------- - keys : list of str - list of valid events keys. - values : list of arrays - list of arrays containting events data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __getitem__(self, item): - return OfflineEvent(self._keys, [v[item] for v in self._values]) - - def __len__(self): - try: - return len(self._values[0]) - except IndexError: - return 0 - - def __str__(self): - return "Number of events: {}".format(len(self)) - - def __repr__(self): - return "<{}: {} parsed events>".format(self.__class__.__name__, - len(self)) - - -class OfflineEvent: - """wrapper for an offline event""" - def __init__(self, keys, values): - """wrapper for one offline event. - - Parameters - ---------- - keys : list of str - list of valid events keys. - values : list of arrays - list of arrays containting event data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __str__(self): - return "offline event:\n\t" + "\n\t".join([ - "{:15} {:^10} {:>10}".format(k, ':', str(v)) - for k, v in zip(self._keys, self._values) - ]) - - -class OfflineHits: - """wrapper for offline hits""" - def __init__(self, keys, values): - """wrapper for offline hits. - - Parameters - ---------- - keys : list of str - list of cropped hits keys. - values : list of arrays - list of arrays containting hits data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __getitem__(self, item): - return OfflineHit(self._keys, [v[item] for v in self._values]) - - def __len__(self): - try: - return len(self._values[0]) - except IndexError: - return 0 - - def __str__(self): - return "Number of hits: {}".format(len(self)) - - def __repr__(self): - return "<{}: {} parsed elements>".format(self.__class__.__name__, - len(self)) - - -class OfflineHit: - """wrapper for an offline hit""" - def __init__(self, keys, values): - """wrapper for one offline hit. - - Parameters - ---------- - keys : list of str - list of cropped hits keys. - values : list of arrays - list of arrays containting hit data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __str__(self): - return "offline hit:\n\t" + "\n\t".join([ - "{:15} {:^10} {:>10}".format(k, ':', str(v)) - for k, v in zip(self._keys, self._values) - ]) - - def __getitem__(self, item): - return self._values[item] - - -class OfflineTracks: +class BranchElement: """wrapper for offline tracks""" - def __init__(self, branch, index=slice(-1)): - keys = [k.decode('utf-8') for k in branch.keys()] - self._keymap = {k[5:].replace('.', '_'): k for k in keys} - self._branch = branch - self._keys = keys + def __init__(self, tree, mapper, index=slice(-1)): + self.mapper = mapper + self.name = mapper.name + self._tree = tree + self._branch = tree[mapper.key] + keys = [k.decode('utf-8') for k in self._branch.keys()] + self._keymap = {**{mapper.attrparser(k): k for k in keys}, **mapper.extra_keys} self._index = index + # for key in keys: + # setattr(self, key, cached_property(self[key])) + def __getitem__(self, item): if isinstance(item, slice): - return OfflineTracks(self._branch, index=item) + return self.__class__(self._tree, self.mapper, index=item) return self._branch[self._keymap[item]].lazyarray( basketcache=uproot.cache.ThreadSafeArrayCache( BASKET_CACHE_SIZE))[self._index] @@ -846,41 +509,12 @@ class OfflineTracks: else: return len(self._branch[self._keymap['id']].lazyarray()[self._index]) + def keys(self): + return self._keymap.keys() + def __str__(self): - return "Number of tracks: {}".format(len(self._branch)) + return "Number of elements: {}".format(len(self._branch)) def __repr__(self): - return "<{}: {} parsed elements>".format(self.__class__.__name__, + return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, self.name, len(self)) - - -class OfflineTrack: - """wrapper for an offline track""" - def __init__(self, keys, values): - """wrapper for one offline track. - - Parameters - ---------- - keys : list of str - list of cropped tracks keys. - values : list of arrays - list of arrays containting track data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __str__(self): - return "offline track:\n\t" + "\n\t".join([ - "{:30} {:^2} {:>26}".format(k, ':', str(v)) - for k, v in zip(self._keys, self._values) if k not in ['fitinf'] - ]) + "\n\t" + "\n\t".join([ - "{:30} {:^2} {:>26}".format(k, ':', str( - getattr(self, 'fitinf')[v])) - for k, v in km3io.definitions.fitparameters.data.items() - if len(getattr(self, 'fitinf')) > v - ]) # I don't like 18 being explicit here - - def __getitem__(self, item): - return self._values[item]