diff --git a/km3io/offline.py b/km3io/offline.py index 19cf5aa55d83f46af5790127e9148f332eef6ad7..de606c6306e1407a103da996365a52f3cbf3bd0f 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -79,7 +79,15 @@ class OfflineReader: "mc_tracks": "mc_trks", } - def __init__(self, f, index_chain=None, step_size=2000, keys=None, aliases=None, event_ctor=None): + def __init__( + self, + f, + index_chain=None, + step_size=2000, + keys=None, + aliases=None, + event_ctor=None, + ): """OfflineReader class is an offline ROOT file wrapper Parameters @@ -187,10 +195,12 @@ class OfflineReader: step_size=self._step_size, aliases=self.aliases, keys=self.keys(), - event_ctor=self._event_ctor + event_ctor=self._event_ctor, ) - if isinstance(key, str) and key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc. + if isinstance(key, str) and key.startswith( + "n_" + ): # group counts, for e.g. n_events, n_hits etc. key = self._keyfor(key.split("n_")[1]) arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4")) return unfold_indices(arr, self._index_chain) @@ -207,9 +217,7 @@ class OfflineReader: if from_field in branch[key].keys(): fields.append(to_field) log.debug(fields) - out = branch[key].arrays( - fields, aliases=self.special_branches[key] - ) + out = branch[key].arrays(fields, aliases=self.special_branches[key]) else: out = branch[self.aliases.get(key, key)].array() @@ -220,9 +228,57 @@ class OfflineReader: return self def _event_generator(self): - for i in range(len(self)): - yield self[i] - return + events = self._fobj[self.event_path] + group_count_keys = set( + k for k in self.keys() if k.startswith("n_") + ) # special keys to make it easy to count subbranch lengths + log.debug("group_count_keys: %s", group_count_keys) + keys = set( + list( + set(self.keys()) + - set(self.special_branches.keys()) + - set(self.special_aliases) + - group_count_keys + ) + + list(self.aliases.keys()) + ) # all top-level keys for regular branches + log.debug("keys: %s", keys) + log.debug("aliases: %s", self.aliases) + events_it = events.iterate( + keys, aliases=self.aliases, step_size=self._step_size + ) + specials = [] + special_keys = ( + self.special_branches.keys() + ) # dict-key ordering is an implementation detail + log.debug("special_keys: %s", special_keys) + for key in special_keys: + # print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}") + + specials.append( + events[key].iterate( + self.special_branches[key].keys(), + aliases=self.special_branches[key], + step_size=self._step_size, + ) + ) + group_counts = {} + for key in group_count_keys: + group_counts[key] = iter(self[key]) + + log.debug("group_counts: %s", group_counts) + for event_set, *special_sets in zip(events_it, *specials): + for _event, *special_items in zip(event_set, *special_sets): + data = {} + for k in keys: + data[k] = _event[k] + for (k, i) in zip(special_keys, special_items): + data[k] = i + for tokey, fromkey in self.special_aliases.items(): + data[tokey] = data[fromkey] + for key in group_counts: + data[key] = next(group_counts[key]) + yield self._event_ctor(**data) def __next__(self): return next(self._events) @@ -246,7 +302,6 @@ class OfflineReader: """The raw number of events without any indexing/slicing magic""" return len(self._fobj[self.event_path]["id"].array()) - def __repr__(self): length = len(self) actual_length = self.__actual_len__() diff --git a/km3io/tools.py b/km3io/tools.py index d75ce95fed15d8cae5f2790ad5a660d1624daff4..2344e74a47e5b461e4f6e615429dfb8e7a9f4ea0 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -275,7 +275,7 @@ def best_track(tracks, startend=None, minmax=None, stages=None): tracks = tracks[m1] - rec_stage_lengths = ak.num(tracks.rec_stages, axis=axis+1) + rec_stage_lengths = ak.num(tracks.rec_stages, axis=axis + 1) max_rec_stage_length = ak.max(rec_stage_lengths, axis=axis) m2 = rec_stage_lengths == max_rec_stage_length tracks = tracks[m2] @@ -284,7 +284,9 @@ def best_track(tracks, startend=None, minmax=None, stages=None): out = tracks[m3] if isinstance(out, ak.highlevel.Record): - return namedtuple("BestTrack", out.fields)(*[getattr(out, a)[0] for a in out.fields]) + return namedtuple("BestTrack", out.fields)( + *[getattr(out, a)[0] for a in out.fields] + ) return out @@ -308,20 +310,22 @@ def mask(arr, sequence=None, startend=None, minmax=None, atleast=None): inputs = (sequence, startend, minmax, atleast) if all(v is None for v in inputs): - raise ValueError("either sequence, startend, minmax or atleast must be specified.") + raise ValueError( + "either sequence, startend, minmax or atleast must be specified." + ) builder = ak.ArrayBuilder() _mask(arr, builder, sequence, startend, minmax, atleast) return builder.snapshot() -#nb.njit # TODO: not supported in awkward yet -# see https://github.com/scikit-hep/awkward-1.0/issues/572 + +# @nb.njit def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None): if arr.ndim == 2: # recursion stop if startend is not None: start, end = startend for els in arr: - if ak.count(els) > 0 and els[0] == start and els[-1] == end: + if len(els) > 0 and els[0] == start and els[-1] == end: builder.boolean(True) else: builder.boolean(False) @@ -362,7 +366,6 @@ def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None) builder.end_list() - def best_jmuon(tracks): """Select the best JMUON track.""" return best_track(tracks, minmax=(krec.JMUONBEGIN, krec.JMUONEND)) diff --git a/tests/test_offline.py b/tests/test_offline.py index eba3ccecb4902f6b6317dffdccee05f9787298a8..8c6d3cdde7fef3286526f980a9a8a603f6654aeb 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -183,17 +183,13 @@ class TestOfflineEvents(unittest.TestCase): def test_index_consistency(self): for i in [0, 2, 5]: - assert np.allclose( - self.events[i].n_hits, self.events.n_hits[i] - ) + assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) def test_index_chaining(self): assert np.allclose( self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist() ) - assert np.allclose( - self.events[3:5][0].n_hits, self.events.n_hits[3:5][0] - ) + assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]) @unittest.skip def test_index_chaining_on_nested_branches_aka_records(self): @@ -361,7 +357,22 @@ class TestOfflineTracks(unittest.TestCase): self.n_events = 10 def test_fields(self): - for field in ['id', 'pos_x', 'pos_y', 'pos_z', 'dir_x', 'dir_y', 'dir_z', 't', 'E', 'len', 'lik', 'rec_type', 'rec_stages', 'fitinf']: + for field in [ + "id", + "pos_x", + "pos_y", + "pos_z", + "dir_x", + "dir_y", + "dir_z", + "t", + "E", + "len", + "lik", + "rec_type", + "rec_stages", + "fitinf", + ]: getattr(self.tracks, field) def test_item_selection(self): diff --git a/tests/test_tools.py b/tests/test_tools.py index 8ae5a4f9266c8c4fa8ed23ebb9a929c02f4ccb78..3505b6a5299fe97902d966c273bf3bdad62ab1d4 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -230,7 +230,6 @@ class TestBestTrackSelection(unittest.TestCase): assert len(best) == 5 - import pdb; pdb.set_trace() assert best.lik == ak.max(tracks_slice.lik) assert best.rec_stages[0].tolist() == [1, 3, 5, 4] @@ -419,6 +418,27 @@ class TestRecStagesMasks(unittest.TestCase): mask(self.tracks) +class TestMask(unittest.TestCase): + def test_minmax_2dim_mask(self): + arr = ak.Array([[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]]) + m = mask(arr, minmax=(1, 4)) + self.assertListEqual(m.tolist(), [True, False, False]) + + def test_minmax_3dim_mask(self): + arr = ak.Array([[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]]) + m = mask(arr, minmax=(1, 4)) + self.assertListEqual(m.tolist(), [[True, False, False], [True]]) + + def test_minmax_4dim_mask(self): + arr = ak.Array( + [[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]] + ) + m = mask(arr, minmax=(1, 4)) + self.assertListEqual( + m.tolist(), [[[True, False, False], [True]], [[False, True]]] + ) + + class TestUnique(unittest.TestCase): def run_random_test_with_dtype(self, dtype): max_range = 100