From d82c3c5da170818bcd74e234499fbb9227bc8064 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Sat, 7 Mar 2020 13:37:19 +0100
Subject: [PATCH] Lightning fast slicing

---
 km3io/offline.py      | 51 ++++++++++++++++++++++++-------------------
 tests/test_offline.py |  4 ----
 2 files changed, 29 insertions(+), 26 deletions(-)

diff --git a/km3io/offline.py b/km3io/offline.py
index b2ce7b3..eb285f6 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -8,6 +8,7 @@ from .definitions import mc_header
 MAIN_TREE_NAME = "E"
 # 110 MB based on the size of the largest basket found so far in km3net
 BASKET_CACHE_SIZE = 110 * 1024**2
+BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
 
 BranchMapper = namedtuple(
     "BranchMapper",
@@ -81,9 +82,7 @@ class OfflineReader:
         if file_path is not None:
             self._fobj = uproot.open(file_path)
             self._tree = self._fobj[MAIN_TREE_NAME]
-            self._data = self._tree.lazyarrays(
-                basketcache=uproot.cache.ThreadSafeArrayCache(
-                    BASKET_CACHE_SIZE))
+            self._data = self._tree.lazyarrays(basketcache=BASKET_CACHE)
         else:
             self._fobj = fobj
             self._tree = self._fobj[MAIN_TREE_NAME]
@@ -121,8 +120,7 @@ class OfflineReader:
             return len(tree)
         else:
             return len(
-                tree.lazyarrays(basketcache=uproot.cache.ThreadSafeArrayCache(
-                    BASKET_CACHE_SIZE))[self.index])
+                tree.lazyarrays(basketcache=BASKET_CACHE)[self.index])
 
     @cached_property
     def header(self):
@@ -480,8 +478,6 @@ class Usr:
         # Here, we assume that every event has the same names in the same order
         # to massively increase the performance. This needs triple check if it's
         # always the case; the usr-format is simply a very bad design.
-        # print("initialising usr for {}".format(name))
-        # print("Setting up usr")
         self._name = name
         try:
             tree['usr']  # This will raise a KeyError in old aanet files
@@ -493,20 +489,15 @@ class Usr:
         except (KeyError, IndexError):  # e.g. old aanet files
             self._usr_names = []
         else:
-            # print(" checking usr data")
             self._usr_idx_lookup = {
                 name: index
                 for index, name in enumerate(self._usr_names)
             }
-            data = tree['usr'].lazyarray(
-                basketcache=uproot.cache.ThreadSafeArrayCache(
-                    BASKET_CACHE_SIZE))
+            data = tree['usr'].lazyarray(basketcache=BASKET_CACHE)
             if index is not None:
                 data = data[index]
             self._usr_data = data
-            # print("    adding attributes")
             for name in self._usr_names:
-                # print("  setting {}".format(name))
                 setattr(self, name, self[name])
 
     def __getitem__(self, item):
@@ -567,12 +558,14 @@ class Header:
 
 class Branch:
     """Branch accessor class"""
+    # @profile
     def __init__(self,
                  tree,
                  mapper,
                  index=None,
                  subbranches=None,
-                 subbranchmaps=None):
+                 subbranchmaps=None,
+                 keymap=None):
         self._tree = tree
         self._mapper = mapper
         self._index = index
@@ -580,7 +573,10 @@ class Branch:
         self._branch = tree[mapper.key]
         self._subbranches = []
 
-        self._initialise_keys()  #
+        if keymap is None:
+            self._initialise_keys()  #
+        else:
+            self._keymap = keymap
 
         if subbranches is not None:
             self._subbranches = subbranches
@@ -593,6 +589,7 @@ class Branch:
         for subbranch in self._subbranches:
             setattr(self, subbranch._mapper.name, subbranch)
 
+    # @profile
     def _initialise_keys(self):
         """Create the keymap and instance attributes for branch keys"""
         keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
@@ -607,8 +604,7 @@ class Branch:
             del self._keymap[k]
 
         for key in self._keymap.keys():
-            # print("setting", self._mapper.name, key)
-            setattr(self, key, self[key])
+            setattr(self, key, None)
 
     def keys(self):
         return self._keymap.keys()
@@ -617,16 +613,29 @@ class Branch:
     def usr(self):
         return Usr(self._mapper.name, self._branch, index=self._index)
 
+    def __getattribute__(self, attr):
+        if attr.startswith("_"):  # let all private and magic methods pass
+            return object.__getattribute__(self, attr)
+        if attr in self._keymap.keys():  # intercept branch key lookups
+            item = self._keymap[attr]
+
+            out = self._branch[item].lazyarray(
+                basketcache=BASKET_CACHE)
+            if self._index is not None:
+                out = out[self._index]
+            return out
+        return object.__getattribute__(self, attr)
+
+    # @profile
     def __getitem__(self, item):
         """Slicing magic a la numpy"""
-        print("Getting item '{}'".format(item))
         if isinstance(item, slice):
             return self.__class__(self._tree,
                                   self._mapper,
                                   index=item,
                                   subbranches=self._subbranches)
         if isinstance(item, int):
-            # A bit ugly, but whatever works
+            # TODO refactor this
             if self._mapper.flat:
                 if self._index is None:
                     dct = {
@@ -665,8 +674,7 @@ class Branch:
             item = self._keymap[item]
 
             out = self._branch[item].lazyarray(
-                basketcache=uproot.cache.ThreadSafeArrayCache(
-                    BASKET_CACHE_SIZE))
+                basketcache=BASKET_CACHE)
             if self._index is not None:
                 out = out[self._index]
             return out
@@ -705,7 +713,6 @@ class BranchElement:
         The slice mask to be applied to the sub-arrays
     """
     def __init__(self, name, dct, index=None, subbranches=[]):
-        print("Creating branch element '{}'".format(name))
         self._dct = dct
         self._name = name
         self._index = index
diff --git a/tests/test_offline.py b/tests/test_offline.py
index 4d409f4..c5ec14c 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -184,15 +184,11 @@ class TestOfflineEvents(unittest.TestCase):
 
     def test_slicing_consistency(self):
         for s in [slice(1, 3), slice(2, 7, 3)]:
-            assert np.allclose(OFFLINE_FILE[s].events.n_hits,
-                               self.events.n_hits[s])
             assert np.allclose(self.events[s].n_hits, self.events.n_hits[s])
 
     def test_index_consistency(self):
         for i in [0,2,5]:
             assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
-            assert np.allclose(OFFLINE_FILE[i].events.n_hits,
-                               self.events.n_hits[i])
 
     def test_str(self):
         assert str(self.n_events) in str(self.events)
-- 
GitLab