Skip to content
Snippets Groups Projects
Commit 006b607e authored by Zineb Aly's avatar Zineb Aly
Browse files

adapt best track

parent e77f79ee
No related branches found
No related tags found
No related merge requests found
Pipeline #12334 passed with warnings
......@@ -271,13 +271,13 @@ def mask(rec_stages, stages):
return builder.snapshot() == 1
def best_track(events, strategy="first", rec_type=None, rec_stages=None):
def best_track(events, strategy="default", rec_type=None, rec_stages=None):
"""best track selection based on different strategies
Parameters
----------
events : class km3io.offline.OfflineBranch
the events branch.
a subset of reconstructed events where `events.n_tracks > 0` is always true.
strategy : str
the trategy desired to select the best tracks. It is either:
- "first" : to select the first tracks.
......@@ -296,14 +296,32 @@ def best_track(events, strategy="first", rec_type=None, rec_stages=None):
-------
class km3io.offline.OfflineBranch
tracks class with the desired "best tracks" selection.
Raises
------
ValueError
ValueError raised when:
- an invalid strategy is requested.
- a subset of events with empty tracks is used.
"""
tracks = events.tracks[events.n_tracks > 0]
options = ['first', 'rec_stages', 'default']
if strategy not in options:
raise ValueError("{} not in {}".format(strategy, options))
if any(events.n_tracks == 0):
raise ValueError(
"'events' should not contain empty tracks. Consider applying the mask: events.n_tracks>0"
)
tracks = events.tracks
if strategy == "first":
return tracks[:, 0]
out = tracks[:, 0]
if strategy == "rec_stages" and rec_stages is not None:
return tracks[mask(tracks.rec_stages, rec_stages)]
elif strategy == "rec_stages" and rec_stages is not None:
out = tracks[mask(tracks.rec_stages, rec_stages)]
if strategy == "default" and rec_type is not None:
return tracks[tracks.rec_type == reconstruction[rec_type]][
elif strategy == "default" and rec_type is not None:
out = tracks[tracks.rec_type == reconstruction[rec_type]][
tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0]
return out
......@@ -56,24 +56,28 @@ class TestBestTrack(unittest.TestCase):
self.events = OFFLINE_FILE.events
def test_best_tracks(self):
first_tracks = best_track(self.events, strategy="first")
rec_stages_tracks = best_track(self.events,
events = self.events[self.events.n_tracks > 0]
first_tracks = best_track(events, strategy="first")
rec_stages_tracks = best_track(events,
strategy="rec_stages",
rec_stages=[1, 3, 5, 4])
default_best = best_track(self.events,
default_best = best_track(events,
strategy="default",
rec_type="JPP_RECONSTRUCTION_TYPE")
assert first_tracks.dir_z[0] == self.events.tracks.dir_z[0][0]
assert first_tracks.dir_x[1] == self.events.tracks.dir_x[1][0]
assert first_tracks.dir_z[0] == events.tracks.dir_z[0][0]
assert first_tracks.dir_x[1] == events.tracks.dir_x[1][0]
assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4]
assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4]
assert default_best.lik[0] == ak.max(self.events.tracks.lik[0])
assert default_best.lik[1] == ak.max(self.events.tracks.lik[1])
assert default_best.lik[0] == ak.max(events.tracks.lik[0])
assert default_best.lik[1] == ak.max(events.tracks.lik[1])
assert default_best.rec_type[0] == 4000
with self.assertRaises(ValueError):
best_track(events, strategy="Zineb")
class TestCountNested(unittest.TestCase):
def test_count_nested(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment