Skip to content
Snippets Groups Projects
Commit f9312fd8 authored by Zineb Aly's avatar Zineb Aly
Browse files

resolve bugs in best_Track

parent 632f7aee
No related branches found
No related tags found
1 merge request!45Adapt best track root access
Pipeline #14280 passed with warnings
...@@ -218,124 +218,6 @@ def count_nested(Array, axis=0): ...@@ -218,124 +218,6 @@ def count_nested(Array, axis=0):
return ak1.count(Array, axis=2) return ak1.count(Array, axis=2)
@nb.jit(nopython=True)
def _find(rec_stages, stages, builder):
"""construct an awkward1 array with the same structure as tracks.rec_stages.
When stages are found, the Array is filled with value 1, otherwise it is filled
with value 0.
Parameters
----------
rec_stages : awkward1 Array
tracks.rec_stages .
stages : awkward1 Array
reconstruction stages of interest.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
for s in rec_stages:
builder.begin_list()
for i in s:
num_stages = len(i)
if num_stages == len(stages):
found = 0
for j in range(num_stages):
if i[j] == stages[j]:
found += 1
if found == num_stages:
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
def mask(rec_stages, stages):
"""create a mask on tracks.rec_stages .
Parameters
----------
rec_stages : awkward1 Array
tracks.rec_stages .
stages : list
reconstruction stages of interest.
Returns
-------
awkward1 Array
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
"""
builder = ak1.ArrayBuilder()
_find(rec_stages, ak1.Array(stages), builder)
return builder.snapshot() == 1
# def best_track(tracks, strategy="default", rec_type=None):
# """best track selection based on different strategies
# Parameters
# ----------
# tracks : class km3io.offline.OfflineBranch
# a subset of reconstructed tracks where `events.n_tracks > 0` is always true.
# strategy : str
# the trategy desired to select the best tracks. It is either:
# - "first" : to select the first tracks.
# - "default": to select the best tracks (the first ones) corresponding to a specific
# reconstruction algorithm (JGandalf, Jshowerfit, etc). This requires rec_type input.
# Example: best_track(my_tracks, strategy="default", rec_type="JPP_RECONSTRUCTION_TYPE").
# rec_type : str, optional
# reconstruction type as defined in the official KM3NeT-Dataformat.
# Returns
# -------
# class km3io.offline.OfflineBranch
# tracks class with the desired "best tracks" selection.
# Raises
# ------
# ValueError
# ValueError raised when:
# - an invalid strategy is requested.
# - a subset of events with empty tracks is used.
# """
# options = ['first', 'default']
# if strategy not in options:
# raise ValueError("{} not in {}".format(strategy, options))
# n_events = 1 if tracks.is_single else len(tracks)
# if n_events > 1 and any(count_nested(tracks.lik, axis=1) == 0):
# raise ValueError(
# "'events' should not contain empty tracks. Consider applying the mask: events.n_tracks>0"
# )
# if strategy == "first":
# if n_events == 1:
# out = tracks[0]
# else:
# out = tracks[:, 0]
# if strategy == "default" and rec_type is None:
# raise ValueError(
# "rec_type must be provided when the default strategy is used.")
# if strategy == "default" and rec_type is not None:
# if n_events == 1:
# rec_types = tracks[tracks.rec_type == krec[rec_type]]
# len_stages = count_nested(rec_types.rec_stages, axis=1)
# longest = rec_types[len_stages == ak1.max(len_stages, axis=0)]
# out = longest[longest.lik == ak1.max(longest.lik, axis=0)]
# else:
# rec_types = tracks[tracks.rec_type == krec[rec_type]]
# len_stages = count_nested(rec_types.rec_stages, axis=2)
# longest = rec_types[len_stages == ak1.max(len_stages, axis=1)]
# out = longest[longest.lik == ak1.max(longest.lik, axis=1)]
# return out
def get_multiplicity(tracks, rec_stages): def get_multiplicity(tracks, rec_stages):
"""tracks selection based on specific reconstruction stages (for multiplicity """tracks selection based on specific reconstruction stages (for multiplicity
calculations). calculations).
...@@ -382,10 +264,10 @@ def _max_lik_track(tracks): ...@@ -382,10 +264,10 @@ def _max_lik_track(tracks):
def _best_track(tracks, start=None, end=None, stages=[]): def _best_track(tracks, start=None, end=None, stages=[]):
if (len(stages) > 0) and (start is None) and (end is None): if (len(stages) > 0) and (start is None) and (end is None):
selected_tracks = tracks[mask(tracks.rec_stages, stages)] selected_tracks = tracks[mask(tracks.rec_stages, stages=stages)]
if (start is not None) and (end is not None) and (len(stages) == 0): if (start is not None) and (end is not None) and (len(stages) == 0):
selected_tracks = tracks[mask_tracks(tracks.rec_stages, start, end)] selected_tracks = tracks[mask(tracks.rec_stages, start=start, end=end)]
if (start is None) and (end is None) and (len(stages) == 0): if (start is None) and (end is None) and (len(stages) == 0):
# this should be modified to a log print and not just a simple print # this should be modified to a log print and not just a simple print
...@@ -424,6 +306,8 @@ def _DUSJShower_stages(): ...@@ -424,6 +306,8 @@ def _DUSJShower_stages():
def _reco_stages(reco): def _reco_stages(reco):
valid_recos = set(("JSHOWER", "JMUON", "AASHOWER", "DUSJSHOWER"))
if reco == "JSHOWER": if reco == "JSHOWER":
stages = _JShower_stages() stages = _JShower_stages()
...@@ -436,7 +320,7 @@ def _reco_stages(reco): ...@@ -436,7 +320,7 @@ def _reco_stages(reco):
if reco == "DUSJSHOWER": if reco == "DUSJSHOWER":
stages == _DUSJShower_stages() stages == _DUSJShower_stages()
else: if reco not in valid_recos:
raise KeyError( raise KeyError(
f"{reco} must be either: 'JSHOWER', 'JMUON', 'AASHOWER', 'DUSJSHOWER'." f"{reco} must be either: 'JSHOWER', 'JMUON', 'AASHOWER', 'DUSJSHOWER'."
) )
...@@ -451,7 +335,7 @@ def best_track(tracks, reco, start=None, end=None, stages=[]): ...@@ -451,7 +335,7 @@ def best_track(tracks, reco, start=None, end=None, stages=[]):
if (start is not None) and (end is not None): if (start is not None) and (end is not None):
if (start not in valid_stages) or (end not in valid_stages): if (start not in valid_stages) or (end not in valid_stages):
raise KeyError( raise KeyError(
f" start and/or end are not in JMuon reconstruction stages") f" start and/or end are not in {reco} reconstruction stages")
if len(stages) > 0: if len(stages) > 0:
if not set(stages).issubset(valid_stages): if not set(stages).issubset(valid_stages):
...@@ -479,6 +363,7 @@ def _find_between(rec_stages, start, end, builder): ...@@ -479,6 +363,7 @@ def _find_between(rec_stages, start, end, builder):
builder : awkward1.highlevel.ArrayBuilder builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder. awkward1 Array builder.
""" """
for s in rec_stages: for s in rec_stages:
builder.begin_list() builder.begin_list()
for i in s: for i in s:
...@@ -493,7 +378,7 @@ def _find_between(rec_stages, start, end, builder): ...@@ -493,7 +378,7 @@ def _find_between(rec_stages, start, end, builder):
builder.end_list() builder.end_list()
def mask_tracks(rec_stages, start, end): def _mask_rec_stages_between_start_end(rec_stages, start, end):
"""mask tracks where tracks.rec_stages are between start and end . """mask tracks where tracks.rec_stages are between start and end .
Parameters Parameters
...@@ -514,3 +399,86 @@ def mask_tracks(rec_stages, start, end): ...@@ -514,3 +399,86 @@ def mask_tracks(rec_stages, start, end):
builder = ak1.ArrayBuilder() builder = ak1.ArrayBuilder()
_find_between(rec_stages, start, end, builder) _find_between(rec_stages, start, end, builder)
return builder.snapshot() == 1 return builder.snapshot() == 1
@nb.jit(nopython=True)
def _find(rec_stages, stages, builder):
"""construct an awkward1 array with the same structure as tracks.rec_stages.
When stages are found, the Array is filled with value 1, otherwise it is filled
with value 0.
Parameters
----------
rec_stages : awkward1 Array
tracks.rec_stages .
stages : awkward1 Array
reconstruction stages of interest.
builder : awkward1.highlevel.ArrayBuilder
awkward1 Array builder.
"""
for s in rec_stages:
builder.begin_list()
for i in s:
num_stages = len(i)
if num_stages == len(stages):
found = 0
for j in range(num_stages):
if i[j] == stages[j]:
found += 1
if found == num_stages:
builder.append(1)
else:
builder.append(0)
else:
builder.append(0)
builder.end_list()
def _mask_explicit_rec_stages(rec_stages, stages):
"""create a mask on tracks.rec_stages .
Parameters
----------
rec_stages : awkward1 Array
tracks.rec_stages .
stages : list
reconstruction stages of interest.
Returns
-------
awkward1 Array
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
"""
builder = ak1.ArrayBuilder()
_find(rec_stages, ak1.Array(stages), builder)
return builder.snapshot() == 1
def mask(rec_stages, stages=None, start=None, end=None):
"""create a mask on tracks.rec_stages .
Parameters
----------
rec_stages : awkward1 Array
tracks.rec_stages .
stages : list
reconstruction stages of interest.
Returns
-------
awkward1 Array
an awkward1 Array mask where True corresponds to the positions
where stages were found. False otherwise.
"""
if (stages is None) and (start is None) and (end is None):
raise KeyError("either stages or (start and end) must be specified")
if (stages is not None) and (start is not None) and (end is not None):
raise ValueError("too many inputs are specified")
if (stages is not None) and (start is None) and (end is None):
return _mask_explicit_rec_stages(rec_stages, stages)
if (stages is None) and (start is not None) and (end is not None):
return _mask_rec_stages_between_start_end(rec_stages, start, end)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment