Skip to content
Snippets Groups Projects
test_aanet.py 10.5 KiB
Newer Older
Tamas Gal's avatar
Tamas Gal committed
import unittest
from pathlib import Path
Tamas Gal's avatar
Tamas Gal committed

Zineb Aly's avatar
Zineb Aly committed
from km3io.aanet import Reader, AanetEvents, AanetHits, AanetTracks
Tamas Gal's avatar
Tamas Gal committed
from km3io import AanetReader
Tamas Gal's avatar
Tamas Gal committed

SAMPLES_DIR = Path(__file__).parent / 'samples'
AANET_FILE = SAMPLES_DIR / 'aanet_v2.0.0.root'
Zineb Aly's avatar
Zineb Aly committed
AANET_NUMUCC = SAMPLES_DIR / "numucc.root"  # with mc data
Tamas Gal's avatar
Tamas Gal committed


Zineb Aly's avatar
Zineb Aly committed
class TestAanetKeys(unittest.TestCase):
Tamas Gal's avatar
Tamas Gal committed
    def setUp(self):
Zineb Aly's avatar
Zineb Aly committed
        self.keys = AanetReader(AANET_FILE).keys

    def test_repr(self):
        reader_repr = repr(self.keys)

        # check that there are 106 keys + 5 extra str
        self.assertEqual(len(reader_repr.split('\n')), 111)

    def test_events_keys(self):
        # there are 22 "valid" events keys
        self.assertEqual(len(self.keys.events_keys), 22)
        self.assertEqual(len(self.keys.cut_events_keys), 22)

    def test_hits_keys(self):
        # there are 20 "valid" hits keys
        self.assertEqual(len(self.keys.hits_keys), 20)
        self.assertEqual(len(self.keys.mc_hits_keys), 20)
        self.assertEqual(len(self.keys.cut_hits_keys), 20)

    def test_tracks_keys(self):
        # there are 22 "valid" tracks keys
        self.assertEqual(len(self.keys.tracks_keys), 22)
        self.assertEqual(len(self.keys.mc_tracks_keys), 22)
        self.assertEqual(len(self.keys.cut_tracks_keys), 22)

    def test_valid_keys(self):
        # there are 106 valid keys: 22*2 + 22 + 20*2
        # (fit keys are excluded)
        self.assertEqual(len(self.keys.valid_keys), 106)

    def test_fit_keys(self):
        # there are 18 fit keys
        self.assertEqual(len(self.keys.fit_keys), 18)


class TestReader(unittest.TestCase):
    def setUp(self):
        self.r = Reader(AANET_FILE)
Tamas Gal's avatar
Tamas Gal committed
        self.lengths = {0: 176, 1: 125, -1: 105}
        self.total_item_count = 1434

    def test_reading_dom_id(self):
        dom_ids = self.r["hits.dom_id"]
Tamas Gal's avatar
Tamas Gal committed

        for event_id, length in self.lengths.items():
            self.assertEqual(length, len(dom_ids[event_id]))
Tamas Gal's avatar
Tamas Gal committed

        self.assertEqual(self.total_item_count, sum(dom_ids.count()))
Tamas Gal's avatar
Tamas Gal committed

        self.assertListEqual([806451572, 806451572, 806451572],
                             list(dom_ids[0][:3]))

    def test_reading_channel_id(self):
        channel_ids = self.r["hits.channel_id"]
Tamas Gal's avatar
Tamas Gal committed

        for event_id, length in self.lengths.items():
            self.assertEqual(length, len(channel_ids[event_id]))
Tamas Gal's avatar
Tamas Gal committed

        self.assertEqual(self.total_item_count, sum(channel_ids.count()))
Tamas Gal's avatar
Tamas Gal committed

        self.assertListEqual([8, 9, 14], list(channel_ids[0][:3]))

        # channel IDs are always between [0, 30]
        self.assertTrue(all(c >= 0 for c in channel_ids.min()))
        self.assertTrue(all(c < 31 for c in channel_ids.max()))
Tamas Gal's avatar
Tamas Gal committed

    def test_reading_times(self):
        ts = self.r["hits.t"]
Tamas Gal's avatar
Tamas Gal committed

        for event_id, length in self.lengths.items():
            self.assertEqual(length, len(ts[event_id]))
