Skip to content
Snippets Groups Projects

Resolve "uproot4 integration"

Merged Tamas Gal requested to merge 58-uproot4-integration-2 into master
Compare and
5 files
+ 456
788
Compare changes
  • Side-by-side
  • Inline
Files
5
+ 76
203
import binascii
from collections import namedtuple
import uproot3
import logging
import warnings
import numba as nb
import uproot
import numpy as np
import awkward as ak
from .definitions import mc_header, fitparameters, reconstruction
from .definitions import mc_header
from .tools import cached_property, to_num, unfold_indices
from .rootio import Branch, BranchMapper
from .rootio import EventReader
MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
log = logging.getLogger("offline")
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
def _nested_mapper(key):
"""Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)"""
return "_".join(key.split(".")[1:])
EVENTS_MAP = BranchMapper(
name="events",
key="Evt",
extra={"t_sec": "t.fSec", "t_ns": "t.fNanoSec"},
exclude=EXCLUDE_KEYS,
update={
"n_hits": "hits",
"n_mc_hits": "mc_hits",
"n_tracks": "trks",
"n_mc_tracks": "mc_trks",
},
)
SUBBRANCH_MAPS = [
BranchMapper(
name="tracks",
key="trks",
extra={},
exclude=EXCLUDE_KEYS
+ ["trks.usr_data", "trks.usr", "trks.fUniqueID", "trks.fBits"],
attrparser=_nested_mapper,
flat=False,
toawkward=["fitinf", "rec_stages"],
),
BranchMapper(
name="mc_tracks",
key="mc_trks",
exclude=EXCLUDE_KEYS
+ [
"mc_trks.rec_stages",
"mc_trks.fitinf",
"mc_trks.fUniqueID",
"mc_trks.fBits",
],
attrparser=_nested_mapper,
toawkward=["usr", "usr_names"],
flat=False,
),
BranchMapper(
name="hits",
key="hits",
exclude=EXCLUDE_KEYS
+ [
"hits.usr",
"hits.pmt_id",
"hits.origin",
"hits.a",
"hits.pure_a",
"hits.fUniqueID",
"hits.fBits",
],
attrparser=_nested_mapper,
flat=False,
),
BranchMapper(
name="mc_hits",
key="mc_hits",
exclude=EXCLUDE_KEYS
+ [
"mc_hits.usr",
"mc_hits.dom_id",
"mc_hits.channel_id",
"mc_hits.tdc",
"mc_hits.tot",
"mc_hits.trig",
"mc_hits.fUniqueID",
"mc_hits.fBits",
],
attrparser=_nested_mapper,
flat=False,
),
]
class OfflineBranch(Branch):
@cached_property
def usr(self):
return Usr(self._mapper, self._branch, index_chain=self._index_chain)
class Usr:
"""Helper class to access AAObject `usr` stuff (only for events.usr)"""
def __init__(self, mapper, branch, index_chain=None):
self._mapper = mapper
self._name = mapper.name
self._index_chain = [] if index_chain is None else index_chain
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = "usr" if mapper.flat else mapper.key + ".usr"
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print(
"The `usr` fields could not be parsed for the '{}' branch.".format(
self._name
)
)
return
self._usr_names = [
n.decode("utf-8")
for n in self._branch[self._usr_key + "_names"].lazyarray()[0]
]
self._usr_idx_lookup = {
name: index for index, name in enumerate(self._usr_names)
}
data = self._branch[self._usr_key].lazyarray()
if self._index_chain:
data = unfold_indices(data, self._index_chain)
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
def __getitem__(self, item):
if self._index_chain:
return unfold_indices(self._usr_data, self._index_chain)[
:, self._usr_idx_lookup[item]
]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def keys(self):
return self._usr_names
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return "\n".join(entries)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
class OfflineReader:
class OfflineReader(EventReader):
"""reader for offline ROOT files"""
def __init__(self, file_path=None):
"""OfflineReader class is an offline ROOT file wrapper
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
"""
self._fobj = uproot3.open(file_path)
self._filename = file_path
self._tree = self._fobj[MAIN_TREE_NAME]
self._uuid = binascii.hexlify(self._fobj._context.uuid).decode("ascii")
@property
def uuid(self):
return self._uuid
def close(self):
self._fobj.close()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
@cached_property
def events(self):
"""The `E` branch, containing all offline events."""
return OfflineBranch(
self._tree, mapper=EVENTS_MAP, subbranchmaps=SUBBRANCH_MAPS
)
event_path = "E/Evt"
item_name = "OfflineEvent"
skip_keys = ["t", "AAObject"]
aliases = {
"t_sec": "t/t.fSec",
"t_ns": "t/t.fNanoSec",
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
}
nested_branches = {
"hits": {
"id": "hits.id",
"channel_id": "hits.channel_id",
"dom_id": "hits.dom_id",
"t": "hits.t",
"tot": "hits.tot",
"trig": "hits.trig", # non-zero if the hit is a triggered hit
},
"mc_hits": {
"id": "mc_hits.id",
"pmt_id": "mc_hits.pmt_id",
"t": "mc_hits.t", # hit time (MC truth)
"a": "mc_hits.a", # hit amplitude (in p.e.)
"origin": "mc_hits.origin", # track id of the track that created this hit
"pure_t": "mc_hits.pure_t", # photon time before pmt simultion
"pure_a": "mc_hits.pure_a", # amplitude before pmt simution,
"type": "mc_hits.type", # particle type or parametrisation used for hit
},
"trks": {
"id": "trks.id",
"pos_x": "trks.pos.x",
"pos_y": "trks.pos.y",
"pos_z": "trks.pos.z",
"dir_x": "trks.dir.x",
"dir_y": "trks.dir.y",
"dir_z": "trks.dir.z",
"t": "trks.t",
"E": "trks.E",
"len": "trks.len",
"lik": "trks.lik",
"rec_type": "trks.rec_type",
"rec_stages": "trks.rec_stages",
"fitinf": "trks.fitinf",
},
"mc_trks": {
"id": "mc_trks.id",
"pos_x": "mc_trks.pos.x",
"pos_y": "mc_trks.pos.y",
"pos_z": "mc_trks.pos.z",
"dir_x": "mc_trks.dir.x",
"dir_y": "mc_trks.dir.y",
"dir_z": "mc_trks.dir.z",
"E": "mc_trks.E",
"t": "mc_trks.t",
"len": "mc_trks.len",
# "status": "mc_trks.status", # TODO: check this
# "mother_id": "mc_trks.mother_id", # TODO: check this
"pdgid": "mc_trks.type",
"hit_ids": "mc_trks.hit_ids",
"usr": "mc_trks.usr", # TODO: trouble with uproot4
"usr_names": "mc_trks.usr_names", # TODO: trouble with uproot4
},
}
nested_aliases = {
"tracks": "trks",
"mc_tracks": "mc_trks",
}
@cached_property
def header(self):
"""The file header"""
if "Head" in self._fobj:
header = {}
for n, x in self._fobj["Head"]._map_3c_string_2c_string_3e_.items():
header[n.decode("utf-8")] = x.decode("utf-8").strip()
return Header(header)
return Header(self._fobj["Head"].tojson()["map<string,string>"])
else:
warnings.warn("Your file header has an unsupported format")
Loading