Skip to content
Snippets Groups Projects

WIP: Slicing and refactoring offline

Closed Tamas Gal requested to merge 37-user-parameters-seem-to-be-transposed into master
Compare and Show latest version
3 files
+ 178
132
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 126
119
from collections import namedtuple
from collections import namedtuple
import uproot
import uproot
import numpy as np
import numpy as np
import awkward1 as ak
import warnings
import warnings
from .definitions import mc_header
from .definitions import mc_header
MAIN_TREE_NAME = "E"
MAIN_TREE_NAME = "E"
# 110 MB based on the size of the largest basket found so far in km3net
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE_SIZE = 110 * 1024**2
 
BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
 
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
BranchMapper = namedtuple(
BranchMapper = namedtuple(
"BranchMapper",
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
['name', 'key', 'extra', 'exclude', 'update', 'attrparser'])
def _nested_mapper(key):
def _nested_mapper(key):
@@ -19,26 +20,57 @@ def _nested_mapper(key):
@@ -19,26 +20,57 @@ def _nested_mapper(key):
return '_'.join(key.split('.')[1:])
return '_'.join(key.split('.')[1:])
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
EVENTS_MAP = BranchMapper(name="events",
BRANCH_MAPS = [
key="Evt",
BranchMapper("tracks", "trks", {}, ['trks.usr_data', 'trks.usr'], {},
extra={
_nested_mapper, False),
't_sec': 't.fSec',
BranchMapper("mc_tracks", "mc_trks", {},
't_ns': 't.fNanoSec'
['mc_trks.usr_data', 'mc_trks.usr'], {}, _nested_mapper,
},
False),
exclude=[],
BranchMapper("hits", "hits", {}, ['hits.usr'], {}, _nested_mapper, False),
update={
BranchMapper("mc_hits", "mc_hits", {},
'n_hits': 'hits',
['mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id'], {},
'n_mc_hits': 'mc_hits',
_nested_mapper, False),
'n_tracks': 'trks',
BranchMapper("events", "Evt", {
'n_mc_tracks': 'mc_trks'
't_sec': 't.fSec',
},
't_ns': 't.fNanoSec'
attrparser=lambda a: a)
}, [], {
'n_hits': 'hits',
SUBBRANCH_MAPS = [
'n_mc_hits': 'mc_hits',
BranchMapper(
'n_tracks': 'trks',
name="tracks",
'n_mc_tracks': 'mc_trks'
key="trks",
}, lambda a: a, True),
extra={},
 
exclude=['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
 
update={},
 
attrparser=_nested_mapper),
 
BranchMapper(name="mc_tracks",
 
key="mc_trks",
 
extra={},
 
exclude=[
 
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages',
 
'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
 
],
 
update={},
 
attrparser=_nested_mapper),
 
BranchMapper(name="hits",
 
key="hits",
 
extra={},
 
exclude=[
 
'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a',
 
'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
 
],
 
update={},
 
attrparser=_nested_mapper),
 
BranchMapper(name="mc_hits",
 
key="mc_hits",
 
extra={},
 
exclude=[
 
'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id',
 
'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig',
 
'mc_hits.fUniqueID', 'mc_hits.fBits'
 
],
 
update={},
 
attrparser=_nested_mapper),
]
]
@@ -70,18 +102,18 @@ class OfflineReader:
@@ -70,18 +102,18 @@ class OfflineReader:
if file_path is not None:
if file_path is not None:
self._fobj = uproot.open(file_path)
self._fobj = uproot.open(file_path)
self._tree = self._fobj[MAIN_TREE_NAME]
self._tree = self._fobj[MAIN_TREE_NAME]
self._data = self._tree.lazyarrays(
self._data = self._tree.lazyarrays(basketcache=BASKET_CACHE)
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
else:
else:
self._fobj = fobj
self._fobj = fobj
self._tree = self._fobj[MAIN_TREE_NAME]
self._tree = self._fobj[MAIN_TREE_NAME]
self._data = data
self._data = data
for mapper in BRANCH_MAPS:
@cached_property
# print("setting mapper {}".format(mapper.name))
def events(self):
setattr(self, mapper.name,
return Branch(self._tree,
Branch(self._tree, mapper=mapper, index=self._index))
mapper=EVENTS_MAP,
 
index=self._index,
 
subbranchmaps=SUBBRANCH_MAPS)
@classmethod
@classmethod
def from_index(cls, source, index):
def from_index(cls, source, index):
@@ -107,9 +139,7 @@ class OfflineReader:
@@ -107,9 +139,7 @@ class OfflineReader:
if self._index is None:
if self._index is None:
return len(tree)
return len(tree)
else:
else:
return len(
return len(tree.lazyarrays(basketcache=BASKET_CACHE)[self.index])
tree.lazyarrays(basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self.index])
@cached_property
@cached_property
def header(self):
def header(self):
@@ -415,10 +445,10 @@ class OfflineReader:
@@ -415,10 +445,10 @@ class OfflineReader:
are not found, None is returned as the stages index.
are not found, None is returned as the stages index.
"""
"""
if mc is False:
if mc is False:
stages_data = self.tracks.rec_stages
stages_data = self.events.tracks.rec_stages
if mc is True:
if mc is True:
stages_data = self.mc_tracks.rec_stages
stages_data = self.events.mc_tracks.rec_stages
for trk_index, rec_stages in enumerate(stages_data):
for trk_index, rec_stages in enumerate(stages_data):
try:
try:
@@ -462,38 +492,32 @@ class OfflineReader:
@@ -462,38 +492,32 @@ class OfflineReader:
class Usr:
class Usr:
"""Helper class to access AAObject usr stuff"""
"""Helper class to access AAObject `usr`` stuff"""
def __init__(self, name, tree, index=None):
def __init__(self, name, tree, index=None):
# Here, we assume that every event has the same names in the same order
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if it's
# to massively increase the performance. This needs triple check if it's
# always the case; the usr-format is simply a very bad design.
# always the case; the usr-format is simply a very bad design.
# print("initialising usr for {}".format(name))
# print("Setting up usr")
self._name = name
self._name = name
try:
try:
tree['usr'] # This will raise a KeyError in old aanet files
tree['usr'] # This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# which has a different strucuter and key (usr_data)
# We do not support those...
# We do not support those...
self._usr_names = [
self._usr_names = [
n.decode("utf-8") for n in tree['usr_names'].lazyarray()[0]
n.decode("utf-8") for n in tree['usr_names'].lazyarray(
 
basketcache=BASKET_CACHE)[0]
]
]
except (KeyError, IndexError): # e.g. old aanet files
except (KeyError, IndexError): # e.g. old aanet files
self._usr_names = []
self._usr_names = []
else:
else:
# print(" checking usr data")
self._usr_idx_lookup = {
self._usr_idx_lookup = {
name: index
name: index
for index, name in enumerate(self._usr_names)
for index, name in enumerate(self._usr_names)
}
}
data = tree['usr'].lazyarray(
data = tree['usr'].lazyarray(basketcache=BASKET_CACHE)
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
if index is not None:
if index is not None:
data = data[index]
data = data[index]
self._usr_data = data
self._usr_data = data
# print(" adding attributes")
for name in self._usr_names:
for name in self._usr_names:
# print(" setting {}".format(name))
setattr(self, name, self[name])
setattr(self, name, self[name])
def __getitem__(self, item):
def __getitem__(self, item):
@@ -528,7 +552,7 @@ def _to_num(value):
@@ -528,7 +552,7 @@ def _to_num(value):
class Header:
class Header:
"""The online header"""
"""The header"""
def __init__(self, header):
def __init__(self, header):
self._data = {}
self._data = {}
for attribute, fields in mc_header.items():
for attribute, fields in mc_header.items():
@@ -554,17 +578,36 @@ class Header:
@@ -554,17 +578,36 @@ class Header:
class Branch:
class Branch:
"""Branch accessor class"""
"""Branch accessor class"""
def __init__(self, tree, mapper, index=None):
def __init__(self,
 
tree,
 
mapper,
 
index=None,
 
subbranchmaps=None,
 
keymap=None):
self._tree = tree
self._tree = tree
self._mapper = mapper
self._mapper = mapper
self._index = index
self._index = index
self._keymap = None
self._keymap = None
self._branch = tree[mapper.key]
self._branch = tree[mapper.key]
 
self._subbranches = []
self._initialise_keys()
if keymap is None:
 
self._initialise_keys() #
 
else:
 
self._keymap = keymap
 
 
if subbranchmaps is not None:
 
for mapper in subbranchmaps:
 
subbranch = Branch(self._tree,
 
mapper=mapper,
 
index=self._index)
 
self._subbranches.append(subbranch)
 
for subbranch in self._subbranches:
 
setattr(self, subbranch._mapper.name, subbranch)
def _initialise_keys(self):
def _initialise_keys(self):
"""Create the keymap and instance attributes"""
"""Create the keymap and instance attributes for branch keys"""
 
# TODO: this could be a cached property
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
self._mapper.exclude) - EXCLUDE_KEYS
self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {
self._keymap = {
@@ -576,11 +619,8 @@ class Branch:
@@ -576,11 +619,8 @@ class Branch:
for k in self._mapper.update.values():
for k in self._mapper.update.values():
del self._keymap[k]
del self._keymap[k]
# self._EntryType = namedtuple(mapper.name[:-1], self.keys())
for key in self._keymap.keys():
setattr(self, key, None)
for key in self.keys():
# print("setting", self._mapper.name, key)
setattr(self, key, self[key])
def keys(self):
def keys(self):
return self._keymap.keys()
return self._keymap.keys()
@@ -589,84 +629,51 @@ class Branch:
@@ -589,84 +629,51 @@ class Branch:
def usr(self):
def usr(self):
return Usr(self._mapper.name, self._branch, index=self._index)
return Usr(self._mapper.name, self._branch, index=self._index)
def __getitem__(self, item):
def __getattribute__(self, attr):
"""Slicing magic a la numpy"""
if attr.startswith("_"): # let all private and magic methods pass
if isinstance(item, slice):
return object.__getattribute__(self, attr)
return self.__class__(self._tree, self._mapper, index=item)
if isinstance(item, int):
if self._mapper.flat:
return BranchElement(
self._mapper.name, {
key:
self._branch[self._keymap[key]].array()[self._index]
for key in self.keys()
})[item]
else:
return BranchElement(
self._mapper.name, {
key:
self._branch[self._keymap[key]].array()[self._index,
item]
for key in self.keys()
})
if isinstance(item, tuple):
return self[item[0]][item[1]]
if isinstance(item, str):
if attr in self._keymap.keys(): # intercept branch key lookups
item = self._keymap[item]
out = self._branch[self._keymap[attr]].lazyarray(
basketcache=BASKET_CACHE)
out = self._branch[item].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
if self._index is not None:
if self._index is not None:
out = out[self._index]
out = out[self._index]
return out
return out
return self.__class__(self._tree, self._mapper, index=np.array(item))
return object.__getattribute__(self, attr)
 
 
def __getitem__(self, item):
 
"""Slicing magic"""
 
if isinstance(item, (int, slice)):
 
return self.__class__(self._tree,
 
self._mapper,
 
index=item,
 
keymap=self._keymap,
 
subbranchmaps=SUBBRANCH_MAPS)
 
 
if isinstance(item, tuple):
 
return self[item[0]][item[1]]
 
 
return self.__class__(self._tree,
 
self._mapper,
 
index=np.array(item),
 
keymap=self._keymap,
 
subbranchmaps=SUBBRANCH_MAPS)
def __len__(self):
def __len__(self):
if self._index is None:
if self._index is None:
return len(self._branch)
return len(self._branch)
 
elif isinstance(self._index, int):
 
return 1
else:
else:
return len(
return len(self._branch[self._keymap['id']].lazyarray(
self._branch[self._keymap['id']].lazyarray()[self._index])
basketcache=BASKET_CACHE)[self._index])
def __str__(self):
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
return "Number of elements: {}".format(len(self._branch))
def __repr__(self):
def __repr__(self):
return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__,
length = len(self)
self._mapper.name,
return "<{}[{}]: {} element{}>".format(self.__class__.__name__,
len(self))
self._mapper.name, length,
's' if length > 1 else '')
class BranchElement:
"""Represents a single branch element
Parameters
----------
name: str
The name of the branch
dct: dict (keys=attributes, values=arrays of values)
The data
index: slice
The slice mask to be applied to the sub-arrays
"""
def __init__(self, name, dct, index=None):
self._dct = dct
self._name = name
self._index = index
self.ItemConstructor = namedtuple(self._name[:-1], dct.keys())
for key, values in dct.items():
setattr(self, key, values[index])
def __getitem__(self, item):
if isinstance(item, slice):
return self.__class__(self._name, self._dct, index=item)
if isinstance(item, int):
return self.ItemConstructor(
**{k: v[self._index][item]
for k, v in self._dct.items()})
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
Loading