diff --git a/tests/test_tools.py b/tests/test_tools.py index e719649c5319bd9718467ab6f94a1d05304737e1..dcc22aba9de766b357a16a5475a7d45493d94532 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,7 +12,8 @@ from km3io.tools import (to_num, cached_property, unfold_indices, unique, uniquecount, fitinf, fitparams, count_nested, _find, mask, best_track, rec_types, get_w2list_param, get_multiplicity, best_jmuon, best_jshower, - best_aashower, best_dusjshower) + best_aashower, best_dusjshower, w2list_genhen_keys, + w2list_gseagen_keys) OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) @@ -353,7 +354,7 @@ class TestRecStagesMasks(unittest.TestCase): assert labels[0][2] == 0 assert labels[1][0] == 0 - def test_mask_with_explicit_rec_stages_with_multiple_events(self): + def test_mask_with_explicit_rec_stages_in_list_with_multiple_events(self): rec_stages = self.tracks.rec_stages stages = [1, 3, 5, 4] masks = mask(self.tracks, stages=stages) @@ -362,6 +363,16 @@ class TestRecStagesMasks(unittest.TestCase): assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages)) assert masks[0][1] == False + def test_mask_with_explicit_rec_stages_in_set_with_multiple_events(self): + stages = {1, 2, 3, 4, 5} + masks = mask(self.tracks, stages=stages) + tracks = self.tracks[masks] + + assert 1 in tracks.rec_stages[0][0] + assert 3 in tracks.rec_stages[0][0] + assert 4 in tracks.rec_stages[0][0] + assert 5 in tracks.rec_stages[0][0] + def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self): rec_stages = self.tracks.rec_stages stages = [1, 3, 5, 4] @@ -389,6 +400,24 @@ class TestRecStagesMasks(unittest.TestCase): assert track[masks].rec_stages[0][0] == stages[0] assert track[masks].rec_stages[0][1] == stages[1] + def test_mask_raises_when_too_many_inputs(self): + with self.assertRaises(ValueError): + mask(self.tracks, start=1, end=4, stages=[1, 3, 5, 4]) + + def test_mask_raises_when_no_inputs(self): + with self.assertRaises(ValueError): + mask(self.tracks) + + +class TestW2listGenhenKeys(unittest.TestCase): + def test_w2list_genhen_keys(self): + assert 'W2LIST_GENHEN_REFF' in w2list_genhen_keys() + + +class TestW2listGseangenKeys(unittest.TestCase): + def test_w2list_gseagen_keys(self): + assert 'W2LIST_GSEAGEN_PS' in w2list_gseagen_keys() + class TestUnique(unittest.TestCase): def run_random_test_with_dtype(self, dtype):