Skip to content
Snippets Groups Projects
Commit 90b7bec7 authored by Zineb Aly's avatar Zineb Aly
Browse files

add additional tests for more coverage

parent c6e2f211
No related branches found
No related tags found
1 merge request!45Adapt best track root access
Pipeline #14313 passed with warnings
......@@ -12,7 +12,8 @@ from km3io.tools import (to_num, cached_property, unfold_indices, unique,
uniquecount, fitinf, fitparams, count_nested, _find,
mask, best_track, rec_types, get_w2list_param,
get_multiplicity, best_jmuon, best_jshower,
best_aashower, best_dusjshower)
best_aashower, best_dusjshower, w2list_genhen_keys,
w2list_gseagen_keys)
OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
......@@ -353,7 +354,7 @@ class TestRecStagesMasks(unittest.TestCase):
assert labels[0][2] == 0
assert labels[1][0] == 0
def test_mask_with_explicit_rec_stages_with_multiple_events(self):
def test_mask_with_explicit_rec_stages_in_list_with_multiple_events(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(self.tracks, stages=stages)
......@@ -362,6 +363,16 @@ class TestRecStagesMasks(unittest.TestCase):
assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
assert masks[0][1] == False
def test_mask_with_explicit_rec_stages_in_set_with_multiple_events(self):
stages = {1, 2, 3, 4, 5}
masks = mask(self.tracks, stages=stages)
tracks = self.tracks[masks]
assert 1 in tracks.rec_stages[0][0]
assert 3 in tracks.rec_stages[0][0]
assert 4 in tracks.rec_stages[0][0]
assert 5 in tracks.rec_stages[0][0]
def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
......@@ -389,6 +400,24 @@ class TestRecStagesMasks(unittest.TestCase):
assert track[masks].rec_stages[0][0] == stages[0]
assert track[masks].rec_stages[0][1] == stages[1]
def test_mask_raises_when_too_many_inputs(self):
with self.assertRaises(ValueError):
mask(self.tracks, start=1, end=4, stages=[1, 3, 5, 4])
def test_mask_raises_when_no_inputs(self):
with self.assertRaises(ValueError):
mask(self.tracks)
class TestW2listGenhenKeys(unittest.TestCase):
def test_w2list_genhen_keys(self):
assert 'W2LIST_GENHEN_REFF' in w2list_genhen_keys()
class TestW2listGseangenKeys(unittest.TestCase):
def test_w2list_gseagen_keys(self):
assert 'W2LIST_GSEAGEN_PS' in w2list_gseagen_keys()
class TestUnique(unittest.TestCase):
def run_random_test_with_dtype(self, dtype):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment