From 3a913f81922c18939728a33deb1da6180f6b69f1 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Wed, 4 Mar 2020 15:48:57 +0100
Subject: [PATCH] Cleanup OfflineReader interface

---
 km3io/offline.py | 25 ++++++++++++++-----------
 1 file changed, 14 insertions(+), 11 deletions(-)

diff --git a/km3io/offline.py b/km3io/offline.py
index 4da2a7c..70f6887 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -201,7 +201,7 @@ class OfflineKeys:
 
 class OfflineReader:
     """reader for offline ROOT files"""
-    def __init__(self, file_path=None):
+    def __init__(self, file_path=None, fobj=None, data=None):
         """ OfflineReader class is an offline ROOT file wrapper
 
         Parameters
@@ -209,34 +209,37 @@ class OfflineReader:
         file_path : path-like object
             Path to the file of interest. It can be a str or any python
             path-like object that points to the file.
+
         """
         self._file_path = file_path
 
         if file_path is not None:
-            self._tree = uproot.open(self._file_path)[MAIN_TREE_NAME]
+            self._fobj = uproot.open(self._file_path)
+            self._tree = self._fobj[MAIN_TREE_NAME]
             self._data = self._tree.lazyarrays(
                 basketcache=uproot.cache.ThreadSafeArrayCache(
                     BASKET_CACHE_SIZE))
+        else:
+            self._fobj = fobj
+            self._tree = self._fobj[MAIN_TREE_NAME]
+            self._data = data
 
     @classmethod
-    def from_tree(cls, tree, data):
-        instance = cls()
-        instance._tree = tree
-        instance._data = data
+    def from_index(cls, source, index):
+        instance = cls(fobj=source._fobj, data=source._data[index])
         return instance
 
-    def __getitem__(self, item):
-        return OfflineReader.from_tree(tree=self._tree, data=self._data[item])
+    def __getitem__(self, index):
+        return OfflineReader.from_index(source=self, index=index)
 
     def __len__(self):
         return len(self._data)
 
     @cached_property
     def header(self):
-        fobj = uproot.open(self._file_path)
-        if 'Head' in fobj:
+        if 'Head' in self._fobj:
             header = {}
-            for n, x in fobj['Head']._map_3c_string_2c_string_3e_.items():
+            for n, x in self._fobj['Head']._map_3c_string_2c_string_3e_.items():
                 header[n.decode("utf-8")] = x.decode("utf-8").strip()
             return header
         else:
-- 
GitLab