Skip to content
Snippets Groups Projects
Commit 3c7e71b3 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Make yapf

parent 4606ed32
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
......@@ -9,8 +9,9 @@ MAIN_TREE_NAME = "E"
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BranchMapper = namedtuple("BranchMapper", ['name', 'key', 'extra', 'exclude', 'update', 'attrparser'])
BranchMapper = namedtuple(
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser'])
def _nested_mapper(key):
......@@ -20,11 +21,22 @@ def _nested_mapper(key):
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
BRANCH_MAPS = [
BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {}, _nested_mapper),
BranchMapper("mc_tracks", "mc_trks", {}, ['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper),
BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {},
_nested_mapper),
BranchMapper("mc_tracks", "mc_trks", {},
['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper),
BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper),
BranchMapper("mc_hits", "mc_hits", {}, ['mc_hits.usr'], {}, _nested_mapper),
BranchMapper("events", "Evt", {'t_sec': 't.fSec', 't_ns': 't.fNanoSec'}, [], {'n_hits': 'hits', 'n_mc_hits': 'mc_hits', 'n_tracks': 'trks', 'n_mc_tracks': 'mc_trks'}, lambda a: a),
BranchMapper("mc_hits", "mc_hits", {}, ['mc_hits.usr'], {},
_nested_mapper),
BranchMapper("events", "Evt", {
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
}, [], {
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
}, lambda a: a),
]
......@@ -42,7 +54,11 @@ class cached_property:
class OfflineReader:
"""reader for offline ROOT files"""
def __init__(self, file_path=None, fobj=None, data=None, index=slice(None)):
def __init__(self,
file_path=None,
fobj=None,
data=None,
index=slice(None)):
""" OfflineReader class is an offline ROOT file wrapper
Parameters
......@@ -65,7 +81,9 @@ class OfflineReader:
self._data = data
for mapper in BRANCH_MAPS:
setattr(self, mapper.name, BranchElement(self._tree, mapper=mapper, index=self._index))
setattr(
self, mapper.name,
BranchElement(self._tree, mapper=mapper, index=self._index))
@classmethod
def from_index(cls, source, index):
......@@ -78,7 +96,9 @@ class OfflineReader:
index: index or slice
The index or slice to create the subtree.
"""
instance = cls(fobj=source._fobj, data=source._data[index], index=index)
instance = cls(fobj=source._fobj,
data=source._data[index],
index=index)
return instance
def __getitem__(self, index):
......@@ -89,8 +109,8 @@ class OfflineReader:
if self._index == slice(None):
return len(tree)
else:
return len(tree.lazyarrays(
basketcache=uproot.cache.ThreadSafeArrayCache(
return len(
tree.lazyarrays(basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self.index])
@cached_property
......@@ -484,9 +504,9 @@ class Usr:
def _to_num(value):
"""Convert value to a numerical value if possible"""
"""Convert a value to a numerical one if possible"""
if value is None:
return None
return
try:
return int(value)
except ValueError:
......@@ -509,7 +529,9 @@ class Header:
Constructor = namedtuple(attribute, fields)
if len(values) < len(fields):
values += [None] * (len(fields) - len(values))
self._data[attribute] = Constructor(**{f: _to_num(v) for (f, v) in zip(fields, values)})
self._data[attribute] = Constructor(
**{f: _to_num(v)
for (f, v) in zip(fields, values)})
for attribute, value in self._data.items():
setattr(self, attribute, value)
......@@ -528,15 +550,19 @@ class BranchElement:
self._mapper = mapper
self._index = index
self._keymap = None
self._branch = tree[mapper.key]
self._initialise_keys()
def _initialise_keys(self):
"""Create the keymap and instance attributes"""
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {**{self._mapper.attrparser(k): k for k in keys}, **self._mapper.extra}
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {
**{self._mapper.attrparser(k): k
for k in keys},
**self._mapper.extra
}
self._keymap.update(self._mapper.update)
for k in self._mapper.update.values():
del self._keymap[k]
......@@ -556,21 +582,24 @@ class BranchElement:
return self.__class__(self._tree, self._mapper, index=item)
if isinstance(item, int):
return {
key: self._branch[self._keymap[key]].array()[self._index, item] for key in self.keys()
}
key: self._branch[self._keymap[key]].array()[self._index, item]
for key in self.keys()
}
return self._branch[self._keymap[item]].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self._index]
basketcache=uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE))[
self._index]
def __len__(self):
if self._index == slice(None):
return len(self._branch)
else:
return len(self._branch[self._keymap['id']].lazyarray()[self._index])
return len(
self._branch[self._keymap['id']].lazyarray()[self._index])
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
def __repr__(self):
return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, self._mapper.name,
len(self))
return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__,
self._mapper.name,
len(self))
......@@ -143,8 +143,14 @@ class TestOfflineEvents(unittest.TestCase):
self.det_id = [44] * self.n_events
self.n_hits = [176, 125, 318, 157, 83, 60, 71, 84, 255, 105]
self.n_tracks = [56, 55, 56, 56, 56, 56, 56, 56, 54, 56]
self.t_sec = [1567036818, 1567036818, 1567036820, 1567036816, 1567036816, 1567036816, 1567036822, 1567036818, 1567036818, 1567036820]
self.t_ns = [200000000, 300000000, 200000000, 500000000, 500000000, 500000000, 200000000, 500000000, 500000000, 400000000]
self.t_sec = [
1567036818, 1567036818, 1567036820, 1567036816, 1567036816,
1567036816, 1567036822, 1567036818, 1567036818, 1567036820
]
self.t_ns = [
200000000, 300000000, 200000000, 500000000, 500000000, 500000000,
200000000, 500000000, 500000000, 400000000
]
def test_len(self):
assert self.n_events == len(self.events)
......@@ -188,19 +194,30 @@ class TestOfflineHits(unittest.TestCase):
self.hits = OfflineReader(OFFLINE_FILE).hits
self.n_hits = 10
self.dom_id = {
0: [806451572, 806451572, 806451572, 806451572, 806455814, 806455814, 806455814, 806483369, 806483369, 806483369],
5: [806455814, 806487219, 806487219, 806487219, 806487226, 808432835, 808432835, 808432835, 808432835, 808432835]
0: [
806451572, 806451572, 806451572, 806451572, 806455814,
806455814, 806455814, 806483369, 806483369, 806483369
],
5: [
806455814, 806487219, 806487219, 806487219, 806487226,
808432835, 808432835, 808432835, 808432835, 808432835
]
}
self.t = {
0: [70104010., 70104016., 70104192., 70104123., 70103096., 70103797., 70103796., 70104191., 70104223., 70104181.],
5: [81861237., 81859608., 81860586., 81861062., 81860357., 81860627., 81860628., 81860625., 81860627., 81860629.]
0: [
70104010., 70104016., 70104192., 70104123., 70103096.,
70103797., 70103796., 70104191., 70104223., 70104181.
],
5: [
81861237., 81859608., 81860586., 81861062., 81860357.,
81860627., 81860628., 81860625., 81860627., 81860629.
]
}
def test_attributes_available(self):
for key in self.hits._keymap.keys():
getattr(self.hits, key)
def test_channel_ids(self):
self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min()))
self.assertTrue(all(c < 31 for c in self.hits.channel_id.max()))
......@@ -213,7 +230,8 @@ class TestOfflineHits(unittest.TestCase):
def test_attributes(self):
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id, list(self.hits.dom_id[idx][:len(dom_id)]))
self.assertListEqual(dom_id,
list(self.hits.dom_id[idx][:len(dom_id)]))
for idx, t in self.t.items():
assert np.allclose(t, self.hits.t[idx][:len(t)])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment