Skip to content
Snippets Groups Projects
Commit 2079c037 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Use tools

parent 7c562496
No related branches found
No related tags found
1 merge request!24Refactor offline
Pipeline #10063 passed with warnings
This commit is part of merge request !24. Comments created here will be created in the context of that merge request.
from collections import namedtuple
import uproot
import numpy as np
import warnings
from .definitions import mc_header
from .tools import Branch, BranchMapper, cached_property, _to_num
MAIN_TREE_NAME = "E"
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
BranchMapper = namedtuple(
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
def _nested_mapper(key):
......@@ -26,7 +19,7 @@ EVENTS_MAP = BranchMapper(name="events",
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
},
exclude=[],
exclude=EXCLUDE_KEYS,
update={
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
......@@ -41,14 +34,14 @@ SUBBRANCH_MAPS = [
name="tracks",
key="trks",
extra={},
exclude=['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
exclude=EXCLUDE_KEYS + ['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
update={},
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="mc_tracks",
key="mc_trks",
extra={},
exclude=[
exclude=EXCLUDE_KEYS + [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages',
'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
],
......@@ -58,7 +51,7 @@ SUBBRANCH_MAPS = [
BranchMapper(name="hits",
key="hits",
extra={},
exclude=[
exclude=EXCLUDE_KEYS + [
'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a',
'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
],
......@@ -68,7 +61,7 @@ SUBBRANCH_MAPS = [
BranchMapper(name="mc_hits",
key="mc_hits",
extra={},
exclude=[
exclude=EXCLUDE_KEYS + [
'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id',
'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig',
'mc_hits.fUniqueID', 'mc_hits.fBits'
......@@ -79,18 +72,6 @@ SUBBRANCH_MAPS = [
]
class cached_property:
"""A simple cache decorator for properties."""
def __init__(self, function):
self.function = function
def __get__(self, obj, cls):
if obj is None:
return self
prop = obj.__dict__[self.function.__name__] = self.function(obj)
return prop
class OfflineReader:
"""reader for offline ROOT files"""
def __init__(self, file_path=None):
......@@ -126,113 +107,6 @@ class OfflineReader:
warnings.warn("Your file header has an unsupported format")
class Usr:
"""Helper class to access AAObject `usr` stuff"""
def __init__(self, mapper, branch, index=None):
self._mapper = mapper
self._name = mapper.name
self._index = index
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr'
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print("The `usr` fields could not be parsed for the '{}' branch.".
format(self._name))
return
if self._mapper.flat:
self._initialise_flat()
def _initialise_flat(self):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if
# it's always the case.
self._usr_names = [
n.decode("utf-8") for n in self._branch[self._usr_key + '_names'].lazyarray(
basketcache=BASKET_CACHE)[0]
]
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
data = self._branch[self._usr_key].lazyarray(basketcache=BASKET_CACHE)
if self._index is not None:
data = data[self._index]
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
# def _initialise_nested(self):
# self._usr_names = [
# n.decode("utf-8") for n in self.branch['usr_names'].lazyarray(
# # TODO this will be fixed soon in uproot,
# # see https://github.com/scikit-hep/uproot/issues/465
# uproot.asgenobj(
# uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
# self.branch['usr_names']._context, 6),
# basketcache=BASKET_CACHE)[0]
# ]
def __getitem__(self, item):
if self._mapper.flat:
return self.__getitem_flat__(item)
return self.__getitem_nested__(item)
def __getitem_flat__(self, item):
if self._index is not None:
return self._usr_data[self._index][:, self._usr_idx_lookup[item]]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def __getitem_nested__(self, item):
data = self._branch[self._usr_key + '_names'].lazyarray(
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
uproot.asgenobj(
uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self._branch[self._usr_key + '_names']._context, 6),
basketcache=BASKET_CACHE)
if self._index is None:
return data
else:
return data[self._index]
def keys(self):
return self._usr_names
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return '\n'.join(entries)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
def _to_num(value):
"""Convert a value to a numerical one if possible"""
for converter in (int, float):
try:
return converter(value)
except (ValueError, TypeError):
pass
return value
class Header:
......@@ -278,110 +152,3 @@ class Header:
return "\n".join(lines)
class Branch:
"""Branch accessor class"""
def __init__(self,
tree,
mapper,
index=None,
subbranchmaps=None,
keymap=None):
self._tree = tree
self._mapper = mapper
self._index = index
self._keymap = None
self._branch = tree[mapper.key]
self._subbranches = []
if keymap is None:
self._initialise_keys() #
else:
self._keymap = keymap
if subbranchmaps is not None:
for mapper in subbranchmaps:
subbranch = Branch(self._tree,
mapper=mapper,
index=self._index)
self._subbranches.append(subbranch)
for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch)
def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys"""
# TODO: this could be a cached property
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {
**{self._mapper.attrparser(k): k
for k in keys},
**self._mapper.extra
}
self._keymap.update(self._mapper.update)
for k in self._mapper.update.values():
del self._keymap[k]
for key in self._keymap.keys():
setattr(self, key, None)
def keys(self):
return self._keymap.keys()
@cached_property
def usr(self):
return Usr(self._mapper, self._branch, index=self._index)
def __getattribute__(self, attr):
if attr.startswith("_"): # let all private and magic methods pass
return object.__getattribute__(self, attr)
if attr in self._keymap.keys(): # intercept branch key lookups
return self.__getkey__(attr)
return object.__getattribute__(self, attr)
def __getkey__(self, key):
out = self._branch[self._keymap[key]].lazyarray(
basketcache=BASKET_CACHE)
if self._index is not None:
out = out[self._index]
return out
def __getitem__(self, item):
"""Slicing magic"""
if isinstance(item, (int, slice)):
return self.__class__(self._tree,
self._mapper,
index=item,
keymap=self._keymap,
subbranchmaps=SUBBRANCH_MAPS)
if isinstance(item, tuple):
return self[item[0]][item[1]]
if isinstance(item, str):
return self.__getkey__(item)
return self.__class__(self._tree,
self._mapper,
index=np.array(item),
keymap=self._keymap,
subbranchmaps=SUBBRANCH_MAPS)
def __len__(self):
if self._index is None:
return len(self._branch)
elif isinstance(self._index, int):
return 1
else:
return len(self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE)[self._index])
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
def __repr__(self):
length = len(self)
return "<{}[{}]: {} element{}>".format(self.__class__.__name__,
self._mapper.name, length,
's' if length > 1 else '')
......@@ -3,7 +3,7 @@ import numpy as np
from pathlib import Path
from km3io import OfflineReader
from km3io.offline import _nested_mapper, cached_property, _to_num, Header
from km3io.offline import _nested_mapper, Header
SAMPLES_DIR = Path(__file__).parent / 'samples'
OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
......@@ -340,22 +340,7 @@ class TestUsr(unittest.TestCase):
self.f.events.usr.DeltaPosZ)
class TestIndependentFunctions(unittest.TestCase):
class TestNestedMapper(unittest.TestCase):
def test_nested_mapper(self):
self.assertEqual('pos_x', _nested_mapper("trks.pos.x"))
def test_to_num(self):
self.assertEqual(10, _to_num("10"))
self.assertEqual(10.5, _to_num("10.5"))
self.assertEqual("test", _to_num("test"))
self.assertIsNone(_to_num(None))
class TestCachedProperty(unittest.TestCase):
def test_cached_properties(self):
class Test:
@cached_property
def prop(self):
pass
self.assertTrue(isinstance(Test.prop, cached_property))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment