From 88d3fa79fb376a4053153f8889cd31cd8d6b5393 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Sun, 26 Apr 2020 08:50:41 +0200
Subject: [PATCH] Add close and context manager

---
 km3io/offline.py      | 10 ++++++++++
 km3io/online.py       | 10 ++++++++++
 tests/test_offline.py |  5 +++++
 tests/test_online.py  |  7 +++++++
 4 files changed, 32 insertions(+)

diff --git a/km3io/offline.py b/km3io/offline.py
index e39238c..604704a 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -323,8 +323,18 @@ class OfflineReader:
 
         """
         self._fobj = uproot.open(file_path)
+        self._filename = file_path
         self._tree = self._fobj[MAIN_TREE_NAME]
 
+    def close(self):
+        self._fobj.close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args):
+        self.close()
+
     @cached_property
     def events(self):
         """The `E` branch, containing all offline events."""
diff --git a/km3io/online.py b/km3io/online.py
index 20c2dcc..5884821 100644
--- a/km3io/online.py
+++ b/km3io/online.py
@@ -111,10 +111,20 @@ class OnlineReader:
     """Reader for online ROOT files"""
     def __init__(self, filename):
         self._fobj = uproot.open(filename)
+        self._filename = filename
         self._events = None
         self._timeslices = None
         self._summaryslices = None
 
+    def close(self):
+        self._fobj.close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args):
+        self.close()
+
     @property
     def events(self):
         if self._events is None:
diff --git a/tests/test_offline.py b/tests/test_offline.py
index 0df2e8f..c489bfc 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -77,6 +77,11 @@ class TestOfflineReader(unittest.TestCase):
         self.nu = OFFLINE_NUMUCC
         self.n_events = 10
 
+    def test_context_manager(self):
+        filename = SAMPLES_DIR / 'aanet_v2.0.0.root'
+        with OfflineReader(filename) as r:
+            assert r._filename == filename
+
     def test_number_events(self):
         assert self.n_events == len(self.r.events)
 
diff --git a/tests/test_online.py b/tests/test_online.py
index 3536420..7e0dfdf 100644
--- a/tests/test_online.py
+++ b/tests/test_online.py
@@ -7,6 +7,13 @@ from km3io.online import OnlineReader, get_rate, has_udp_trailer, get_udp_max_se
 SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples")
 
 
+class TestOnlineReaderContextManager(unittest.TestCase):
+    def test_context_manager(self):
+        filename = os.path.join(SAMPLES_DIR, "daq_v1.0.0.root")
+        with OnlineReader(filename) as r:
+            assert r._filename == filename
+
+
 class TestOnlineEvents(unittest.TestCase):
     def setUp(self):
         self.events = OnlineReader(os.path.join(SAMPLES_DIR,
-- 
GitLab