diff --git a/km3io/offline.py b/km3io/offline.py index ebe891cecec59b6aadddb2c8a012ad327494af7e..993b87931b1938913016551a94b1cc5c2381bcfc 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -3,7 +3,7 @@ import uproot4 as uproot import warnings from .definitions import mc_header -from .tools import cached_property +from .tools import cached_property, to_num class OfflineReader: @@ -90,6 +90,7 @@ class OfflineReader: self._uuid = self._fobj._file.uuid self._iterator_index = 0 self._keys = None + self._grouped_counts = {} # TODO: e.g. {"events": [3, 66, 34]} self._initialise_keys() @@ -104,10 +105,14 @@ class OfflineReader: ) def _initialise_keys(self): + skip_keys = set(self.skip_keys) toplevel_keys = set(k.split("/")[0] for k in self._fobj[self.event_path].keys()) - keys = (toplevel_keys - set(self.skip_keys)).union( + keys = (toplevel_keys - skip_keys).union( list(self.aliases.keys()) + list(self.special_aliases) ) + for key in list(self.special_branches) + list(self.special_aliases): + keys.add("n_" + key) + # self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)} self._keys = keys def keys(self): @@ -124,6 +129,7 @@ class OfflineReader: def __getattr__(self, attr): attr = self._keyfor(attr) + # if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches): if attr in self.keys(): return self.__getitem__(attr) raise AttributeError( @@ -131,6 +137,10 @@ class OfflineReader: ) def __getitem__(self, key): + if key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc. + key = self._keyfor(key.split("n_")[1]) + return self._fobj[self.event_path][key].array(uproot.AsDtype(">i4")) + key = self._keyfor(key) branch = self._fobj[self.event_path] # These are special branches which are nested, like hits/trks/mc_trks @@ -149,11 +159,13 @@ class OfflineReader: def _event_generator(self): events = self._fobj[self.event_path] - keys = list( + group_count_keys = set(k for k in self.keys() if k.startswith("n_")) + keys = set(list( set(self.keys()) - set(self.special_branches.keys()) - set(self.special_aliases) - ) + list(self.aliases.keys()) + - group_count_keys + ) + list(self.aliases.keys())) events_it = events.iterate(keys, aliases=self.aliases, step_size=self.step_size) specials = [] special_keys = ( @@ -167,6 +179,9 @@ class OfflineReader: step_size=self.step_size, ) ) + group_counts = {} + for key in group_count_keys: + group_counts[key] = iter(self[key]) for event_set, *special_sets in zip(events_it, *specials): for _event, *special_items in zip(event_set, *special_sets): data = { @@ -175,6 +190,8 @@ class OfflineReader: } for tokey, fromkey in self.special_aliases.items(): data[tokey] = data[fromkey] + for key in group_counts: + data[key] = next(group_counts[key]) yield self._event_ctor(**data) def __next__(self):