test_offline.py 15.91 KiB
import unittest
import numpy as np
from pathlib import Path
import uuid
import awkward as ak
from km3net_testdata import data_path
from km3io import OfflineReader
from km3io.offline import Header
OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root"))
OFFLINE_MC_TRACK_USR = OfflineReader(
data_path(
"offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root"
)
)
OFFLINE_NUMUCC = OfflineReader(data_path("offline/numucc.root")) # with mc data
class TestOfflineReader(unittest.TestCase):
def setUp(self):
self.r = OFFLINE_FILE
self.nu = OFFLINE_NUMUCC
self.n_events = 10
def test_context_manager(self):
filename = OFFLINE_FILE
with OfflineReader(data_path("offline/km3net_offline.root")) as r:
assert r
def test_number_events(self):
assert self.n_events == len(self.r.events)
def test_uuid(self):
assert str(self.r.uuid) == "b192d888-fcc7-11e9-b430-6cf09e86beef"
class TestHeader(unittest.TestCase):
def test_str_header(self):
assert "MC Header" in str(OFFLINE_NUMUCC.header)
def test_warning_if_unsupported_header(self):
# test the warning for unsupported fheader format
with self.assertWarns(UserWarning):
OFFLINE_FILE.header
def test_missing_key_definitions(self):
head = {"a": "1 2 3", "b": "4", "c": "d"}
header = Header(head)
assert 1 == header.a.field_0
assert 2 == header.a.field_1
assert 3 == header.a.field_2
assert 4 == header.b
assert "d" == header.c
def test_missing_values(self):
head = {"can": "1"}
header = Header(head)
assert 1 == header.can.zmin
assert header.can.zmax is None
assert header.can.r is None
def test_additional_values_compared_to_definition(self):
head = {"can": "1 2 3 4"}
header = Header(head)
assert 1 == header.can.zmin
assert 2 == header.can.zmax
assert 3 == header.can.r
assert 4 == header.can.field_3
def test_header(self):
head = {
"DAQ": "394",
"PDF": "4",
"can": "0 1027 888.4",
"undefined": "1 2 test 3.4",
}
header = Header(head)
assert 394 == header.DAQ.livetime
assert 4 == header.PDF.i1
assert header.PDF.i2 is None
assert 0 == header.can.zmin
assert 1027 == header.can.zmax
assert 888.4 == header.can.r
assert 1 == header.undefined.field_0
assert 2 == header.undefined.field_1
assert "test" == header.undefined.field_2
assert 3.4 == header.undefined.field_3
def test_reading_header_from_sample_file(self):
head = OFFLINE_NUMUCC.header
assert 394 == head.DAQ.livetime
assert 4 == head.PDF.i1
assert 58 == head.PDF.i2
assert 0 == head.coord_origin.x
assert 0 == head.coord_origin.y
assert 0 == head.coord_origin.z
assert 100 == head.cut_nu.Emin
assert 100000000.0 == head.cut_nu.Emax
assert -1 == head.cut_nu.cosTmin
assert 1 == head.cut_nu.cosTmax
assert "diffuse" == head.sourcemode
assert 100000.0 == head.ngen
class TestOfflineEvents(unittest.TestCase):
def setUp(self):
self.events = OFFLINE_FILE.events
self.n_events = 10
self.det_id = [44] * self.n_events
self.n_hits = [176, 125, 318, 157, 83, 60, 71, 84, 255, 105]
self.n_tracks = [56, 55, 56, 56, 56, 56, 56, 56, 54, 56]
self.t_sec = [
1567036818,
1567036818,
1567036820,
1567036816,
1567036816,
1567036816,
1567036822,
1567036818,
1567036818,
1567036820,
]
self.t_ns = [
200000000,
300000000,
200000000,
500000000,
500000000,
500000000,
200000000,
500000000,
500000000,
400000000,
]
def test_len(self):
assert self.n_events == len(self.events)
def test_attributes(self):
assert self.n_events == len(self.events.det_id)
self.assertListEqual(self.det_id, list(self.events.det_id))
print(self.n_hits)
print(self.events.hits)
self.assertListEqual(self.n_hits, list(self.events.n_hits))
self.assertListEqual(self.n_tracks, list(self.events.n_tracks))
self.assertListEqual(self.t_sec, list(self.events.t_sec))
self.assertListEqual(self.t_ns, list(self.events.t_ns))
def test_keys(self):
assert np.allclose(self.n_hits, self.events["n_hits"].tolist())
assert np.allclose(self.n_tracks, self.events["n_tracks"].tolist())
assert np.allclose(self.t_sec, self.events["t_sec"].tolist())
assert np.allclose(self.t_ns, self.events["t_ns"].tolist())
def test_slicing(self):
s = slice(2, 8, 2)
s_events = self.events[s]
assert 3 == len(s_events)
self.assertListEqual(self.n_hits[s], list(s_events.n_hits))
self.assertListEqual(self.n_tracks[s], list(s_events.n_tracks))
self.assertListEqual(self.t_sec[s], list(s_events.t_sec))
self.assertListEqual(self.t_ns[s], list(s_events.t_ns))
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
assert np.allclose(
self.events[s].n_hits.tolist(), self.events.n_hits[s].tolist()
)
def test_index_consistency(self):
for i in [0, 2, 5]:
assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
def test_index_chaining(self):
assert np.allclose(
self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist()
)
assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0])
def test_index_chaining_on_nested_branches_aka_records(self):
assert np.allclose(
self.events[3:5].hits[1].dom_id[4],
self.events.hits[3:5][1].dom_id[4],
)
assert np.allclose(
self.events.hits[3:5][1].dom_id[4],
self.events[3:5][1].hits.dom_id[4],
)
def test_fancy_indexing(self):
mask = self.events.n_tracks > 55
tracks = self.events.tracks[mask]
first_tracks = tracks[:, 0]
assert 8 == len(first_tracks)
assert 8 == len(first_tracks.rec_stages)
assert 8 == len(first_tracks.lik)
@unittest.skip
def test_iteration(self):
i = 0
for event in self.events:
i += 1
assert 10 == i
@unittest.skip
def test_iteration_2(self):
n_hits = [len(e.hits.id) for e in self.events]
assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist())
def test_str(self):
assert str(self.n_events) in str(self.events)
def test_repr(self):
assert str(self.n_events) in repr(self.events)
class TestOfflineHits(unittest.TestCase):
def setUp(self):
self.hits = OFFLINE_FILE.events.hits
self.n_hits = 10
self.dom_id = {
0: [
806451572,
806451572,
806451572,
806451572,
806455814,
806455814,
806455814,
806483369,
806483369,
806483369,
],
5: [
806455814,
806487219,
806487219,
806487219,
806487226,
808432835,
808432835,
808432835,
808432835,
808432835,
],
}
self.t = {
0: [
70104010.0,
70104016.0,
70104192.0,
70104123.0,
70103096.0,
70103797.0,
70103796.0,
70104191.0,
70104223.0,
70104181.0,
],
5: [
81861237.0,
81859608.0,
81860586.0,
81861062.0,
81860357.0,
81860627.0,
81860628.0,
81860625.0,
81860627.0,
81860629.0,
],
}
def test_fields_work_as_keys_and_attributes(self):
for key in self.hits.fields:
getattr(self.hits, key)
self.hits[key]
def test_channel_ids(self):
self.assertTrue(all(c >= 0 for c in ak.min(self.hits.channel_id, axis=1)))
self.assertTrue(all(c < 31 for c in ak.max(self.hits.channel_id, axis=1)))
def test_repr(self):
assert str(self.n_hits) in repr(self.hits)
def test_attributes(self):
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)]))
for idx, t in self.t.items():
assert np.allclose(t, self.hits.t[idx][: len(t)].tolist())
def test_slicing(self):
s = slice(2, 8, 2)
s_hits = self.hits[s]
assert 3 == len(s_hits)
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id[s], list(self.hits.dom_id[idx][s]))
for idx, t in self.t.items():
self.assertListEqual(t[s], list(self.hits.t[idx][s]))
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
for idx in range(3):
assert np.allclose(
self.hits.dom_id[idx][s].tolist(), self.hits[idx].dom_id[s].tolist()
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[s].tolist(),
self.hits.dom_id[idx][s].tolist(),
)
def test_index_consistency(self):
for idx, dom_ids in self.dom_id.items():
assert np.allclose(
self.hits[idx].dom_id[: self.n_hits].tolist(), dom_ids[: self.n_hits]
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits].tolist(),
dom_ids[: self.n_hits],
)
for idx, ts in self.t.items():
assert np.allclose(
self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits]
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.t[: self.n_hits].tolist(),
ts[: self.n_hits],
)
def test_fields(self):
assert "dom_id" in self.hits.fields
assert "channel_id" in self.hits.fields
assert "t" in self.hits.fields
assert "tot" in self.hits.fields
assert "trig" in self.hits.fields
assert "id" in self.hits.fields
class TestOfflineTracks(unittest.TestCase):
def setUp(self):
self.f = OFFLINE_FILE
self.tracks = OFFLINE_FILE.events.tracks
self.tracks_numucc = OFFLINE_NUMUCC
self.n_events = 10
def test_fields(self):
for field in [
"id",
"pos_x",
"pos_y",
"pos_z",
"dir_x",
"dir_y",
"dir_z",
"t",
"E",
"len",
"lik",
"rec_type",
"rec_stages",
"fitinf",
]:
getattr(self.tracks, field)
def test_item_selection(self):
self.assertListEqual(
list(self.tracks[0].dir_z[:2]), [-0.872885221293917, -0.872885221293917]
)
def test_repr(self):
assert "10" in repr(self.tracks)
def test_slicing(self):
tracks = self.tracks
self.assertEqual(10, len(tracks)) # 10 events
self.assertEqual(56, len(tracks[0].id)) # number of tracks in first event
track_selection = tracks[2:7]
assert 5 == len(track_selection)
track_selection_2 = tracks[1:3]
assert 2 == len(track_selection_2)
for _slice in [
slice(0, 1),
slice(0, 2),
slice(1, 5),
slice(3, -2),
]:
print(f"checking {_slice}")
self.assertListEqual(
list(tracks.E[:, 0][_slice]), list(tracks[_slice].E[:, 0])
)
def test_nested_indexing(self):
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5].tracks[1].fitinf[9][2],
)
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5][1].tracks.fitinf[9][2],
)
class TestBranchIndexingMagic(unittest.TestCase):
def setUp(self):
self.events = OFFLINE_FILE.events
def test_slicing_magic(self):
self.assertEqual(318, self.events[2:4].n_hits[0])
assert np.allclose(
self.events[3].tracks.dir_z[10], self.events.tracks.dir_z[3, 10]
)
assert np.allclose(
self.events[3:6].tracks.pos_y[:, 0].tolist(),
self.events.tracks.pos_y[3:6, 0].tolist(),
)
def test_selecting_specific_items_via_a_list(self):
# test selecting with a list
self.assertEqual(3, len(self.events[[0, 2, 3]]))
def test_selecting_specific_items_via_a_numpy_array(self):
# test selecting with a list
self.assertEqual(3, len(self.events[np.array([0, 2, 3])]))
def test_selecting_specific_items_via_a_awkward_array(self):
# test selecting with a list
self.assertEqual(3, len(self.events[ak.Array([0, 2, 3])]))
class TestUsr(unittest.TestCase):
def setUp(self):
self.f = OFFLINE_USR
@unittest.skip
def test_str_flat(self):
print(self.f.events.usr)
@unittest.skip
def test_keys_flat(self):
self.assertListEqual(
[
"RecoQuality",
"RecoNDF",
"CoC",
"ToT",
"ChargeAbove",
"ChargeBelow",
"ChargeRatio",
"DeltaPosZ",
"FirstPartPosZ",
"LastPartPosZ",
"NSnapHits",
"NTrigHits",
"NTrigDOMs",
"NTrigLines",
"NSpeedVetoHits",
"NGeometryVetoHits",
"ClassficationScore",
],
self.f.events.usr.keys().tolist(),
)
@unittest.skip
def test_getitem_flat(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.events.usr["CoC"].tolist(),
)
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
self.f.events.usr["DeltaPosZ"].tolist(),
)
@unittest.skip
def test_attributes_flat(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.events.usr.CoC.tolist(),
)
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
self.f.events.usr.DeltaPosZ.tolist(),
)
class TestMcTrackUsr(unittest.TestCase):
def setUp(self):
self.f = OFFLINE_MC_TRACK_USR
def test_usr_names(self):
n_tracks = len(self.f.events)
for i in range(3):
self.assertListEqual(
["bx", "by", "ichan", "cc"],
self.f.events.mc_tracks.usr_names[i][0].tolist(),
)
self.assertListEqual(
["energy_lost_in_can"],
self.f.events.mc_tracks.usr_names[i][1].tolist(),
)
def test_usr(self):
assert np.allclose(
[0.0487, 0.0588, 3, 2],
self.f.events.mc_tracks.usr[0][0].tolist(),
atol=0.0001,
)
assert np.allclose(
[0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001
)