From 006b607ed8872b60f1363992b0be1b33d0ea6e80 Mon Sep 17 00:00:00 2001 From: zineb aly <aly.zineb.az@gmail.com> Date: Mon, 22 Jun 2020 15:11:22 +0200 Subject: [PATCH] adapt best track --- km3io/tools.py | 34 ++++++++++++++++++++++++++-------- tests/test_tools.py | 18 +++++++++++------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/km3io/tools.py b/km3io/tools.py index 2246276..f0694d9 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -271,13 +271,13 @@ def mask(rec_stages, stages): return builder.snapshot() == 1 -def best_track(events, strategy="first", rec_type=None, rec_stages=None): +def best_track(events, strategy="default", rec_type=None, rec_stages=None): """best track selection based on different strategies Parameters ---------- events : class km3io.offline.OfflineBranch - the events branch. + a subset of reconstructed events where `events.n_tracks > 0` is always true. strategy : str the trategy desired to select the best tracks. It is either: - "first" : to select the first tracks. @@ -296,14 +296,32 @@ def best_track(events, strategy="first", rec_type=None, rec_stages=None): ------- class km3io.offline.OfflineBranch tracks class with the desired "best tracks" selection. + + Raises + ------ + ValueError + ValueError raised when: + - an invalid strategy is requested. + - a subset of events with empty tracks is used. """ - tracks = events.tracks[events.n_tracks > 0] + options = ['first', 'rec_stages', 'default'] + if strategy not in options: + raise ValueError("{} not in {}".format(strategy, options)) + + if any(events.n_tracks == 0): + raise ValueError( + "'events' should not contain empty tracks. Consider applying the mask: events.n_tracks>0" + ) + + tracks = events.tracks if strategy == "first": - return tracks[:, 0] + out = tracks[:, 0] - if strategy == "rec_stages" and rec_stages is not None: - return tracks[mask(tracks.rec_stages, rec_stages)] + elif strategy == "rec_stages" and rec_stages is not None: + out = tracks[mask(tracks.rec_stages, rec_stages)] - if strategy == "default" and rec_type is not None: - return tracks[tracks.rec_type == reconstruction[rec_type]][ + elif strategy == "default" and rec_type is not None: + out = tracks[tracks.rec_type == reconstruction[rec_type]][ tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0] + + return out diff --git a/tests/test_tools.py b/tests/test_tools.py index 71c4a40..743ba82 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -56,24 +56,28 @@ class TestBestTrack(unittest.TestCase): self.events = OFFLINE_FILE.events def test_best_tracks(self): - first_tracks = best_track(self.events, strategy="first") - rec_stages_tracks = best_track(self.events, + events = self.events[self.events.n_tracks > 0] + first_tracks = best_track(events, strategy="first") + rec_stages_tracks = best_track(events, strategy="rec_stages", rec_stages=[1, 3, 5, 4]) - default_best = best_track(self.events, + default_best = best_track(events, strategy="default", rec_type="JPP_RECONSTRUCTION_TYPE") - assert first_tracks.dir_z[0] == self.events.tracks.dir_z[0][0] - assert first_tracks.dir_x[1] == self.events.tracks.dir_x[1][0] + assert first_tracks.dir_z[0] == events.tracks.dir_z[0][0] + assert first_tracks.dir_x[1] == events.tracks.dir_x[1][0] assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4] assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4] - assert default_best.lik[0] == ak.max(self.events.tracks.lik[0]) - assert default_best.lik[1] == ak.max(self.events.tracks.lik[1]) + assert default_best.lik[0] == ak.max(events.tracks.lik[0]) + assert default_best.lik[1] == ak.max(events.tracks.lik[1]) assert default_best.rec_type[0] == 4000 + with self.assertRaises(ValueError): + best_track(events, strategy="Zineb") + class TestCountNested(unittest.TestCase): def test_count_nested(self): -- GitLab