diff --git a/tests/test_tools.py b/tests/test_tools.py index 743ba8298d7a23cabe3ee2b1a48fbec447b4c1b2..17f1fd1aee6f490980a9782dabd35a036530e816 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -8,7 +8,7 @@ from pathlib import Path from km3io import OfflineReader from km3io.tools import (to_num, cached_property, unfold_indices, unique, uniquecount, fitinf, fitparams, count_nested, _find, - mask, best_track, rec_types, get_w2list_param) + mask, best_track, rec_types, get_w2list_param, get_multiplicity) SAMPLES_DIR = Path(__file__).parent / 'samples' OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'km3net_offline.root') @@ -54,13 +54,12 @@ class TestRecoTypes(unittest.TestCase): class TestBestTrack(unittest.TestCase): def setUp(self): self.events = OFFLINE_FILE.events + self.one_event = OFFLINE_FILE.events[0] def test_best_tracks(self): + # test selection from multiple events events = self.events[self.events.n_tracks > 0] first_tracks = best_track(events, strategy="first") - rec_stages_tracks = best_track(events, - strategy="rec_stages", - rec_stages=[1, 3, 5, 4]) default_best = best_track(events, strategy="default", rec_type="JPP_RECONSTRUCTION_TYPE") @@ -68,17 +67,34 @@ class TestBestTrack(unittest.TestCase): assert first_tracks.dir_z[0] == events.tracks.dir_z[0][0] assert first_tracks.dir_x[1] == events.tracks.dir_x[1][0] - assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4] - assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4] - assert default_best.lik[0] == ak.max(events.tracks.lik[0]) assert default_best.lik[1] == ak.max(events.tracks.lik[1]) assert default_best.rec_type[0] == 4000 + # test selection from one event + first_track = best_track(self.one_event, strategy="first") + best = best_track(self.one_event, + strategy="default", + rec_type="JPP_RECONSTRUCTION_TYPE") + + assert first_track.dir_z == self.one_event.tracks.dir_z[0] + assert first_track.lik == self.one_event.tracks.lik[0] + + assert best.lik == ak.max(self.one_event.tracks.lik) + assert best.rec_type == 4000 + + # test raising ValueError with self.assertRaises(ValueError): best_track(events, strategy="Zineb") +class TestGetMultiplicity(unittest.TestCase): + def test_get_multiplicity(self): + rec_stages_tracks = get_multiplicity(OFFLINE_FILE.events.tracks, [1, 3, 5, 4]) + + assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4] + assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4] + class TestCountNested(unittest.TestCase): def test_count_nested(self): fit = OFFLINE_FILE.events.tracks.fitinf