Skip to content
Snippets Groups Projects

WIP: Resolve "uproot4 integration"

Open Tamas Gal requested to merge 58-uproot4-integration into master
3 files
+ 184
76
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 146
42
from collections import namedtuple
import uproot4 as uproot
import warnings
import uproot4 as uproot
import numpy as np
import awkward1 as ak
from .definitions import mc_header
from .tools import cached_property
from .tools import cached_property, to_num, unfold_indices
class OfflineReader:
@@ -13,22 +15,24 @@ class OfflineReader:
item_name = "OfflineEvent"
skip_keys = ["t", "AAObject"]
aliases = {
"t_s": "t.fSec",
"t_sec": "t.fSec",
"t_ns": "t.fNanoSec",
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
}
special_branches = {
"hits": {
"id": "hits.id",
"channel_id": "hits.channel_id",
"dom_id": "hits.dom_id",
"time": "hits.t",
"t": "hits.t",
"tot": "hits.tot",
"triggered": "hits.trig", # non-zero if the hit is a triggered hit
"trig": "hits.trig", # non-zero if the hit is a triggered hit
},
"mc_hits": {
"id": "mc_hits.id",
"pmt_id": "mc_hits.pmt_id",
"time": "mc_hits.t", # hit time (MC truth)
"t": "mc_hits.t", # hit time (MC truth)
"a": "mc_hits.a", # hit amplitude (in p.e.)
"origin": "mc_hits.origin", # track id of the track that created this hit
"pure_t": "mc_hits.pure_t", # photon time before pmt simultion
@@ -63,6 +67,8 @@ class OfflineReader:
# "mother_id": "mc_trks.mother_id", # TODO: check this
"type": "mc_trks.type",
"hit_ids": "mc_trks.hit_ids",
"usr": "mc_trks.usr",
"usr_names": "mc_trks.usr_names",
},
}
special_aliases = {
@@ -70,44 +76,79 @@ class OfflineReader:
"mc_tracks": "mc_trks",
}
def __init__(self, file_path, step_size=2000):
def __init__(self, f, index_chain=None, step_size=2000, keys=None, aliases=None, event_ctor=None):
"""OfflineReader class is an offline ROOT file wrapper
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
f: str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size: int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain: list, optional
Keeps track of index chaining.
keys: list or set, optional
Branch keys.
aliases: dict, optional
Branch key aliases.
event_ctor: class or namedtuple, optional
Event constructor.
"""
self._fobj = uproot.open(file_path)
self.step_size = step_size
self._filename = file_path
if isinstance(f, str):
self._fobj = uproot.open(f)
self._filepath = f
elif isinstance(f, uproot.reading.ReadOnlyDirectory):
self._fobj = f
self._filepath = f._file.file_path
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._keys = None
self._initialise_keys()
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
self._keys = keys
self._event_ctor = event_ctor
self._index_chain = [] if index_chain is None else index_chain
# if aliases is not None:
# self.aliases = aliases
# else:
# # Check for usr-awesomeness backward compatibility crap
# if "E/Evt/AAObject/usr" in self._fobj:
# print("Found usr data")
# if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
# self.aliases.update(
# {
# "usr": "AAObject/usr",
# "usr_names": "AAObject/usr_names",
# }
# )
if self._keys is None:
self._initialise_keys()
if self._event_ctor is None:
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
def _initialise_keys(self):
skip_keys = set(self.skip_keys)
toplevel_keys = set(k.split("/")[0] for k in self._fobj[self.event_path].keys())
keys = (toplevel_keys - set(self.skip_keys)).union(
keys = (toplevel_keys - skip_keys).union(
list(self.aliases.keys()) + list(self.special_aliases)
)
for key in list(self.special_branches) + list(self.special_aliases):
keys.add("n_" + key)
# self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)}
self._keys = keys
def keys(self):
@@ -124,6 +165,7 @@ class OfflineReader:
def __getattr__(self, attr):
attr = self._keyfor(attr)
# if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches):
if attr in self.keys():
return self.__getitem__(attr)
raise AttributeError(
@@ -131,16 +173,43 @@ class OfflineReader:
)
def __getitem__(self, key):
# indexing
if isinstance(key, (slice, int, np.int32, np.int64)):
if not isinstance(key, slice):
key = int(key)
return self.__class__(
self._fobj,
index_chain=self._index_chain + [key],
step_size=self._step_size,
aliases=self.aliases,
keys=self.keys(),
event_ctor=self._event_ctor
)
if isinstance(key, str) and key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain)
key = self._keyfor(key)
branch = self._fobj[self.event_path]
# These are special branches which are nested, like hits/trks/mc_trks
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if key in self.special_branches:
return branch[key].arrays(
self.special_branches[key].keys(), aliases=self.special_branches[key]
keys_of_interest = []
# some fields are not always available, like `usr_names`
for from_key, to_key in self.special_branches[key].keys():
if to_key in branch.keys():
keys_of_interest.append(from_key)
out = branch[key].arrays(
keys_of_interest, aliases=self.special_branches[key]
)
return branch[self.aliases.get(key, key)].array()
else:
out = branch[self.aliases.get(key, key)].array()
return unfold_indices(out, self._index_chain)
def __iter__(self):
self._iterator_index = 0
@@ -149,12 +218,19 @@ class OfflineReader:
def _event_generator(self):
events = self._fobj[self.event_path]
keys = list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
) + list(self.aliases.keys())
events_it = events.iterate(keys, aliases=self.aliases, step_size=self.step_size)
group_count_keys = set(k for k in self.keys() if k.startswith("n_"))
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
@@ -164,24 +240,52 @@ class OfflineReader:
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self.step_size,
step_size=self._step_size,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {
**{k: _event[k] for k in keys},
**{k: i for (k, i) in zip(special_keys, special_items)},
}
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
def __next__(self):
return next(self._events)
def __len__(self):
return self._fobj[self.event_path].num_entries
if not self._index_chain:
return self._fobj[self.event_path].num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(self._fobj[self.event_path]["id"].array(), self._index_chain)
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array())
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)"
@property
def uuid(self):
Loading