Skip to content
Snippets Groups Projects

Resolve "Reduce the amount of uproot.open (to one)"

Merged Tamas Gal requested to merge 36-reduce-the-amount-of-uproot-open-to-one into master
2 files
+ 40
83
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 40
74
@@ -34,15 +34,6 @@ class OfflineKeys:
The main ROOT tree.
"""
self._tree = tree
self._mc_tracks_keys = None
self._valid_keys = None
self._fit_keys = None
self._cut_hits_keys = None
self._cut_tracks_keys = None
self._cut_events_keys = None
self._trigger = None
self._fitparameters = None
self._reconstruction = None
def __str__(self):
return '\n'.join([
@@ -56,6 +47,29 @@ class OfflineKeys:
def __repr__(self):
return "<{}>".format(self.__class__.__name__)
def _get_keys(self, tree, fake_branches=None):
"""Get tree keys except those in fake_branches
Parameters
----------
tree : uproot.Tree
The tree to look for keys
fake_branches : list of str or None
The fake branches to ignore
Returns
-------
list of str
The keys of the tree.
"""
keys = []
for key in tree.keys():
key = key.decode('utf-8')
if fake_branches is not None and key in fake_branches:
continue
keys.append(key)
return keys
@cached_property
def events_keys(self):
"""reads events keys from an offline file.
@@ -69,10 +83,7 @@ class OfflineKeys:
fake_branches = ['Evt', 'AAObject', 'TObject', 't']
t_baskets = ['t.fSec', 't.fNanoSec']
tree = self._tree['Evt']
return [
key.decode('utf-8')
for key in tree.keys() if key.decode('utf-8') not in fake_branches
] + t_baskets
return self._get_keys(self._tree['Evt'], fake_branches) + t_baskets
@cached_property
def hits_keys(self):
@@ -84,13 +95,8 @@ class OfflineKeys:
list of all hits keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['hits.usr', 'hits.usr_names'
] # to be treated like trks.usr and trks.usr_names
tree = self._tree['hits']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
fake_branches = ['hits.usr', 'hits.usr_names']
return self._get_keys(self._tree['hits'], fake_branches)
@cached_property
def tracks_keys(self):
@@ -104,14 +110,8 @@ class OfflineKeys:
"""
# a solution can be tree['trks.usr_data'].array(
# uproot.asdtype(">i4"))
fake_branches = [
'trks.usr_data', 'trks.usr', 'trks.usr_names'
] # can be accessed using tree['trks.usr_names'].array()
tree = self._tree['Evt']['trks']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names']
return self._get_keys(self._tree['Evt']['trks'], fake_branches)
@cached_property
def mc_hits_keys(self):
@@ -124,11 +124,7 @@ class OfflineKeys:
except those found in fake branches.
"""
fake_branches = ['mc_hits.usr', 'mc_hits.usr_names']
tree = self._tree['Evt']['mc_hits']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
return self._get_keys(self._tree['Evt']['mc_hits'], fake_branches)
@cached_property
def mc_tracks_keys(self):
@@ -140,14 +136,8 @@ class OfflineKeys:
list of all mc tracks keys found in an offline file,
except those found in fake branches.
"""
fake_branches = [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names'
] # same solution as above can be used
tree = self._tree['Evt']['mc_trks']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
fake_branches = ['mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names']
return self._get_keys(self._tree['Evt']['mc_trks'], fake_branches)
@cached_property
def valid_keys(self):
@@ -169,8 +159,8 @@ class OfflineKeys:
list of str
list of all "trks.fitinf" keys.
"""
return sorted(self.fitparameters,
key=self.fitparameters.get,
return sorted(km3io.definitions.fitparameters.data,
key=km3io.definitions.fitparameters.data.get,
reverse=False)
@cached_property
@@ -232,18 +222,6 @@ class OfflineKeys:
"""
return km3io.definitions.reconstruction.data
@cached_property
def fitparameters(self):
"""fit parameters parameters and their index from km3net-Dataformat.
Returns
-------
dict
dictionary of fit parameters and their index in an Offline
file.
"""
return km3io.definitions.fitparameters.data
class Reader:
"""Reader for one offline ROOT file"""
@@ -400,8 +378,8 @@ class OfflineReader:
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.tracks_keys],
fitparameters=self.keys.fitparameters)
[self._data[key] for key in self.keys.tracks_keys]
)
@cached_property
def mc_hits(self):
@@ -426,8 +404,7 @@ class OfflineReader:
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.mc_tracks_keys],
fitparameters=self.keys.fitparameters)
[self._data[key] for key in self.keys.mc_tracks_keys])
@cached_property
def usr(self):
@@ -931,7 +908,7 @@ class OfflineHit:
class OfflineTracks:
"""wrapper for offline tracks"""
def __init__(self, keys, values, fitparameters=None):
def __init__(self, keys, values):
"""wrapper for offline tracks
Parameters
@@ -940,20 +917,14 @@ class OfflineTracks:
list of cropped tracks keys.
values : list of arrays
list of arrays containting tracks data.
fitparameters : None, optional
dictionary of tracks fit information (not yet outsourced in offline
files).
"""
self._keys = keys
self._values = values
if fitparameters is not None:
self._fitparameters = fitparameters
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __getitem__(self, item):
return OfflineTrack(self._keys, [v[item] for v in self._values],
fitparameters=self._fitparameters)
return OfflineTrack(self._keys, [v[item] for v in self._values])
def __len__(self):
try:
@@ -971,7 +942,7 @@ class OfflineTracks:
class OfflineTrack:
"""wrapper for an offline track"""
def __init__(self, keys, values, fitparameters=None):
def __init__(self, keys, values):
"""wrapper for one offline track.
Parameters
@@ -980,14 +951,9 @@ class OfflineTrack:
list of cropped tracks keys.
values : list of arrays
list of arrays containting track data.
fitparameters : None, optional
dictionary of tracks fit information (not yet outsourced in offline
files).
"""
self._keys = keys
self._values = values
if fitparameters is not None:
self._fitparameters = fitparameters
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
@@ -998,7 +964,7 @@ class OfflineTrack:
]) + "\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(
getattr(self, 'fitinf')[v]))
for k, v in self._fitparameters.items()
for k, v in km3io.definitions.fitparameters.data.items()
if len(getattr(self, 'fitinf')) > v
]) # I don't like 18 being explicit here
Loading