diff --git a/km3io/tools.py b/km3io/tools.py index 599268c8dffa836aae4874345e0d24cf28a2fc41..0536391f069c47bcbe1b1d217a7e72e7e31e337c 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -118,6 +118,7 @@ def get_w2list_param(events, generator, param): if (generator == "gseagen") and param in w2list_gseagen_keys: return events.w2list[:, kw2gsg[param]] + if generator == "genhen" and param in w2list_genhen_keys: return events.w2list[:, kw2gen[param]] @@ -136,7 +137,7 @@ def fitinf(fitparam, tracks): Returns ------- - awkward array + awkward1.Array awkward array of the values of the fit parameter requested. """ fit = tracks.fitinf @@ -203,7 +204,7 @@ def get_multiplicity(tracks, rec_stages): awkward1.Array tracks multiplicty. """ - masked_tracks = tracks[mask(tracks, rec_stages)] + masked_tracks = tracks[mask(tracks, stages=rec_stages)] if tracks.is_single: out = count_nested(masked_tracks.rec_stages, axis=0) @@ -251,24 +252,36 @@ def best_track(tracks, - no inputs are specified. """ - if (start_stages is None) and (end_stages is None) and (stages - is not None): - selected_tracks = tracks[mask(tracks, stages=stages)] - if (start_stages is not None) and (end_stages - is not None) and (stages is None): - selected_tracks = tracks[mask(tracks, - start_stages=start_stages, - end_stages=end_stages)] + inputs = [stages, start_stages, end_stages, min_stages, max_stages] + min_max = [min_stages, max_stages] + start_end = [start_stages, end_stages] - if (start_stages is None) and (end_stages is None) and (stages is None): + if all(v is None for v in inputs): raise ValueError("No reconstruction stages were specified") - if ((start_stages is not None) or - (end_stages is not None)) and (stages is not None): + if all(v is not None for v in inputs) or all(v is not None + for v in inputs[0:3]): raise ValueError( "Please specify either a range or a set of rec stages.") + if all(v is None for v in inputs[1:]) and (stages is not None): + selected_tracks = tracks[mask(tracks, stages=stages)] + + if all(v is not None + for v in start_end) and all(v is None + for v in min_max) and (stages is None): + selected_tracks = tracks[mask(tracks, + start_stages=start_stages, + end_stages=end_stages)] + + if all(v is not None + for v in min_max) and all(v is None + for v in start_end) and (stages is None): + selected_tracks = tracks[mask(tracks, + min_stages=start_stages, + max_stages=end_stages)] + return _max_lik_track(_longest_tracks(selected_tracks)) @@ -337,19 +350,21 @@ def mask(tracks, - too many inputs specified. - no inputs are specified. """ - if (stages is None) and (start_stages is None) and ( - end_stages is None) and (min_stages is None) and (max_stages is - None): + + inputs = [stages, start_stages, end_stages, min_stages, max_stages] + min_max = [min_stages, max_stages] + start_end = [start_stages, end_stages] + + if all(v is None for v in inputs): raise ValueError( - "either stages or (start_stages and end_stages) or (min_stages and max_stages) must be specified" + "either stages or (start_stages and end_stages) or (min_stages and max_stages) must be specified." ) - if (stages is not None) and (start_stages is not None) and (end_stages - is not None): - raise ValueError("too many inputs are specified") + if all(v is not None for v in inputs) or all(v is not None + for v in inputs[0:3]): + raise ValueError("too many inputs are specified.") - if (stages - is not None) and (start_stages is None) and (end_stages is None): + if (stages is not None) and all(v is None for v in inputs[1:]): if isinstance(stages, list): # order of stages is conserved return _mask_explicit_rec_stages(tracks, stages) @@ -358,13 +373,15 @@ def mask(tracks, return _mask_rec_stages_in_range_min_max(tracks, valid_stages=stages) - if (stages is None) and (start_stages is not None) and (end_stages - is not None): + if all(v is not None + for v in start_end) and all(v is None + for v in min_max) and (stages is None): return _mask_rec_stages_between_start_end(tracks, start_stages, end_stages) - if (stages is None) and (min_stages is not None) and (max_stages - is not None): + if all(v is None + for v in start_end) and all(v is not None + for v in min_max) and (stages is None): return _mask_rec_stages_in_range_min_max(tracks, min_stages=min_stages, max_stages=max_stages)