Skip to content
Snippets Groups Projects
Commit 88d3fa79 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Add close and context manager

parent ac6e0c1e
No related branches found
No related tags found
1 merge request!30Resolve "Closing the files?"
Pipeline #10812 failed
...@@ -323,8 +323,18 @@ class OfflineReader: ...@@ -323,8 +323,18 @@ class OfflineReader:
""" """
self._fobj = uproot.open(file_path) self._fobj = uproot.open(file_path)
self._filename = file_path
self._tree = self._fobj[MAIN_TREE_NAME] self._tree = self._fobj[MAIN_TREE_NAME]
def close(self):
self._fobj.close()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
@cached_property @cached_property
def events(self): def events(self):
"""The `E` branch, containing all offline events.""" """The `E` branch, containing all offline events."""
......
...@@ -111,10 +111,20 @@ class OnlineReader: ...@@ -111,10 +111,20 @@ class OnlineReader:
"""Reader for online ROOT files""" """Reader for online ROOT files"""
def __init__(self, filename): def __init__(self, filename):
self._fobj = uproot.open(filename) self._fobj = uproot.open(filename)
self._filename = filename
self._events = None self._events = None
self._timeslices = None self._timeslices = None
self._summaryslices = None self._summaryslices = None
def close(self):
self._fobj.close()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
@property @property
def events(self): def events(self):
if self._events is None: if self._events is None:
......
...@@ -77,6 +77,11 @@ class TestOfflineReader(unittest.TestCase): ...@@ -77,6 +77,11 @@ class TestOfflineReader(unittest.TestCase):
self.nu = OFFLINE_NUMUCC self.nu = OFFLINE_NUMUCC
self.n_events = 10 self.n_events = 10
def test_context_manager(self):
filename = SAMPLES_DIR / 'aanet_v2.0.0.root'
with OfflineReader(filename) as r:
assert r._filename == filename
def test_number_events(self): def test_number_events(self):
assert self.n_events == len(self.r.events) assert self.n_events == len(self.r.events)
......
...@@ -7,6 +7,13 @@ from km3io.online import OnlineReader, get_rate, has_udp_trailer, get_udp_max_se ...@@ -7,6 +7,13 @@ from km3io.online import OnlineReader, get_rate, has_udp_trailer, get_udp_max_se
SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples") SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples")
class TestOnlineReaderContextManager(unittest.TestCase):
def test_context_manager(self):
filename = os.path.join(SAMPLES_DIR, "daq_v1.0.0.root")
with OnlineReader(filename) as r:
assert r._filename == filename
class TestOnlineEvents(unittest.TestCase): class TestOnlineEvents(unittest.TestCase):
def setUp(self): def setUp(self):
self.events = OnlineReader(os.path.join(SAMPLES_DIR, self.events = OnlineReader(os.path.join(SAMPLES_DIR,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment