From f62e97f51beacd359a88bb2f9ba075a42d6ecd22 Mon Sep 17 00:00:00 2001 From: zineb aly <aly.zineb.az@gmail.com> Date: Tue, 26 May 2020 09:29:47 +0200 Subject: [PATCH] test best track --- tests/test_offline.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/test_offline.py b/tests/test_offline.py index 965ead5..a89a584 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -4,12 +4,15 @@ import awkward1 as ak1 from pathlib import Path from km3io import OfflineReader -from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask +from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask, best_track SAMPLES_DIR = Path(__file__).parent / 'samples' OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root') OFFLINE_USR = OfflineReader(SAMPLES_DIR / 'usr-sample.root') -OFFLINE_MC_TRACK_USR = OfflineReader(SAMPLES_DIR / 'mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root') +OFFLINE_MC_TRACK_USR = OfflineReader( + SAMPLES_DIR / + 'mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root' +) OFFLINE_NUMUCC = OfflineReader(SAMPLES_DIR / "numucc.root") # with mc data @@ -38,6 +41,21 @@ class TestFitinf(unittest.TestCase): assert "JGANDALF_BETA0_RAD" in keys +class TestBestTrack(unittest.TestCase): + def setUp(self): + self.tracks = OFFLINE_FILE.events.tracks + + def test_best_tracks(self): + first_tracks = best_track(self.tracks, strategy="first") + rec_stages_tracks = best_track(self.tracks, + strategy="rec_stages", + rec_stages=[1, 3, 5, 4]) + + assert first_tracks.dir_z[0] == self.tracks.dir_z[0][0] + assert first_tracks.dir_x[1] == self.tracks.dir_x[1][0] + assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4] + + class TestCountNested(unittest.TestCase): def test_count_nested(self): fit = OFFLINE_FILE.events.tracks.fitinf @@ -442,10 +460,12 @@ class TestMcTrackUsr(unittest.TestCase): def test_usr_names(self): n_tracks = len(self.f.events) for i in range(3): - self.assertListEqual([b'bx', b'by', b'ichan', b'cc'], - self.f.events.mc_tracks.usr_names[i][0].tolist()) - self.assertListEqual([b'energy_lost_in_can'], - self.f.events.mc_tracks.usr_names[i][1].tolist()) + self.assertListEqual( + [b'bx', b'by', b'ichan', b'cc'], + self.f.events.mc_tracks.usr_names[i][0].tolist()) + self.assertListEqual( + [b'energy_lost_in_can'], + self.f.events.mc_tracks.usr_names[i][1].tolist()) def test_usr(self): assert np.allclose([0.0487, 0.0588, 3, 2], -- GitLab