diff --git a/km3io/offline.py b/km3io/offline.py index f6c5fa826fd1c539daaa66a97b248f7b13c0d770..692a47b326394c616fd2678ff5626ae85e008287 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -127,7 +127,7 @@ class OfflineReader: class Usr: - """Helper class to access AAObject `usr`` stuff""" + """Helper class to access AAObject `usr` stuff""" def __init__(self, mapper, branch, index=None): self._mapper = mapper self._name = mapper.name @@ -136,29 +136,30 @@ class Usr: self._usr_names = [] self._usr_idx_lookup = {} + self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr' + + self._initialise() + + def _initialise(self): try: - branch['usr'] + self._branch[self._usr_key] # This will raise a KeyError in old aanet files # which has a different strucuter and key (usr_data) - # We do not support those + # We do not support those (yet) except (KeyError, IndexError): print("The `usr` fields could not be parsed for the '{}' branch.". format(self._name)) return - if mapper.flat: + if self._mapper.flat: self._initialise_flat() - else: - # self._initialise_nested() - # branch[self._mapper.key + '.usr'] - pass def _initialise_flat(self): # Here, we assume that every event has the same names in the same order # to massively increase the performance. This needs triple check if # it's always the case. self._usr_names = [ - n.decode("utf-8") for n in self._branch['usr_names'].lazyarray( + n.decode("utf-8") for n in self._branch[self._usr_key + '_names'].lazyarray( basketcache=BASKET_CACHE)[0] ] self._usr_idx_lookup = { @@ -166,7 +167,7 @@ class Usr: for index, name in enumerate(self._usr_names) } - data = self._branch['usr'].lazyarray(basketcache=BASKET_CACHE) + data = self._branch[self._usr_key].lazyarray(basketcache=BASKET_CACHE) if self._index is not None: data = data[self._index] @@ -176,17 +177,16 @@ class Usr: for name in self._usr_names: setattr(self, name, self[name]) - def _initialise_nested(self): - self._usr_names = [ - n.decode("utf-8") for n in self.branch['usr_names'].lazyarray( - # TODO this will be fixed soon in uproot, - # see https://github.com/scikit-hep/uproot/issues/465 - uproot.asgenobj( - uproot.SimpleArray(uproot.STLVector(uproot.STLString())), - self.branch['usr_names']._context, 6), - basketcache=BASKET_CACHE)[0] - ] - self.__getitem__ = self.__getitem_nested__ + # def _initialise_nested(self): + # self._usr_names = [ + # n.decode("utf-8") for n in self.branch['usr_names'].lazyarray( + # # TODO this will be fixed soon in uproot, + # # see https://github.com/scikit-hep/uproot/issues/465 + # uproot.asgenobj( + # uproot.SimpleArray(uproot.STLVector(uproot.STLString())), + # self.branch['usr_names']._context, 6), + # basketcache=BASKET_CACHE)[0] + # ] def __getitem__(self, item): if self._mapper.flat: @@ -199,6 +199,19 @@ class Usr: else: return self._usr_data[:, self._usr_idx_lookup[item]] + def __getitem_nested__(self, item): + data = self._branch[self._usr_key + '_names'].lazyarray( + # TODO this will be fixed soon in uproot, + # see https://github.com/scikit-hep/uproot/issues/465 + uproot.asgenobj( + uproot.SimpleArray(uproot.STLVector(uproot.STLString())), + self._branch[self._usr_key + '_names']._context, 6), + basketcache=BASKET_CACHE) + if self._index is None: + return data + else: + return data[self._index] + def keys(self): return self._usr_names diff --git a/tests/test_offline.py b/tests/test_offline.py index 5b3a61bb50b67c0cd24f849e5d5c353acd6ca6ff..635a34be23f1d82086b12f933aeb387b677bb822 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -307,11 +307,10 @@ class TestUsr(unittest.TestCase): def setUp(self): self.f = OFFLINE_USR - def test_str(self): + def test_str_flat(self): print(self.f.events.usr) - @unittest.skip - def test_keys(self): + def test_keys_flat(self): self.assertListEqual([ 'RecoQuality', 'RecoNDF', 'CoC', 'ToT', 'ChargeAbove', 'ChargeBelow', 'ChargeRatio', 'DeltaPosZ', 'FirstPartPosZ', @@ -320,8 +319,7 @@ class TestUsr(unittest.TestCase): 'ClassficationScore' ], self.f.events.usr.keys()) - @unittest.skip - def test_getitem(self): + def test_getitem_flat(self): assert np.allclose( [118.6302815337638, 44.33580521344907, 99.93916717621543], self.f.events.usr['CoC']) @@ -330,7 +328,10 @@ class TestUsr(unittest.TestCase): self.f.events.usr['DeltaPosZ']) @unittest.skip - def test_attributes(self): + def test_keys_nested(self): + self.assertListEqual(["a"], self.f.events.mc_tracks.usr.keys()) + + def test_attributes_flat(self): assert np.allclose( [118.6302815337638, 44.33580521344907, 99.93916717621543], self.f.events.usr.CoC)