Skip to content
Snippets Groups Projects
Commit 1cf61e88 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Use tools

parent 51b1a78b
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
from collections import namedtuple
import uproot
import numpy as np
import warnings
from .definitions import mc_header
from .tools import Branch, BranchMapper, cached_property, _to_num
MAIN_TREE_NAME = "E"
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
BranchMapper = namedtuple(
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
def _nested_mapper(key):
......@@ -26,7 +19,7 @@ EVENTS_MAP = BranchMapper(name="events",
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
},
exclude=[],
exclude=EXCLUDE_KEYS,
update={
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
......@@ -41,14 +34,14 @@ SUBBRANCH_MAPS = [
name="tracks",
key="trks",
extra={},
exclude=['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
exclude=EXCLUDE_KEYS + ['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
update={},
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="mc_tracks",
key="mc_trks",
extra={},
exclude=[
exclude=EXCLUDE_KEYS + [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages',
'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
],
......@@ -58,7 +51,7 @@ SUBBRANCH_MAPS = [
BranchMapper(name="hits",
key="hits",
extra={},
exclude=[
exclude=EXCLUDE_KEYS + [
'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a',
'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
],
......@@ -68,7 +61,7 @@ SUBBRANCH_MAPS = [
BranchMapper(name="mc_hits",
key="mc_hits",
extra={},
exclude=[
exclude=EXCLUDE_KEYS + [
'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id',
'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig',
'mc_hits.fUniqueID', 'mc_hits.fBits'
......@@ -79,18 +72,6 @@ SUBBRANCH_MAPS = [
]
class cached_property:
"""A simple cache decorator for properties."""
def __init__(self, function):
self.function = function
def __get__(self, obj, cls):
if obj is None:
return self
prop = obj.__dict__[self.function.__name__] = self.function(obj)
return prop
class OfflineReader:
"""reader for offline ROOT files"""
def __init__(self, file_path=None):
......@@ -126,113 +107,6 @@ class OfflineReader:
warnings.warn("Your file header has an unsupported format")
class Usr:
"""Helper class to access AAObject `usr` stuff"""
def __init__(self, mapper, branch, index=None):
self._mapper = mapper
self._name = mapper.name
self._index = index
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr'
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print("The `usr` fields could not be parsed for the '{}' branch.".
format(self._name))
return
if self._mapper.flat:
self._initialise_flat()
def _initialise_flat(self):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if
# it's always the case.
self._usr_names = [
n.decode("utf-8") for n in self._branch[self._usr_key + '_names'].lazyarray(
basketcache=BASKET_CACHE)[0]
]
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
data = self._branch[self._usr_key].lazyarray(basketcache=BASKET_CACHE)
if self._index is not None:
data = data[self._index]
self._usr_data = data
for name in self._usr_names:
setattr(self, name, self[name])
# def _initialise_nested(self):
# self._usr_names = [
# n.decode("utf-8") for n in self.branch['usr_names'].lazyarray(
# # TODO this will be fixed soon in uproot,
# # see https://github.com/scikit-hep/uproot/issues/465
# uproot.asgenobj(
# uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
# self.branch['usr_names']._context, 6),
# basketcache=BASKET_CACHE)[0]
# ]
def __getitem__(self, item):
if self._mapper.flat:
return self.__getitem_flat__(item)
return self.__getitem_nested__(item)
def __getitem_flat__(self, item):
if self._index is not None:
return self._usr_data[self._index][:, self._usr_idx_lookup[item]]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def __getitem_nested__(self, item):
data = self._branch[self._usr_key + '_names'].lazyarray(
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
uproot.asgenobj(
uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self._branch[self._usr_key + '_names']._context, 6),
basketcache=BASKET_CACHE)
if self._index is None:
return data
else:
return data[self._index]
def keys(self):
return self._usr_names
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return '\n'.join(entries)
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
def _to_num(value):
"""Convert a value to a numerical one if possible"""
for converter in (int, float):
try:
return converter(value)
except (ValueError, TypeError):
pass
return value
class Header:
......@@ -278,110 +152,3 @@ class Header:
return "\n".join(lines)
class Branch:
"""Branch accessor class"""
def __init__(self,
tree,
mapper,
index=None,
subbranchmaps=None,
keymap=None):
self._tree = tree
self._mapper = mapper
self._index = index
self._keymap = None
self._branch = tree[mapper.key]
self._subbranches = []
if keymap is None:
self._initialise_keys() #
else:
self._keymap = keymap
if subbranchmaps is not None:
for mapper in subbranchmaps:
subbranch = Branch(self._tree,
mapper=mapper,
index=self._index)
self._subbranches.append(subbranch)
for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch)
def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys"""
# TODO: this could be a cached property
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
self._mapper.exclude) - EXCLUDE_KEYS
self._keymap = {
**{self._mapper.attrparser(k): k
for k in keys},
**self._mapper.extra
}
self._keymap.update(self._mapper.update)
for k in self._mapper.update.values():
del self._keymap[k]
for key in self._keymap.keys():
setattr(self, key, None)
def keys(self):
return self._keymap.keys()
@cached_property
def usr(self):
return Usr(self._mapper, self._branch, index=self._index)
def __getattribute__(self, attr):
if attr.startswith("_"): # let all private and magic methods pass
return object.__getattribute__(self, attr)
if attr in self._keymap.keys(): # intercept branch key lookups
return self.__getkey__(attr)
return object.__getattribute__(self, attr)
def __getkey__(self, key):
out = self._branch[self._keymap[key]].lazyarray(
basketcache=BASKET_CACHE)
if self._index is not None:
out = out[self._index]
return out
def __getitem__(self, item):
"""Slicing magic"""
if isinstance(item, (int, slice)):
return self.__class__(self._tree,
self._mapper,
index=item,
keymap=self._keymap,
subbranchmaps=SUBBRANCH_MAPS)
if isinstance(item, tuple):
return self[item[0]][item[1]]
if isinstance(item, str):
return self.__getkey__(item)
return self.__class__(self._tree,
self._mapper,
index=np.array(item),
keymap=self._keymap,
subbranchmaps=SUBBRANCH_MAPS)
def __len__(self):
if self._index is None:
return len(self._branch)
elif isinstance(self._index, int):
return 1
else:
return len(self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE)[self._index])
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
def __repr__(self):
length = len(self)
return "<{}[{}]: {} element{}>".format(self.__class__.__name__,
self._mapper.name, length,
's' if length > 1 else '')
......@@ -3,7 +3,7 @@ import numpy as np
from pathlib import Path
from km3io import OfflineReader
from km3io.offline import _nested_mapper, cached_property, _to_num, Header
from km3io.offline import _nested_mapper, Header
SAMPLES_DIR = Path(__file__).parent / 'samples'
OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
......@@ -340,22 +340,7 @@ class TestUsr(unittest.TestCase):
self.f.events.usr.DeltaPosZ)
class TestIndependentFunctions(unittest.TestCase):
class TestNestedMapper(unittest.TestCase):
def test_nested_mapper(self):
self.assertEqual('pos_x', _nested_mapper("trks.pos.x"))
def test_to_num(self):
self.assertEqual(10, _to_num("10"))
self.assertEqual(10.5, _to_num("10.5"))
self.assertEqual("test", _to_num("test"))
self.assertIsNone(_to_num(None))
class TestCachedProperty(unittest.TestCase):
def test_cached_properties(self):
class Test:
@cached_property
def prop(self):
pass
self.assertTrue(isinstance(Test.prop, cached_property))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment