diff --git a/km3io/offline.py b/km3io/offline.py index 37f52d69565c4baefc7af8c41e89235e9ea4b646..bfd5b0aff76d02ab115bceae60ab257dd7c1dcf8 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -9,8 +9,9 @@ MAIN_TREE_NAME = "E" # 110 MB based on the size of the largest basket found so far in km3net BASKET_CACHE_SIZE = 110 * 1024**2 - -BranchMapper = namedtuple("BranchMapper", ['name', 'key', 'extra', 'exclude', 'update', 'attrparser']) +BranchMapper = namedtuple( + "BranchMapper", + ['name', 'key', 'extra', 'exclude', 'update', 'attrparser']) def _nested_mapper(key): @@ -20,11 +21,22 @@ def _nested_mapper(key): EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"]) BRANCH_MAPS = [ - BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {}, _nested_mapper), - BranchMapper("mc_tracks", "mc_trks", {}, ['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper), + BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {}, + _nested_mapper), + BranchMapper("mc_tracks", "mc_trks", {}, + ['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper), BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper), - BranchMapper("mc_hits", "mc_hits", {}, ['mc_hits.usr'], {}, _nested_mapper), - BranchMapper("events", "Evt", {'t_sec': 't.fSec', 't_ns': 't.fNanoSec'}, [], {'n_hits': 'hits', 'n_mc_hits': 'mc_hits', 'n_tracks': 'trks', 'n_mc_tracks': 'mc_trks'}, lambda a: a), + BranchMapper("mc_hits", "mc_hits", {}, ['mc_hits.usr'], {}, + _nested_mapper), + BranchMapper("events", "Evt", { + 't_sec': 't.fSec', + 't_ns': 't.fNanoSec' + }, [], { + 'n_hits': 'hits', + 'n_mc_hits': 'mc_hits', + 'n_tracks': 'trks', + 'n_mc_tracks': 'mc_trks' + }, lambda a: a), ] @@ -42,7 +54,11 @@ class cached_property: class OfflineReader: """reader for offline ROOT files""" - def __init__(self, file_path=None, fobj=None, data=None, index=slice(None)): + def __init__(self, + file_path=None, + fobj=None, + data=None, + index=slice(None)): """ OfflineReader class is an offline ROOT file wrapper Parameters @@ -65,7 +81,9 @@ class OfflineReader: self._data = data for mapper in BRANCH_MAPS: - setattr(self, mapper.name, BranchElement(self._tree, mapper=mapper, index=self._index)) + setattr( + self, mapper.name, + BranchElement(self._tree, mapper=mapper, index=self._index)) @classmethod def from_index(cls, source, index): @@ -78,7 +96,9 @@ class OfflineReader: index: index or slice The index or slice to create the subtree. """ - instance = cls(fobj=source._fobj, data=source._data[index], index=index) + instance = cls(fobj=source._fobj, + data=source._data[index], + index=index) return instance def __getitem__(self, index): @@ -89,8 +109,8 @@ class OfflineReader: if self._index == slice(None): return len(tree) else: - return len(tree.lazyarrays( - basketcache=uproot.cache.ThreadSafeArrayCache( + return len( + tree.lazyarrays(basketcache=uproot.cache.ThreadSafeArrayCache( BASKET_CACHE_SIZE))[self.index]) @cached_property @@ -484,9 +504,9 @@ class Usr: def _to_num(value): - """Convert value to a numerical value if possible""" + """Convert a value to a numerical one if possible""" if value is None: - return None + return try: return int(value) except ValueError: @@ -509,7 +529,9 @@ class Header: Constructor = namedtuple(attribute, fields) if len(values) < len(fields): values += [None] * (len(fields) - len(values)) - self._data[attribute] = Constructor(**{f: _to_num(v) for (f, v) in zip(fields, values)}) + self._data[attribute] = Constructor( + **{f: _to_num(v) + for (f, v) in zip(fields, values)}) for attribute, value in self._data.items(): setattr(self, attribute, value) @@ -528,15 +550,19 @@ class BranchElement: self._mapper = mapper self._index = index self._keymap = None - self._branch = tree[mapper.key] self._initialise_keys() def _initialise_keys(self): """Create the keymap and instance attributes""" - keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(self._mapper.exclude) - EXCLUDE_KEYS - self._keymap = {**{self._mapper.attrparser(k): k for k in keys}, **self._mapper.extra} + keys = set(k.decode('utf-8') for k in self._branch.keys()) - set( + self._mapper.exclude) - EXCLUDE_KEYS + self._keymap = { + **{self._mapper.attrparser(k): k + for k in keys}, + **self._mapper.extra + } self._keymap.update(self._mapper.update) for k in self._mapper.update.values(): del self._keymap[k] @@ -556,21 +582,24 @@ class BranchElement: return self.__class__(self._tree, self._mapper, index=item) if isinstance(item, int): return { - key: self._branch[self._keymap[key]].array()[self._index, item] for key in self.keys() - } + key: self._branch[self._keymap[key]].array()[self._index, item] + for key in self.keys() + } return self._branch[self._keymap[item]].lazyarray( - basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE))[self._index] + basketcache=uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE))[ + self._index] def __len__(self): if self._index == slice(None): return len(self._branch) else: - return len(self._branch[self._keymap['id']].lazyarray()[self._index]) + return len( + self._branch[self._keymap['id']].lazyarray()[self._index]) def __str__(self): return "Number of elements: {}".format(len(self._branch)) def __repr__(self): - return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, self._mapper.name, - len(self)) + return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, + self._mapper.name, + len(self)) diff --git a/tests/test_offline.py b/tests/test_offline.py index ce9451ec917bae17dc5b14c9b727276d8aca4372..bc8b1c8dd52be1a7bc089af9956f680c7a6a42b0 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -143,8 +143,14 @@ class TestOfflineEvents(unittest.TestCase): self.det_id = [44] * self.n_events self.n_hits = [176, 125, 318, 157, 83, 60, 71, 84, 255, 105] self.n_tracks = [56, 55, 56, 56, 56, 56, 56, 56, 54, 56] - self.t_sec = [1567036818, 1567036818, 1567036820, 1567036816, 1567036816, 1567036816, 1567036822, 1567036818, 1567036818, 1567036820] - self.t_ns = [200000000, 300000000, 200000000, 500000000, 500000000, 500000000, 200000000, 500000000, 500000000, 400000000] + self.t_sec = [ + 1567036818, 1567036818, 1567036820, 1567036816, 1567036816, + 1567036816, 1567036822, 1567036818, 1567036818, 1567036820 + ] + self.t_ns = [ + 200000000, 300000000, 200000000, 500000000, 500000000, 500000000, + 200000000, 500000000, 500000000, 400000000 + ] def test_len(self): assert self.n_events == len(self.events) @@ -188,19 +194,30 @@ class TestOfflineHits(unittest.TestCase): self.hits = OfflineReader(OFFLINE_FILE).hits self.n_hits = 10 self.dom_id = { - 0: [806451572, 806451572, 806451572, 806451572, 806455814, 806455814, 806455814, 806483369, 806483369, 806483369], - 5: [806455814, 806487219, 806487219, 806487219, 806487226, 808432835, 808432835, 808432835, 808432835, 808432835] + 0: [ + 806451572, 806451572, 806451572, 806451572, 806455814, + 806455814, 806455814, 806483369, 806483369, 806483369 + ], + 5: [ + 806455814, 806487219, 806487219, 806487219, 806487226, + 808432835, 808432835, 808432835, 808432835, 808432835 + ] } self.t = { - 0: [70104010., 70104016., 70104192., 70104123., 70103096., 70103797., 70103796., 70104191., 70104223., 70104181.], - 5: [81861237., 81859608., 81860586., 81861062., 81860357., 81860627., 81860628., 81860625., 81860627., 81860629.] + 0: [ + 70104010., 70104016., 70104192., 70104123., 70103096., + 70103797., 70103796., 70104191., 70104223., 70104181. + ], + 5: [ + 81861237., 81859608., 81860586., 81861062., 81860357., + 81860627., 81860628., 81860625., 81860627., 81860629. + ] } def test_attributes_available(self): for key in self.hits._keymap.keys(): getattr(self.hits, key) - def test_channel_ids(self): self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min())) self.assertTrue(all(c < 31 for c in self.hits.channel_id.max())) @@ -213,7 +230,8 @@ class TestOfflineHits(unittest.TestCase): def test_attributes(self): for idx, dom_id in self.dom_id.items(): - self.assertListEqual(dom_id, list(self.hits.dom_id[idx][:len(dom_id)])) + self.assertListEqual(dom_id, + list(self.hits.dom_id[idx][:len(dom_id)])) for idx, t in self.t.items(): assert np.allclose(t, self.hits.t[idx][:len(t)])