diff --git a/km3io/offline.py b/km3io/offline.py index 0e986f07a2b7555c1d3ed37a26dd5a1543e439c8..72d0a8e13f15997f5774006059909b0fc777fcc2 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -21,6 +21,17 @@ def _nested_mapper(key): EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"]) +EVENTS_MAP = BranchMapper("events", "Evt", { + 't_sec': 't.fSec', + 't_ns': 't.fNanoSec' + }, [], { + 'n_hits': 'hits', + 'n_mc_hits': 'mc_hits', + 'n_tracks': 'trks', + 'n_mc_tracks': 'mc_trks' + }, lambda a: a, True) + + SUBBRANCH_MAPS = [ BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {}, _nested_mapper, False), @@ -42,16 +53,6 @@ SUBBRANCH_MAPS = [ }, lambda a: a, True), ] -EVENTS_MAP = BranchMapper("events", "Evt", { - 't_sec': 't.fSec', - 't_ns': 't.fNanoSec' - }, [], { - 'n_hits': 'hits', - 'n_mc_hits': 'mc_hits', - 'n_tracks': 'trks', - 'n_mc_tracks': 'mc_trks' - }, lambda a: a, True) - class cached_property: """A simple cache decorator for properties.""" def __init__(self, function): @@ -90,7 +91,7 @@ class OfflineReader: @cached_property def events(self): - return Branch(self._tree, mapper=EVENTS_MAP, index=self._index, subbranches=SUBBRANCH_MAPS) + return Branch(self._tree, mapper=EVENTS_MAP, index=self._index, subbranchmaps=SUBBRANCH_MAPS) @classmethod def from_index(cls, source, index): @@ -563,24 +564,27 @@ class Header: class Branch: """Branch accessor class""" - def __init__(self, tree, mapper, index=None, subbranches=[]): + def __init__(self, tree, mapper, index=None, subbranches=None, subbranchmaps=None): self._tree = tree self._mapper = mapper self._index = index self._keymap = None self._branch = tree[mapper.key] - self._subbranches = subbranches - self._subbranch_keys = [] + self._subbranches = [] - self._initialise_keys() + self._initialise_keys() # - for mapper in subbranches: - setattr(self, mapper.name, - Branch(self._tree, mapper=mapper, index=self._index)) - self._subbranch_keys.append(mapper.name) + if subbranches is not None: + self._subbranches = subbranches + if subbranchmaps is not None: + for mapper in subbranchmaps: + subbranch = Branch(self._tree, mapper=mapper, index=self._index) + self._subbranches.append(subbranch) + for subbranch in self._subbranches: + setattr(self, subbranch._mapper.name, subbranch) def _initialise_keys(self): - """Create the keymap and instance attributes""" + """Create the keymap and instance attributes for branch keys""" keys = set(k.decode('utf-8') for k in self._branch.keys()) - set( self._mapper.exclude) - EXCLUDE_KEYS self._keymap = { @@ -592,8 +596,6 @@ class Branch: for k in self._mapper.update.values(): del self._keymap[k] - # self._EntryType = namedtuple(mapper.name[:-1], self.keys()) - for key in self._keymap.keys(): # print("setting", self._mapper.name, key) setattr(self, key, self[key]) @@ -623,6 +625,8 @@ class Branch: self._branch[self._keymap[key]].array()[self._index] for key in self.keys() } + for subbranch in self._subbranches: + dct[subbranch._mapper.name] = subbranch return BranchElement(self._mapper.name, dct)[item] else: if self._index is None: @@ -637,6 +641,8 @@ class Branch: item] for key in self.keys() } + for subbranch in self._subbranches: + dct[subbranch._mapper.name] = subbranch return BranchElement(self._mapper.name, dct) if isinstance(item, tuple): diff --git a/tests/test_offline.py b/tests/test_offline.py index 6eadc00fa9fe05c6e66888886aa4541bd677a2a2..4d409f44644cb7ab16d112c2692360c4ce57f69d 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -188,9 +188,8 @@ class TestOfflineEvents(unittest.TestCase): self.events.n_hits[s]) assert np.allclose(self.events[s].n_hits, self.events.n_hits[s]) - @unittest.skip def test_index_consistency(self): - for i in range(self.n_events): + for i in [0,2,5]: assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) assert np.allclose(OFFLINE_FILE[i].events.n_hits, self.events.n_hits[i]) @@ -257,7 +256,6 @@ class TestOfflineHits(unittest.TestCase): for idx, t in self.t.items(): self.assertListEqual(t[s], list(self.hits.t[idx][s])) - @unittest.skip def test_slicing_consistency(self): for s in [slice(1, 3), slice(2, 7, 3)]: for idx in range(3): @@ -266,12 +264,13 @@ class TestOfflineHits(unittest.TestCase): assert np.allclose(OFFLINE_FILE.events[idx].hits.dom_id[s], self.hits.dom_id[idx][s]) - @unittest.skip def test_index_consistency(self): - for i in range(self.n_events): - assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) - assert np.allclose(OFFLINE_FILE[i].events.n_hits, - self.events.n_hits[i]) + for idx, dom_ids in self.dom_id.items(): + assert np.allclose(self.hits[idx].dom_id[:self.n_hits], dom_ids[:self.n_hits]) + assert np.allclose(OFFLINE_FILE.events[idx].hits.dom_id[:self.n_hits], dom_ids[:self.n_hits]) + for idx, ts in self.t.items(): + assert np.allclose(self.hits[idx].t[:self.n_hits], ts[:self.n_hits]) + assert np.allclose(OFFLINE_FILE.events[idx].hits.t[:self.n_hits], ts[:self.n_hits]) class TestOfflineTracks(unittest.TestCase):