Skip to content
Snippets Groups Projects

WIP: Resolve "uproot4 integration"

Open Tamas Gal requested to merge 58-uproot4-integration into master
3 files
+ 176
152
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 141
118
import binascii
from collections import namedtuple, defaultdict
import uproot4 as uproot
from collections import namedtuple
import warnings
import numba as nb
import awkward1 as ak1
import uproot4 as uproot
import numpy as np
import awkward1 as ak
from .definitions import mc_header, fitparameters, reconstruction
from .definitions import mc_header
from .tools import cached_property, to_num, unfold_indices
from .rootio import Branch, BranchMapper
MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot.cache.LRUArrayCache(BASKET_CACHE_SIZE)
class Usr:
"""Helper class to access AAObject `usr` stuff (only for events.usr)"""
def __init__(self, mapper, branch, index_chain=None):
self._mapper = mapper
self._name = mapper.name
self._index_chain = [] if index_chain is None else index_chain
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = "usr" if mapper.flat else mapper.key + ".usr"
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print(
"The `usr` fields could not be parsed for the '{}' branch.".format(
self._name
)
)
return
self._usr_names = self._branch[self._usr_key + "_names"].array()[0]
self._usr_idx_lookup = {
name: index for index, name in enumerate(self._usr_names)
}
data = self._branch[self._usr_key].array()
if self._index_chain:
data = unfold_indices(data, self._index_chain)
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
def __getitem__(self, item):
if self._index_chain:
return unfold_indices(self._usr_data, self._index_chain)[
:, self._usr_idx_lookup[item]
]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def keys(self):
return self._usr_names
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return "\n".join(entries)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
class OfflineReader:
@@ -88,18 +14,23 @@ class OfflineReader:
event_path = "E/Evt"
item_name = "OfflineEvent"
skip_keys = ["t", "AAObject"]
aliases = {"t_s": "t.fSec", "t_ns": "t.fNanoSec"}
aliases = {
"t_sec": "t.fSec",
"t_ns": "t.fNanoSec",
}
special_branches = {
"hits": {
"id": "hits.id",
"channel_id": "hits.channel_id",
"dom_id": "hits.dom_id",
"time": "hits.t",
"t": "hits.t",
"tot": "hits.tot",
"triggered": "hits.trig", # non-zero if the hit is a triggered hit
"trig": "hits.trig", # non-zero if the hit is a triggered hit
},
"mc_hits": {
"id": "mc_hits.id",
"pmt_id": "mc_hits.pmt_id",
"time": "mc_hits.t", # hit time (MC truth)
"t": "mc_hits.t", # hit time (MC truth)
"a": "mc_hits.a", # hit amplitude (in p.e.)
"origin": "mc_hits.origin", # track id of the track that created this hit
"pure_t": "mc_hits.pure_t", # photon time before pmt simultion
@@ -141,44 +72,79 @@ class OfflineReader:
"mc_tracks": "mc_trks",
}
def __init__(self, file_path, step_size=2000):
def __init__(self, f, index_chain=None, step_size=2000, keys=None, aliases=None, event_ctor=None):
"""OfflineReader class is an offline ROOT file wrapper
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
f: str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size: int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain: list, optional
Keeps track of index chaining.
keys: list or set, optional
Branch keys.
aliases: dict, optional
Branch key aliases.
event_ctor: class or namedtuple, optional
Event constructor.
"""
self._fobj = uproot.open(file_path)
self.step_size = step_size
self._filename = file_path
if isinstance(f, str):
self._fobj = uproot.open(f)
self._filepath = f
elif isinstance(f, uproot.reading.ReadOnlyDirectory):
self._fobj = f
self._filepath = f._file.file_path
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._keys = None
self._initialise_keys()
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
self._keys = keys
self._event_ctor = event_ctor
self._index_chain = [] if index_chain is None else index_chain
if aliases is not None:
self.aliases = aliases
else:
# Check for usr-awesomeness backward compatibility crap
print("Found usr data")
if "E/Evt/AAObject/usr" in self._fobj:
if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
self.aliases.update(
{
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
}
)
if self._keys is None:
self._initialise_keys()
if self._event_ctor is None:
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
def _initialise_keys(self):
skip_keys = set(self.skip_keys)
toplevel_keys = set(k.split("/")[0] for k in self._fobj[self.event_path].keys())
keys = (toplevel_keys - set(self.skip_keys)).union(
keys = (toplevel_keys - skip_keys).union(
list(self.aliases.keys()) + list(self.special_aliases)
)
for key in list(self.special_branches) + list(self.special_aliases):
keys.add("n_" + key)
# self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)}
self._keys = keys
def keys(self):
@@ -195,6 +161,7 @@ class OfflineReader:
def __getattr__(self, attr):
attr = self._keyfor(attr)
# if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches):
if attr in self.keys():
return self.__getitem__(attr)
raise AttributeError(
@@ -202,16 +169,37 @@ class OfflineReader:
)
def __getitem__(self, key):
# indexing
if isinstance(key, (slice, int, np.int32, np.int64)):
if not isinstance(key, slice):
key = int(key)
return self.__class__(
self._fobj,
index_chain=self._index_chain + [key],
step_size=self._step_size,
aliases=self.aliases,
keys=self.keys(),
event_ctor=self._event_ctor
)
if isinstance(key, str) and key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain)
key = self._keyfor(key)
branch = self._fobj[self.event_path]
# These are special branches which are nested, like hits/trks/mc_trks
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if key in self.special_branches:
return branch[key].arrays(
out = branch[key].arrays(
self.special_branches[key].keys(), aliases=self.special_branches[key]
)
return branch[self.aliases.get(key, key)].array()
else:
out = branch[self.aliases.get(key, key)].array()
return unfold_indices(out, self._index_chain)
def __iter__(self):
self._iterator_index = 0
@@ -220,12 +208,19 @@ class OfflineReader:
def _event_generator(self):
events = self._fobj[self.event_path]
keys = list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
) + list(self.aliases.keys())
events_it = events.iterate(keys, aliases=self.aliases, step_size=self.step_size)
group_count_keys = set(k for k in self.keys() if k.startswith("n_"))
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
@@ -235,24 +230,52 @@ class OfflineReader:
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self.step_size,
step_size=self._step_size,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {
**{k: _event[k] for k in keys},
**{k: i for (k, i) in zip(special_keys, special_items)},
}
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
def __next__(self):
return next(self._events)
def __len__(self):
return self._fobj[self.event_path].num_entries
if not self._index_chain:
return self._fobj[self.event_path].num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(self._fobj[self.event_path]["id"].array(), self._index_chain)
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array())
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)"
@property
def uuid(self):
Loading