Skip to content
Snippets Groups Projects
Commit 0a92e8c2 authored by Zineb Aly's avatar Zineb Aly
Browse files

add tests

parent 72a9ef61
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ from pathlib import Path
from km3io import OfflineReader
from km3io.tools import (to_num, cached_property, unfold_indices, unique,
uniquecount, fitinf, fitparams, count_nested, _find,
mask, best_track, rec_types, get_w2list_param)
mask, best_track, rec_types, get_w2list_param, get_multiplicity)
SAMPLES_DIR = Path(__file__).parent / 'samples'
OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'km3net_offline.root')
......@@ -54,13 +54,12 @@ class TestRecoTypes(unittest.TestCase):
class TestBestTrack(unittest.TestCase):
def setUp(self):
self.events = OFFLINE_FILE.events
self.one_event = OFFLINE_FILE.events[0]
def test_best_tracks(self):
# test selection from multiple events
events = self.events[self.events.n_tracks > 0]
first_tracks = best_track(events, strategy="first")
rec_stages_tracks = best_track(events,
strategy="rec_stages",
rec_stages=[1, 3, 5, 4])
default_best = best_track(events,
strategy="default",
rec_type="JPP_RECONSTRUCTION_TYPE")
......@@ -68,17 +67,34 @@ class TestBestTrack(unittest.TestCase):
assert first_tracks.dir_z[0] == events.tracks.dir_z[0][0]
assert first_tracks.dir_x[1] == events.tracks.dir_x[1][0]
assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4]
assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4]
assert default_best.lik[0] == ak.max(events.tracks.lik[0])
assert default_best.lik[1] == ak.max(events.tracks.lik[1])
assert default_best.rec_type[0] == 4000
# test selection from one event
first_track = best_track(self.one_event, strategy="first")
best = best_track(self.one_event,
strategy="default",
rec_type="JPP_RECONSTRUCTION_TYPE")
assert first_track.dir_z == self.one_event.tracks.dir_z[0]
assert first_track.lik == self.one_event.tracks.lik[0]
assert best.lik == ak.max(self.one_event.tracks.lik)
assert best.rec_type == 4000
# test raising ValueError
with self.assertRaises(ValueError):
best_track(events, strategy="Zineb")
class TestGetMultiplicity(unittest.TestCase):
def test_get_multiplicity(self):
rec_stages_tracks = get_multiplicity(OFFLINE_FILE.events.tracks, [1, 3, 5, 4])
assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4]
assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4]
class TestCountNested(unittest.TestCase):
def test_count_nested(self):
fit = OFFLINE_FILE.events.tracks.fitinf
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment