Skip to content
Snippets Groups Projects
Commit 9b36e9f3 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Track slicing prototype!

parent a462a0bc
No related branches found
No related tags found
2 merge requests!24Refactor offline,!22WIP: Slicing and refactoring offline
Pipeline #9244 failed
......@@ -22,6 +22,30 @@ class cached_property:
return prop
def _get_keys(tree, fake_branches=None):
"""Get tree keys except those in fake_branches
Parameters
----------
tree : uproot.Tree
The tree to look for keys
fake_branches : list of str or None
The fake branches to ignore
Returns
-------
list of str
The keys of the tree.
"""
keys = []
for key in tree.keys():
key = key.decode('utf-8')
if fake_branches is not None and key in fake_branches:
continue
keys.append(key)
return keys
class OfflineKeys:
"""wrapper for offline keys"""
def __init__(self, tree):
......@@ -47,29 +71,6 @@ class OfflineKeys:
def __repr__(self):
return "<{}>".format(self.__class__.__name__)
def _get_keys(self, tree, fake_branches=None):
"""Get tree keys except those in fake_branches
Parameters
----------
tree : uproot.Tree
The tree to look for keys
fake_branches : list of str or None
The fake branches to ignore
Returns
-------
list of str
The keys of the tree.
"""
keys = []
for key in tree.keys():
key = key.decode('utf-8')
if fake_branches is not None and key in fake_branches:
continue
keys.append(key)
return keys
@cached_property
def events_keys(self):
"""reads events keys from an offline file.
......@@ -83,7 +84,7 @@ class OfflineKeys:
fake_branches = ['Evt', 'AAObject', 'TObject', 't']
t_baskets = ['t.fSec', 't.fNanoSec']
tree = self._tree['Evt']
return self._get_keys(self._tree['Evt'], fake_branches) + t_baskets
return _get_keys(self._tree['Evt'], fake_branches) + t_baskets
@cached_property
def hits_keys(self):
......@@ -96,7 +97,7 @@ class OfflineKeys:
except those found in fake branches.
"""
fake_branches = ['hits.usr', 'hits.usr_names']
return self._get_keys(self._tree['hits'], fake_branches)
return _get_keys(self._tree['hits'], fake_branches)
@cached_property
def tracks_keys(self):
......@@ -111,7 +112,7 @@ class OfflineKeys:
# a solution can be tree['trks.usr_data'].array(
# uproot.asdtype(">i4"))
fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names']
return self._get_keys(self._tree['Evt']['trks'], fake_branches)
return _get_keys(self._tree['Evt']['trks'], fake_branches)
@cached_property
def mc_hits_keys(self):
......@@ -124,7 +125,7 @@ class OfflineKeys:
except those found in fake branches.
"""
fake_branches = ['mc_hits.usr', 'mc_hits.usr_names']
return self._get_keys(self._tree['Evt']['mc_hits'], fake_branches)
return _get_keys(self._tree['Evt']['mc_hits'], fake_branches)
@cached_property
def mc_tracks_keys(self):
......@@ -139,7 +140,7 @@ class OfflineKeys:
fake_branches = [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names'
]
return self._get_keys(self._tree['Evt']['mc_trks'], fake_branches)
return _get_keys(self._tree['Evt']['mc_trks'], fake_branches)
@cached_property
def valid_keys(self):
......@@ -300,9 +301,7 @@ class OfflineReader:
Class
OfflineTracks.
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.tracks_keys])
return OfflineTracks(self._tree['trks'])
@cached_property
def mc_hits(self):
......@@ -325,9 +324,7 @@ class OfflineReader:
Class
OfflineTracks.
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.mc_tracks_keys])
return OfflineTracks(self._tree['mc_trks'])
@cached_property
def usr(self):
......@@ -829,32 +826,28 @@ class OfflineHit:
class OfflineTracks:
"""wrapper for offline tracks"""
def __init__(self, keys, values):
"""wrapper for offline tracks
Parameters
----------
keys : list of str
list of cropped tracks keys.
values : list of arrays
list of arrays containting tracks data.
"""
def __init__(self, branch, index=slice(-1)):
keys = [k.decode('utf-8') for k in branch.keys()]
self._keymap = {k[5:].replace('.', '_'): k for k in keys}
self._branch = branch
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
self._index = index
def __getitem__(self, item):
return OfflineTrack(self._keys, [v[item] for v in self._values])
if isinstance(item, slice):
return OfflineTracks(self._branch, index=item)
return self._branch[self._keymap[item]].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))[self._index]
def __len__(self):
try:
return len(self._values[0])
except IndexError:
return 0
if self._index == slice(-1):
return len(self._branch)
else:
return len(self._branch[self._keymap['id']].lazyarray()[self._index])
def __str__(self):
return "Number of tracks: {}".format(len(self))
return "Number of tracks: {}".format(len(self._branch))
def __repr__(self):
return "<{}: {} parsed elements>".format(self.__class__.__name__,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment