offline.py 10.70 KiB
from collections import namedtuple
import uproot
import warnings
import awkward1 as ak1
from .definitions import mc_header, fitparameters
from .tools import Branch, BranchMapper, cached_property, _to_num, _unfold_indices
MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
def _nested_mapper(key):
"""Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)"""
return '_'.join(key.split('.')[1:])
def fitinf(fitparam, tracks):
"""Access fit parameters in tracks.fitinf.
Parameters
----------
fitparam : str
the fit parameter name according to fitparameters defined in
KM3NeT-Dataformat.
tracks : class km3io.offline.OfflineBranch
the tracks class. both full tracks branch or a slice of the
tracks branch (example tracks[:, 0]) work.
Returns
-------
awkward array
awkward array of the values of the fit parameter requested.
"""
fit = tracks.fitinf
index = fitparameters[fitparam]
try:
params = fit[count_nested(fit, axis=2) > index]
return ak1.Array([i[:, index] for i in params])
except ValueError:
# This is the case for tracks[:, 0] or any other selection.
params = fit[count_nested(fit, axis=1) > index]
return params[:, index]
def fitparams():
"""name of the fit parameters as defined in the official
KM3NeT-Dataformat.
Returns
-------
dict_keys
fit parameters keys.
"""
return fitparameters.keys()
def count_nested(Array, axis=0):
"""count elements in a nested awkward Array.
Parameters
----------
Array : Awkward1 Array
Array of data. Example tracks.fitinf or tracks.rec_stages.
axis : int, optional
axis = 0: to count elements in the outmost level of nesting.
axis = 1: to count elements in the first level of nesting.
axis = 2: to count elements in the second level of nesting.
Returns
-------
awkward1 Array or int
counts of elements found in a nested awkward1 Array.
"""
if axis == 0:
return ak1.num(Array, axis=0)
if axis == 1:
return ak1.num(Array, axis=1)
if axis == 2:
return ak1.count(Array, axis=2)
def best_track(tracks, strategy="first", rec_stages=None):
"""best track selection based on different strategies
Parameters
----------
tracks : class km3io.offline.OfflineBranch
the tracks branch.
strategy : str
the trategy desired to select the best tracks.
"""
if strategy == "first":
return tracks[:, 0]
# if strategy == "rec_stages" and rec_stages is not None:
# mask = tracks.rec_stages[]
EVENTS_MAP = BranchMapper(name="events",
key="Evt",
extra={
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
},
exclude=EXCLUDE_KEYS,
update={
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
})
SUBBRANCH_MAPS = [
BranchMapper(name="tracks",
key="trks",
extra={},
exclude=EXCLUDE_KEYS +
['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
attrparser=_nested_mapper,
flat=False,
toawkward=['fitinf', 'rec_stages']),
BranchMapper(name="mc_tracks",
key="mc_trks",
exclude=EXCLUDE_KEYS + [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages',
'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
],
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="hits",
key="hits",
exclude=EXCLUDE_KEYS + [
'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a',
'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
],
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="mc_hits",
key="mc_hits",
exclude=EXCLUDE_KEYS + [
'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id',
'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig',
'mc_hits.fUniqueID', 'mc_hits.fBits'
],
attrparser=_nested_mapper,
flat=False),
]
class OfflineBranch(Branch):
@cached_property
def usr(self):
return Usr(self._mapper, self._branch, index_chain=self._index_chain)
class Usr:
"""Helper class to access AAObject `usr` stuff"""
def __init__(self, mapper, branch, index_chain=None):
self._mapper = mapper
self._name = mapper.name
self._index_chain = [] if index_chain is None else index_chain
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr'
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print("The `usr` fields could not be parsed for the '{}' branch.".
format(self._name))
return
if self._mapper.flat:
self._initialise_flat()
def _initialise_flat(self):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if
# it's always the case.
self._usr_names = [
n.decode("utf-8")
for n in self._branch[self._usr_key + '_names'].lazyarray()[0]
]
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
data = self._branch[self._usr_key].lazyarray()
if self._index_chain:
data = _unfold_indices(data, self._index_chain)
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
# def _initialise_nested(self):
# self._usr_names = [
# n.decode("utf-8") for n in self.branch['usr_names'].lazyarray(
# # TODO this will be fixed soon in uproot,
# # see https://github.com/scikit-hep/uproot/issues/465
# uproot.asgenobj(
# uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
# self.branch['usr_names']._context, 6),
# basketcache=BASKET_CACHE)[0]
# ]
def __getitem__(self, item):
if self._mapper.flat:
return self.__getitem_flat__(item)
return self.__getitem_nested__(item)
def __getitem_flat__(self, item):
if self._index_chain:
return _unfold_indices(
self._usr_data, self._index_chain)[:,
self._usr_idx_lookup[item]]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def __getitem_nested__(self, item):
data = self._branch[self._usr_key + '_names'].lazyarray(
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
uproot.asgenobj(
uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self._branch[self._usr_key + '_names']._context, 6),
basketcache=BASKET_CACHE)
return _unfold_indices(data, self._index_chain)
def keys(self):
return self._usr_names
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return '\n'.join(entries)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
class OfflineReader:
"""reader for offline ROOT files"""
def __init__(self, file_path=None):
""" OfflineReader class is an offline ROOT file wrapper
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
"""
self._fobj = uproot.open(file_path)
self._tree = self._fobj[MAIN_TREE_NAME]
@cached_property
def events(self):
"""The `E` branch, containing all offline events."""
return OfflineBranch(self._tree,
mapper=EVENTS_MAP,
subbranchmaps=SUBBRANCH_MAPS)
@cached_property
def header(self):
"""The file header"""
if 'Head' in self._fobj:
header = {}
for n, x in self._fobj['Head']._map_3c_string_2c_string_3e_.items(
):
header[n.decode("utf-8")] = x.decode("utf-8").strip()
return Header(header)
else:
warnings.warn("Your file header has an unsupported format")
class Header:
"""The header"""
def __init__(self, header):
self._data = {}
for attribute, fields in header.items():
values = fields.split()
fields = mc_header.get(attribute, [])
n_values = len(values)
n_fields = len(fields)
if n_values == 1 and n_fields == 0:
self._data[attribute] = _to_num(values[0])
continue
n_max = max(n_values, n_fields)
values += [None] * (n_max - n_values)
fields += ["field_{}".format(i) for i in range(n_fields, n_max)]
Constructor = namedtuple(attribute, fields)
if not values:
continue
self._data[attribute] = Constructor(
**{f: _to_num(v)
for (f, v) in zip(fields, values)})
for attribute, value in self._data.items():
setattr(self, attribute, value)
def __str__(self):
lines = ["MC Header:"]
keys = set(mc_header.keys())
for key, value in self._data.items():
if key in keys:
lines.append(" {}".format(value))
else:
lines.append(" {}: {}".format(key, value))
return "\n".join(lines)