Skip to content
Snippets Groups Projects
Commit 0c6f0bb1 authored by Zineb Aly's avatar Zineb Aly
Browse files

adapt rec stages mask and best track to one track

parent f6aad06b
No related branches found
No related tags found
1 merge request!45Adapt best track root access
Pipeline #14282 failed
......@@ -264,10 +264,10 @@ def _max_lik_track(tracks):
def _best_track(tracks, start=None, end=None, stages=[]):
if (len(stages) > 0) and (start is None) and (end is None):
selected_tracks = tracks[mask(tracks.rec_stages, stages=stages)]
selected_tracks = tracks[mask(tracks, stages=stages)]
if (start is not None) and (end is not None) and (len(stages) == 0):
selected_tracks = tracks[mask(tracks.rec_stages, start=start, end=end)]
selected_tracks = tracks[mask(tracks, start=start, end=end)]
if (start is None) and (end is None) and (len(stages) == 0):
# this should be modified to a log print and not just a simple print
......@@ -378,7 +378,38 @@ def _find_between(rec_stages, start, end, builder):
builder.end_list()
def _mask_rec_stages_between_start_end(rec_stages, start, end):
@nb.jit(nopython=True)
def _find_between_single(rec_stages, start, end, builder):
"""construct an awkward1 array with the same structure as tracks.rec_stages.
When stages are between start and end, the Array is filled with value 1, otherwise it is filled
with value 0.
Parameters
----------
rec_stages : awkward1 Array
tracks.rec_stages .
start : int
start of reconstruction stages of interest.
end : int
end of reconstruction stages of interest.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
builder.begin_list()
for s in rec_stages:
num_stages = len(s)
if num_stages != 0:
if (s[0] == start) and (s[-1] == end):
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
def _mask_rec_stages_between_start_end(tracks, start, end):
"""mask tracks where tracks.rec_stages are between start and end .
Parameters
......@@ -397,8 +428,12 @@ def _mask_rec_stages_between_start_end(rec_stages, start, end):
where stages were found. False otherwise.
"""
builder = ak1.ArrayBuilder()
_find_between(rec_stages, start, end, builder)
return builder.snapshot() == 1
if tracks.is_single:
_find_between_single(tracks.rec_stages, start, end, builder)
return (builder.snapshot() == 1)[0]
else:
_find_between(tracks.rec_stages, start, end, builder)
return builder.snapshot() == 1
@nb.jit(nopython=True)
......@@ -434,7 +469,39 @@ def _find(rec_stages, stages, builder):
builder.end_list()
def _mask_explicit_rec_stages(rec_stages, stages):
@nb.jit(nopython=True)
def _find_single(rec_stages, stages, builder):
"""construct an awkward1 array with the same structure as tracks.rec_stages.
When stages are found, the Array is filled with value 1, otherwise it is filled
with value 0.
Parameters
----------
rec_stages : awkward1 Array
tracks.rec_stages .
stages : awkward1 Array
reconstruction stages of interest.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
builder.begin_list()
for s in rec_stages:
num_stages = len(s)
if num_stages == len(stages):
found = 0
for j in range(num_stages):
if s[j] == stages[j]:
found += 1
if found == num_stages:
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
def _mask_explicit_rec_stages(tracks, stages):
"""create a mask on tracks.rec_stages .
Parameters
......@@ -450,12 +517,17 @@ def _mask_explicit_rec_stages(rec_stages, stages):
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
"""
# rec_stages = tracks.rec_stages
builder = ak1.ArrayBuilder()
_find(rec_stages, ak1.Array(stages), builder)
return builder.snapshot() == 1
if tracks.is_single:
_find_single(tracks.rec_stages, ak1.Array(stages), builder)
return (builder.snapshot() == 1)[0]
else:
_find(tracks.rec_stages, ak1.Array(stages), builder)
return builder.snapshot() == 1
def mask(rec_stages, stages=None, start=None, end=None):
def mask(tracks, stages=None, start=None, end=None):
"""create a mask on tracks.rec_stages .
Parameters
......@@ -478,7 +550,7 @@ def mask(rec_stages, stages=None, start=None, end=None):
raise ValueError("too many inputs are specified")
if (stages is not None) and (start is None) and (end is None):
return _mask_explicit_rec_stages(rec_stages, stages)
return _mask_explicit_rec_stages(tracks, stages)
if (stages is None) and (start is not None) and (end is not None):
return _mask_rec_stages_between_start_end(rec_stages, start, end)
return _mask_rec_stages_between_start_end(tracks, start, end)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment