Skip to content
Snippets Groups Projects
Commit 51b1a78b authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Usr improvement, still not working for nested

parent 36c2a064
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
...@@ -127,7 +127,7 @@ class OfflineReader: ...@@ -127,7 +127,7 @@ class OfflineReader:
class Usr: class Usr:
"""Helper class to access AAObject `usr`` stuff""" """Helper class to access AAObject `usr` stuff"""
def __init__(self, mapper, branch, index=None): def __init__(self, mapper, branch, index=None):
self._mapper = mapper self._mapper = mapper
self._name = mapper.name self._name = mapper.name
...@@ -136,29 +136,30 @@ class Usr: ...@@ -136,29 +136,30 @@ class Usr:
self._usr_names = [] self._usr_names = []
self._usr_idx_lookup = {} self._usr_idx_lookup = {}
self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr'
self._initialise()
def _initialise(self):
try: try:
branch['usr'] self._branch[self._usr_key]
# This will raise a KeyError in old aanet files # This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data) # which has a different strucuter and key (usr_data)
# We do not support those # We do not support those (yet)
except (KeyError, IndexError): except (KeyError, IndexError):
print("The `usr` fields could not be parsed for the '{}' branch.". print("The `usr` fields could not be parsed for the '{}' branch.".
format(self._name)) format(self._name))
return return
if mapper.flat: if self._mapper.flat:
self._initialise_flat() self._initialise_flat()
else:
# self._initialise_nested()
# branch[self._mapper.key + '.usr']
pass
def _initialise_flat(self): def _initialise_flat(self):
# Here, we assume that every event has the same names in the same order # Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if # to massively increase the performance. This needs triple check if
# it's always the case. # it's always the case.
self._usr_names = [ self._usr_names = [
n.decode("utf-8") for n in self._branch['usr_names'].lazyarray( n.decode("utf-8") for n in self._branch[self._usr_key + '_names'].lazyarray(
basketcache=BASKET_CACHE)[0] basketcache=BASKET_CACHE)[0]
] ]
self._usr_idx_lookup = { self._usr_idx_lookup = {
...@@ -166,7 +167,7 @@ class Usr: ...@@ -166,7 +167,7 @@ class Usr:
for index, name in enumerate(self._usr_names) for index, name in enumerate(self._usr_names)
} }
data = self._branch['usr'].lazyarray(basketcache=BASKET_CACHE) data = self._branch[self._usr_key].lazyarray(basketcache=BASKET_CACHE)
if self._index is not None: if self._index is not None:
data = data[self._index] data = data[self._index]
...@@ -176,17 +177,16 @@ class Usr: ...@@ -176,17 +177,16 @@ class Usr:
for name in self._usr_names: for name in self._usr_names:
setattr(self, name, self[name]) setattr(self, name, self[name])
def _initialise_nested(self): # def _initialise_nested(self):
self._usr_names = [ # self._usr_names = [
n.decode("utf-8") for n in self.branch['usr_names'].lazyarray( # n.decode("utf-8") for n in self.branch['usr_names'].lazyarray(
# TODO this will be fixed soon in uproot, # # TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465 # # see https://github.com/scikit-hep/uproot/issues/465
uproot.asgenobj( # uproot.asgenobj(
uproot.SimpleArray(uproot.STLVector(uproot.STLString())), # uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self.branch['usr_names']._context, 6), # self.branch['usr_names']._context, 6),
basketcache=BASKET_CACHE)[0] # basketcache=BASKET_CACHE)[0]
] # ]
self.__getitem__ = self.__getitem_nested__
def __getitem__(self, item): def __getitem__(self, item):
if self._mapper.flat: if self._mapper.flat:
...@@ -199,6 +199,19 @@ class Usr: ...@@ -199,6 +199,19 @@ class Usr:
else: else:
return self._usr_data[:, self._usr_idx_lookup[item]] return self._usr_data[:, self._usr_idx_lookup[item]]
def __getitem_nested__(self, item):
data = self._branch[self._usr_key + '_names'].lazyarray(
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
uproot.asgenobj(
uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self._branch[self._usr_key + '_names']._context, 6),
basketcache=BASKET_CACHE)
if self._index is None:
return data
else:
return data[self._index]
def keys(self): def keys(self):
return self._usr_names return self._usr_names
......
...@@ -307,11 +307,10 @@ class TestUsr(unittest.TestCase): ...@@ -307,11 +307,10 @@ class TestUsr(unittest.TestCase):
def setUp(self): def setUp(self):
self.f = OFFLINE_USR self.f = OFFLINE_USR
def test_str(self): def test_str_flat(self):
print(self.f.events.usr) print(self.f.events.usr)
@unittest.skip def test_keys_flat(self):
def test_keys(self):
self.assertListEqual([ self.assertListEqual([
'RecoQuality', 'RecoNDF', 'CoC', 'ToT', 'ChargeAbove', 'RecoQuality', 'RecoNDF', 'CoC', 'ToT', 'ChargeAbove',
'ChargeBelow', 'ChargeRatio', 'DeltaPosZ', 'FirstPartPosZ', 'ChargeBelow', 'ChargeRatio', 'DeltaPosZ', 'FirstPartPosZ',
...@@ -320,8 +319,7 @@ class TestUsr(unittest.TestCase): ...@@ -320,8 +319,7 @@ class TestUsr(unittest.TestCase):
'ClassficationScore' 'ClassficationScore'
], self.f.events.usr.keys()) ], self.f.events.usr.keys())
@unittest.skip def test_getitem_flat(self):
def test_getitem(self):
assert np.allclose( assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543], [118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.events.usr['CoC']) self.f.events.usr['CoC'])
...@@ -330,7 +328,10 @@ class TestUsr(unittest.TestCase): ...@@ -330,7 +328,10 @@ class TestUsr(unittest.TestCase):
self.f.events.usr['DeltaPosZ']) self.f.events.usr['DeltaPosZ'])
@unittest.skip @unittest.skip
def test_attributes(self): def test_keys_nested(self):
self.assertListEqual(["a"], self.f.events.mc_tracks.usr.keys())
def test_attributes_flat(self):
assert np.allclose( assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543], [118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.events.usr.CoC) self.f.events.usr.CoC)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment