From f85a6cff691d0f999dc15a6b9a3bbcc21018e446 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Fri, 6 Mar 2020 17:18:04 +0100 Subject: [PATCH] Fix indexing --- km3io/offline.py | 45 ++++++++++++++++++++++++++++++++----------- tests/test_offline.py | 37 ++++++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/km3io/offline.py b/km3io/offline.py index 4368ee9..7aa8708 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -594,21 +594,35 @@ class Branch: if isinstance(item, slice): return self.__class__(self._tree, self._mapper, index=item) if isinstance(item, int): + # A bit ugly, but whatever works if self._mapper.flat: - return BranchElement( - self._mapper.name, { + if self._index is None: + dct = { + key: self._branch[self._keymap[key]].array() + for key in self.keys() + } + else: + dct = { key: self._branch[self._keymap[key]].array()[self._index] for key in self.keys() - })[item] + } + return BranchElement(self._mapper.name, dct)[item] else: - return BranchElement( - self._mapper.name, { + if self._index is None: + dct = { + key: self._branch[self._keymap[key]].array()[item] + for key in self.keys() + } + else: + dct = { key: self._branch[self._keymap[key]].array()[self._index, item] for key in self.keys() - }) + } + return BranchElement(self._mapper.name, dct) + if isinstance(item, tuple): return self[item[0]][item[1]] @@ -657,16 +671,25 @@ class BranchElement: self._name = name self._index = index self.ItemConstructor = namedtuple(self._name[:-1], dct.keys()) - for key, values in dct.items(): - setattr(self, key, values[index]) + if index is None: + for key, values in dct.items(): + setattr(self, key, values) + else: + for key, values in dct.items(): + setattr(self, key, values[index]) def __getitem__(self, item): if isinstance(item, slice): return self.__class__(self._name, self._dct, index=item) if isinstance(item, int): - return self.ItemConstructor( - **{k: v[self._index][item] - for k, v in self._dct.items()}) + if self._index is None: + return self.ItemConstructor( + **{k: v[item] + for k, v in self._dct.items()}) + else: + return self.ItemConstructor( + **{k: v[self._index][item] + for k, v in self._dct.items()}) def __repr__(self): return "<{}[{}]>".format(self.__class__.__name__, self._name) diff --git a/tests/test_offline.py b/tests/test_offline.py index fb8ed19..fee41d7 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -182,6 +182,18 @@ class TestOfflineEvents(unittest.TestCase): self.assertListEqual(self.t_sec[s], list(s_events.t_sec)) self.assertListEqual(self.t_ns[s], list(s_events.t_ns)) + def test_slicing_consistency(self): + for s in [slice(1, 3), slice(2, 7, 3)]: + assert np.allclose(OFFLINE_FILE[s].events.n_hits, + self.events.n_hits[s]) + assert np.allclose(self.events[s].n_hits, self.events.n_hits[s]) + + def test_index_consistency(self): + for i in range(self.n_events): + assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) + assert np.allclose(OFFLINE_FILE[i].events.n_hits, + self.events.n_hits[i]) + def test_str(self): assert str(self.n_events) in str(self.events) @@ -235,9 +247,32 @@ class TestOfflineHits(unittest.TestCase): for idx, t in self.t.items(): assert np.allclose(t, self.hits.t[idx][:len(t)]) + def test_slicing(self): + s = slice(2, 8, 2) + s_hits = self.hits[s] + assert 3 == len(s_hits) + for idx, dom_id in self.dom_id.items(): + self.assertListEqual(dom_id[s], list(self.hits.dom_id[idx][s])) + for idx, t in self.t.items(): + self.assertListEqual(t[s], list(self.hits.t[idx][s])) + + def test_slicing_consistency(self): + for s in [slice(1, 3), slice(2, 7, 3)]: + for idx in range(3): + assert np.allclose(self.hits.dom_id[idx][s], + self.hits[idx].dom_id[s]) + assert np.allclose(OFFLINE_FILE[idx].hits.dom_id[s], + self.hits.dom_id[idx][s]) -class TestOfflineTracks(unittest.TestCase): @unittest.skip + def test_index_consistency(self): + for i in range(self.n_events): + assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) + assert np.allclose(OFFLINE_FILE[i].events.n_hits, + self.events.n_hits[i]) + + +class TestOfflineTracks(unittest.TestCase): def setUp(self): self.tracks = OFFLINE_FILE.tracks self.r_mc = OFFLINE_NUMUCC -- GitLab