Skip to content
Snippets Groups Projects
Commit 2a3c1cdf authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Quick fixes

parent 0362b3e4
No related branches found
No related tags found
1 merge request!39WIP: Resolve "uproot4 integration"
Pipeline #16039 failed
......@@ -13,22 +13,22 @@ class OfflineReader:
item_name = "OfflineEvent"
skip_keys = ["t", "AAObject"]
aliases = {
"t_s": "t.fSec",
"t_sec": "t.fSec",
"t_ns": "t.fNanoSec",
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
}
special_branches = {
"hits": {
"id": "hits.id",
"channel_id": "hits.channel_id",
"dom_id": "hits.dom_id",
"time": "hits.t",
"t": "hits.t",
"tot": "hits.tot",
"triggered": "hits.trig", # non-zero if the hit is a triggered hit
"trig": "hits.trig", # non-zero if the hit is a triggered hit
},
"mc_hits": {
"id": "mc_hits.id",
"pmt_id": "mc_hits.pmt_id",
"time": "mc_hits.t", # hit time (MC truth)
"t": "mc_hits.t", # hit time (MC truth)
"a": "mc_hits.a", # hit amplitude (in p.e.)
"origin": "mc_hits.origin", # track id of the track that created this hit
"pure_t": "mc_hits.pure_t", # photon time before pmt simultion
......@@ -92,6 +92,13 @@ class OfflineReader:
self._keys = None
self._grouped_counts = {} # TODO: e.g. {"events": [3, 66, 34]}
if "E/Evt/AAObject/usr" in self._fobj:
if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
self.aliases.update({
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
})
self._initialise_keys()
self._event_ctor = namedtuple(
......@@ -184,10 +191,11 @@ class OfflineReader:
group_counts[key] = iter(self[key])
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {
**{k: _event[k] for k in keys},
**{k: i for (k, i) in zip(special_keys, special_items)},
}
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
......
......@@ -160,8 +160,8 @@ class TestOfflineEvents(unittest.TestCase):
self.assertListEqual(self.det_id, list(self.events.det_id))
print(self.n_hits)
print(self.events.hits)
self.assertListEqual(self.n_hits, len(self.events.hits))
self.assertListEqual(self.n_tracks, len(self.events.tracks))
self.assertListEqual(self.n_hits, list(self.events.n_hits))
self.assertListEqual(self.n_tracks, list(self.events.n_tracks))
self.assertListEqual(self.t_sec, list(self.events.t_sec))
self.assertListEqual(self.t_ns, list(self.events.t_ns))
......@@ -296,13 +296,14 @@ class TestOfflineHits(unittest.TestCase):
],
}
def test_attributes_available(self):
for key in self.hits._keymap.keys():
def test_fields_work_as_keys_and_attributes(self):
for key in self.hits.fields:
getattr(self.hits, key)
self.hits[key]
def test_channel_ids(self):
self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min()))
self.assertTrue(all(c < 31 for c in self.hits.channel_id.max()))
self.assertTrue(all(c >= 0 for c in ak.min(self.hits.channel_id, axis=1)))
self.assertTrue(all(c < 31 for c in ak.max(self.hits.channel_id, axis=1)))
def test_str(self):
assert str(self.n_hits) in str(self.hits)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment