From 2a4176236efb82eb85696323eb873510c8819ef8 Mon Sep 17 00:00:00 2001 From: Zineb Aly <zaly@km3net.de> Date: Tue, 6 Oct 2020 11:24:42 +0200 Subject: [PATCH] add tests for best_track --- tests/test_tools.py | 80 ++++++++++++++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 7de8aec..66bac91 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -15,11 +15,6 @@ from km3io.tools import (to_num, cached_property, unfold_indices, unique, OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) -# class TestGetw2listParam(unittest.TestCase): -# def test_get_w2list_param(self): -# xsec_mean = get_w2list_param(OFFLINE_FILE.events, "gseagen", "W2LIST_GSEAGEN_XSEC_MEAN") -# print(xsec_mean) - class TestFitinf(unittest.TestCase): def setUp(self): @@ -53,24 +48,61 @@ class TestRecoTypes(unittest.TestCase): assert "JPP_RECONSTRUCTION_TYPE" in keys -# class TestBestTrack(unittest.TestCase): -# def setUp(self): -# self.events = OFFLINE_FILE.events -# self.one_event = OFFLINE_FILE.events[0] +class TestBestTrack(unittest.TestCase): + def setUp(self): + self.events = OFFLINE_FILE.events + self.one_event = OFFLINE_FILE.events[0] + + def test_best_track_selection_from_multiple_events_with_no_stages(self): + longest = best_track(self.events.tracks, "JMUON") + + assert len(longest) == 10 + + assert longest.rec_stages[0] == [1, 3, 5, 4] + assert longest.rec_stages[1] == [1, 3, 5, 4] + assert longest.rec_stages[2] == [1, 3, 5, 4] + assert longest.rec_stages[3] == [1, 3, 5, 4] + + def test_best_track_selection_from_multiple_events_with_explicit_stages( + self): + best = best_track(self.events.tracks, "JMUON", stages=[1, 3, 5, 4]) + + assert len(best) == 10 + + assert best.rec_stages[0] == [1, 3, 5, 4] + assert best.rec_stages[1] == [1, 3, 5, 4] + assert best.rec_stages[2] == [1, 3, 5, 4] + assert best.rec_stages[3] == [1, 3, 5, 4] -# def test_best_track_from_multiple_events(self): -# events = self.events[self.events.n_tracks > 0] -# first_tracks = best_track(events.tracks, strategy="first") -# default_best = best_track(events.tracks, -# strategy="default", -# rec_type="JPP_RECONSTRUCTION_TYPE") + def test_best_track_selection_from_multiple_events_with_start_end(self): + best = best_track(self.events.tracks, "JMUON", start=1, end=4) + best2 = best_track(self.events.tracks, "JMUON", start=1, end=3) -# assert first_tracks.dir_z[0] == events.tracks.dir_z[0][0] -# assert first_tracks.dir_x[1] == events.tracks.dir_x[1][0] + assert len(best) == 10 + assert len(best2) == 10 + + assert best.rec_stages[0] == [1, 3, 5, 4] + assert best.rec_stages[1] == [1, 3, 5, 4] + assert best.rec_stages[2] == [1, 3, 5, 4] + assert best.rec_stages[3] == [1, 3, 5, 4] + + assert best2.rec_stages[0] == [1, 3] + assert best2.rec_stages[1] == [1, 3] + assert best2.rec_stages[2] == [1, 3] + assert best2.rec_stages[3] == [1, 3] + + def test_best_track_raises_when_unknown_reco(self): + with self.assertRaises(KeyError): + best_track(self.events.tracks, "Zineb") + + def test_best_track_raises_when_start_or_end_not_valid(self): + with self.assertRaises(ValueError): + best_track(self.events.tracks, "JMUON", start=100, end=103) + + def test_best_track_raises_when_wrong_reco_stages(self): + with self.assertRaises(KeyError): + best_track(self.events.tracks, "JMUON", stages=[233, 100, 500]) -# assert default_best.lik[0] == ak.max(events.tracks.lik[0]) -# assert default_best.lik[1] == ak.max(events.tracks.lik[1]) -# assert default_best.rec_type[0] == 4000 # def test_best_track_from_a_single_event(self): # first_track = best_track(self.one_event.tracks, strategy="first") @@ -130,6 +162,8 @@ class TestRecStagesMasks(unittest.TestCase): self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]], [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]]) + self.tracks = OFFLINE_FILE.events.tracks + def test_find(self): builder = ak.ArrayBuilder() _find(self.nested, ak.Array([1, 2, 3]), builder) @@ -140,10 +174,10 @@ class TestRecStagesMasks(unittest.TestCase): assert labels[0][2] == 0 assert labels[1][0] == 0 - def test_mask(self): - rec_stages = OFFLINE_FILE.events.tracks.rec_stages + def test_mask_with_explicit_rec_stages(self): + rec_stages = self.tracks.rec_stages stages = [1, 3, 5, 4] - masks = mask(rec_stages, stages) + masks = mask(rec_stages, stages=stages) assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages)) assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages)) -- GitLab