From 1cf61e8848eeb0ddfbd737f65b7821d70de235e5 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Sun, 29 Mar 2020 10:36:19 +0200
Subject: [PATCH] Use tools

---
 km3io/offline.py      | 247 ++----------------------------------------
 tests/test_offline.py |  21 +---
 2 files changed, 10 insertions(+), 258 deletions(-)

diff --git a/km3io/offline.py b/km3io/offline.py
index 692a47b..af884ed 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -1,18 +1,11 @@
 from collections import namedtuple
 import uproot
-import numpy as np
 import warnings
 from .definitions import mc_header
+from .tools import Branch, BranchMapper, cached_property, _to_num
 
 MAIN_TREE_NAME = "E"
-# 110 MB based on the size of the largest basket found so far in km3net
-BASKET_CACHE_SIZE = 110 * 1024**2
-BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
-EXCLUDE_KEYS = set(["AAObject", "t", "fBits", "fUniqueID"])
-
-BranchMapper = namedtuple(
-    "BranchMapper",
-    ['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
+EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
 
 
 def _nested_mapper(key):
@@ -26,7 +19,7 @@ EVENTS_MAP = BranchMapper(name="events",
                               't_sec': 't.fSec',
                               't_ns': 't.fNanoSec'
                           },
-                          exclude=[],
+                          exclude=EXCLUDE_KEYS,
                           update={
                               'n_hits': 'hits',
                               'n_mc_hits': 'mc_hits',
@@ -41,14 +34,14 @@ SUBBRANCH_MAPS = [
         name="tracks",
         key="trks",
         extra={},
-        exclude=['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
+        exclude=EXCLUDE_KEYS + ['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
         update={},
         attrparser=_nested_mapper,
         flat=False),
     BranchMapper(name="mc_tracks",
                  key="mc_trks",
                  extra={},
-                 exclude=[
+                 exclude=EXCLUDE_KEYS + [
                      'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages',
                      'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
                  ],
@@ -58,7 +51,7 @@ SUBBRANCH_MAPS = [
     BranchMapper(name="hits",
                  key="hits",
                  extra={},
-                 exclude=[
+                 exclude=EXCLUDE_KEYS + [
                      'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a',
                      'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
                  ],
@@ -68,7 +61,7 @@ SUBBRANCH_MAPS = [
     BranchMapper(name="mc_hits",
                  key="mc_hits",
                  extra={},
-                 exclude=[
+                 exclude=EXCLUDE_KEYS + [
                      'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id',
                      'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig',
                      'mc_hits.fUniqueID', 'mc_hits.fBits'
@@ -79,18 +72,6 @@ SUBBRANCH_MAPS = [
 ]
 
 
-class cached_property:
-    """A simple cache decorator for properties."""
-    def __init__(self, function):
-        self.function = function
-
-    def __get__(self, obj, cls):
-        if obj is None:
-            return self
-        prop = obj.__dict__[self.function.__name__] = self.function(obj)
-        return prop
-
-
 class OfflineReader:
     """reader for offline ROOT files"""
     def __init__(self, file_path=None):
@@ -126,113 +107,6 @@ class OfflineReader:
             warnings.warn("Your file header has an unsupported format")
 
 
-class Usr:
-    """Helper class to access AAObject `usr` stuff"""
-    def __init__(self, mapper, branch, index=None):
-        self._mapper = mapper
-        self._name = mapper.name
-        self._index = index
-        self._branch = branch
-        self._usr_names = []
-        self._usr_idx_lookup = {}
-
-        self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr'
-
-        self._initialise()
-
-    def _initialise(self):
-        try:
-            self._branch[self._usr_key]
-            # This will raise a KeyError in old aanet files
-            # which has a different strucuter and key (usr_data)
-            # We do not support those (yet)
-        except (KeyError, IndexError):
-            print("The `usr` fields could not be parsed for the '{}' branch.".
-                  format(self._name))
-            return
-
-        if self._mapper.flat:
-            self._initialise_flat()
-
-    def _initialise_flat(self):
-        # Here, we assume that every event has the same names in the same order
-        # to massively increase the performance. This needs triple check if
-        # it's always the case.
-        self._usr_names = [
-            n.decode("utf-8") for n in self._branch[self._usr_key + '_names'].lazyarray(
-                basketcache=BASKET_CACHE)[0]
-        ]
-        self._usr_idx_lookup = {
-            name: index
-            for index, name in enumerate(self._usr_names)
-        }
-
-        data = self._branch[self._usr_key].lazyarray(basketcache=BASKET_CACHE)
-
-        if self._index is not None:
-            data = data[self._index]
-
-        self._usr_data = data
-
-        for name in self._usr_names:
-            setattr(self, name, self[name])
-
-    # def _initialise_nested(self):
-    #     self._usr_names = [
-    #         n.decode("utf-8") for n in self.branch['usr_names'].lazyarray(
-    #             # TODO this will be fixed soon in uproot,
-    #             # see https://github.com/scikit-hep/uproot/issues/465
-    #             uproot.asgenobj(
-    #                 uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
-    #                 self.branch['usr_names']._context, 6),
-    #             basketcache=BASKET_CACHE)[0]
-    #     ]
-
-    def __getitem__(self, item):
-        if self._mapper.flat:
-            return self.__getitem_flat__(item)
-        return self.__getitem_nested__(item)
-
-    def __getitem_flat__(self, item):
-        if self._index is not None:
-            return self._usr_data[self._index][:, self._usr_idx_lookup[item]]
-        else:
-            return self._usr_data[:, self._usr_idx_lookup[item]]
-
-    def __getitem_nested__(self, item):
-        data = self._branch[self._usr_key + '_names'].lazyarray(
-                # TODO this will be fixed soon in uproot,
-                # see https://github.com/scikit-hep/uproot/issues/465
-                uproot.asgenobj(
-                    uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
-                    self._branch[self._usr_key + '_names']._context, 6),
-                basketcache=BASKET_CACHE)
-        if self._index is None:
-            return data
-        else:
-            return data[self._index]
-
-    def keys(self):
-        return self._usr_names
-
-    def __str__(self):
-        entries = []
-        for name in self.keys():
-            entries.append("{}: {}".format(name, self[name]))
-        return '\n'.join(entries)
-
-    def __repr__(self):
-        return "<{}[{}]>".format(self.__class__.__name__, self._name)
-
-
-def _to_num(value):
-    """Convert a value to a numerical one if possible"""
-    for converter in (int, float):
-        try:
-            return converter(value)
-        except (ValueError, TypeError):
-            pass
-    return value
 
 
 class Header:
@@ -278,110 +152,3 @@ class Header:
         return "\n".join(lines)
 
 
-class Branch:
-    """Branch accessor class"""
-    def __init__(self,
-                 tree,
-                 mapper,
-                 index=None,
-                 subbranchmaps=None,
-                 keymap=None):
-        self._tree = tree
-        self._mapper = mapper
-        self._index = index
-        self._keymap = None
-        self._branch = tree[mapper.key]
-        self._subbranches = []
-
-        if keymap is None:
-            self._initialise_keys()  #
-        else:
-            self._keymap = keymap
-
-        if subbranchmaps is not None:
-            for mapper in subbranchmaps:
-                subbranch = Branch(self._tree,
-                                   mapper=mapper,
-                                   index=self._index)
-                self._subbranches.append(subbranch)
-        for subbranch in self._subbranches:
-            setattr(self, subbranch._mapper.name, subbranch)
-
-    def _initialise_keys(self):
-        """Create the keymap and instance attributes for branch keys"""
-        # TODO: this could be a cached property
-        keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
-            self._mapper.exclude) - EXCLUDE_KEYS
-        self._keymap = {
-            **{self._mapper.attrparser(k): k
-               for k in keys},
-            **self._mapper.extra
-        }
-        self._keymap.update(self._mapper.update)
-        for k in self._mapper.update.values():
-            del self._keymap[k]
-
-        for key in self._keymap.keys():
-            setattr(self, key, None)
-
-    def keys(self):
-        return self._keymap.keys()
-
-    @cached_property
-    def usr(self):
-        return Usr(self._mapper, self._branch, index=self._index)
-
-    def __getattribute__(self, attr):
-        if attr.startswith("_"):  # let all private and magic methods pass
-            return object.__getattribute__(self, attr)
-
-        if attr in self._keymap.keys():  # intercept branch key lookups
-            return self.__getkey__(attr)
-
-        return object.__getattribute__(self, attr)
-
-    def __getkey__(self, key):
-        out = self._branch[self._keymap[key]].lazyarray(
-            basketcache=BASKET_CACHE)
-        if self._index is not None:
-            out = out[self._index]
-        return out
-
-    def __getitem__(self, item):
-        """Slicing magic"""
-        if isinstance(item, (int, slice)):
-            return self.__class__(self._tree,
-                                  self._mapper,
-                                  index=item,
-                                  keymap=self._keymap,
-                                  subbranchmaps=SUBBRANCH_MAPS)
-
-        if isinstance(item, tuple):
-            return self[item[0]][item[1]]
-
-        if isinstance(item, str):
-            return self.__getkey__(item)
-
-        return self.__class__(self._tree,
-                              self._mapper,
-                              index=np.array(item),
-                              keymap=self._keymap,
-                              subbranchmaps=SUBBRANCH_MAPS)
-
-    def __len__(self):
-        if self._index is None:
-            return len(self._branch)
-        elif isinstance(self._index, int):
-            return 1
-        else:
-            return len(self._branch[self._keymap['id']].lazyarray(
-                basketcache=BASKET_CACHE)[self._index])
-
-    def __str__(self):
-        return "Number of elements: {}".format(len(self._branch))
-
-    def __repr__(self):
-        length = len(self)
-        return "<{}[{}]: {} element{}>".format(self.__class__.__name__,
-                                               self._mapper.name, length,
-                                               's' if length > 1 else '')
diff --git a/tests/test_offline.py b/tests/test_offline.py
index 635a34b..8e1f685 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -3,7 +3,7 @@ import numpy as np
 from pathlib import Path
 
 from km3io import OfflineReader
-from km3io.offline import _nested_mapper, cached_property, _to_num, Header
+from km3io.offline import _nested_mapper, Header
 
 SAMPLES_DIR = Path(__file__).parent / 'samples'
 OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
@@ -340,22 +340,7 @@ class TestUsr(unittest.TestCase):
             self.f.events.usr.DeltaPosZ)
 
 
-class TestIndependentFunctions(unittest.TestCase):
+
+class TestNestedMapper(unittest.TestCase):
     def test_nested_mapper(self):
         self.assertEqual('pos_x', _nested_mapper("trks.pos.x"))
-
-    def test_to_num(self):
-        self.assertEqual(10, _to_num("10"))
-        self.assertEqual(10.5, _to_num("10.5"))
-        self.assertEqual("test", _to_num("test"))
-        self.assertIsNone(_to_num(None))
-
-
-class TestCachedProperty(unittest.TestCase):
-    def test_cached_properties(self):
-        class Test:
-            @cached_property
-            def prop(self):
-                pass
-
-        self.assertTrue(isinstance(Test.prop, cached_property))
-- 
GitLab