diff --git a/km3io/tools.py b/km3io/tools.py index e28acdbf1311a44f6908d3b9c38208e0a674861b..d6cff15c0625737e5f501146cd17825170ab649c 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -315,13 +315,21 @@ def best_track(tracks, strategy="default", rec_type=None): else: out = tracks[:, 0] - elif strategy == "default" and rec_type is not None: + if strategy == "default" and rec_type is None: + raise KeyError( + "rec_type must be provided when the default strategy is used.") + + if strategy == "default" and rec_type is not None: if n_events == 1: - out = tracks[tracks.rec_type == reconstruction[rec_type]][ - tracks.lik == ak1.max(tracks.lik, axis=0)][0] + len_stages = count_nested(tracks.rec_stages, axis=1) + rec_types = tracks[tracks.rec_type == reconstruction[rec_type]] + longest = rec_types[len_stages == ak1.max(len_stages, axis=0)] + out = longest[longest.lik == ak1.max(longest.lik, axis=0)] else: - out = tracks[tracks.rec_type == reconstruction[rec_type]][ - tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0] + len_stages = count_nested(tracks.rec_stages, axis=2) + rec_types = tracks[tracks.rec_type == reconstruction[rec_type]] + longest = rec_types[len_stages == ak1.max(len_stages, axis=1)] + out = longest[longest.lik == ak1.max(longest.lik, axis=1)] return out