Skip to content
Snippets Groups Projects

Refactor offline

Closed Tamas Gal requested to merge refactor-offline into master
Compare and Show latest version
3 files
+ 248
78
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 121
51
@@ -12,7 +12,7 @@ EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
BranchMapper = namedtuple(
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser'])
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
def _nested_mapper(key):
@@ -33,7 +33,8 @@ EVENTS_MAP = BranchMapper(name="events",
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
},
attrparser=lambda a: a)
attrparser=lambda a: a,
flat=True)
SUBBRANCH_MAPS = [
BranchMapper(
@@ -42,7 +43,8 @@ SUBBRANCH_MAPS = [
extra={},
exclude=['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
update={},
attrparser=_nested_mapper),
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="mc_tracks",
key="mc_trks",
extra={},
@@ -51,7 +53,8 @@ SUBBRANCH_MAPS = [
'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
],
update={},
attrparser=_nested_mapper),
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="hits",
key="hits",
extra={},
@@ -60,7 +63,8 @@ SUBBRANCH_MAPS = [
'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
],
update={},
attrparser=_nested_mapper),
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="mc_hits",
key="mc_hits",
extra={},
@@ -70,7 +74,8 @@ SUBBRANCH_MAPS = [
'mc_hits.fUniqueID', 'mc_hits.fBits'
],
update={},
attrparser=_nested_mapper),
attrparser=_nested_mapper,
flat=False),
]
@@ -122,38 +127,90 @@ class OfflineReader:
class Usr:
"""Helper class to access AAObject `usr`` stuff"""
def __init__(self, name, tree, index=None):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if
# it's always the case; the usr-format is simply a very bad design.
self._name = name
"""Helper class to access AAObject `usr` stuff"""
def __init__(self, mapper, branch, index=None):
self._mapper = mapper
self._name = mapper.name
self._index = index
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr'
self._initialise()
def _initialise(self):
try:
tree['usr'] # This will raise a KeyError in old aanet files
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those...
self._usr_names = [
n.decode("utf-8") for n in tree['usr_names'].lazyarray(
basketcache=BASKET_CACHE)[0]
]
except (KeyError, IndexError): # e.g. old aanet files
print("The `usr` fields could not be parsed for the '{}' branch."
.format(name))
self._usr_names = []
else:
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
data = tree['usr'].lazyarray(basketcache=BASKET_CACHE)
if index is not None:
data = data[index]
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
# We do not support those (yet)
except (KeyError, IndexError):
print("The `usr` fields could not be parsed for the '{}' branch.".
format(self._name))
return
if self._mapper.flat:
self._initialise_flat()
def _initialise_flat(self):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if
# it's always the case.
self._usr_names = [
n.decode("utf-8") for n in self._branch[self._usr_key + '_names'].lazyarray(
basketcache=BASKET_CACHE)[0]
]
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
data = self._branch[self._usr_key].lazyarray(basketcache=BASKET_CACHE)
if self._index is not None:
data = data[self._index]
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
# def _initialise_nested(self):
# self._usr_names = [
# n.decode("utf-8") for n in self.branch['usr_names'].lazyarray(
# # TODO this will be fixed soon in uproot,
# # see https://github.com/scikit-hep/uproot/issues/465
# uproot.asgenobj(
# uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
# self.branch['usr_names']._context, 6),
# basketcache=BASKET_CACHE)[0]
# ]
def __getitem__(self, item):
return self._usr_data[:, self._usr_idx_lookup[item]]
if self._mapper.flat:
return self.__getitem_flat__(item)
return self.__getitem_nested__(item)
def __getitem_flat__(self, item):
if self._index is not None:
return self._usr_data[self._index][:, self._usr_idx_lookup[item]]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def __getitem_nested__(self, item):
data = self._branch[self._usr_key + '_names'].lazyarray(
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
uproot.asgenobj(
uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self._branch[self._usr_key + '_names']._context, 6),
basketcache=BASKET_CACHE)
if self._index is None:
return data
else:
return data[self._index]
def keys(self):
return self._usr_names
@@ -170,30 +227,39 @@ class Usr:
def _to_num(value):
"""Convert a value to a numerical one if possible"""
if value is None:
return
try:
return int(value)
except ValueError:
for converter in (int, float):
try:
return float(value)
except ValueError:
return converter(value)
except (ValueError, TypeError):
pass
else:
return value
return value
class Header:
"""The header"""
def __init__(self, header):
self._data = {}
for attribute, fields in mc_header.items():
values = header.get(attribute, '').split()
if not values:
for attribute, fields in header.items():
values = fields.split()
fields = mc_header.get(attribute, [])
n_values = len(values)
n_fields = len(fields)
if n_values == 1 and n_fields == 0:
self._data[attribute] = _to_num(values[0])
continue
n_max = max(n_values, n_fields)
values += [None] * (n_max - n_values)
fields += ["field_{}".format(i) for i in range(n_fields, n_max)]
Constructor = namedtuple(attribute, fields)
if len(values) < len(fields):
values += [None] * (len(fields) - len(values))
if not values:
continue
self._data[attribute] = Constructor(
**{f: _to_num(v)
for (f, v) in zip(fields, values)})
@@ -203,8 +269,12 @@ class Header:
def __str__(self):
lines = ["MC Header:"]
for value in self._data.values():
lines.append(" {}".format(value))
keys = set(mc_header.keys())
for key, value in self._data.items():
if key in keys:
lines.append(" {}".format(value))
else:
lines.append(" {}: {}".format(key, value))
return "\n".join(lines)
@@ -259,7 +329,7 @@ class Branch:
@cached_property
def usr(self):
return Usr(self._mapper.name, self._branch, index=self._index)
return Usr(self._mapper, self._branch, index=self._index)
def __getattribute__(self, attr):
if attr.startswith("_"): # let all private and magic methods pass
Loading