Skip to content
Snippets Groups Projects
test_offline.py 15.91 KiB
import unittest
import numpy as np
from pathlib import Path
import uuid

import awkward as ak
from km3net_testdata import data_path

from km3io import OfflineReader
from km3io.offline import Header

OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root"))
OFFLINE_MC_TRACK_USR = OfflineReader(
    data_path(
        "offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root"
    )
)
OFFLINE_NUMUCC = OfflineReader(data_path("offline/numucc.root"))  # with mc data


class TestOfflineReader(unittest.TestCase):
    def setUp(self):
        self.r = OFFLINE_FILE
        self.nu = OFFLINE_NUMUCC
        self.n_events = 10

    def test_context_manager(self):
        filename = OFFLINE_FILE
        with OfflineReader(data_path("offline/km3net_offline.root")) as r:
            assert r

    def test_number_events(self):
        assert self.n_events == len(self.r.events)

    def test_uuid(self):
        assert str(self.r.uuid) == "b192d888-fcc7-11e9-b430-6cf09e86beef"


class TestHeader(unittest.TestCase):
    def test_str_header(self):
        assert "MC Header" in str(OFFLINE_NUMUCC.header)

    def test_warning_if_unsupported_header(self):
        # test the warning for unsupported fheader format
        with self.assertWarns(UserWarning):
            OFFLINE_FILE.header

    def test_missing_key_definitions(self):
        head = {"a": "1 2 3", "b": "4", "c": "d"}

        header = Header(head)

        assert 1 == header.a.field_0
        assert 2 == header.a.field_1
        assert 3 == header.a.field_2
        assert 4 == header.b
        assert "d" == header.c

    def test_missing_values(self):
        head = {"can": "1"}

        header = Header(head)

        assert 1 == header.can.zmin
        assert header.can.zmax is None
        assert header.can.r is None

    def test_additional_values_compared_to_definition(self):
        head = {"can": "1 2 3 4"}

        header = Header(head)

        assert 1 == header.can.zmin
        assert 2 == header.can.zmax
        assert 3 == header.can.r
        assert 4 == header.can.field_3

    def test_header(self):
        head = {
            "DAQ": "394",
            "PDF": "4",
            "can": "0 1027 888.4",
            "undefined": "1 2 test 3.4",
        }

        header = Header(head)

        assert 394 == header.DAQ.livetime
        assert 4 == header.PDF.i1
        assert header.PDF.i2 is None
        assert 0 == header.can.zmin
        assert 1027 == header.can.zmax
        assert 888.4 == header.can.r
        assert 1 == header.undefined.field_0
        assert 2 == header.undefined.field_1
        assert "test" == header.undefined.field_2
        assert 3.4 == header.undefined.field_3

    def test_reading_header_from_sample_file(self):
        head = OFFLINE_NUMUCC.header

        assert 394 == head.DAQ.livetime
        assert 4 == head.PDF.i1
        assert 58 == head.PDF.i2
        assert 0 == head.coord_origin.x
        assert 0 == head.coord_origin.y
        assert 0 == head.coord_origin.z
        assert 100 == head.cut_nu.Emin
        assert 100000000.0 == head.cut_nu.Emax
        assert -1 == head.cut_nu.cosTmin
        assert 1 == head.cut_nu.cosTmax
        assert "diffuse" == head.sourcemode
        assert 100000.0 == head.ngen


class TestOfflineEvents(unittest.TestCase):
    def setUp(self):
        self.events = OFFLINE_FILE.events
        self.n_events = 10
        self.det_id = [44] * self.n_events
        self.n_hits = [176, 125, 318, 157, 83, 60, 71, 84, 255, 105]
        self.n_tracks = [56, 55, 56, 56, 56, 56, 56, 56, 54, 56]
        self.t_sec = [
            1567036818,
            1567036818,
            1567036820,
            1567036816,
            1567036816,
            1567036816,
            1567036822,
            1567036818,
            1567036818,
            1567036820,
        ]
        self.t_ns = [
            200000000,
            300000000,
            200000000,
            500000000,
            500000000,
            500000000,
            200000000,
            500000000,
            500000000,
            400000000,
        ]

    def test_len(self):
        assert self.n_events == len(self.events)

    def test_attributes(self):
        assert self.n_events == len(self.events.det_id)
        self.assertListEqual(self.det_id, list(self.events.det_id))
        print(self.n_hits)
        print(self.events.hits)
        self.assertListEqual(self.n_hits, list(self.events.n_hits))
        self.assertListEqual(self.n_tracks, list(self.events.n_tracks))
        self.assertListEqual(self.t_sec, list(self.events.t_sec))
        self.assertListEqual(self.t_ns, list(self.events.t_ns))

    def test_keys(self):
        assert np.allclose(self.n_hits, self.events["n_hits"].tolist())
        assert np.allclose(self.n_tracks, self.events["n_tracks"].tolist())
        assert np.allclose(self.t_sec, self.events["t_sec"].tolist())
        assert np.allclose(self.t_ns, self.events["t_ns"].tolist())

    def test_slicing(self):
        s = slice(2, 8, 2)
        s_events = self.events[s]
        assert 3 == len(s_events)
        self.assertListEqual(self.n_hits[s], list(s_events.n_hits))
        self.assertListEqual(self.n_tracks[s], list(s_events.n_tracks))
        self.assertListEqual(self.t_sec[s], list(s_events.t_sec))
        self.assertListEqual(self.t_ns[s], list(s_events.t_ns))

    def test_slicing_consistency(self):
        for s in [slice(1, 3), slice(2, 7, 3)]:
            assert np.allclose(
                self.events[s].n_hits.tolist(), self.events.n_hits[s].tolist()
            )

    def test_index_consistency(self):
        for i in [0, 2, 5]:
            assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])

    def test_index_chaining(self):
        assert np.allclose(
            self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist()
        )
        assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0])

    def test_index_chaining_on_nested_branches_aka_records(self):
        assert np.allclose(
            self.events[3:5].hits[1].dom_id[4],
            self.events.hits[3:5][1].dom_id[4],
        )
        assert np.allclose(
            self.events.hits[3:5][1].dom_id[4],
            self.events[3:5][1].hits.dom_id[4],
        )

    def test_fancy_indexing(self):
        mask = self.events.n_tracks > 55
        tracks = self.events.tracks[mask]
        first_tracks = tracks[:, 0]
        assert 8 == len(first_tracks)
        assert 8 == len(first_tracks.rec_stages)
        assert 8 == len(first_tracks.lik)

    @unittest.skip
    def test_iteration(self):
        i = 0
        for event in self.events:
            i += 1
        assert 10 == i

    @unittest.skip
    def test_iteration_2(self):
        n_hits = [len(e.hits.id) for e in self.events]
        assert np.allclose(n_hits, ak.num(self.events.hits.id, axis=1).tolist())

    def test_str(self):
        assert str(self.n_events) in str(self.events)

    def test_repr(self):
        assert str(self.n_events) in repr(self.events)


