diff --git a/km3io/offline.py b/km3io/offline.py index 993b87931b1938913016551a94b1cc5c2381bcfc..0f49224be3dae1b405589c40755ab5827841a412 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -13,22 +13,22 @@ class OfflineReader: item_name = "OfflineEvent" skip_keys = ["t", "AAObject"] aliases = { - "t_s": "t.fSec", + "t_sec": "t.fSec", "t_ns": "t.fNanoSec", - "usr": "AAObject/usr", - "usr_names": "AAObject/usr_names", } special_branches = { "hits": { + "id": "hits.id", "channel_id": "hits.channel_id", "dom_id": "hits.dom_id", - "time": "hits.t", + "t": "hits.t", "tot": "hits.tot", - "triggered": "hits.trig", # non-zero if the hit is a triggered hit + "trig": "hits.trig", # non-zero if the hit is a triggered hit }, "mc_hits": { + "id": "mc_hits.id", "pmt_id": "mc_hits.pmt_id", - "time": "mc_hits.t", # hit time (MC truth) + "t": "mc_hits.t", # hit time (MC truth) "a": "mc_hits.a", # hit amplitude (in p.e.) "origin": "mc_hits.origin", # track id of the track that created this hit "pure_t": "mc_hits.pure_t", # photon time before pmt simultion @@ -92,6 +92,13 @@ class OfflineReader: self._keys = None self._grouped_counts = {} # TODO: e.g. {"events": [3, 66, 34]} + if "E/Evt/AAObject/usr" in self._fobj: + if ak.count(f["E/Evt/AAObject/usr"].array()) > 0: + self.aliases.update({ + "usr": "AAObject/usr", + "usr_names": "AAObject/usr_names", + }) + self._initialise_keys() self._event_ctor = namedtuple( @@ -184,10 +191,11 @@ class OfflineReader: group_counts[key] = iter(self[key]) for event_set, *special_sets in zip(events_it, *specials): for _event, *special_items in zip(event_set, *special_sets): - data = { - **{k: _event[k] for k in keys}, - **{k: i for (k, i) in zip(special_keys, special_items)}, - } + data = {} + for k in keys: + data[k] = _event[k] + for (k, i) in zip(special_keys, special_items): + data[k] = i for tokey, fromkey in self.special_aliases.items(): data[tokey] = data[fromkey] for key in group_counts: diff --git a/tests/test_offline.py b/tests/test_offline.py index 8407705c87003d2fb5c1cff6b3e503446d026c44..a39796064e7624e74a5db0f35640ce5d66997add 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -160,8 +160,8 @@ class TestOfflineEvents(unittest.TestCase): self.assertListEqual(self.det_id, list(self.events.det_id)) print(self.n_hits) print(self.events.hits) - self.assertListEqual(self.n_hits, len(self.events.hits)) - self.assertListEqual(self.n_tracks, len(self.events.tracks)) + self.assertListEqual(self.n_hits, list(self.events.n_hits)) + self.assertListEqual(self.n_tracks, list(self.events.n_tracks)) self.assertListEqual(self.t_sec, list(self.events.t_sec)) self.assertListEqual(self.t_ns, list(self.events.t_ns)) @@ -296,13 +296,14 @@ class TestOfflineHits(unittest.TestCase): ], } - def test_attributes_available(self): - for key in self.hits._keymap.keys(): + def test_fields_work_as_keys_and_attributes(self): + for key in self.hits.fields: getattr(self.hits, key) + self.hits[key] def test_channel_ids(self): - self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min())) - self.assertTrue(all(c < 31 for c in self.hits.channel_id.max())) + self.assertTrue(all(c >= 0 for c in ak.min(self.hits.channel_id, axis=1))) + self.assertTrue(all(c < 31 for c in ak.max(self.hits.channel_id, axis=1))) def test_str(self): assert str(self.n_hits) in str(self.hits)