diff --git a/km3io/definitions/__init__.py b/km3io/definitions/__init__.py index f9d18d31391b2e59a2cd26145e69dbc9233e49cb..fbf2ee5d3bcbd1cb966048a2cda6dc4a28248d05 100644 --- a/km3io/definitions/__init__.py +++ b/km3io/definitions/__init__.py @@ -13,3 +13,10 @@ fitparameters_idx = {v: k for k, v in fitparameters.items()} reconstruction_idx = {v: k for k, v in reconstruction.items()} w2list_genhen_idx = {v: k for k, v in w2list_genhen.items()} w2list_gseagen_idx = {v: k for k, v in w2list_gseagen.items()} + +recos = { + "jmuon": reconstruction["JPP_RECONSTRUCTION_TYPE"], + "jshower": reconstruction["JPP_RECONSTRUCTION_TYPE"], + "dusjshower": reconstruction["DUSJ_RECONSTRUCTION_TYPE"], + "aashower": reconstruction["AANET_RECONSTRUCTION_TYPE"], +} diff --git a/km3io/tools.py b/km3io/tools.py index 0b856e2fafa37063608d21e9cf638a743e5b3f18..5855c525f24d7e3dd2ba4f50ce3fc91bc6a3ede5 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -4,7 +4,7 @@ import numpy as np import awkward1 as ak1 import uproot -from .definitions import fitparameters, reconstruction, w2list_genhen, w2list_gseagen +from .definitions import fitparameters, reconstruction, w2list_genhen, w2list_gseagen, recos # 110 MB based on the size of the largest basket found so far in km3net BASKET_CACHE_SIZE = 110 * 1024**2 @@ -270,7 +270,7 @@ def mask(rec_stages, stages): return builder.snapshot() == 1 -def best_track(tracks, strategy="default", rec_type=None): +def best_track(tracks, strategy="default", reco=None): """best track selection based on different strategies Parameters @@ -283,8 +283,8 @@ def best_track(tracks, strategy="default", rec_type=None): - "default": to select the best tracks (the first ones) corresponding to a specific reconstruction algorithm (JGandalf, Jshowerfit, etc). This requires rec_type input. Example: best_track(my_tracks, strategy="default", rec_type="JPP_RECONSTRUCTION_TYPE"). - rec_type : str, optional - reconstruction type as defined in the official KM3NeT-Dataformat. + reco : str, required when using the "default" strategy + Reconstruction type: e.g. jmuon, jshower, dusjshower or aashower Returns ------- @@ -294,9 +294,10 @@ def best_track(tracks, strategy="default", rec_type=None): Raises ------ ValueError - ValueError raised when: - - an invalid strategy is requested. - - a subset of events with empty tracks is used. + - an invalid strategy is requested. + - a subset of events with empty tracks is used. + KeyError + - unable to find reconstruction stages """ options = ['first', 'default'] if strategy not in options: @@ -311,27 +312,46 @@ def best_track(tracks, strategy="default", rec_type=None): if strategy == "first": if n_events == 1: - out = tracks[0] - else: - out = tracks[:, 0] + return tracks[0] + return tracks[:, 0] - if strategy == "default" and rec_type is None: + if strategy == "default" and reco is None: raise ValueError( - "rec_type must be provided when the default strategy is used.") + "The reco parameter must be provided when the default strategy is used.") + + if strategy == "default" and reco is not None: + if reco not in recos: + raise ValueError("Unknown reconstruction, please choose from: {}".format(", ".join(recos.keys))) + rec_type = recos[reco] + reco = reco.upper() + try: + rec_stage_min_idx = reconstruction[reco + "BEGIN"] + rec_stage_max_idx = reconstruction[reco + "END"] + except KeyError: + raise KeyError("Unable to find the reconstruction stages for {}".format(reco)) + # backward compat for aanet + if reco == "AASHOWER": + rec_stage_min_idx = 0 + rec_stage_max_idx = np.iinfo(np.int32).max - if strategy == "default" and rec_type is not None: if n_events == 1: + raise NotImplementedError + # first mask to select those with the correct reconstruction type rec_types = tracks[tracks.rec_type == reconstruction[rec_type]] + # TODO: check if rec_stages are all between rec_stage_min_idx and rec_stage_max_idx len_stages = count_nested(rec_types.rec_stages, axis=1) + # TODO: etc. longest = rec_types[len_stages == ak1.max(len_stages, axis=0)] out = longest[longest.lik == ak1.max(longest.lik, axis=0)] else: + # TODO: not done yet + raise NotImplementedError rec_types = tracks[tracks.rec_type == reconstruction[rec_type]] len_stages = count_nested(rec_types.rec_stages, axis=2) longest = rec_types[len_stages == ak1.max(len_stages, axis=1)] out = longest[longest.lik == ak1.max(longest.lik, axis=1)] - return out + return out def get_multiplicity(tracks, rec_stages):