class TestOfflineHits(unittest.TestCase):
    def setUp(self):
        self.hits = OFFLINE_FILE.events.hits
        self.n_hits = 10
        self.dom_id = {
            0: [
                806451572,
                806451572,
                806451572,
                806451572,
                806455814,
                806455814,
                806455814,
                806483369,
                806483369,
                806483369,
            ],
            5: [
                806455814,
                806487219,
                806487219,
                806487219,
                806487226,
                808432835,
                808432835,
                808432835,
                808432835,
                808432835,
            ],
        }
        self.t = {
            0: [
                70104010.0,
                70104016.0,
                70104192.0,
                70104123.0,
                70103096.0,
                70103797.0,
                70103796.0,
                70104191.0,
                70104223.0,
                70104181.0,
            ],
            5: [
                81861237.0,
                81859608.0,
                81860586.0,
                81861062.0,
                81860357.0,
                81860627.0,
                81860628.0,
                81860625.0,
                81860627.0,
                81860629.0,
            ],
        }

    def test_fields_work_as_keys_and_attributes(self):
        for key in self.hits.fields:
            getattr(self.hits, key)
            self.hits[key]

    def test_channel_ids(self):
        self.assertTrue(all(c >= 0 for c in ak.min(self.hits.channel_id, axis=1)))
        self.assertTrue(all(c < 31 for c in ak.max(self.hits.channel_id, axis=1)))

    def test_repr(self):
        assert str(self.n_hits) in repr(self.hits)

    def test_attributes(self):
        for idx, dom_id in self.dom_id.items():
            self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)]))
        for idx, t in self.t.items():
            assert np.allclose(t, self.hits.t[idx][: len(t)].tolist())

    def test_slicing(self):
        s = slice(2, 8, 2)
        s_hits = self.hits[s]
        assert 3 == len(s_hits)
        for idx, dom_id in self.dom_id.items():
            self.assertListEqual(dom_id[s], list(self.hits.dom_id[idx][s]))
        for idx, t in self.t.items():
            self.assertListEqual(t[s], list(self.hits.t[idx][s]))

    def test_slicing_consistency(self):
        for s in [slice(1, 3), slice(2, 7, 3)]:
            for idx in range(3):
                assert np.allclose(
                    self.hits.dom_id[idx][s].tolist(), self.hits[idx].dom_id[s].tolist()
                )
                assert np.allclose(
                    OFFLINE_FILE.events[idx].hits.dom_id[s].tolist(),
                    self.hits.dom_id[idx][s].tolist(),
                )

    def test_index_consistency(self):
        for idx, dom_ids in self.dom_id.items():
            assert np.allclose(
                self.hits[idx].dom_id[: self.n_hits].tolist(), dom_ids[: self.n_hits]
            )
            assert np.allclose(
                OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits].tolist(),
                dom_ids[: self.n_hits],
            )
        for idx, ts in self.t.items():
            assert np.allclose(
                self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits]
            )
            assert np.allclose(
                OFFLINE_FILE.events[idx].hits.t[: self.n_hits].tolist(),
                ts[: self.n_hits],
            )

    def test_fields(self):
        assert "dom_id" in self.hits.fields
        assert "channel_id" in self.hits.fields
        assert "t" in self.hits.fields
        assert "tot" in self.hits.fields
        assert "trig" in self.hits.fields
        assert "id" in self.hits.fields


class TestOfflineTracks(unittest.TestCase):
    def setUp(self):
        self.f = OFFLINE_FILE
        self.tracks = OFFLINE_FILE.events.tracks
        self.tracks_numucc = OFFLINE_NUMUCC
        self.n_events = 10

    def test_fields(self):
        for field in [
            "id",
            "pos_x",
            "pos_y",
            "pos_z",
            "dir_x",
            "dir_y",
            "dir_z",
            "t",
            "E",
            "len",
            "lik",
            "rec_type",
            "rec_stages",
            "fitinf",
        ]:
            getattr(self.tracks, field)

    def test_item_selection(self):
        self.assertListEqual(
            list(self.tracks[0].dir_z[:2]), [-0.872885221293917, -0.872885221293917]
        )

    def test_repr(self):
        assert "10" in repr(self.tracks)

    def test_slicing(self):
        tracks = self.tracks
        self.assertEqual(10, len(tracks))  # 10 events
        self.assertEqual(56, len(tracks[0].id))  # number of tracks in first event
        track_selection = tracks[2:7]
        assert 5 == len(track_selection)
        track_selection_2 = tracks[1:3]
        assert 2 == len(track_selection_2)
        for _slice in [
            slice(0, 1),
            slice(0, 2),
            slice(1, 5),
            slice(3, -2),
        ]:
            print(f"checking {_slice}")
            self.assertListEqual(
                list(tracks.E[:, 0][_slice]), list(tracks[_slice].E[:, 0])
            )

    def test_nested_indexing(self):
        self.assertAlmostEqual(
            self.f.events.tracks.fitinf[3:5][1][9][2],
            self.f.events[3:5].tracks[1].fitinf[9][2],
        )
        self.assertAlmostEqual(
            self.f.events.tracks.fitinf[3:5][1][9][2],
            self.f.events[3:5][1].tracks.fitinf[9][2],
        )


