diff --git a/km3io/offline.py b/km3io/offline.py index 8bd986d53834e73ff8d5a2f66e8080333c2462b6..7d42d99fbf41f4e751ba7214c79cf07d044f0cb3 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -1,5 +1,6 @@ import uproot import numpy as np + import warnings import km3io.definitions.trigger import km3io.definitions.fitparameters @@ -794,6 +795,11 @@ class OfflineTracks: if isinstance(item, int): return OfflineTrack(self._keys, [v[item] for v in self._values], fitparameters=self._fitparameters) + elif isinstance(item, list) and all(isinstance(i, str) for i in item): + cols = item + data = [getattr(self, c) for c in cols] + dtype = dict(names=cols, formats=[d.dtype for d in data]) + return np.rec.fromarrays(data, dtype=dtype) else: return OfflineTracks(self._keys, [v[item] for v in self._values], fitparameters=self._fitparameters) diff --git a/tests/test_offline.py b/tests/test_offline.py index 5257b4d99b834f5a33f7624041ce145351e7d674..69867ed8e708a54e954b971950fddff126d175d7 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -419,6 +419,11 @@ class TestOfflineTracks(unittest.TestCase): self.assertListEqual(list(tracks.E[:, 0][_slice]), list(tracks[_slice].E[:, 0])) + def test_slicing_via_columns(self): + tracks = self.tracks + data = tracks[['E', 'lik']] + assert 1 == data.E + class TestOfflineTrack(unittest.TestCase): def setUp(self):