Tamas Gal's avatar
Tamas Gal committed

        self.assertEqual(self.total_item_count, sum(ts.count()))
Tamas Gal's avatar
Tamas Gal committed

        self.assertListEqual([70104010.0, 70104016.0, 70104192.0],
                             list(ts[0][:3]))

    def test_reading_keys(self):
Zineb Aly's avatar
Zineb Aly committed
        # there are 106 "valid" keys in Aanet file
        self.assertEqual(len(self.r.keys.valid_keys), 106)

        # there are 20 hits keys
        self.assertEqual(len(self.r.keys.hits_keys), 20)
        self.assertEqual(len(self.r.keys.mc_hits_keys), 20)
Zineb Aly's avatar
Zineb Aly committed
        # there are 22 tracks keys
        self.assertEqual(len(self.r.keys.tracks_keys), 22)
        self.assertEqual(len(self.r.keys.mc_tracks_keys), 22)

    def test_raising_KeyError(self):
        # non valid keys must raise a KeyError
        with self.assertRaises(KeyError):
            self.r['whatever']

    def test_number_events(self):
        Nevents = len(self.r)

        # check that there are 10 events
        self.assertEqual(Nevents, 10)
Zineb Aly's avatar
Zineb Aly committed


class TestAanetReader(unittest.TestCase):
    def setUp(self):
        self.r = AanetReader(AANET_FILE)
        self.Nevents = 10
        self.selected_data = AanetReader(AANET_FILE,
                                         data=self.r._data[0])._data

    def test_item_selection(self):
        # test class instance with data=None option
        self.assertEqual(len(self.selected_data), len(self.r._data[0]))

        # test item selection (here we test with hits=176)
        self.assertEqual(self.r[0].events.hits, self.selected_data['hits'])

    def test_number_events(self):
        Nevents = len(self.r)

        # check that there are 10 events
        self.assertEqual(Nevents, self.Nevents)


class TestAanetEvents(unittest.TestCase):
    def setUp(self):
        self.events = AanetReader(AANET_FILE).events
        self.hits = {0: 176, 1: 125, -1: 105}
        self.Nevents = 10

    def test_reading_hits(self):
        # test item selection
        for event_id, hit in self.hits.items():
            self.assertEqual(hit, self.events.hits[event_id])

    def reading_tracks(self):
        self.assertListEqual(list(self.events.trks[:3]), [56, 55, 56])

    def test_item_selection(self):
        for event_id, hit in self.hits.items():
            self.assertEqual(hit, self.events[event_id].hits)

    def test_len(self):
        self.assertEqual(len(self.events), self.Nevents)

    def test_IndexError(self):
        # test handling IndexError with empty lists/arrays
        self.assertEqual(len(AanetEvents(['whatever'], [])), 0)

    def test_str(self):
        self.assertEqual(str(self.events), 'Number of events: 10')

    def test_repr(self):
        self.assertEqual(repr(self.events), '<AanetEvents: 10 parsed events>')


class TestAanetEvent(unittest.TestCase):
    def setUp(self):
        self.event = AanetReader(AANET_FILE).events[0]

    def test_str(self):
        self.assertEqual(repr(self.event).split('\n\t')[0], 'Aanet event:')
        self.assertEqual(
            repr(self.event).split('\n\t')[2],
            'det_id              :              44')