class TestBranchIndexingMagic(unittest.TestCase):
    def setUp(self):
        self.events = OFFLINE_FILE.events

    def test_slicing_magic(self):
        self.assertEqual(318, self.events[2:4].n_hits[0])
        assert np.allclose(
            self.events[3].tracks.dir_z[10], self.events.tracks.dir_z[3, 10]
        )
        assert np.allclose(
            self.events[3:6].tracks.pos_y[:, 0].tolist(),
            self.events.tracks.pos_y[3:6, 0].tolist(),
        )

    def test_selecting_specific_items_via_a_list(self):
        # test selecting with a list
        self.assertEqual(3, len(self.events[[0, 2, 3]]))

    def test_selecting_specific_items_via_a_numpy_array(self):
        # test selecting with a list
        self.assertEqual(3, len(self.events[np.array([0, 2, 3])]))

    def test_selecting_specific_items_via_a_awkward_array(self):
        # test selecting with a list
        self.assertEqual(3, len(self.events[ak.Array([0, 2, 3])]))


class TestUsr(unittest.TestCase):
    def setUp(self):
        self.f = OFFLINE_USR

    @unittest.skip
    def test_str_flat(self):
        print(self.f.events.usr)

    @unittest.skip
    def test_keys_flat(self):
        self.assertListEqual(
            [
                "RecoQuality",
                "RecoNDF",
                "CoC",
                "ToT",
                "ChargeAbove",
                "ChargeBelow",
                "ChargeRatio",
                "DeltaPosZ",
                "FirstPartPosZ",
                "LastPartPosZ",
                "NSnapHits",
                "NTrigHits",
                "NTrigDOMs",
                "NTrigLines",
                "NSpeedVetoHits",
                "NGeometryVetoHits",
                "ClassficationScore",
            ],
            self.f.events.usr.keys().tolist(),
        )

    @unittest.skip
    def test_getitem_flat(self):
        assert np.allclose(
            [118.6302815337638, 44.33580521344907, 99.93916717621543],
            self.f.events.usr["CoC"].tolist(),
        )
        assert np.allclose(
            [37.51967774166617, -10.280346193553832, 13.67595659707355],
            self.f.events.usr["DeltaPosZ"].tolist(),
        )

    @unittest.skip
    def test_attributes_flat(self):
        assert np.allclose(
            [118.6302815337638, 44.33580521344907, 99.93916717621543],
            self.f.events.usr.CoC.tolist(),
        )
        assert np.allclose(
            [37.51967774166617, -10.280346193553832, 13.67595659707355],
            self.f.events.usr.DeltaPosZ.tolist(),
        )


class TestMcTrackUsr(unittest.TestCase):
    def setUp(self):
        self.f = OFFLINE_MC_TRACK_USR

    def test_usr_names(self):
        n_tracks = len(self.f.events)
        for i in range(3):
            self.assertListEqual(
                ["bx", "by", "ichan", "cc"],
                self.f.events.mc_tracks.usr_names[i][0].tolist(),
            )
            self.assertListEqual(
                ["energy_lost_in_can"],
                self.f.events.mc_tracks.usr_names[i][1].tolist(),
            )

    def test_usr(self):
        assert np.allclose(
            [0.0487, 0.0588, 3, 2],
            self.f.events.mc_tracks.usr[0][0].tolist(),
            atol=0.0001,
        )
        assert np.allclose(
            [0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001
        )