Skip to content
Snippets Groups Projects
Commit bb3bdc67 authored by Zineb Aly's avatar Zineb Aly
Browse files

debug tests

parent 0c6f0bb1
No related branches found
No related tags found
1 merge request!45Adapt best track root access
Pipeline #14283 passed with warnings
......@@ -234,7 +234,7 @@ def get_multiplicity(tracks, rec_stages):
class km3io.offline.OfflineBranch
tracks branch with the desired reconstruction stages only.
"""
return tracks[mask(tracks.rec_stages, rec_stages)]
return tracks[mask(tracks, rec_stages)]
def _longest_tracks(tracks):
......@@ -377,7 +377,6 @@ def _find_between(rec_stages, start, end, builder):
builder.append(0)
builder.end_list()
@nb.jit(nopython=True)
def _find_between_single(rec_stages, start, end, builder):
"""construct an awkward1 array with the same structure as tracks.rec_stages.
......
......@@ -159,25 +159,25 @@ class TestCountNested(unittest.TestCase):
class TestRecStagesMasks(unittest.TestCase):
def setUp(self):
self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
[[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
# self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]],
# [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]])
self.tracks = OFFLINE_FILE.events.tracks
def test_find(self):
builder = ak.ArrayBuilder()
_find(self.nested, ak.Array([1, 2, 3]), builder)
labels = builder.snapshot()
# def test_find(self):
# builder = ak.ArrayBuilder()
# _find(self.nested, ak.Array([1, 2, 3]), builder)
# labels = builder.snapshot()
assert labels[0][0] == 1
assert labels[0][1] == 1
assert labels[0][2] == 0
assert labels[1][0] == 0
# assert labels[0][0] == 1
# assert labels[0][1] == 1
# assert labels[0][2] == 0
# assert labels[1][0] == 0
def test_mask_with_explicit_rec_stages(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(rec_stages, stages=stages)
masks = mask(self.tracks, stages=stages)
assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages))
assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
......@@ -186,14 +186,13 @@ class TestRecStagesMasks(unittest.TestCase):
def test_mask_with_start_and_end_of_rec_stages(self):
rec_stages = self.tracks.rec_stages
stages = [1, 3, 5, 4]
masks = mask(rec_stages, start=1, end=4)
masks = mask(self.tracks, start=1, end=4)
assert masks[0][0] == all(rec_stages[0][0] == ak.Array(stages))
assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
assert masks[0][1] == False
class TestUnique(unittest.TestCase):
def run_random_test_with_dtype(self, dtype):
max_range = 100
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment