diff --git a/km3io/offline.py b/km3io/offline.py index c473a094eddd27c2c149991fa5121b58a458355d..16b94fc43d9cf0e9196596712dedde4bb16bfab5 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -1,5 +1,5 @@ import binascii -from collections import namedtuple +from collections import namedtuple, defaultdict import uproot4 as uproot import warnings import numba as nb @@ -17,87 +17,6 @@ BASKET_CACHE_SIZE = 110 * 1024**2 BASKET_CACHE = uproot.cache.LRUArrayCache(BASKET_CACHE_SIZE) -def _nested_mapper(key): - """Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)""" - return "_".join(key.split(".")[1:]) - - -EVENTS_MAP = BranchMapper( - name="events", - key="Evt", - extra={"t_sec": "t.fSec", "t_ns": "t.fNanoSec"}, - exclude=EXCLUDE_KEYS, -) - -SUBBRANCH_MAPS = [ - BranchMapper( - name="tracks", - key="trks", - extra={}, - exclude=EXCLUDE_KEYS - + ["trks.usr_data", "trks.usr", "trks.fUniqueID", "trks.fBits", - "trks.usr_names" # TODO: this we might need! uproot4 chokes on empty ones - ], - attrparser=_nested_mapper, - flat=False, - ), - BranchMapper( - name="mc_tracks", - key="mc_trks", - exclude=EXCLUDE_KEYS - + [ - "mc_trks.rec_stages", - "mc_trks.fitinf", - "mc_trks.fUniqueID", - "mc_trks.fBits", - "mc_trks.comment", - "mc_trks" - ], - attrparser=_nested_mapper, - flat=False, - ), - BranchMapper( - name="hits", - key="hits", - exclude=EXCLUDE_KEYS - + [ - "hits.usr", - "hits.pmt_id", - "hits.origin", - "hits.a", - "hits.pure_a", - "hits.fUniqueID", - "hits.fBits", - ], - attrparser=_nested_mapper, - flat=False, - ), - BranchMapper( - name="mc_hits", - key="mc_hits", - exclude=EXCLUDE_KEYS - + [ - "mc_hits.usr", - "mc_hits.dom_id", - "mc_hits.channel_id", - "mc_hits.tdc", - "mc_hits.tot", - "mc_hits.trig", - "mc_hits.fUniqueID", - "mc_hits.fBits", - ], - attrparser=_nested_mapper, - flat=False, - ), -] - - -class OfflineBranch(Branch): - @cached_property - def usr(self): - return Usr(self._mapper, self._branch, index_chain=self._index_chain) - - class Usr: """Helper class to access AAObject `usr` stuff (only for events.usr)""" @@ -166,7 +85,40 @@ class Usr: class OfflineReader: """reader for offline ROOT files""" - def __init__(self, file_path=None): + event_path = "E/Evt" + skip_keys = ['mc_trks', 'trks', 't', 'AAObject'] + aliases = {"t_s": "t.fSec", "t_ns": "t.fNanoSec"} + special_keys = { + 'hits': { + 'channel_id': 'hits.channel_id', + 'dom_id': 'hits.dom_id', + 'time': 'hits.t', + 'tot': 'hits.tot', + 'triggered': 'hits.trig' + }, + 'mc_hits': { + 'pmt_id': 'mc_hits.pmt_id', + 'time': 'mc_hits.t', + 'a': 'mc_hits.a', + }, + 'trks': { + 'dir_x': 'trks.dir.x', + 'dir_y': 'trks.dir.y', + 'dir_z': 'trks.dir.z', + 'rec_stages': 'trks.rec_stages', + 'fitinf': 'trks.fitinf' + }, + 'mc_trks': { + 'dir_x': 'mc_trks.dir.x', + 'dir_y': 'mc_trks.dir.y', + 'dir_z': 'mc_trks.dir.z', + }, + + } + # TODO: this is fishy + special_aliases = {'trks': 'tracks', 'hits': "hits", "mc_hits": "mc_hits", "mc_trks": "mc_tracks"} + + def __init__(self, file_path, step_size=2000): """OfflineReader class is an offline ROOT file wrapper Parameters @@ -174,12 +126,73 @@ class OfflineReader: file_path : path-like object Path to the file of interest. It can be a str or any python path-like object that points to the file. + step_size: int, optional + Number of events to read into the cache when iterating. + Choosing higher numbers may improve the speed but also increases + the memory overhead. """ self._fobj = uproot.open(file_path) + self.step_size = step_size self._filename = file_path - self._tree = self._fobj[MAIN_TREE_NAME] self._uuid = self._fobj._file.uuid + self._iterator_index = 0 + self._subbranches = None + self._event_ctor = namedtuple("OfflineEvent", set(list(self.keys()) + list(self.aliases.keys()) + list(self.special_aliases[k] for k in self.special_keys))) + + def keys(self): + if self._subbranches is None: + subbranches = defaultdict(list) + for key in self._fobj[self.event_path].keys(): + toplevel, *remaining = key.split("/") + if remaining: + subbranches[toplevel].append("/".join(remaining)) + else: + subbranches[toplevel] = [] + for key in self.skip_keys: + del subbranches[key] + self._subbranches = subbranches + return self._subbranches.keys() + + @property + def events(self): + return iter(self) + + def __getitem__(self, key): + return self._fobj[self.event_path][key].array() + + def __iter__(self): + self._iterator_index = 0 + self._events = self._event_generator() + return self + + def _event_generator(self): + events = self._fobj[self.event_path] + keys = list(set(self.keys()) - set(self.special_keys.keys())) + list(self.aliases.keys()) + events_it = events.iterate( + keys, + aliases=self.aliases, + step_size=self.step_size) + specials = [] + special_keys = self.special_keys.keys() # dict-key ordering is an implementation detail + for key in special_keys: + specials.append( + events[key].iterate( + self.special_keys[key].keys(), + aliases=self.special_keys[key], + step_size=self.step_size + ) + ) + for event_set, *special_sets in zip(events_it, *specials): + for _event, *special_items in zip(event_set, *special_sets): + yield self._event_ctor(**{k: _event[k] for k in keys}, + **{k: i for (k, i) in zip(special_keys, special_items)}) + + def __next__(self): + return next(self._events) + + def __len__(self): + return self._fobj[self.event_path].num_entries @property def uuid(self): @@ -194,13 +207,6 @@ class OfflineReader: def __exit__(self, *args): self.close() - @cached_property - def events(self): - """The `E` branch, containing all offline events.""" - return OfflineBranch( - self._tree, mapper=EVENTS_MAP, subbranchmaps=SUBBRANCH_MAPS - ) - @cached_property def header(self): """The file header""" diff --git a/tests/test_offline.py b/tests/test_offline.py index 00c6062506a4ff57a5bd947526fa3065804aead0..ded53c33df5e13bee4b710e0ef239fb71ae1ccf9 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -7,7 +7,7 @@ import awkward1 as ak from km3net_testdata import data_path from km3io import OfflineReader -from km3io.offline import _nested_mapper, Header +from km3io.offline import Header OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root")) @@ -498,8 +498,3 @@ class TestMcTrackUsr(unittest.TestCase): assert np.allclose( [0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001 ) - - -class TestNestedMapper(unittest.TestCase): - def test_nested_mapper(self): - self.assertEqual("pos_x", _nested_mapper("trks.pos.x"))