From 36e438534e014a0ae459e5a6b7b3c35aaf939427 Mon Sep 17 00:00:00 2001 From: zineb aly <aly.zineb.az@gmail.com> Date: Tue, 21 Apr 2020 22:09:52 +0200 Subject: [PATCH] test rec_stages masks --- tests/test_offline.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_offline.py b/tests/test_offline.py index 14cd392..0df2e8f 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -4,7 +4,7 @@ import awkward1 as ak1 from pathlib import Path from km3io import OfflineReader -from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested +from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask SAMPLES_DIR = Path(__file__).parent / 'samples' OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root') @@ -46,6 +46,31 @@ class TestCountNested(unittest.TestCase): assert count_nested(fit, axis=2)[0][0:4] == ak1.Array([17, 11, 8, 8]) +class TestRecStagesMasks(unittest.TestCase): + def setUp(self): + self.nested = ak1.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]], + [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]]) + + def test_find(self): + builder = ak1.ArrayBuilder() + _find(self.nested, ak1.Array([1, 2, 3]), builder) + labels = builder.snapshot() + + assert labels[0][0] == 1 + assert labels[0][1] == 1 + assert labels[0][2] == 0 + assert labels[1][0] == 0 + + def test_mask(self): + rec_stages = OFFLINE_FILE.events.tracks.rec_stages + stages = [1, 3, 5, 4] + masks = mask(rec_stages, stages) + + assert masks[0][0] == all(rec_stages[0][0] == ak1.Array(stages)) + assert masks[1][0] == all(rec_stages[1][0] == ak1.Array(stages)) + assert masks[0][1] == False + + class TestOfflineReader(unittest.TestCase): def setUp(self): self.r = OFFLINE_FILE -- GitLab