Skip to content
Snippets Groups Projects
Commit c41a8731 authored by Zineb Aly's avatar Zineb Aly
Browse files

add tests for rec_stages masks for signle event

parent bb3bdc67
No related branches found
No related tags found
1 merge request!45Adapt best track root access
Pipeline #14284 passed with warnings
......@@ -159,22 +159,22 @@ class TestCountNested(unittest.TestCase):
class TestRecStagesMasks(unittest.TestCase):
def setUp(self):
# self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
# [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
[[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
self.tracks = OFFLINE_FILE.events.tracks
# def test_find(self):
# builder = ak.ArrayBuilder()
# _find(self.nested, ak.Array([1, 2, 3]), builder)
# labels = builder.snapshot()
def test_find(self):
builder = ak.ArrayBuilder()
_find(self.nested, ak.Array([1, 2, 3]), builder)
labels = builder.snapshot()
# assert labels[0][0] == 1
# assert labels[0][1] == 1
# assert labels[0][2] == 0
# assert labels[1][0] == 0
assert labels[0][0] == 1
assert labels[0][1] == 1
assert labels[0][2] == 0
assert labels[1][0] == 0
def test_mask_with_explicit_rec_stages(self):
def test_mask_with_explicit_rec_stages_with_multiple_events(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(self.tracks, stages=stages)
......@@ -183,7 +183,7 @@ class TestRecStagesMasks(unittest.TestCase):
assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
assert masks[0][1] == False
def test_mask_with_start_and_end_of_rec_stages(self):
def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(self.tracks, start=1, end=4)
......@@ -192,6 +192,24 @@ class TestRecStagesMasks(unittest.TestCase):
assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
assert masks[0][1] == False
def test_mask_with_start_and_end_of_rec_stages_signle_event(self):
rec_stages = self.tracks.rec_stages[0][0]
stages = [1, 3, 5, 4]
track = self.tracks[0]
masks = mask(track, start=1, end=4)
assert track[masks].rec_stages[0][0] == 1
assert track[masks].rec_stages[0][-1] == 4
def test_mask_with_explicit_rec_stages_with_single_event(self):
rec_stages = self.tracks.rec_stages[0][0]
stages = [1, 3]
track = self.tracks[0]
masks = mask(track, stages=stages)
assert track[masks].rec_stages[0][0] == stages[0]
assert track[masks].rec_stages[0][1] == stages[1]
class TestUnique(unittest.TestCase):
def run_random_test_with_dtype(self, dtype):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment