diff --git a/km3io/tools.py b/km3io/tools.py index d6cff15c0625737e5f501146cd17825170ab649c..69e88061f2e50bf792a9961657d5cbcb0985b083 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -316,7 +316,7 @@ def best_track(tracks, strategy="default", rec_type=None): out = tracks[:, 0] if strategy == "default" and rec_type is None: - raise KeyError( + raise ValueError( "rec_type must be provided when the default strategy is used.") if strategy == "default" and rec_type is not None: diff --git a/tests/test_tools.py b/tests/test_tools.py index 2bd8984c7ae7622844ea82f34232544411ec19e2..279222813252c383cbaf9af5d1f6e82918dd2dcf 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -58,8 +58,7 @@ class TestBestTrack(unittest.TestCase): self.events = OFFLINE_FILE.events self.one_event = OFFLINE_FILE.events[0] - def test_best_tracks(self): - # test selection from multiple events + def test_best_track_from_multiple_events(self): events = self.events[self.events.n_tracks > 0] first_tracks = best_track(events.tracks, strategy="first") default_best = best_track(events.tracks, @@ -73,7 +72,7 @@ class TestBestTrack(unittest.TestCase): assert default_best.lik[1] == ak.max(events.tracks.lik[1]) assert default_best.rec_type[0] == 4000 - # test selection from one event + def test_best_track_from_a_single_event(self): first_track = best_track(self.one_event.tracks, strategy="first") best = best_track(self.one_event.tracks, strategy="default", @@ -85,13 +84,13 @@ class TestBestTrack(unittest.TestCase): assert best.lik == ak.max(self.one_event.tracks.lik) assert best.rec_type == 4000 - # test raising ValueError + def test_best_track_raises_when_unknown_strategy(self): with self.assertRaises(ValueError): - best_track(events.tracks, strategy="Zineb") + best_track(self.events.tracks, strategy="Zineb") - # test raising KeyError - with self.assertRaises(KeyError): - best_track(events.tracks) + def test_best_track_raises_when_default_strategy_and_no_rectype(self): + with self.assertRaises(ValueError): + best_track(self.events.tracks) class TestGetMultiplicity(unittest.TestCase):