diff --git a/km3io/offline.py b/km3io/offline.py index e39238c61ce523ebf0fc9be623871a8c1524d0bd..604704a48f981f2e34bf42b5b1c617361f1834db 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -323,8 +323,18 @@ class OfflineReader: """ self._fobj = uproot.open(file_path) + self._filename = file_path self._tree = self._fobj[MAIN_TREE_NAME] + def close(self): + self._fobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + @cached_property def events(self): """The `E` branch, containing all offline events.""" diff --git a/km3io/online.py b/km3io/online.py index 20c2dccc11eadb2677f7cab3f9de97aea5f721b9..5884821c8fe74c913e7015e936a1e73a4632ccb7 100644 --- a/km3io/online.py +++ b/km3io/online.py @@ -111,10 +111,20 @@ class OnlineReader: """Reader for online ROOT files""" def __init__(self, filename): self._fobj = uproot.open(filename) + self._filename = filename self._events = None self._timeslices = None self._summaryslices = None + def close(self): + self._fobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + @property def events(self): if self._events is None: diff --git a/tests/test_offline.py b/tests/test_offline.py index 0df2e8f613fdf3c8860ee8f17ed5a4afdc67df0a..c489bfc34f7b589f4fc4c4a4034c74d72140dea3 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -77,6 +77,11 @@ class TestOfflineReader(unittest.TestCase): self.nu = OFFLINE_NUMUCC self.n_events = 10 + def test_context_manager(self): + filename = SAMPLES_DIR / 'aanet_v2.0.0.root' + with OfflineReader(filename) as r: + assert r._filename == filename + def test_number_events(self): assert self.n_events == len(self.r.events) diff --git a/tests/test_online.py b/tests/test_online.py index 35364205d7dd3aad375357d25e8ec4c7e6d825c1..7e0dfdf30fffd427cf5caf66a52e806a130bc272 100644 --- a/tests/test_online.py +++ b/tests/test_online.py @@ -7,6 +7,13 @@ from km3io.online import OnlineReader, get_rate, has_udp_trailer, get_udp_max_se SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples") +class TestOnlineReaderContextManager(unittest.TestCase): + def test_context_manager(self): + filename = os.path.join(SAMPLES_DIR, "daq_v1.0.0.root") + with OnlineReader(filename) as r: + assert r._filename == filename + + class TestOnlineEvents(unittest.TestCase): def setUp(self): self.events = OnlineReader(os.path.join(SAMPLES_DIR,