class TestAanetHits(unittest.TestCase):
    def setUp(self):
        self.hits = AanetReader(AANET_FILE).hits
        self.lengths = {0: 176, 1: 125, -1: 105}
        self.total_item_count = 1434
        self.r_mc = AanetReader(AANET_NUMUCC)
        self.Nevents = 10

    def test_item_selection(self):
        self.assertListEqual(list(self.hits[0].dom_id[:3]),
                             [806451572, 806451572, 806451572])

    def test_IndexError(self):
        # test handling IndexError with empty lists/arrays
        self.assertEqual(len(AanetHits(['whatever'], [])), 0)

    def test_repr(self):
        self.assertEqual(repr(self.hits), '<AanetHits: 10 parsed elements>')

    def test_str(self):
        self.assertEqual(str(self.hits), 'Number of hits: 10')

    def test_reading_dom_id(self):
        dom_ids = self.hits.dom_id

        for event_id, length in self.lengths.items():
            self.assertEqual(length, len(dom_ids[event_id]))

        self.assertEqual(self.total_item_count, sum(dom_ids.count()))

        self.assertListEqual([806451572, 806451572, 806451572],
                             list(dom_ids[0][:3]))

    def test_reading_channel_id(self):
        channel_ids = self.hits.channel_id

        for event_id, length in self.lengths.items():
            self.assertEqual(length, len(channel_ids[event_id]))

        self.assertEqual(self.total_item_count, sum(channel_ids.count()))

        self.assertListEqual([8, 9, 14], list(channel_ids[0][:3]))

        # channel IDs are always between [0, 30]
        self.assertTrue(all(c >= 0 for c in channel_ids.min()))
        self.assertTrue(all(c < 31 for c in channel_ids.max()))

    def test_reading_times(self):
        ts = self.hits.t

        for event_id, length in self.lengths.items():
            self.assertEqual(length, len(ts[event_id]))

        self.assertEqual(self.total_item_count, sum(ts.count()))

        self.assertListEqual([70104010.0, 70104016.0, 70104192.0],
                             list(ts[0][:3]))

    def test_reading_mc_pmt_id(self):
        pmt_ids = self.r_mc.mc_hits.pmt_id
        lengths = {0: 58, 2: 28, -1: 48}

        for hit_id, length in lengths.items():
            self.assertEqual(length, len(pmt_ids[hit_id]))

        self.assertEqual(self.Nevents, len(pmt_ids))

        self.assertListEqual([677, 687, 689], list(pmt_ids[0][:3]))


class TestAanetHit(unittest.TestCase):
    def setUp(self):
        self.hit = AanetReader(AANET_FILE)[0].hits[0]

    def test_item_selection(self):
        self.assertEqual(self.hit[0], self.hit.id)
        self.assertEqual(self.hit[1], self.hit.dom_id)

    def test_str(self):
        self.assertEqual(repr(self.hit).split('\n\t')[0], 'Aanet hit:')
        self.assertEqual(
            repr(self.hit).split('\n\t')[2],
            'dom_id              :       806451572')


class TestAanetTracks(unittest.TestCase):
    def setUp(self):
        self.tracks = AanetReader(AANET_FILE).tracks
        self.r_mc = AanetReader(AANET_NUMUCC)
        self.Nevents = 10

    def test_item_selection(self):
        self.assertListEqual(list(self.tracks[0].dir_z[:2]),
                             [-0.872885221293917, -0.872885221293917])

    def test_IndexError(self):
        # test handling IndexError with empty lists/arrays
        self.assertEqual(len(AanetTracks(['whatever'], [])), 0)

    def test_repr(self):
        self.assertEqual(repr(self.tracks),
                         '<AanetTracks: 10 parsed elements>')

    def test_str(self):
        self.assertEqual(str(self.tracks), 'Number of tracks: 10')

    def test_reading_tracks_dir_z(self):
        dir_z = self.tracks.dir_z
        tracks_dir_z = {0: 56, 1: 55, 8: 54}

        for track_id, n_dir in tracks_dir_z.items():
            self.assertEqual(n_dir, len(dir_z[track_id]))

        # check that there are 10 arrays of tracks.dir_z info
        self.assertEqual(len(dir_z), self.Nevents)

    def test_reading_mc_tracks_dir_z(self):
        dir_z = self.r_mc.mc_tracks.dir_z
        tracks_dir_z = {0: 11, 1: 25, 8: 13}

        for track_id, n_dir in tracks_dir_z.items():
            self.assertEqual(n_dir, len(dir_z[track_id]))

        # check that there are 10 arrays of tracks.dir_z info
        self.assertEqual(len(dir_z), self.Nevents)

        self.assertListEqual([0.230189, 0.230189, 0.218663],
                             list(dir_z[0][:3]))


class TestAanetTrack(unittest.TestCase):
    def setUp(self):
        self.track = AanetReader(AANET_FILE)[0].tracks[0]

    def test_item_selection(self):
        self.assertEqual(self.track[0], self.track.fUniqueID)
        self.assertEqual(self.track[10], self.track.E)

    def test_str(self):
        self.assertEqual(repr(self.track).split('\n\t')[0], 'Aanet track:')
        self.assertEqual(
            repr(self.track).split('\n\t')[28],
            'JGANDALF_LAMBDA                :      4.2409761837248484e-12')