From 88d3fa79fb376a4053153f8889cd31cd8d6b5393 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Sun, 26 Apr 2020 08:50:41 +0200 Subject: [PATCH] Add close and context manager --- km3io/offline.py | 10 ++++++++++ km3io/online.py | 10 ++++++++++ tests/test_offline.py | 5 +++++ tests/test_online.py | 7 +++++++ 4 files changed, 32 insertions(+) diff --git a/km3io/offline.py b/km3io/offline.py index e39238c..604704a 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -323,8 +323,18 @@ class OfflineReader: """ self._fobj = uproot.open(file_path) + self._filename = file_path self._tree = self._fobj[MAIN_TREE_NAME] + def close(self): + self._fobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + @cached_property def events(self): """The `E` branch, containing all offline events.""" diff --git a/km3io/online.py b/km3io/online.py index 20c2dcc..5884821 100644 --- a/km3io/online.py +++ b/km3io/online.py @@ -111,10 +111,20 @@ class OnlineReader: """Reader for online ROOT files""" def __init__(self, filename): self._fobj = uproot.open(filename) + self._filename = filename self._events = None self._timeslices = None self._summaryslices = None + def close(self): + self._fobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + @property def events(self): if self._events is None: diff --git a/tests/test_offline.py b/tests/test_offline.py index 0df2e8f..c489bfc 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -77,6 +77,11 @@ class TestOfflineReader(unittest.TestCase): self.nu = OFFLINE_NUMUCC self.n_events = 10 + def test_context_manager(self): + filename = SAMPLES_DIR / 'aanet_v2.0.0.root' + with OfflineReader(filename) as r: + assert r._filename == filename + def test_number_events(self): assert self.n_events == len(self.r.events) diff --git a/tests/test_online.py b/tests/test_online.py index 3536420..7e0dfdf 100644 --- a/tests/test_online.py +++ b/tests/test_online.py @@ -7,6 +7,13 @@ from km3io.online import OnlineReader, get_rate, has_udp_trailer, get_udp_max_se SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples") +class TestOnlineReaderContextManager(unittest.TestCase): + def test_context_manager(self): + filename = os.path.join(SAMPLES_DIR, "daq_v1.0.0.root") + with OnlineReader(filename) as r: + assert r._filename == filename + + class TestOnlineEvents(unittest.TestCase): def setUp(self): self.events = OnlineReader(os.path.join(SAMPLES_DIR, -- GitLab