diff --git a/km3io/tools.py b/km3io/tools.py index 22462762068b4d9bc32a6174265d28d4352dc080..f0694d9e4e2e13a1ebdf24ba96d5004f34da0831 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -271,13 +271,13 @@ def mask(rec_stages, stages): return builder.snapshot() == 1 -def best_track(events, strategy="first", rec_type=None, rec_stages=None): +def best_track(events, strategy="default", rec_type=None, rec_stages=None): """best track selection based on different strategies Parameters ---------- events : class km3io.offline.OfflineBranch - the events branch. + a subset of reconstructed events where `events.n_tracks > 0` is always true. strategy : str the trategy desired to select the best tracks. It is either: - "first" : to select the first tracks. @@ -296,14 +296,32 @@ def best_track(events, strategy="first", rec_type=None, rec_stages=None): ------- class km3io.offline.OfflineBranch tracks class with the desired "best tracks" selection. + + Raises + ------ + ValueError + ValueError raised when: + - an invalid strategy is requested. + - a subset of events with empty tracks is used. """ - tracks = events.tracks[events.n_tracks > 0] + options = ['first', 'rec_stages', 'default'] + if strategy not in options: + raise ValueError("{} not in {}".format(strategy, options)) + + if any(events.n_tracks == 0): + raise ValueError( + "'events' should not contain empty tracks. Consider applying the mask: events.n_tracks>0" + ) + + tracks = events.tracks if strategy == "first": - return tracks[:, 0] + out = tracks[:, 0] - if strategy == "rec_stages" and rec_stages is not None: - return tracks[mask(tracks.rec_stages, rec_stages)] + elif strategy == "rec_stages" and rec_stages is not None: + out = tracks[mask(tracks.rec_stages, rec_stages)] - if strategy == "default" and rec_type is not None: - return tracks[tracks.rec_type == reconstruction[rec_type]][ + elif strategy == "default" and rec_type is not None: + out = tracks[tracks.rec_type == reconstruction[rec_type]][ tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0] + + return out diff --git a/tests/test_tools.py b/tests/test_tools.py index 71c4a409c0c603d873e97e934fbb9bc3279f816c..743ba8298d7a23cabe3ee2b1a48fbec447b4c1b2 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -56,24 +56,28 @@ class TestBestTrack(unittest.TestCase): self.events = OFFLINE_FILE.events def test_best_tracks(self): - first_tracks = best_track(self.events, strategy="first") - rec_stages_tracks = best_track(self.events, + events = self.events[self.events.n_tracks > 0] + first_tracks = best_track(events, strategy="first") + rec_stages_tracks = best_track(events, strategy="rec_stages", rec_stages=[1, 3, 5, 4]) - default_best = best_track(self.events, + default_best = best_track(events, strategy="default", rec_type="JPP_RECONSTRUCTION_TYPE") - assert first_tracks.dir_z[0] == self.events.tracks.dir_z[0][0] - assert first_tracks.dir_x[1] == self.events.tracks.dir_x[1][0] + assert first_tracks.dir_z[0] == events.tracks.dir_z[0][0] + assert first_tracks.dir_x[1] == events.tracks.dir_x[1][0] assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4] assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4] - assert default_best.lik[0] == ak.max(self.events.tracks.lik[0]) - assert default_best.lik[1] == ak.max(self.events.tracks.lik[1]) + assert default_best.lik[0] == ak.max(events.tracks.lik[0]) + assert default_best.lik[1] == ak.max(events.tracks.lik[1]) assert default_best.rec_type[0] == 4000 + with self.assertRaises(ValueError): + best_track(events, strategy="Zineb") + class TestCountNested(unittest.TestCase): def test_count_nested(self):