diff --git a/km3io/rootio.py b/km3io/rootio.py index 5c4e4de0bede7f414b336aeea6ee5551123e1c5f..17442feb9d13bdd47fe92595b5d8957764011d15 100644 --- a/km3io/rootio.py +++ b/km3io/rootio.py @@ -143,10 +143,8 @@ class EventReader: keys=self.keys(), event_ctor=self._event_ctor, ) - - if isinstance(key, str) and key.startswith( - "n_" - ): # group counts, for e.g. n_events, n_hits etc. + # group counts, for e.g. n_events, n_hits etc. + if isinstance(key, str) and key.startswith("n_"): key = self._keyfor(key.split("n_")[1]) arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4")) return unfold_indices(arr, self._index_chain) @@ -163,11 +161,10 @@ class EventReader: if from_field in branch[key].keys(): fields.append(to_field) log.debug(fields) - out = branch[key].arrays(fields, aliases=self.special_branches[key]) + # out = branch[key].arrays(fields, aliases=self.special_branches[key]) + return Branch(branch[key], fields, self.special_branches[key], self._index_chain) else: - out = branch[self.aliases.get(key, key)].array() - - return unfold_indices(out, self._index_chain) + return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain) def __iter__(self): self._events = self._event_generator() @@ -269,3 +266,52 @@ class EventReader: def __exit__(self, *args): self.close() + + +class Branch: + """Helper class for nested branches likes tracks/hits""" + def __init__(self, branch, fields, aliases, index_chain): + self._branch = branch + self.fields = fields + self._aliases = aliases + self._index_chain = index_chain + + def __getattr__(self, attr): + if attr not in self._aliases: + raise AttributeError(f"No field named {attr}. Available fields: {self.fields}") + return unfold_indices(self._branch[self._aliases[attr]].array(), self._index_chain) + + def __getitem__(self, key): + return self.__class__(self._branch, self.fields, self._aliases, self._index_chain + [key]) + + def __len__(self): + if not self._index_chain: + return self._branch.num_entries + elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)): + if len(self._index_chain) == 1: + return 1 + # try: + # return len(self[:]) + # except IndexError: + # return 1 + return 1 + else: + # ignore the usual index magic and access `id` directly + return len(self.id) + + def __actual_len__(self): + """The raw number of events without any indexing/slicing magic""" + return len(self._branch[self._aliases["id"]].array()) + + def __repr__(self): + length = len(self) + actual_length = self.__actual_len__() + return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} {self._branch.name})" + + @property + def ndim(self): + if not self._index_chain: + return 2 + elif any(isinstance(i, (int, np.int32, np.int64)) for i in self._index_chain): + return 1 + return 2 diff --git a/km3io/tools.py b/km3io/tools.py index 2e00758916b39558330d202d77ccfb9ca2d025fc..4fcdc38ad952714586942a41828b36e861cfe05b 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -209,12 +209,7 @@ def get_multiplicity(tracks, rec_stages): """ masked_tracks = tracks[mask(tracks.rec_stages, sequence=rec_stages)] - try: - axis = tracks.ndim - except AttributeError: - axis = 0 - - out = count_nested(masked_tracks.rec_stages, axis=axis) + out = count_nested(masked_tracks.rec_stages, axis=tracks.ndim - 1) return out @@ -269,10 +264,8 @@ def best_track(tracks, startend=None, minmax=None, stages=None): if minmax is not None: m1 = mask(tracks.rec_stages, minmax=minmax) - try: - axis = tracks.ndim - except AttributeError: - axis = 0 + original_ndim = tracks.ndim + axis = 1 if original_ndim == 2 else 0 tracks = tracks[m1] @@ -284,10 +277,8 @@ def best_track(tracks, startend=None, minmax=None, stages=None): m3 = ak.argmax(tracks.lik, axis=axis, keepdims=True) out = tracks[m3] - if isinstance(out, ak.highlevel.Record): - return namedtuple("BestTrack", out.fields)( - *[getattr(out, a)[0] for a in out.fields] - ) + if original_ndim == 1: + return out[0] return out[:, 0] diff --git a/tests/test_offline.py b/tests/test_offline.py index 8e9e52e0a4d11ab0ed96922ad0e1507e77617c18..575ce3db56d458cacc0ed86326137539e3a6ce2a 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -381,7 +381,7 @@ class TestOfflineTracks(unittest.TestCase): ) def test_repr(self): - assert "10 * " in repr(self.tracks) + assert "10" in repr(self.tracks) def test_slicing(self): tracks = self.tracks