Skip to content
Snippets Groups Projects
Commit f85a6cff authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Fix indexing

parent 9115d4b8
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
...@@ -594,21 +594,35 @@ class Branch: ...@@ -594,21 +594,35 @@ class Branch:
if isinstance(item, slice): if isinstance(item, slice):
return self.__class__(self._tree, self._mapper, index=item) return self.__class__(self._tree, self._mapper, index=item)
if isinstance(item, int): if isinstance(item, int):
# A bit ugly, but whatever works
if self._mapper.flat: if self._mapper.flat:
return BranchElement( if self._index is None:
self._mapper.name, { dct = {
key: self._branch[self._keymap[key]].array()
for key in self.keys()
}
else:
dct = {
key: key:
self._branch[self._keymap[key]].array()[self._index] self._branch[self._keymap[key]].array()[self._index]
for key in self.keys() for key in self.keys()
})[item] }
return BranchElement(self._mapper.name, dct)[item]
else: else:
return BranchElement( if self._index is None:
self._mapper.name, { dct = {
key: self._branch[self._keymap[key]].array()[item]
for key in self.keys()
}
else:
dct = {
key: key:
self._branch[self._keymap[key]].array()[self._index, self._branch[self._keymap[key]].array()[self._index,
item] item]
for key in self.keys() for key in self.keys()
}) }
return BranchElement(self._mapper.name, dct)
if isinstance(item, tuple): if isinstance(item, tuple):
return self[item[0]][item[1]] return self[item[0]][item[1]]
...@@ -657,16 +671,25 @@ class BranchElement: ...@@ -657,16 +671,25 @@ class BranchElement:
self._name = name self._name = name
self._index = index self._index = index
self.ItemConstructor = namedtuple(self._name[:-1], dct.keys()) self.ItemConstructor = namedtuple(self._name[:-1], dct.keys())
for key, values in dct.items(): if index is None:
setattr(self, key, values[index]) for key, values in dct.items():
setattr(self, key, values)
else:
for key, values in dct.items():
setattr(self, key, values[index])
def __getitem__(self, item): def __getitem__(self, item):
if isinstance(item, slice): if isinstance(item, slice):
return self.__class__(self._name, self._dct, index=item) return self.__class__(self._name, self._dct, index=item)
if isinstance(item, int): if isinstance(item, int):
return self.ItemConstructor( if self._index is None:
**{k: v[self._index][item] return self.ItemConstructor(
for k, v in self._dct.items()}) **{k: v[item]
for k, v in self._dct.items()})
else:
return self.ItemConstructor(
**{k: v[self._index][item]
for k, v in self._dct.items()})
def __repr__(self): def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name) return "<{}[{}]>".format(self.__class__.__name__, self._name)
...@@ -182,6 +182,18 @@ class TestOfflineEvents(unittest.TestCase): ...@@ -182,6 +182,18 @@ class TestOfflineEvents(unittest.TestCase):
self.assertListEqual(self.t_sec[s], list(s_events.t_sec)) self.assertListEqual(self.t_sec[s], list(s_events.t_sec))
self.assertListEqual(self.t_ns[s], list(s_events.t_ns)) self.assertListEqual(self.t_ns[s], list(s_events.t_ns))
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
assert np.allclose(OFFLINE_FILE[s].events.n_hits,
self.events.n_hits[s])
assert np.allclose(self.events[s].n_hits, self.events.n_hits[s])
def test_index_consistency(self):
for i in range(self.n_events):
assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
assert np.allclose(OFFLINE_FILE[i].events.n_hits,
self.events.n_hits[i])
def test_str(self): def test_str(self):
assert str(self.n_events) in str(self.events) assert str(self.n_events) in str(self.events)
...@@ -235,9 +247,32 @@ class TestOfflineHits(unittest.TestCase): ...@@ -235,9 +247,32 @@ class TestOfflineHits(unittest.TestCase):
for idx, t in self.t.items(): for idx, t in self.t.items():
assert np.allclose(t, self.hits.t[idx][:len(t)]) assert np.allclose(t, self.hits.t[idx][:len(t)])
def test_slicing(self):
s = slice(2, 8, 2)
s_hits = self.hits[s]
assert 3 == len(s_hits)
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id[s], list(self.hits.dom_id[idx][s]))
for idx, t in self.t.items():
self.assertListEqual(t[s], list(self.hits.t[idx][s]))
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
for idx in range(3):
assert np.allclose(self.hits.dom_id[idx][s],
self.hits[idx].dom_id[s])
assert np.allclose(OFFLINE_FILE[idx].hits.dom_id[s],
self.hits.dom_id[idx][s])
class TestOfflineTracks(unittest.TestCase):
@unittest.skip @unittest.skip
def test_index_consistency(self):
for i in range(self.n_events):
assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
assert np.allclose(OFFLINE_FILE[i].events.n_hits,
self.events.n_hits[i])
class TestOfflineTracks(unittest.TestCase):
def setUp(self): def setUp(self):
self.tracks = OFFLINE_FILE.tracks self.tracks = OFFLINE_FILE.tracks
self.r_mc = OFFLINE_NUMUCC self.r_mc = OFFLINE_NUMUCC
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment