Skip to content
Snippets Groups Projects
Commit 13584962 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Further cleanup

parent e2390f0c
No related branches found
No related tags found
1 merge request!21Resolve "Reduce the amount of uproot.open (to one)"
Pipeline #9184 passed with warnings
......@@ -34,15 +34,6 @@ class OfflineKeys:
The main ROOT tree.
"""
self._tree = tree
self._mc_tracks_keys = None
self._valid_keys = None
self._fit_keys = None
self._cut_hits_keys = None
self._cut_tracks_keys = None
self._cut_events_keys = None
self._trigger = None
self._fitparameters = None
self._reconstruction = None
def __str__(self):
return '\n'.join([
......@@ -56,6 +47,29 @@ class OfflineKeys:
def __repr__(self):
return "<{}>".format(self.__class__.__name__)
def _get_keys(self, tree, fake_branches=None):
"""Get tree keys except those in fake_branches
Parameters
----------
tree : uproot.Tree
The tree to look for keys
fake_branches : list of str or None
The fake branches to ignore
Returns
-------
list of str
The keys of the tree.
"""
keys = []
for key in tree.keys():
key = key.decode('utf-8')
if fake_branches is not None and key in fake_branches:
continue
keys.append(key)
return keys
@cached_property
def events_keys(self):
"""reads events keys from an offline file.
......@@ -69,10 +83,7 @@ class OfflineKeys:
fake_branches = ['Evt', 'AAObject', 'TObject', 't']
t_baskets = ['t.fSec', 't.fNanoSec']
tree = self._tree['Evt']
return [
key.decode('utf-8')
for key in tree.keys() if key.decode('utf-8') not in fake_branches
] + t_baskets
return self._get_keys(self._tree['Evt'], fake_branches) + t_baskets
@cached_property
def hits_keys(self):
......@@ -84,13 +95,8 @@ class OfflineKeys:
list of all hits keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['hits.usr', 'hits.usr_names'
] # to be treated like trks.usr and trks.usr_names
tree = self._tree['hits']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
fake_branches = ['hits.usr', 'hits.usr_names']
return self._get_keys(self._tree['hits'], fake_branches)
@cached_property
def tracks_keys(self):
......@@ -104,14 +110,8 @@ class OfflineKeys:
"""
# a solution can be tree['trks.usr_data'].array(
# uproot.asdtype(">i4"))
fake_branches = [
'trks.usr_data', 'trks.usr', 'trks.usr_names'
] # can be accessed using tree['trks.usr_names'].array()
tree = self._tree['Evt']['trks']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names']
return self._get_keys(self._tree['Evt']['trks'], fake_branches)
@cached_property
def mc_hits_keys(self):
......@@ -124,11 +124,7 @@ class OfflineKeys:
except those found in fake branches.
"""
fake_branches = ['mc_hits.usr', 'mc_hits.usr_names']
tree = self._tree['Evt']['mc_hits']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
return self._get_keys(self._tree['Evt']['mc_hits'], fake_branches)
@cached_property
def mc_tracks_keys(self):
......@@ -140,14 +136,8 @@ class OfflineKeys:
list of all mc tracks keys found in an offline file,
except those found in fake branches.
"""
fake_branches = [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names'
] # same solution as above can be used
tree = self._tree['Evt']['mc_trks']
return [
key.decode('utf8') for key in tree.keys()
if key.decode('utf8') not in fake_branches
]
fake_branches = ['mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names']
return self._get_keys(self._tree['Evt']['mc_trks'], fake_branches)
@cached_property
def valid_keys(self):
......@@ -169,8 +159,8 @@ class OfflineKeys:
list of str
list of all "trks.fitinf" keys.
"""
return sorted(self.fitparameters,
key=self.fitparameters.get,
return sorted(km3io.definitions.fitparameters.data,
key=km3io.definitions.fitparameters.data.get,
reverse=False)
@cached_property
......@@ -232,18 +222,6 @@ class OfflineKeys:
"""
return km3io.definitions.reconstruction.data
@cached_property
def fitparameters(self):
"""fit parameters parameters and their index from km3net-Dataformat.
Returns
-------
dict
dictionary of fit parameters and their index in an Offline
file.
"""
return km3io.definitions.fitparameters.data
class Reader:
"""Reader for one offline ROOT file"""
......@@ -400,8 +378,8 @@ class OfflineReader:
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.tracks_keys],
fitparameters=self.keys.fitparameters)
[self._data[key] for key in self.keys.tracks_keys]
)
@cached_property
def mc_hits(self):
......@@ -426,8 +404,7 @@ class OfflineReader:
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.mc_tracks_keys],
fitparameters=self.keys.fitparameters)
[self._data[key] for key in self.keys.mc_tracks_keys])
@cached_property
def usr(self):
......@@ -931,7 +908,7 @@ class OfflineHit:
class OfflineTracks:
"""wrapper for offline tracks"""
def __init__(self, keys, values, fitparameters=None):
def __init__(self, keys, values):
"""wrapper for offline tracks
Parameters
......@@ -940,20 +917,14 @@ class OfflineTracks:
list of cropped tracks keys.
values : list of arrays
list of arrays containting tracks data.
fitparameters : None, optional
dictionary of tracks fit information (not yet outsourced in offline
files).
"""
self._keys = keys
self._values = values
if fitparameters is not None:
self._fitparameters = fitparameters
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __getitem__(self, item):
return OfflineTrack(self._keys, [v[item] for v in self._values],
fitparameters=self._fitparameters)
return OfflineTrack(self._keys, [v[item] for v in self._values])
def __len__(self):
try:
......@@ -971,7 +942,7 @@ class OfflineTracks:
class OfflineTrack:
"""wrapper for an offline track"""
def __init__(self, keys, values, fitparameters=None):
def __init__(self, keys, values):
"""wrapper for one offline track.
Parameters
......@@ -980,14 +951,9 @@ class OfflineTrack:
list of cropped tracks keys.
values : list of arrays
list of arrays containting track data.
fitparameters : None, optional
dictionary of tracks fit information (not yet outsourced in offline
files).
"""
self._keys = keys
self._values = values
if fitparameters is not None:
self._fitparameters = fitparameters
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
......@@ -998,7 +964,7 @@ class OfflineTrack:
]) + "\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(
getattr(self, 'fitinf')[v]))
for k, v in self._fitparameters.items()
for k, v in km3io.definitions.fitparameters.data.items()
if len(getattr(self, 'fitinf')) > v
]) # I don't like 18 being explicit here
......
......@@ -66,15 +66,6 @@ class TestOfflineKeys(unittest.TestCase):
for k, v in zip(keys, values):
self.assertEqual(v, reco[k])
def test_fitparameters(self):
# there are 18 parameters in v1.1.2 of km3net-Dataformat
fit = self.keys.fitparameters
values = [i for i in range(18)]
self.assertEqual(18, len([*fit.keys()]))
for k, v in fit.items():
self.assertEqual(values[v], fit[k])
class TestReader(unittest.TestCase):
def setUp(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment