Skip to content
Snippets Groups Projects
Commit d1dc6342 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Massive overhaul of branch parsing

parent 5408ca44
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
from collections import namedtuple
import uproot
import numpy as np
import warnings
......@@ -10,6 +11,22 @@ MAIN_TREE_NAME = "E"
BASKET_CACHE_SIZE = 110 * 1024**2
BranchMapper = namedtuple("BranchMapper", ['name', 'key', 'extra_keys', 'attrparser'])
def _nested_mapper(key):
"""Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)"""
return '_'.join(key.split('.')[1:])
BRANCH_MAPS = [
BranchMapper("tracks", "trks", {}, _nested_mapper),
BranchMapper("mc_tracks", "mc_trks", {}, _nested_mapper),
BranchMapper("hits", "mc_hits", {}, _nested_mapper),
BranchMapper("mc_hits", "mc_hits", {}, _nested_mapper),
BranchMapper("events", "Evt", {'t_sec': 't.fSec', 't_ns': 't.fNanoSec'}, lambda a: a),
]
class cached_property:
"""A simple cache decorator for properties."""
def __init__(self, function):
......@@ -22,189 +39,9 @@ class cached_property:
return prop
def _get_keys(tree, fake_branches=None):
"""Get tree keys except those in fake_branches
Parameters
----------
tree : uproot.Tree
The tree to look for keys
fake_branches : list of str or None
The fake branches to ignore
Returns
-------
list of str
The keys of the tree.
"""
keys = []
for key in tree.keys():
key = key.decode('utf-8')
if fake_branches is not None and key in fake_branches:
continue
keys.append(key)
return keys
class OfflineKeys:
"""wrapper for offline keys"""
def __init__(self, tree):
"""OfflineKeys is a class that reads all the available keys in an offline
file and adapts the keys format to Python format.
Parameters
----------
tree : uproot.TTree
The main ROOT tree.
"""
self._tree = tree
def __str__(self):
return '\n'.join([
"Events keys are:\n\t" + "\n\t".join(self.events_keys),
"Hits keys are:\n\t" + '\n\t'.join(self.hits_keys),
"Tracks keys are:\n\t" + '\n\t'.join(self.tracks_keys),
"Mc hits keys are:\n\t" + '\n\t'.join(self.mc_hits_keys),
"Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys)
])
def __repr__(self):
return "<{}>".format(self.__class__.__name__)
@cached_property
def events_keys(self):
"""reads events keys from an offline file.
Returns
-------
list of str
list of all events keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['Evt', 'AAObject', 'TObject', 't']
t_baskets = ['t.fSec', 't.fNanoSec']
tree = self._tree['Evt']
return _get_keys(self._tree['Evt'], fake_branches) + t_baskets
@cached_property
def hits_keys(self):
"""reads hits keys from an offline file.
Returns
-------
list of str
list of all hits keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['hits.usr', 'hits.usr_names']
return _get_keys(self._tree['hits'], fake_branches)
@cached_property
def tracks_keys(self):
"""reads tracks keys from an offline file.
Returns
-------
list of str
list of all tracks keys found in an offline file,
except those found in fake branches.
"""
# a solution can be tree['trks.usr_data'].array(
# uproot.asdtype(">i4"))
fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names']
return _get_keys(self._tree['Evt']['trks'], fake_branches)
@cached_property
def mc_hits_keys(self):
"""reads mc hits keys from an offline file.
Returns
-------
list of str
list of all mc hits keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['mc_hits.usr', 'mc_hits.usr_names']
return _get_keys(self._tree['Evt']['mc_hits'], fake_branches)
@cached_property
def mc_tracks_keys(self):
"""reads mc tracks keys from an offline file.
Returns
-------
list of str
list of all mc tracks keys found in an offline file,
except those found in fake branches.
"""
fake_branches = [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names'
]
return _get_keys(self._tree['Evt']['mc_trks'], fake_branches)
@cached_property
def valid_keys(self):
"""constructs a list of all valid keys to be read from an offline event file.
Returns
-------
list of str
list of all valid keys.
"""
return (self.events_keys + self.hits_keys + self.tracks_keys +
self.mc_tracks_keys + self.mc_hits_keys)
@cached_property
def fit_keys(self):
"""constructs a list of fit parameters, not yet outsourced in an offline file.
Returns
-------
list of str
list of all "trks.fitinf" keys.
"""
return sorted(km3io.definitions.fitparameters.data,
key=km3io.definitions.fitparameters.data.get,
reverse=False)
@cached_property
def cut_hits_keys(self):
"""adapts hits keys for instance variables format in a Python class.
Returns
-------
list of str
list of adapted hits keys.
"""
return [k.split('hits.')[1].replace('.', '_') for k in self.hits_keys]
@cached_property
def cut_tracks_keys(self):
"""adapts tracks keys for instance variables format in a Python class.
Returns
-------
list of str
list of adapted tracks keys.
"""
return [
k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys
]
@cached_property
def cut_events_keys(self):
"""adapts events keys for instance variables format in a Python class.
Returns
-------
list of str
list of adapted events keys.
"""
return [k.replace('.', '_') for k in self.events_keys]
class OfflineReader:
"""reader for offline ROOT files"""
def __init__(self, file_path=None, fobj=None, data=None):
def __init__(self, file_path=None, fobj=None, data=None, index=slice(-1)):
""" OfflineReader class is an offline ROOT file wrapper
Parameters
......@@ -214,6 +51,7 @@ class OfflineReader:
path-like object that points to the file.
"""
self._index = index
if file_path is not None:
self._fobj = uproot.open(file_path)
self._tree = self._fobj[MAIN_TREE_NAME]
......@@ -225,6 +63,9 @@ class OfflineReader:
self._tree = self._fobj[MAIN_TREE_NAME]
self._data = data
for mapper in BRANCH_MAPS:
setattr(self, mapper.name, BranchElement(self._tree, mapper=mapper, index=self._index))
@classmethod
def from_index(cls, source, index):
"""Create an instance with a subtree of a given index
......@@ -232,18 +73,24 @@ class OfflineReader:
Parameters
----------
source: ROOTDirectory
The source file.
The source file object.
index: index or slice
The index or slice to create the subtree.
"""
instance = cls(fobj=source._fobj, data=source._data[index])
instance = cls(fobj=source._fobj, data=source._data[index], index=index)
return instance
def __getitem__(self, index):
return OfflineReader.from_index(source=self, index=index)
def __len__(self):
return len(self._data)
tree = self._fobj[MAIN_TREE_NAME]
if self._index == slice(-1):
return len(tree)
else:
return len(tree.lazyarrays(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self.index])
@cached_property
def header(self):
......@@ -256,76 +103,6 @@ class OfflineReader:
else:
warnings.warn("Your file header has an unsupported format")
@cached_property
def keys(self):
"""wrapper for all keys in an offline file.
Returns
-------
Class
OfflineKeys.
"""
return OfflineKeys(self._tree)
@cached_property
def events(self):
"""wrapper for offline events.
Returns
-------
Class
OfflineEvents.
"""
return OfflineEvents(
self.keys.cut_events_keys,
[self._data[key] for key in self.keys.events_keys])
@cached_property
def hits(self):
"""wrapper for offline hits.
Returns
-------
Class
OfflineHits.
"""
return OfflineHits(self.keys.cut_hits_keys,
[self._data[key] for key in self.keys.hits_keys])
@cached_property
def tracks(self):
"""wrapper for offline tracks.
Returns
-------
Class
OfflineTracks.
"""
return OfflineTracks(self._tree['trks'])
@cached_property
def mc_hits(self):
"""wrapper for offline mc hits.
Returns
-------
Class
OfflineHits.
"""
return OfflineHits(self.keys.cut_hits_keys,
[self._data[key] for key in self.keys.mc_hits_keys])
@cached_property
def mc_tracks(self):
"""wrapper for offline mc tracks.
Returns
-------
Class
OfflineTracks.
"""
return OfflineTracks(self._tree['mc_trks'])
@cached_property
def usr(self):
return Usr(self._tree)
......@@ -705,137 +482,23 @@ class Usr:
return '\n'.join(entries)
class OfflineEvents:
"""wrapper for offline events"""
def __init__(self, keys, values):
"""wrapper for offline events.
Parameters
----------
keys : list of str
list of valid events keys.
values : list of arrays
list of arrays containting events data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __getitem__(self, item):
return OfflineEvent(self._keys, [v[item] for v in self._values])
def __len__(self):
try:
return len(self._values[0])
except IndexError:
return 0
def __str__(self):
return "Number of events: {}".format(len(self))
def __repr__(self):
return "<{}: {} parsed events>".format(self.__class__.__name__,
len(self))
class OfflineEvent:
"""wrapper for an offline event"""
def __init__(self, keys, values):
"""wrapper for one offline event.
Parameters
----------
keys : list of str
list of valid events keys.
values : list of arrays
list of arrays containting event data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __str__(self):
return "offline event:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values)
])
class OfflineHits:
"""wrapper for offline hits"""
def __init__(self, keys, values):
"""wrapper for offline hits.
Parameters
----------
keys : list of str
list of cropped hits keys.
values : list of arrays
list of arrays containting hits data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __getitem__(self, item):
return OfflineHit(self._keys, [v[item] for v in self._values])
def __len__(self):
try:
return len(self._values[0])
except IndexError:
return 0
def __str__(self):
return "Number of hits: {}".format(len(self))
def __repr__(self):
return "<{}: {} parsed elements>".format(self.__class__.__name__,
len(self))
class OfflineHit:
"""wrapper for an offline hit"""
def __init__(self, keys, values):
"""wrapper for one offline hit.
Parameters
----------
keys : list of str
list of cropped hits keys.
values : list of arrays
list of arrays containting hit data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __str__(self):
return "offline hit:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values)
])
def __getitem__(self, item):
return self._values[item]
class OfflineTracks:
class BranchElement:
"""wrapper for offline tracks"""
def __init__(self, branch, index=slice(-1)):
keys = [k.decode('utf-8') for k in branch.keys()]
self._keymap = {k[5:].replace('.', '_'): k for k in keys}
self._branch = branch
self._keys = keys
def __init__(self, tree, mapper, index=slice(-1)):
self.mapper = mapper
self.name = mapper.name
self._tree = tree
self._branch = tree[mapper.key]
keys = [k.decode('utf-8') for k in self._branch.keys()]
self._keymap = {**{mapper.attrparser(k): k for k in keys}, **mapper.extra_keys}
self._index = index
# for key in keys:
# setattr(self, key, cached_property(self[key]))
def __getitem__(self, item):
if isinstance(item, slice):
return OfflineTracks(self._branch, index=item)
return self.__class__(self._tree, self.mapper, index=item)
return self._branch[self._keymap[item]].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self._index]
......@@ -846,41 +509,12 @@ class OfflineTracks:
else:
return len(self._branch[self._keymap['id']].lazyarray()[self._index])
def keys(self):
return self._keymap.keys()
def __str__(self):
return "Number of tracks: {}".format(len(self._branch))
return "Number of elements: {}".format(len(self._branch))
def __repr__(self):
return "<{}: {} parsed elements>".format(self.__class__.__name__,
return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, self.name,
len(self))
class OfflineTrack:
"""wrapper for an offline track"""
def __init__(self, keys, values):
"""wrapper for one offline track.
Parameters
----------
keys : list of str
list of cropped tracks keys.
values : list of arrays
list of arrays containting track data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __str__(self):
return "offline track:\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values) if k not in ['fitinf']
]) + "\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(
getattr(self, 'fitinf')[v]))
for k, v in km3io.definitions.fitparameters.data.items()
if len(getattr(self, 'fitinf')) > v
]) # I don't like 18 being explicit here
def __getitem__(self, item):
return self._values[item]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment