Skip to content
Snippets Groups Projects
Commit 3a9f500a authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Merge remote-tracking branch 'origin/master'

parents abac4dd0 9f4af33c
No related branches found
No related tags found
No related merge requests found
Pipeline #11599 failed
import numpy as np
import awkward as ak
import awkward1 as ak1
from pkg_resources import get_distribution, DistributionNotFound
version = get_distribution(__name__).version
......@@ -9,18 +5,4 @@ version = get_distribution(__name__).version
from .offline import OfflineReader
from .online import OnlineReader
from .gseagen import GSGReader
# to avoid infinite recursion
old_getitem = ak.ChunkedArray.__getitem__
def new_getitem(self, item):
"""Monkey patch the getitem in awkward.ChunkedArray to apply
awkward1.Array masks on awkward.ChunkedArray"""
if isinstance(item, (ak1.Array, ak.ChunkedArray)):
return ak1.Array(self)[item]
else:
return old_getitem(self, item)
ak.ChunkedArray.__getitem__ = new_getitem
from . import patches
......@@ -4,7 +4,7 @@ import warnings
import numba as nb
import awkward1 as ak1
from .definitions import mc_header, fitparameters
from .definitions import mc_header, fitparameters, reconstruction
from .tools import cached_property, to_num, unfold_indices
from .rootio import Branch, BranchMapper
......@@ -21,6 +21,18 @@ def _nested_mapper(key):
return '_'.join(key.split('.')[1:])
def rec_types():
"""name of the reconstruction type as defined in the official
KM3NeT-Dataformat.
Returns
-------
dict_keys
reconstruction types.
"""
return reconstruction.keys()
def fitinf(fitparam, tracks):
"""Access fit parameters in tracks.fitinf.
......@@ -140,7 +152,7 @@ def mask(rec_stages, stages):
return builder.snapshot() == 1
def best_track(tracks, strategy="first", rec_stages=None):
def best_track(tracks, strategy="first", rec_type=None, rec_stages=None):
"""best track selection based on different strategies
Parameters
......@@ -154,6 +166,9 @@ def best_track(tracks, strategy="first", rec_stages=None):
return tracks[:, 0]
if strategy == "rec_stages" and rec_stages is not None:
return tracks[mask(tracks.rec_stages, rec_stages)]
if strategy == "default" and rec_type is not None:
return tracks[tracks.rec_type == reconstruction[rec_type]][
tracks.lik == ak1.max(tracks.lik, axis=1)][:, 0]
EVENTS_MAP = BranchMapper(name="events",
......
import awkward as ak
import awkward1 as ak1
# to avoid infinite recursion
old_getitem = ak.ChunkedArray.__getitem__
def new_getitem(self, item):
"""Monkey patch the getitem in awkward.ChunkedArray to apply
awkward1.Array masks on awkward.ChunkedArray"""
if isinstance(item, (ak1.Array, ak.ChunkedArray)):
return ak1.Array(self)[item]
else:
return old_getitem(self, item)
ak.ChunkedArray.__getitem__ = new_getitem
\ No newline at end of file
......@@ -4,7 +4,7 @@ import awkward1 as ak1
from pathlib import Path
from km3io import OfflineReader
from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask, best_track
from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask, best_track, rec_types
SAMPLES_DIR = Path(__file__).parent / 'samples'
OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
......@@ -41,6 +41,13 @@ class TestFitinf(unittest.TestCase):
assert "JGANDALF_BETA0_RAD" in keys
class TestRecoTypes(unittest.TestCase):
def test_reco_types(self):
keys = set(rec_types())
assert "JPP_RECONSTRUCTION_TYPE" in keys
class TestBestTrack(unittest.TestCase):
def setUp(self):
self.tracks = OFFLINE_FILE.events.tracks
......@@ -50,10 +57,19 @@ class TestBestTrack(unittest.TestCase):
rec_stages_tracks = best_track(self.tracks,
strategy="rec_stages",
rec_stages=[1, 3, 5, 4])
default_best = best_track(self.tracks,
strategy="default",
rec_type="JPP_RECONSTRUCTION_TYPE")
assert first_tracks.dir_z[0] == self.tracks.dir_z[0][0]
assert first_tracks.dir_x[1] == self.tracks.dir_x[1][0]
assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4]
assert rec_stages_tracks.rec_stages[1] == [1, 3, 5, 4]
assert default_best.lik[0] == ak1.max(self.tracks.lik[0])
assert default_best.lik[1] == ak1.max(self.tracks.lik[1])
assert default_best.rec_type[0] == 4000
class TestCountNested(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment