Skip to content
Snippets Groups Projects

WIP: Slicing and refactoring offline

Closed Tamas Gal requested to merge 37-user-parameters-seem-to-be-transposed into master
Compare and Show latest version
2 files
+ 85
141
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 74
136
from collections import namedtuple
import uproot
import numpy as np
import awkward1 as ak
import warnings
from .definitions import mc_header
@@ -9,10 +8,11 @@ MAIN_TREE_NAME = "E"
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
BranchMapper = namedtuple(
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
['name', 'key', 'extra', 'exclude', 'update', 'attrparser'])
def _nested_mapper(key):
@@ -20,37 +20,57 @@ def _nested_mapper(key):
return '_'.join(key.split('.')[1:])
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
EVENTS_MAP = BranchMapper("events", "Evt", {
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
}, [], {
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
}, lambda a: a, True)
EVENTS_MAP = BranchMapper(name="events",
key="Evt",
extra={
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
},
exclude=[],
update={
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
},
attrparser=lambda a: a)
SUBBRANCH_MAPS = [
BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {},
_nested_mapper, False),
BranchMapper("mc_tracks", "mc_trks", {},
['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper,
False),
BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper, False),
BranchMapper("mc_hits", "mc_hits", {},
['mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id'], {},
_nested_mapper, False),
BranchMapper("events", "Evt", {
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
}, [], {
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
}, lambda a: a, True),
BranchMapper(
name="tracks",
key="trks",
extra={},
exclude=['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
update={},
attrparser=_nested_mapper),
BranchMapper(name="mc_tracks",
key="mc_trks",
extra={},
exclude=[
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages',
'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
],
update={},
attrparser=_nested_mapper),
BranchMapper(name="hits",
key="hits",
extra={},
exclude=[
'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a',
'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
],
update={},
attrparser=_nested_mapper),
BranchMapper(name="mc_hits",
key="mc_hits",
extra={},
exclude=[
'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id',
'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig',
'mc_hits.fUniqueID', 'mc_hits.fBits'
],
update={},
attrparser=_nested_mapper),
]
@@ -119,8 +139,7 @@ class OfflineReader:
if self._index is None:
return len(tree)
else:
return len(
tree.lazyarrays(basketcache=BASKET_CACHE)[self.index])
return len(tree.lazyarrays(basketcache=BASKET_CACHE)[self.index])
@cached_property
def header(self):
@@ -473,7 +492,7 @@ class OfflineReader:
class Usr:
"""Helper class to access AAObject usr stuff"""
"""Helper class to access AAObject `usr`` stuff"""
def __init__(self, name, tree, index=None):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if it's
@@ -484,7 +503,8 @@ class Usr:
# which has a different strucuter and key (usr_data)
# We do not support those...
self._usr_names = [
n.decode("utf-8") for n in tree['usr_names'].lazyarray()[0]
n.decode("utf-8") for n in tree['usr_names'].lazyarray(
basketcache=BASKET_CACHE)[0]
]
except (KeyError, IndexError): # e.g. old aanet files
self._usr_names = []
@@ -532,7 +552,7 @@ def _to_num(value):
class Header:
"""The online header"""
"""The header"""
def __init__(self, header):
self._data = {}
for attribute, fields in mc_header.items():
@@ -558,12 +578,10 @@ class Header:
class Branch:
"""Branch accessor class"""
# @profile
def __init__(self,
tree,
mapper,
index=None,
subbranches=None,
subbranchmaps=None,
keymap=None):
self._tree = tree
@@ -578,8 +596,6 @@ class Branch:
else:
self._keymap = keymap
if subbranches is not None:
self._subbranches = subbranches
if subbranchmaps is not None:
for mapper in subbranchmaps:
subbranch = Branch(self._tree,
@@ -589,9 +605,9 @@ class Branch:
for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch)
# @profile
def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys"""
# TODO: this could be a cached property
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {
@@ -616,126 +632,48 @@ class Branch:
def __getattribute__(self, attr):
if attr.startswith("_"): # let all private and magic methods pass
return object.__getattribute__(self, attr)
if attr in self._keymap.keys(): # intercept branch key lookups
item = self._keymap[attr]
out = self._branch[item].lazyarray(
if attr in self._keymap.keys(): # intercept branch key lookups
out = self._branch[self._keymap[attr]].lazyarray(
basketcache=BASKET_CACHE)
if self._index is not None:
out = out[self._index]
return out
return object.__getattribute__(self, attr)
# @profile
def __getitem__(self, item):
"""Slicing magic a la numpy"""
if isinstance(item, slice):
"""Slicing magic"""
if isinstance(item, (int, slice)):
return self.__class__(self._tree,
self._mapper,
index=item,
subbranches=self._subbranches)
if isinstance(item, int):
# TODO refactor this
if self._mapper.flat:
if self._index is None:
dct = {
key: self._branch[self._keymap[key]].lazyarray()
for key in self.keys()
}
else:
dct = {
key: self._branch[self._keymap[key]].lazyarray()[
self._index]
for key in self.keys()
}
for subbranch in self._subbranches:
dct[subbranch._mapper.name] = subbranch
return BranchElement(self._mapper.name, dct)[item]
else:
if self._index is None:
dct = {
key: self._branch[self._keymap[key]].lazyarray()[item]
for key in self.keys()
}
else:
dct = {
key: self._branch[self._keymap[key]].lazyarray()[
self._index, item]
for key in self.keys()
}
for subbranch in self._subbranches:
dct[subbranch._mapper.name] = subbranch
return BranchElement(self._mapper.name, dct)
keymap=self._keymap,
subbranchmaps=SUBBRANCH_MAPS)
if isinstance(item, tuple):
return self[item[0]][item[1]]
if isinstance(item, str):
item = self._keymap[item]
out = self._branch[item].lazyarray(
basketcache=BASKET_CACHE)
if self._index is not None:
out = out[self._index]
return out
return self.__class__(self._tree,
self._mapper,
index=np.array(item),
subbranches=self._subbranches)
keymap=self._keymap,
subbranchmaps=SUBBRANCH_MAPS)
def __len__(self):
if self._index is None:
return len(self._branch)
elif isinstance(self._index, int):
return 1
else:
return len(
self._branch[self._keymap['id']].lazyarray()[self._index])
return len(self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE)[self._index])
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
def __repr__(self):
return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__,
self._mapper.name,
len(self))
class BranchElement:
"""Represents a single branch element
Parameters
----------
name: str
The name of the branch
dct: dict (keys=attributes, values=arrays of values)
The data
index: slice
The slice mask to be applied to the sub-arrays
"""
def __init__(self, name, dct, index=None, subbranches=[]):
self._dct = dct
self._name = name
self._index = index
self.ItemConstructor = namedtuple(self._name[:-1], dct.keys())
if index is None:
for key, values in dct.items():
setattr(self, key, values)
else:
for key, values in dct.items():
setattr(self, key, values[index])
def __getitem__(self, item):
if isinstance(item, slice):
return self.__class__(self._name, self._dct, index=item)
if isinstance(item, int):
if self._index is None:
return self.ItemConstructor(
**{k: v[item]
for k, v in self._dct.items()})
else:
return self.ItemConstructor(
**{k: v[self._index][item]
for k, v in self._dct.items()})
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
length = len(self)
return "<{}[{}]: {} element{}>".format(self.__class__.__name__,
self._mapper.name, length,
's' if length > 1 else '')
Loading