Skip to content
Snippets Groups Projects

Resolve "uproot4 integration"

Merged Tamas Gal requested to merge 58-uproot4-integration-2 into master
3 files
+ 364
237
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 280
178
import binascii
from collections import namedtuple
import uproot3
import logging
import warnings
import numba as nb
import uproot
import numpy as np
import awkward as ak
from .definitions import mc_header, fitparameters, reconstruction
from .definitions import mc_header
from .tools import cached_property, to_num, unfold_indices
from .rootio import Branch, BranchMapper
MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
def _nested_mapper(key):
"""Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)"""
return "_".join(key.split(".")[1:])
EVENTS_MAP = BranchMapper(
name="events",
key="Evt",
extra={"t_sec": "t.fSec", "t_ns": "t.fNanoSec"},
exclude=EXCLUDE_KEYS,
update={
"n_hits": "hits",
"n_mc_hits": "mc_hits",
"n_tracks": "trks",
"n_mc_tracks": "mc_trks",
},
)
SUBBRANCH_MAPS = [
BranchMapper(
name="tracks",
key="trks",
extra={},
exclude=EXCLUDE_KEYS
+ ["trks.usr_data", "trks.usr", "trks.fUniqueID", "trks.fBits"],
attrparser=_nested_mapper,
flat=False,
toawkward=["fitinf", "rec_stages"],
),
BranchMapper(
name="mc_tracks",
key="mc_trks",
exclude=EXCLUDE_KEYS
+ [
"mc_trks.rec_stages",
"mc_trks.fitinf",
"mc_trks.fUniqueID",
"mc_trks.fBits",
],
attrparser=_nested_mapper,
toawkward=["usr", "usr_names"],
flat=False,
),
BranchMapper(
name="hits",
key="hits",
exclude=EXCLUDE_KEYS
+ [
"hits.usr",
"hits.pmt_id",
"hits.origin",
"hits.a",
"hits.pure_a",
"hits.fUniqueID",
"hits.fBits",
],
attrparser=_nested_mapper,
flat=False,
),
BranchMapper(
name="mc_hits",
key="mc_hits",
exclude=EXCLUDE_KEYS
+ [
"mc_hits.usr",
"mc_hits.dom_id",
"mc_hits.channel_id",
"mc_hits.tdc",
"mc_hits.tot",
"mc_hits.trig",
"mc_hits.fUniqueID",
"mc_hits.fBits",
],
attrparser=_nested_mapper,
flat=False,
),
]
class OfflineBranch(Branch):
@cached_property
def usr(self):
return Usr(self._mapper, self._branch, index_chain=self._index_chain)
class Usr:
"""Helper class to access AAObject `usr` stuff (only for events.usr)"""
log = logging.getLogger("offline")
def __init__(self, mapper, branch, index_chain=None):
self._mapper = mapper
self._name = mapper.name
self._index_chain = [] if index_chain is None else index_chain
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = "usr" if mapper.flat else mapper.key + ".usr"
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print(
"The `usr` fields could not be parsed for the '{}' branch.".format(
self._name
)
)
return
self._usr_names = [
n.decode("utf-8")
for n in self._branch[self._usr_key + "_names"].lazyarray()[0]
]
self._usr_idx_lookup = {
name: index for index, name in enumerate(self._usr_names)
}
class OfflineReader:
"""reader for offline ROOT files"""
data = self._branch[self._usr_key].lazyarray()
event_path = "E/Evt"
item_name = "OfflineEvent"
skip_keys = ["t", "AAObject"]
aliases = {
"t_sec": "t.fSec",
"t_ns": "t.fNanoSec",
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
}
special_branches = {
"hits": {
"id": "hits.id",
"channel_id": "hits.channel_id",
"dom_id": "hits.dom_id",
"t": "hits.t",
"tot": "hits.tot",
"trig": "hits.trig", # non-zero if the hit is a triggered hit
},
"mc_hits": {
"id": "mc_hits.id",
"pmt_id": "mc_hits.pmt_id",
"t": "mc_hits.t", # hit time (MC truth)
"a": "mc_hits.a", # hit amplitude (in p.e.)
"origin": "mc_hits.origin", # track id of the track that created this hit
"pure_t": "mc_hits.pure_t", # photon time before pmt simultion
"pure_a": "mc_hits.pure_a", # amplitude before pmt simution,
"type": "mc_hits.type", # particle type or parametrisation used for hit
},
"trks": {
"id": "trks.id",
"pos_x": "trks.pos.x",
"pos_y": "trks.pos.y",
"pos_z": "trks.pos.z",
"dir_x": "trks.dir.x",
"dir_y": "trks.dir.y",
"dir_z": "trks.dir.z",
"t": "trks.t",
"E": "trks.E",
"len": "trks.len",
"lik": "trks.lik",
"rec_type": "trks.rec_type",
"rec_stages": "trks.rec_stages",
"fitinf": "trks.fitinf",
},
"mc_trks": {
"id": "mc_trks.id",
"pos_x": "mc_trks.pos.x",
"pos_y": "mc_trks.pos.y",
"pos_z": "mc_trks.pos.z",
"dir_x": "mc_trks.dir.x",
"dir_y": "mc_trks.dir.y",
"dir_z": "mc_trks.dir.z",
# "status": "mc_trks.status", # TODO: check this
# "mother_id": "mc_trks.mother_id", # TODO: check this
"type": "mc_trks.type",
"hit_ids": "mc_trks.hit_ids",
"usr": "mc_trks.usr", # TODO: trouble with uproot4
"usr_names": "mc_trks.usr_names", # TODO: trouble with uproot4
},
}
special_aliases = {
"tracks": "trks",
"mc_tracks": "mc_trks",
}
def __init__(self, f, index_chain=None, step_size=2000, keys=None, aliases=None, event_ctor=None):
"""OfflineReader class is an offline ROOT file wrapper
if self._index_chain:
data = unfold_indices(data, self._index_chain)
Parameters
----------
f: str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size: int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain: list, optional
Keeps track of index chaining.
keys: list or set, optional
Branch keys.
aliases: dict, optional
Branch key aliases.
event_ctor: class or namedtuple, optional
Event constructor.
self._usr_data = data
"""
if isinstance(f, str):
self._fobj = uproot.open(f)
self._filepath = f
elif isinstance(f, uproot.reading.ReadOnlyDirectory):
self._fobj = f
self._filepath = f._file.file_path
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._keys = keys
self._event_ctor = event_ctor
self._index_chain = [] if index_chain is None else index_chain
for name in self._usr_names:
setattr(self, name, self[name])
# if aliases is not None:
# self.aliases = aliases
# else:
# # Check for usr-awesomeness backward compatibility crap
# if "E/Evt/AAObject/usr" in self._fobj:
# print("Found usr data")
# if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
# self.aliases.update(
# {
# "usr": "AAObject/usr",
# "usr_names": "AAObject/usr_names",
# }
# )
if self._keys is None:
self._initialise_keys()
if self._event_ctor is None:
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
def __getitem__(self, item):
if self._index_chain:
return unfold_indices(self._usr_data, self._index_chain)[
:, self._usr_idx_lookup[item]
]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def _initialise_keys(self):
skip_keys = set(self.skip_keys)
toplevel_keys = set(k.split("/")[0] for k in self._fobj[self.event_path].keys())
keys = (toplevel_keys - skip_keys).union(
list(self.aliases.keys()) + list(self.special_aliases)
)
for key in list(self.special_branches) + list(self.special_aliases):
keys.add("n_" + key)
# self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)}
self._keys = keys
def keys(self):
return self._usr_names
"""Returns all accessible branch keys, without the skipped ones."""
return self._keys
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return "\n".join(entries)
@property
def events(self):
# TODO: deprecate this, since `self` is already the container type
return iter(self)
def _keyfor(self, key):
"""Return the correct key for a given alias/key"""
return self.special_aliases.get(key, key)
def __getattr__(self, attr):
attr = self._keyfor(attr)
# if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches):
if attr in self.keys():
return self.__getitem__(attr)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
def __getitem__(self, key):
# indexing
if isinstance(key, (slice, int, np.int32, np.int64)):
if not isinstance(key, slice):
key = int(key)
return self.__class__(
self._fobj,
index_chain=self._index_chain + [key],
step_size=self._step_size,
aliases=self.aliases,
keys=self.keys(),
event_ctor=self._event_ctor
)
if isinstance(key, str) and key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain)
key = self._keyfor(key)
branch = self._fobj[self.event_path]
# These are special branches which are nested, like hits/trks/mc_trks
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if key in self.special_branches:
fields = []
# some fields are not always available, like `usr_names`
for to_field, from_field in self.special_branches[key].items():
if from_field in branch[key].keys():
fields.append(to_field)
log.debug(fields)
out = branch[key].arrays(
fields, aliases=self.special_branches[key]
)
else:
out = branch[self.aliases.get(key, key)].array()
class OfflineReader:
"""reader for offline ROOT files"""
return unfold_indices(out, self._index_chain)
def __init__(self, file_path=None):
"""OfflineReader class is an offline ROOT file wrapper
def __iter__(self):
self._iterator_index = 0
self._events = self._event_generator()
return self
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
def _event_generator(self):
events = self._fobj[self.event_path]
group_count_keys = set(k for k in self.keys() if k.startswith("n_")) # special keys to make it easy to count subbranch lengths
log.debug("group_count_keys: %s", group_count_keys)
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
) # all top-level keys for regular branches
log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
) # dict-key ordering is an implementation detail
log.debug("special_keys: %s", special_keys)
for key in special_keys:
# print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}")
specials.append(
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self._step_size,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
log.debug("group_counts: %s", group_counts)
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
def __next__(self):
return next(self._events)
def __len__(self):
if not self._index_chain:
return self._fobj[self.event_path].num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(self._fobj[self.event_path]["id"].array(), self._index_chain)
"""
self._fobj = uproot3.open(file_path)
self._filename = file_path
self._tree = self._fobj[MAIN_TREE_NAME]
self._uuid = binascii.hexlify(self._fobj._context.uuid).decode("ascii")
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array())
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)"
@property
def uuid(self):
@@ -200,21 +312,11 @@ class OfflineReader:
def __exit__(self, *args):
self.close()
@cached_property
def events(self):
"""The `E` branch, containing all offline events."""
return OfflineBranch(
self._tree, mapper=EVENTS_MAP, subbranchmaps=SUBBRANCH_MAPS
)
@cached_property
def header(self):
"""The file header"""
if "Head" in self._fobj:
header = {}
for n, x in self._fobj["Head"]._map_3c_string_2c_string_3e_.items():
header[n.decode("utf-8")] = x.decode("utf-8").strip()
return Header(header)
return Header(self._fobj["Head"].tojson()["map<string,string>"])
else:
warnings.warn("Your file header has an unsupported format")
Loading