From c41a87316b586aca8b96214082235c657fcdf952 Mon Sep 17 00:00:00 2001 From: Zineb Aly <zaly@km3net.de> Date: Tue, 6 Oct 2020 12:16:16 +0200 Subject: [PATCH] add tests for rec_stages masks for signle event --- tests/test_tools.py | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 5031552..53518f9 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -159,22 +159,22 @@ class TestCountNested(unittest.TestCase): class TestRecStagesMasks(unittest.TestCase): def setUp(self): - # self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]], - # [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]]) + self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]], + [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]]) self.tracks = OFFLINE_FILE.events.tracks - # def test_find(self): - # builder = ak.ArrayBuilder() - # _find(self.nested, ak.Array([1, 2, 3]), builder) - # labels = builder.snapshot() + def test_find(self): + builder = ak.ArrayBuilder() + _find(self.nested, ak.Array([1, 2, 3]), builder) + labels = builder.snapshot() - # assert labels[0][0] == 1 - # assert labels[0][1] == 1 - # assert labels[0][2] == 0 - # assert labels[1][0] == 0 + assert labels[0][0] == 1 + assert labels[0][1] == 1 + assert labels[0][2] == 0 + assert labels[1][0] == 0 - def test_mask_with_explicit_rec_stages(self): + def test_mask_with_explicit_rec_stages_with_multiple_events(self): rec_stages = self.tracks.rec_stages stages = [1, 3, 5, 4] masks = mask(self.tracks, stages=stages) @@ -183,7 +183,7 @@ class TestRecStagesMasks(unittest.TestCase): assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages)) assert masks[0][1] == False - def test_mask_with_start_and_end_of_rec_stages(self): + def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self): rec_stages = self.tracks.rec_stages stages = [1, 3, 5, 4] masks = mask(self.tracks, start=1, end=4) @@ -192,6 +192,24 @@ class TestRecStagesMasks(unittest.TestCase): assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages)) assert masks[0][1] == False + def test_mask_with_start_and_end_of_rec_stages_signle_event(self): + rec_stages = self.tracks.rec_stages[0][0] + stages = [1, 3, 5, 4] + track = self.tracks[0] + masks = mask(track, start=1, end=4) + + assert track[masks].rec_stages[0][0] == 1 + assert track[masks].rec_stages[0][-1] == 4 + + def test_mask_with_explicit_rec_stages_with_single_event(self): + rec_stages = self.tracks.rec_stages[0][0] + stages = [1, 3] + track = self.tracks[0] + masks = mask(track, stages=stages) + + assert track[masks].rec_stages[0][0] == stages[0] + assert track[masks].rec_stages[0][1] == stages[1] + class TestUnique(unittest.TestCase): def run_random_test_with_dtype(self, dtype): -- GitLab