diff --git a/km3io/tools.py b/km3io/tools.py index 5448d61bf092407e186c5dccb0184b0a8c499ae9..642e37082a3cb5cc9899c009efdb7b38f17e4ad2 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -159,6 +159,9 @@ class Branch: if isinstance(item, str): return self.__getkey__(item) + if item.__class__.__name__ == "ChunkedArray": + item = np.array(item) + return self.__class__(self._tree, self._mapper, index_chain=self._index_chain + [item], diff --git a/tests/test_offline.py b/tests/test_offline.py index 1506c8700c22ff77b3a051ff1f1d9012b44c6868..b6971d7d8850910ca5092b5814a748ebc756184d 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -179,6 +179,14 @@ class TestOfflineEvents(unittest.TestCase): assert np.allclose(self.events.hits[3:5][1][4].dom_id, self.events[3:5][1][4].hits.dom_id) + def test_fancy_indexing(self): + mask = self.events.n_tracks > 55 + tracks = self.events.tracks[mask] + first_tracks = tracks[:, 0] + assert 8 == len(first_tracks) + assert 8 == len(first_tracks.rec_stages) + assert 8 == len(first_tracks.lik) + def test_iteration(self): i = 0 for event in self.events: