diff --git a/km3io/__init__.py b/km3io/__init__.py index 2cb507fe32e62be07ef4aaf275212fe127b61546..52ba6348e74fb7334c04ebe5a1a1b1025869c31a 100644 --- a/km3io/__init__.py +++ b/km3io/__init__.py @@ -1,7 +1,3 @@ -import numpy as np -import awkward as ak -import awkward1 as ak1 - from pkg_resources import get_distribution, DistributionNotFound version = get_distribution(__name__).version @@ -9,18 +5,4 @@ version = get_distribution(__name__).version from .offline import OfflineReader from .online import OnlineReader from .gseagen import GSGReader - -# to avoid infinite recursion -old_getitem = ak.ChunkedArray.__getitem__ - - -def new_getitem(self, item): - """Monkey patch the getitem in awkward.ChunkedArray to apply - awkward1.Array masks on awkward.ChunkedArray""" - if isinstance(item, (ak1.Array, ak.ChunkedArray)): - return ak1.Array(self)[item] - else: - return old_getitem(self, item) - - -ak.ChunkedArray.__getitem__ = new_getitem +from . import patches diff --git a/km3io/offline.py b/km3io/offline.py index b3ed691ed3b73f6503d2a5420756ed1472989b3e..e51f319d33220b03f2a7b304069044ee128967c7 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -4,7 +4,7 @@ import warnings import numba as nb import awkward1 as ak1 -from .definitions import mc_header, fitparameters +from .definitions import mc_header, fitparameters, reconstruction from .tools import cached_property, to_num, unfold_indices from .rootio import Branch, BranchMapper @@ -21,6 +21,18 @@ def _nested_mapper(key): return '_'.join(key.split('.')[1:]) +def rec_types(): + """name of the reconstruction type as defined in the official + KM3NeT-Dataformat. + + Returns + ------- + dict_keys + reconstruction types. + """ + return reconstruction.keys() + + def fitinf(fitparam, tracks): """Access fit parameters in tracks.fitinf. @@ -140,7 +152,7 @@ def mask(rec_stages, stages): return builder.snapshot() == 1 -def best_track(tracks, strategy="first", rec_stages=None): +def best_track(tracks, strategy="first", rec_type=None, rec_stages=None): """best track selection based on different strategies Parameters @@ -154,6 +166,9 @@ def best_track(tracks, strategy="first", rec_stages=None): return tracks[:, 0] if strategy == "rec_stages" and rec_stages is not None: return tracks[mask(tracks.rec_stages, rec_stages)] + if strategy == "default" and rec_type is not None: + return tracks[tracks.rec_type == reconstruction[rec_type]][ + tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0] EVENTS_MAP = BranchMapper(name="events", diff --git a/km3io/patches.py b/km3io/patches.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee5b34d961b1ce31f52f20751b6e7de1a602f6b --- /dev/null +++ b/km3io/patches.py @@ -0,0 +1,17 @@ +import awkward as ak +import awkward1 as ak1 + +# to avoid infinite recursion +old_getitem = ak.ChunkedArray.__getitem__ + + +def new_getitem(self, item): + """Monkey patch the getitem in awkward.ChunkedArray to apply + awkward1.Array masks on awkward.ChunkedArray""" + if isinstance(item, (ak1.Array, ak.ChunkedArray)): + return ak1.Array(self)[item] + else: + return old_getitem(self, item) + + +ak.ChunkedArray.__getitem__ = new_getitem \ No newline at end of file diff --git a/tests/test_offline.py b/tests/test_offline.py index a89a5840fc069dd17fd3b189bc6f3de4344d64b8..3ccc07459a4a159a1142c1a5261deec458172a52 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -4,7 +4,7 @@ import awkward1 as ak1 from pathlib import Path from km3io import OfflineReader -from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask, best_track +from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask, best_track, rec_types SAMPLES_DIR = Path(__file__).parent / 'samples' OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root') @@ -41,6 +41,13 @@ class TestFitinf(unittest.TestCase): assert "JGANDALF_BETA0_RAD" in keys +class TestRecoTypes(unittest.TestCase): + def test_reco_types(self): + keys = set(rec_types()) + + assert "JPP_RECONSTRUCTION_TYPE" in keys + + class TestBestTrack(unittest.TestCase): def setUp(self): self.tracks = OFFLINE_FILE.events.tracks @@ -50,10 +57,19 @@ class TestBestTrack(unittest.TestCase): rec_stages_tracks = best_track(self.tracks, strategy="rec_stages", rec_stages=[1, 3, 5, 4]) + default_best = best_track(self.tracks, + strategy="default", + rec_type="JPP_RECONSTRUCTION_TYPE") assert first_tracks.dir_z[0] == self.tracks.dir_z[0][0] assert first_tracks.dir_x[1] == self.tracks.dir_x[1][0] + assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4] + assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4] + + assert default_best.lik[0] == ak1.max(self.tracks.lik[0]) + assert default_best.lik[1] == ak1.max(self.tracks.lik[1]) + assert default_best.rec_type[0] == 4000 class TestCountNested(unittest.TestCase):