diff --git a/km3io/tools.py b/km3io/tools.py index a422264e88bbc235051eac33d45ef6fa50808c42..c4c157a722fddd6954c484f306ea45b46446b99d 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -145,14 +145,22 @@ def fitinf(fitparam, tracks): """ fit = tracks.fitinf index = fitparam - try: - params = fit[count_nested(fit, axis=2) > index] - return ak1.Array([i[:, index] for i in params]) - except ValueError: - # This is the case for tracks[:, 0] or any other selection. + if tracks.is_single and len(tracks) != 1: params = fit[count_nested(fit, axis=1) > index] - return params[:, index] + out = params[:, index] + if tracks.is_single and len(tracks) == 1: + out = fit[index] + + else: + if len(tracks[0]) == 1: # case of tracks slice with 1 track per event. + params = fit[count_nested(fit, axis=1) > index] + out = params[:, index] + else: + params = fit[count_nested(fit, axis=2) > index] + out = ak1.Array([i[:, index] for i in params]) + + return out def count_nested(arr, axis=0): """Count elements in a nested awkward Array. diff --git a/tests/test_tools.py b/tests/test_tools.py index d5cfd00090250df768a135936e59d202db0bc5f1..d576c67f445076a681c68ceb4bb3e1238126a825 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -39,17 +39,24 @@ class TestFitinf(unittest.TestCase): self.best = self.tracks[:, 0] self.best_fit = self.best.fitinf - def test_fitinf(self): + def test_fitinf_from_all_events(self): beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks) - best_beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best) assert beta[0][0] == self.fit[0][0][0] assert beta[0][1] == self.fit[0][1][0] assert beta[0][2] == self.fit[0][2][0] - assert best_beta[0] == self.best_fit[0][0] - assert best_beta[1] == self.best_fit[1][0] - assert best_beta[2] == self.best_fit[2][0] + def test_fitinf_from_one_event(self): + beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best) + + assert beta[0] == self.best_fit[0][0] + assert beta[1] == self.best_fit[1][0] + assert beta[2] == self.best_fit[2][0] + + def test_fitinf_from_one_event_and_one_track(self): + beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks[0][0]) + + assert beta == self.tracks[0][0].fitinf[0] class TestBestTrackSelection(unittest.TestCase):