diff --git a/km3io/online.py b/km3io/online.py index 438e84517779152907bdb5a15645a00e0c114538..ad080f96a31e9f84a123f29aff60600406271f4e 100644 --- a/km3io/online.py +++ b/km3io/online.py @@ -96,6 +96,22 @@ class SummarysliceReader: *[getattr(chunk, bc.branch_address) for bc in self._subbranches] ) + def __getitem__(self, idx): + if idx >= len(self) or idx < -len(self): + raise IndexError("Chunk index out of range") + + s = self._step_size + + if idx < 0: + idx = len(self) + idx + + chunk = self._branch.arrays( + dict(self._subbranches), entry_start=idx * s, entry_stop=(idx + 1) * s + ) + return self.ChunksConstructor( + *[getattr(chunk, bc.branch_address) for bc in self._subbranches] + ) + def __iter__(self): self._chunks = self._chunks_generator() return self @@ -107,7 +123,14 @@ class SummarysliceReader: return int(np.ceil(self._branch.num_entries / self._step_size)) def __repr__(self): - return f"<{self.__class__.__name__} {self._branch.num_entries} items, step_size={self._step_size} ({len(self)} chunks)>" + step_size = self._step_size + n_items = self._branch.num_entries + cls_name = self.__class__.__name__ + n_chunks = len(self) + return ( + f"<{cls_name} {n_items} items, step_size={step_size} " + f"({n_chunks} chunk{'' if n_chunks == 1 else 's'})>" + ) @nb.vectorize( diff --git a/tests/test_online.py b/tests/test_online.py index 09fb9acc7438c78858e968e7c495beef9fd54d8e..e7efc6ec6d5727a4201baf51ee722bfcdeb73100 100644 --- a/tests/test_online.py +++ b/tests/test_online.py @@ -749,6 +749,46 @@ class TestSummarysliceReader(unittest.TestCase): sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=3) assert 1 == len(sr) + def test_getitem_raises_when_out_of_range(self): + sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=1) + with self.assertRaises(IndexError): + sr[123] + with self.assertRaises(IndexError): + sr[-123] + with self.assertRaises(IndexError): + sr[3] + sr[-3] # this should still work, gives the first element in this case + with self.assertRaises(IndexError): + sr[-4] + + def test_getitem(self): + sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=1) + for idx in range(len(sr)): + assert len(sr[idx].headers) == 1 + assert len(sr[idx].slices) == 1 + + first_frame_index = sr[0].headers.frame_index # 126 + last_frame_index = sr[2].headers.frame_index # 128 + + assert 126 == first_frame_index + assert 128 == last_frame_index + + sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=2) + assert len(sr[0].headers) == 2 + assert len(sr[0].slices) == 2 + assert len(sr[1].headers) == 1 + assert len(sr[1].slices) == 1 + with self.assertRaises(IndexError): + assert len(sr[2].headers) == 0 + assert len(sr[2].slices) == 0 + + assert first_frame_index == sr[0].headers[0].frame_index + assert last_frame_index == sr[1].headers[0].frame_index + + + assert last_frame_index == sr[-1].headers[0].frame_index + assert first_frame_index == sr[-2].headers[0].frame_index + def test_iterate_with_step_size_one(self): sr = SummarysliceReader(data_path("online/km3net_online.root"), step_size=1) i = 0