diff --git a/km3io/offline.py b/km3io/offline.py index 3a04e793af707062eb24364410e86b99c865403f..e694f0a35ddc08a88351959c590f0e0d281bec14 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -266,14 +266,17 @@ class Branch: return object.__getattribute__(self, attr) if attr in self._keymap.keys(): # intercept branch key lookups - out = self._branch[self._keymap[attr]].lazyarray( - basketcache=BASKET_CACHE) - if self._index is not None: - out = out[self._index] - return out + return self.__getkey__(attr) return object.__getattribute__(self, attr) + def __getkey__(self, key): + out = self._branch[self._keymap[key]].lazyarray( + basketcache=BASKET_CACHE) + if self._index is not None: + out = out[self._index] + return out + def __getitem__(self, item): """Slicing magic""" if isinstance(item, (int, slice)): @@ -286,6 +289,9 @@ class Branch: if isinstance(item, tuple): return self[item[0]][item[1]] + if isinstance(item, str): + return self.__getkey__(item) + return self.__class__(self._tree, self._mapper, index=np.array(item), diff --git a/tests/test_offline.py b/tests/test_offline.py index 5500416885f7acb25a3d8684321a21abaac3a511..86f6390407e899fa4cbd913665258e8861faf0b2 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -63,10 +63,10 @@ class TestOfflineEvents(unittest.TestCase): self.assertListEqual(self.t_ns, list(self.events.t_ns)) def test_keys(self): - self.assertListEqual(self.n_hits, list(self.events['n_hits'])) - self.assertListEqual(self.n_tracks, list(self.events['n_tracks'])) - self.assertListEqual(self.t_sec, list(self.events['t_sec'])) - self.assertListEqual(self.t_ns, list(self.events['t_ns'])) + assert np.allclose(self.n_hits, self.events['n_hits']) + assert np.allclose(self.n_tracks, self.events['n_tracks']) + assert np.allclose(self.t_sec, self.events['t_sec']) + assert np.allclose(self.t_ns, self.events['t_ns']) def test_slicing(self): s = slice(2, 8, 2)