diff --git a/tests/test_offline.py b/tests/test_offline.py index 1506c8700c22ff77b3a051ff1f1d9012b44c6868..994df65cab75f530f5b267e7ba2269b301de0b3b 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -1,9 +1,10 @@ import unittest import numpy as np +import awkward1 as ak1 from pathlib import Path from km3io import OfflineReader -from km3io.offline import _nested_mapper, Header, fitinf, fitparams +from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested SAMPLES_DIR = Path(__file__).parent / 'samples' OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root') @@ -15,19 +16,34 @@ class TestFitinf(unittest.TestCase): def setUp(self): self.tracks = OFFLINE_FILE.events.tracks self.fit = self.tracks.fitinf + self.best = self.tracks[:, 0] + self.best_fit = self.best.fitinf def test_fitinf(self): beta = fitinf('JGANDALF_BETA0_RAD', self.tracks) + best_beta = fitinf('JGANDALF_BETA0_RAD', self.best) assert beta[0][0] == self.fit[0][0][0] assert beta[0][1] == self.fit[0][1][0] assert beta[0][2] == self.fit[0][2][0] + assert best_beta[0] == self.best_fit[0][0] + assert best_beta[1] == self.best_fit[1][0] + assert best_beta[2] == self.best_fit[2][0] + def test_fitparams(self): keys = set(fitparams()) assert "JGANDALF_BETA0_RAD" in keys +class TestCountNested(unittest.TestCase): + def test_count_nested(self): + fit = OFFLINE_FILE.events.tracks.fitinf + + assert count_nested(fit, axis=0) == 10 + assert count_nested(fit, axis=1)[0:4] == ak1.Array([56, 55, 56, 56]) + assert count_nested(fit, axis=2)[0][0:4] == ak1.Array([17, 11, 8, 8]) + class TestOfflineReader(unittest.TestCase): def setUp(self):