From 41e151accb7d4651f2d8d2eaeeef500d85c27676 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Tue, 10 Mar 2020 10:17:49 +0100 Subject: [PATCH] Allow dict-key access to attributes --- km3io/offline.py | 16 +++++++++++----- tests/test_offline.py | 8 ++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/km3io/offline.py b/km3io/offline.py index 3a04e79..e694f0a 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -266,14 +266,17 @@ class Branch: return object.__getattribute__(self, attr) if attr in self._keymap.keys(): # intercept branch key lookups - out = self._branch[self._keymap[attr]].lazyarray( - basketcache=BASKET_CACHE) - if self._index is not None: - out = out[self._index] - return out + return self.__getkey__(attr) return object.__getattribute__(self, attr) + def __getkey__(self, key): + out = self._branch[self._keymap[key]].lazyarray( + basketcache=BASKET_CACHE) + if self._index is not None: + out = out[self._index] + return out + def __getitem__(self, item): """Slicing magic""" if isinstance(item, (int, slice)): @@ -286,6 +289,9 @@ class Branch: if isinstance(item, tuple): return self[item[0]][item[1]] + if isinstance(item, str): + return self.__getkey__(item) + return self.__class__(self._tree, self._mapper, index=np.array(item), diff --git a/tests/test_offline.py b/tests/test_offline.py index 5500416..86f6390 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -63,10 +63,10 @@ class TestOfflineEvents(unittest.TestCase): self.assertListEqual(self.t_ns, list(self.events.t_ns)) def test_keys(self): - self.assertListEqual(self.n_hits, list(self.events['n_hits'])) - self.assertListEqual(self.n_tracks, list(self.events['n_tracks'])) - self.assertListEqual(self.t_sec, list(self.events['t_sec'])) - self.assertListEqual(self.t_ns, list(self.events['t_ns'])) + assert np.allclose(self.n_hits, self.events['n_hits']) + assert np.allclose(self.n_tracks, self.events['n_tracks']) + assert np.allclose(self.t_sec, self.events['t_sec']) + assert np.allclose(self.t_ns, self.events['t_ns']) def test_slicing(self): s = slice(2, 8, 2) -- GitLab