Skip to content
Snippets Groups Projects

WIP: Slicing and refactoring offline

Closed Tamas Gal requested to merge 37-user-parameters-seem-to-be-transposed into master
Compare and Show latest version
4 files
+ 249
82
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 158
39
from collections import namedtuple
import uproot
import numpy as np
import awkward as ak
import awkward1 as ak
import warnings
from .definitions import mc_header
@@ -9,8 +9,9 @@ MAIN_TREE_NAME = "E"
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BranchMapper = namedtuple("BranchMapper", ['name', 'key', 'extra', 'exclude', 'update', 'attrparser'])
BranchMapper = namedtuple(
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
def _nested_mapper(key):
@@ -20,11 +21,24 @@ def _nested_mapper(key):
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
BRANCH_MAPS = [
BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {}, _nested_mapper),
BranchMapper("mc_tracks", "mc_trks", {}, ['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper),
BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper),
BranchMapper("mc_hits", "mc_hits", {}, ['mc_hits.usr'], {}, _nested_mapper),
BranchMapper("events", "Evt", {'t_sec': 't.fSec', 't_ns': 't.fNanoSec'}, [], {'n_hits': 'hits', 'n_mc_hits': 'mc_hits', 'n_tracks': 'trks', 'n_mc_tracks': 'mc_trks'}, lambda a: a),
BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {},
_nested_mapper, False),
BranchMapper("mc_tracks", "mc_trks", {},
['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper,
False),
BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper, False),
BranchMapper("mc_hits", "mc_hits", {},
['mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id'], {},
_nested_mapper, False),
BranchMapper("events", "Evt", {
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
}, [], {
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
}, lambda a: a, True),
]
@@ -42,7 +56,7 @@ class cached_property:
class OfflineReader:
"""reader for offline ROOT files"""
def __init__(self, file_path=None, fobj=None, data=None, index=slice(None)):
def __init__(self, file_path=None, fobj=None, data=None, index=None):
""" OfflineReader class is an offline ROOT file wrapper
Parameters
@@ -65,7 +79,9 @@ class OfflineReader:
self._data = data
for mapper in BRANCH_MAPS:
setattr(self, mapper.name, BranchElement(self._tree, mapper=mapper, index=self._index))
# print("setting mapper {}".format(mapper.name))
setattr(self, mapper.name,
Branch(self._tree, mapper=mapper, index=self._index))
@classmethod
def from_index(cls, source, index):
@@ -78,7 +94,9 @@ class OfflineReader:
index: index or slice
The index or slice to create the subtree.
"""
instance = cls(fobj=source._fobj, data=source._data[index], index=index)
instance = cls(fobj=source._fobj,
data=source._data[index],
index=index)
return instance
def __getitem__(self, index):
@@ -86,11 +104,11 @@ class OfflineReader:
def __len__(self):
tree = self._fobj[MAIN_TREE_NAME]
if self._index == slice(None):
if self._index is None:
return len(tree)
else:
return len(tree.lazyarrays(
basketcache=uproot.cache.ThreadSafeArrayCache(
return len(
tree.lazyarrays(basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self.index])
@cached_property
@@ -104,10 +122,6 @@ class OfflineReader:
else:
warnings.warn("Your file header has an unsupported format")
@cached_property
def usr(self):
return Usr(self._tree)
def get_best_reco(self):
"""returns the best reconstructed track fit data. The best fit is defined
as the track fit with the maximum reconstruction stages. When "nan" is
@@ -449,25 +463,37 @@ class OfflineReader:
class Usr:
"""Helper class to access AAObject usr stuff"""
def __init__(self, tree):
def __init__(self, name, tree, index=None):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if it's
# always the case; the usr-format is simply a very bad design.
# print("initialising usr for {}".format(name))
# print("Setting up usr")
self._name = name
try:
tree['usr'] # This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those...
self._usr_names = [
n.decode("utf-8") for n in tree['Evt']['usr_names'].array()[0]
n.decode("utf-8") for n in tree['usr_names'].lazyarray()[0]
]
except (KeyError, IndexError): # e.g. old aanet files
self._usr_names = []
else:
# print(" checking usr data")
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
self._usr_data = tree['Evt']['usr'].lazyarray(
data = tree['usr'].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
if index is not None:
data = data[index]
self._usr_data = data
# print(" adding attributes")
for name in self._usr_names:
# print(" setting {}".format(name))
setattr(self, name, self[name])
def __getitem__(self, item):
@@ -482,11 +508,14 @@ class Usr:
entries.append("{}: {}".format(name, self[name]))
return '\n'.join(entries)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
def _to_num(value):
"""Convert value to a numerical value if possible"""
"""Convert a value to a numerical one if possible"""
if value is None:
return None
return
try:
return int(value)
except ValueError:
@@ -509,7 +538,9 @@ class Header:
Constructor = namedtuple(attribute, fields)
if len(values) < len(fields):
values += [None] * (len(fields) - len(values))
self._data[attribute] = Constructor(**{f: _to_num(v) for (f, v) in zip(fields, values)})
self._data[attribute] = Constructor(
**{f: _to_num(v)
for (f, v) in zip(fields, values)})
for attribute, value in self._data.items():
setattr(self, attribute, value)
@@ -521,22 +552,26 @@ class Header:
return "\n".join(lines)
class BranchElement:
"""wrapper for offline tracks"""
def __init__(self, tree, mapper, index=slice(None)):
class Branch:
"""Branch accessor class"""
def __init__(self, tree, mapper, index=None):
self._tree = tree
self._mapper = mapper
self._index = index
self._keymap = None
self._branch = tree[mapper.key]
self._initialise_keys()
def _initialise_keys(self):
"""Create the keymap and instance attributes"""
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {**{self._mapper.attrparser(k): k for k in keys}, **self._mapper.extra}
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {
**{self._mapper.attrparser(k): k
for k in keys},
**self._mapper.extra
}
self._keymap.update(self._mapper.update)
for k in self._mapper.update.values():
del self._keymap[k]
@@ -550,27 +585,111 @@ class BranchElement:
def keys(self):
return self._keymap.keys()
@cached_property
def usr(self):
return Usr(self._mapper.name, self._branch, index=self._index)
def __getitem__(self, item):
"""Slicing magic a la numpy"""
if isinstance(item, slice):
return self.__class__(self._tree, self._mapper, index=item)
if isinstance(item, int):
return {
key: self._branch[self._keymap[key]].array()[self._index, item] for key in self.keys()
}
return self._branch[self._keymap[item]].lazyarray(
# A bit ugly, but whatever works
if self._mapper.flat:
if self._index is None:
dct = {
key: self._branch[self._keymap[key]].array()
for key in self.keys()
}
else:
dct = {
key:
self._branch[self._keymap[key]].array()[self._index]
for key in self.keys()
}
return BranchElement(self._mapper.name, dct)[item]
else:
if self._index is None:
dct = {
key: self._branch[self._keymap[key]].array()[item]
for key in self.keys()
}
else:
dct = {
key:
self._branch[self._keymap[key]].array()[self._index,
item]
for key in self.keys()
}
return BranchElement(self._mapper.name, dct)
if isinstance(item, tuple):
return self[item[0]][item[1]]
if isinstance(item, str):
item = self._keymap[item]
out = self._branch[item].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self._index]
BASKET_CACHE_SIZE))
if self._index is not None:
out = out[self._index]
return out
return self.__class__(self._tree, self._mapper, index=np.array(item))
def __len__(self):
if self._index == slice(None):
if self._index is None:
return len(self._branch)
else:
return len(self._branch[self._keymap['id']].lazyarray()[self._index])
return len(
self._branch[self._keymap['id']].lazyarray()[self._index])
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
def __repr__(self):
return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, self._mapper.name,
len(self))
return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__,
self._mapper.name,
len(self))
class BranchElement:
"""Represents a single branch element
Parameters
----------
name: str
The name of the branch
dct: dict (keys=attributes, values=arrays of values)
The data
index: slice
The slice mask to be applied to the sub-arrays
"""
def __init__(self, name, dct, index=None):
self._dct = dct
self._name = name
self._index = index
self.ItemConstructor = namedtuple(self._name[:-1], dct.keys())
if index is None:
for key, values in dct.items():
setattr(self, key, values)
else:
for key, values in dct.items():
setattr(self, key, values[index])
def __getitem__(self, item):
if isinstance(item, slice):
return self.__class__(self._name, self._dct, index=item)
if isinstance(item, int):
if self._index is None:
return self.ItemConstructor(
**{k: v[item]
for k, v in self._dct.items()})
else:
return self.ItemConstructor(
**{k: v[self._index][item]
for k, v in self._dct.items()})
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
Loading