From 36e438534e014a0ae459e5a6b7b3c35aaf939427 Mon Sep 17 00:00:00 2001
From: zineb aly <aly.zineb.az@gmail.com>
Date: Tue, 21 Apr 2020 22:09:52 +0200
Subject: [PATCH] test rec_stages masks

---
 tests/test_offline.py | 27 ++++++++++++++++++++++++++-
 1 file changed, 26 insertions(+), 1 deletion(-)

diff --git a/tests/test_offline.py b/tests/test_offline.py
index 14cd392..0df2e8f 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -4,7 +4,7 @@ import awkward1 as ak1
 from pathlib import Path
 
 from km3io import OfflineReader
-from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested
+from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask
 
 SAMPLES_DIR = Path(__file__).parent / 'samples'
 OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
@@ -46,6 +46,31 @@ class TestCountNested(unittest.TestCase):
         assert count_nested(fit, axis=2)[0][0:4] == ak1.Array([17, 11, 8, 8])
 
 
+class TestRecStagesMasks(unittest.TestCase):
+    def setUp(self):
+        self.nested = ak1.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
+                                 [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
+
+    def test_find(self):
+        builder = ak1.ArrayBuilder()
+        _find(self.nested, ak1.Array([1, 2, 3]), builder)
+        labels = builder.snapshot()
+
+        assert labels[0][0] == 1
+        assert labels[0][1] == 1
+        assert labels[0][2] == 0
+        assert labels[1][0] == 0
+
+    def test_mask(self):
+        rec_stages = OFFLINE_FILE.events.tracks.rec_stages
+        stages = [1, 3, 5, 4]
+        masks = mask(rec_stages, stages)
+
+        assert masks[0][0] == all(rec_stages[0][0] == ak1.Array(stages))
+        assert masks[1][0] == all(rec_stages[1][0] == ak1.Array(stages))
+        assert masks[0][1] == False
+
+
 class TestOfflineReader(unittest.TestCase):
     def setUp(self):
         self.r = OFFLINE_FILE
-- 
GitLab