diff --git a/tests/test_offline.py b/tests/test_offline.py index 14cd3926f65d0506ce39f2314927c5c390a9474e..0df2e8f613fdf3c8860ee8f17ed5a4afdc67df0a 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -4,7 +4,7 @@ import awkward1 as ak1 from pathlib import Path from km3io import OfflineReader -from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested +from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask SAMPLES_DIR = Path(__file__).parent / 'samples' OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root') @@ -46,6 +46,31 @@ class TestCountNested(unittest.TestCase): assert count_nested(fit, axis=2)[0][0:4] == ak1.Array([17, 11, 8, 8]) +class TestRecStagesMasks(unittest.TestCase): + def setUp(self): + self.nested = ak1.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]], + [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]]) + + def test_find(self): + builder = ak1.ArrayBuilder() + _find(self.nested, ak1.Array([1, 2, 3]), builder) + labels = builder.snapshot() + + assert labels[0][0] == 1 + assert labels[0][1] == 1 + assert labels[0][2] == 0 + assert labels[1][0] == 0 + + def test_mask(self): + rec_stages = OFFLINE_FILE.events.tracks.rec_stages + stages = [1, 3, 5, 4] + masks = mask(rec_stages, stages) + + assert masks[0][0] == all(rec_stages[0][0] == ak1.Array(stages)) + assert masks[1][0] == all(rec_stages[1][0] == ak1.Array(stages)) + assert masks[0][1] == False + + class TestOfflineReader(unittest.TestCase): def setUp(self): self.r = OFFLINE_FILE