Skip to content
Snippets Groups Projects
Commit 36e43853 authored by Zineb Aly's avatar Zineb Aly
Browse files

test rec_stages masks

parent addbf250
No related branches found
No related tags found
No related merge requests found
Pipeline #10694 failed
......@@ -4,7 +4,7 @@ import awkward1 as ak1
from pathlib import Path
from km3io import OfflineReader
from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested
from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask
SAMPLES_DIR = Path(__file__).parent / 'samples'
OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
......@@ -46,6 +46,31 @@ class TestCountNested(unittest.TestCase):
assert count_nested(fit, axis=2)[0][0:4] == ak1.Array([17, 11, 8, 8])
class TestRecStagesMasks(unittest.TestCase):
def setUp(self):
self.nested = ak1.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
[[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
def test_find(self):
builder = ak1.ArrayBuilder()
_find(self.nested, ak1.Array([1, 2, 3]), builder)
labels = builder.snapshot()
assert labels[0][0] == 1
assert labels[0][1] == 1
assert labels[0][2] == 0
assert labels[1][0] == 0
def test_mask(self):
rec_stages = OFFLINE_FILE.events.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(rec_stages, stages)
assert masks[0][0] == all(rec_stages[0][0] == ak1.Array(stages))
assert masks[1][0] == all(rec_stages[1][0] == ak1.Array(stages))
assert masks[0][1] == False
class TestOfflineReader(unittest.TestCase):
def setUp(self):
self.r = OFFLINE_FILE
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment