Skip to content
Snippets Groups Projects
Commit 8345901c authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Black

parent 5b6112b0
No related branches found
No related tags found
1 merge request!39WIP: Resolve "uproot4 integration"
Pipeline #15931 failed
......@@ -13,7 +13,7 @@ MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot.cache.LRUArrayCache(BASKET_CACHE_SIZE)
......@@ -46,7 +46,7 @@ class Usr:
)
return
self._usr_names = self._branch[self._usr_key + '_names'].array()[0]
self._usr_names = self._branch[self._usr_key + "_names"].array()[0]
self._usr_idx_lookup = {
name: index for index, name in enumerate(self._usr_names)
}
......@@ -86,37 +86,42 @@ class OfflineReader:
"""reader for offline ROOT files"""
event_path = "E/Evt"
skip_keys = ['mc_trks', 'trks', 't', 'AAObject']
item_name = "OfflineEvent"
skip_keys = ["mc_trks", "trks", "t", "AAObject"]
aliases = {"t_s": "t.fSec", "t_ns": "t.fNanoSec"}
special_keys = {
'hits': {
'channel_id': 'hits.channel_id',
'dom_id': 'hits.dom_id',
'time': 'hits.t',
'tot': 'hits.tot',
'triggered': 'hits.trig'
"hits": {
"channel_id": "hits.channel_id",
"dom_id": "hits.dom_id",
"time": "hits.t",
"tot": "hits.tot",
"triggered": "hits.trig",
},
'mc_hits': {
'pmt_id': 'mc_hits.pmt_id',
'time': 'mc_hits.t',
'a': 'mc_hits.a',
"mc_hits": {
"pmt_id": "mc_hits.pmt_id",
"time": "mc_hits.t",
"a": "mc_hits.a",
},
'trks': {
'dir_x': 'trks.dir.x',
'dir_y': 'trks.dir.y',
'dir_z': 'trks.dir.z',
'rec_stages': 'trks.rec_stages',
'fitinf': 'trks.fitinf'
"trks": {
"dir_x": "trks.dir.x",
"dir_y": "trks.dir.y",
"dir_z": "trks.dir.z",
"rec_stages": "trks.rec_stages",
"fitinf": "trks.fitinf",
},
'mc_trks': {
'dir_x': 'mc_trks.dir.x',
'dir_y': 'mc_trks.dir.y',
'dir_z': 'mc_trks.dir.z',
"mc_trks": {
"dir_x": "mc_trks.dir.x",
"dir_y": "mc_trks.dir.y",
"dir_z": "mc_trks.dir.z",
},
}
# TODO: this is fishy
special_aliases = {'trks': 'tracks', 'hits': "hits", "mc_hits": "mc_hits", "mc_trks": "mc_tracks"}
special_aliases = {
"trks": "tracks",
"hits": "hits",
"mc_hits": "mc_hits",
"mc_trks": "mc_tracks",
}
def __init__(self, file_path, step_size=2000):
"""OfflineReader class is an offline ROOT file wrapper
......@@ -138,7 +143,14 @@ class OfflineReader:
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._subbranches = None
self._event_ctor = namedtuple("OfflineEvent", set(list(self.keys()) + list(self.aliases.keys()) + list(self.special_aliases[k] for k in self.special_keys)))
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases.keys())
+ list(self.special_aliases[k] for k in self.special_keys)
),
)
def keys(self):
if self._subbranches is None:
......@@ -168,25 +180,28 @@ class OfflineReader:
def _event_generator(self):
events = self._fobj[self.event_path]
keys = list(set(self.keys()) - set(self.special_keys.keys())) + list(self.aliases.keys())
events_it = events.iterate(
keys,
aliases=self.aliases,
step_size=self.step_size)
keys = list(set(self.keys()) - set(self.special_keys.keys())) + list(
self.aliases.keys()
)
events_it = events.iterate(keys, aliases=self.aliases, step_size=self.step_size)
specials = []
special_keys = self.special_keys.keys() # dict-key ordering is an implementation detail
special_keys = (
self.special_keys.keys()
) # dict-key ordering is an implementation detail
for key in special_keys:
specials.append(
events[key].iterate(
self.special_keys[key].keys(),
aliases=self.special_keys[key],
step_size=self.step_size
step_size=self.step_size,
)
)
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
yield self._event_ctor(**{k: _event[k] for k in keys},
**{k: i for (k, i) in zip(special_keys, special_items)})
yield self._event_ctor(
**{k: _event[k] for k in keys},
**{k: i for (k, i) in zip(special_keys, special_items)}
)
def __next__(self):
return next(self._events)
......@@ -211,7 +226,7 @@ class OfflineReader:
def header(self):
"""The file header"""
if "Head" in self._fobj:
return Header(self._fobj['Head'].tojson()['map<string,string>'])
return Header(self._fobj["Head"].tojson()["map<string,string>"])
else:
warnings.warn("Your file header has an unsupported format")
......
......@@ -5,9 +5,9 @@ import numpy as np
import numba as nb
TIMESLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte]
SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte]
BASKET_CACHE_SIZE = 110 * 1024**2
TIMESLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024 ** 2 # [byte]
SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024 ** 2 # [byte]
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot.cache.LRUArrayCache(BASKET_CACHE_SIZE)
# Parameters for PMT rate conversions, since the rates in summary slices are
......@@ -215,23 +215,35 @@ class SummarySlices:
def _read_summaryslices(self):
"""Reads the summary slices"""
tree = self._fobj[b'KM3NET_SUMMARYSLICE'][b'KM3NET_SUMMARYSLICE']
return tree[b'vector<KM3NETDAQ::JDAQSummaryFrame>'].array(
uproot.asjagged(uproot.astable(
uproot.asdtype([("dom_id", "i4"), ("dq_status", "u4"),
("hrv", "u4"), ("fifo", "u4"),
("status3", "u4"), ("status4", "u4")] +
[(c, "u1") for c in self._ch_selector])),
skipbytes=10),
tree = self._fobj[b"KM3NET_SUMMARYSLICE"][b"KM3NET_SUMMARYSLICE"]
return tree[b"vector<KM3NETDAQ::JDAQSummaryFrame>"].array(
uproot.asjagged(
uproot.astable(
uproot.asdtype(
[
("dom_id", "i4"),
("dq_status", "u4"),
("hrv", "u4"),
("fifo", "u4"),
("status3", "u4"),
("status4", "u4"),
]
+ [(c, "u1") for c in self._ch_selector]
)
),
skipbytes=10,
),
basketcache=uproot.cache.LRUArrayCache(
SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE))
SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE
),
)
def _read_headers(self):
"""Reads the summary slice headers"""
tree = self._fobj[b'KM3NET_SUMMARYSLICE'][b'KM3NET_SUMMARYSLICE']
return tree[b'KM3NETDAQ::JDAQSummarysliceHeader'].array(
uproot.interpret(tree[b'KM3NETDAQ::JDAQSummarysliceHeader'],
cntvers=True))
tree = self._fobj[b"KM3NET_SUMMARYSLICE"][b"KM3NET_SUMMARYSLICE"]
return tree[b"KM3NETDAQ::JDAQSummarysliceHeader"].array(
uproot.interpret(tree[b"KM3NETDAQ::JDAQSummarysliceHeader"], cntvers=True)
)
def __str__(self):
return "Number of summaryslices: {}".format(len(self.headers))
......@@ -264,15 +276,25 @@ class Timeslices:
superframes = tree[b"vector<KM3NETDAQ::JDAQSuperFrame>"]
hits_dtype = np.dtype([("pmt", "u1"), ("tdc", "<u4"), ("tot", "u1")])
hits_buffer = superframes[
b'vector<KM3NETDAQ::JDAQSuperFrame>.buffer'].array(
uproot.asjagged(uproot.astable(uproot.asdtype(hits_dtype)),
skipbytes=6),
basketcache=uproot.cache.LRUArrayCache(
TIMESLICE_FRAME_BASKET_CACHE_SIZE))
self._timeslices[stream.decode("ascii")] = (headers, superframes,
hits_buffer)
setattr(self, stream.decode("ascii"),
TimesliceStream(headers, superframes, hits_buffer))
b"vector<KM3NETDAQ::JDAQSuperFrame>.buffer"
].array(
uproot.asjagged(
uproot.astable(uproot.asdtype(hits_dtype)), skipbytes=6
),
basketcache=uproot.cache.LRUArrayCache(
TIMESLICE_FRAME_BASKET_CACHE_SIZE
),
)
self._timeslices[stream.decode("ascii")] = (
headers,
superframes,
hits_buffer,
)
setattr(
self,
stream.decode("ascii"),
TimesliceStream(headers, superframes, hits_buffer),
)
def stream(self, stream, idx):
ts = self._timeslices[stream]
......@@ -335,12 +357,12 @@ class Timeslice:
"""Populate a dictionary of frames with the module ID as key"""
hits_buffer = self._hits_buffer[self._idx]
n_hits = self._superframe[
b'vector<KM3NETDAQ::JDAQSuperFrame>.numberOfHits'].array(
basketcache=BASKET_CACHE)[self._idx]
b"vector<KM3NETDAQ::JDAQSuperFrame>.numberOfHits"
].array(basketcache=BASKET_CACHE)[self._idx]
try:
module_ids = self._superframe[
b'vector<KM3NETDAQ::JDAQSuperFrame>.id'].array(
basketcache=BASKET_CACHE)[self._idx]
b"vector<KM3NETDAQ::JDAQSuperFrame>.id"
].array(basketcache=BASKET_CACHE)[self._idx]
except KeyError:
raise
# module_ids = self._superframe[
......@@ -361,8 +383,10 @@ class Timeslice:
def __len__(self):
if self._n_frames is None:
self._n_frames = len(
self._superframe[b'vector<KM3NETDAQ::JDAQSuperFrame>.id'].
array(basketcache=BASKET_CACHE)[self._idx])
self._superframe[b"vector<KM3NETDAQ::JDAQSuperFrame>.id"].array(
basketcache=BASKET_CACHE
)[self._idx]
)
return self._n_frames
def __str__(self):
......
......@@ -6,7 +6,7 @@ import uproot4 as uproot
from .tools import unfold_indices
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot.cache.LRUArrayCache(BASKET_CACHE_SIZE)
......@@ -148,8 +148,7 @@ class Branch:
# 6,
# )
#
out = self._branch[self._keymap[key]].array(
interpretation=interpretation)
out = self._branch[self._keymap[key]].array(interpretation=interpretation)
# if self._index_chain is not None and key in self._mapper.toawkward:
# cache_key = self._mapper.name + "/" + key
# if cache_key not in self._awkward_cache:
......@@ -192,7 +191,7 @@ class Branch:
else:
return len(
unfold_indices(
self._branch[self._keymap['id']].array(), self._index_chain
self._branch[self._keymap["id"]].array(), self._index_chain
)
)
......
......@@ -11,7 +11,7 @@ from km3io.definitions import w2list_genhen as kw2gen
from km3io.definitions import w2list_gseagen as kw2gsg
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot.cache.LRUArrayCache(BASKET_CACHE_SIZE)
......
......@@ -34,7 +34,7 @@ class TestOfflineReader(unittest.TestCase):
assert self.n_events == len(self.r.events)
def test_uuid(self):
assert str(self.r.uuid) == 'b192d888-fcc7-11e9-b430-6cf09e86beef'
assert str(self.r.uuid) == "b192d888-fcc7-11e9-b430-6cf09e86beef"
class TestHeader(unittest.TestCase):
......@@ -183,22 +183,32 @@ class TestOfflineEvents(unittest.TestCase):
@unittest.skip
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
assert np.allclose(self.events[s].n_hits.tolist(), self.events.n_hits[s].tolist())
assert np.allclose(
self.events[s].n_hits.tolist(), self.events.n_hits[s].tolist()
)
@unittest.skip
def test_index_consistency(self):
for i in [0, 2, 5]:
assert np.allclose(self.events[i].n_hits.tolist(), self.events.n_hits[i].tolist())
assert np.allclose(
self.events[i].n_hits.tolist(), self.events.n_hits[i].tolist()
)
@unittest.skip
def test_index_chaining(self):
assert np.allclose(self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist())
assert np.allclose(self.events[3:5][0].n_hits.tolist(), self.events.n_hits[3:5][0].tolist())
assert np.allclose(
self.events[3:5].hits[1].dom_id[4].tolist(), self.events.hits[3:5][1][4].dom_id.tolist()
self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist()
)
assert np.allclose(
self.events[3:5][0].n_hits.tolist(), self.events.n_hits[3:5][0].tolist()
)
assert np.allclose(
self.events.hits[3:5][1][4].dom_id.tolist(), self.events[3:5][1][4].hits.dom_id.tolist()
self.events[3:5].hits[1].dom_id[4].tolist(),
self.events.hits[3:5][1][4].dom_id.tolist(),
)
assert np.allclose(
self.events.hits[3:5][1][4].dom_id.tolist(),
self.events[3:5][1][4].hits.dom_id.tolist(),
)
@unittest.skip
......@@ -316,9 +326,12 @@ class TestOfflineHits(unittest.TestCase):
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
for idx in range(3):
assert np.allclose(self.hits.dom_id[idx][s].tolist(), self.hits[idx].dom_id[s].tolist())
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[s].tolist(), self.hits.dom_id[idx][s].tolist()
self.hits.dom_id[idx][s].tolist(), self.hits[idx].dom_id[s].tolist()
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[s].tolist(),
self.hits.dom_id[idx][s].tolist(),
)
def test_index_consistency(self):
......@@ -331,9 +344,12 @@ class TestOfflineHits(unittest.TestCase):
dom_ids[: self.n_hits].tolist(),
)
for idx, ts in self.t.items():
assert np.allclose(self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits].tolist())
assert np.allclose(
OFFLINE_FILE.events[idx].hits.t[: self.n_hits].tolist(), ts[: self.n_hits].tolist()
self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits].tolist()
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.t[: self.n_hits].tolist(),
ts[: self.n_hits].tolist(),
)
def test_keys(self):
......@@ -414,7 +430,8 @@ class TestBranchIndexingMagic(unittest.TestCase):
self.events[3].tracks.dir_z[10], self.events.tracks.dir_z[3, 10]
)
assert np.allclose(
self.events[3:6].tracks.pos_y[:, 0].tolist(), self.events.tracks.pos_y[3:6, 0].tolist()
self.events[3:6].tracks.pos_y[:, 0].tolist(),
self.events.tracks.pos_y[3:6, 0].tolist(),
)
# test selecting with a list
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment