From d1dc6342e4aadeafb08755a1a5bd2dfae5c48d7f Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Wed, 4 Mar 2020 23:36:41 +0100
Subject: [PATCH] Massive overhaul of branch parsing

---
 km3io/offline.py | 462 +++++------------------------------------------
 1 file changed, 48 insertions(+), 414 deletions(-)

diff --git a/km3io/offline.py b/km3io/offline.py
index dac9f81..981de83 100644
--- a/km3io/offline.py
+++ b/km3io/offline.py
@@ -1,3 +1,4 @@
+from collections import namedtuple
 import uproot
 import numpy as np
 import warnings
@@ -10,6 +11,22 @@ MAIN_TREE_NAME = "E"
 BASKET_CACHE_SIZE = 110 * 1024**2
 
 
+BranchMapper = namedtuple("BranchMapper", ['name', 'key', 'extra_keys', 'attrparser'])
+
+def _nested_mapper(key):
+    """Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)"""
+    return '_'.join(key.split('.')[1:])
+
+
+BRANCH_MAPS = [
+    BranchMapper("tracks", "trks", {}, _nested_mapper),
+    BranchMapper("mc_tracks", "mc_trks", {}, _nested_mapper),
+    BranchMapper("hits", "mc_hits", {}, _nested_mapper),
+    BranchMapper("mc_hits", "mc_hits", {}, _nested_mapper),
+    BranchMapper("events", "Evt", {'t_sec': 't.fSec', 't_ns': 't.fNanoSec'}, lambda a: a),
+]
+
+
 class cached_property:
     """A simple cache decorator for properties."""
     def __init__(self, function):
@@ -22,189 +39,9 @@ class cached_property:
         return prop
 
 
-def _get_keys(tree, fake_branches=None):
-    """Get tree keys except those in fake_branches
-
-    Parameters
-    ----------
-    tree : uproot.Tree
-        The tree to look for keys
-    fake_branches : list of str or None
-        The fake branches to ignore
-
-    Returns
-    -------
-    list of str
-        The keys of the tree.
-    """
-    keys = []
-    for key in tree.keys():
-        key = key.decode('utf-8')
-        if fake_branches is not None and key in fake_branches:
-            continue
-        keys.append(key)
-    return keys
-
-
-class OfflineKeys:
-    """wrapper for offline keys"""
-    def __init__(self, tree):
-        """OfflineKeys is a class that reads all the available keys in an offline
-        file and adapts the keys format to Python format.
-
-        Parameters
-        ----------
-        tree : uproot.TTree
-            The main ROOT tree.
-        """
-        self._tree = tree
-
-    def __str__(self):
-        return '\n'.join([
-            "Events keys are:\n\t" + "\n\t".join(self.events_keys),
-            "Hits keys are:\n\t" + '\n\t'.join(self.hits_keys),
-            "Tracks keys are:\n\t" + '\n\t'.join(self.tracks_keys),
-            "Mc hits keys are:\n\t" + '\n\t'.join(self.mc_hits_keys),
-            "Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys)
-        ])
-
-    def __repr__(self):
-        return "<{}>".format(self.__class__.__name__)
-
-    @cached_property
-    def events_keys(self):
-        """reads events keys from an offline file.
-
-        Returns
-        -------
-        list of str
-            list of all events keys found in an offline file,
-            except those found in fake branches.
-        """
-        fake_branches = ['Evt', 'AAObject', 'TObject', 't']
-        t_baskets = ['t.fSec', 't.fNanoSec']
-        tree = self._tree['Evt']
-        return _get_keys(self._tree['Evt'], fake_branches) + t_baskets
-
-    @cached_property
-    def hits_keys(self):
-        """reads hits keys from an offline file.
-
-        Returns
-        -------
-        list of str
-            list of all hits keys found in an offline file,
-            except those found in fake branches.
-        """
-        fake_branches = ['hits.usr', 'hits.usr_names']
-        return _get_keys(self._tree['hits'], fake_branches)
-
-    @cached_property
-    def tracks_keys(self):
-        """reads tracks keys from an offline file.
-
-        Returns
-        -------
-        list of str
-            list of all tracks keys found in an offline file,
-            except those found in fake branches.
-        """
-        # a solution can be tree['trks.usr_data'].array(
-        # uproot.asdtype(">i4"))
-        fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names']
-        return _get_keys(self._tree['Evt']['trks'], fake_branches)
-
-    @cached_property
-    def mc_hits_keys(self):
-        """reads mc hits keys from an offline file.
-
-        Returns
-        -------
-        list of str
-            list of all mc hits keys found in an offline file,
-            except those found in fake branches.
-        """
-        fake_branches = ['mc_hits.usr', 'mc_hits.usr_names']
-        return _get_keys(self._tree['Evt']['mc_hits'], fake_branches)
-
-    @cached_property
-    def mc_tracks_keys(self):
-        """reads mc tracks keys from an offline file.
-
-        Returns
-        -------
-        list of str
-            list of all mc tracks keys found in an offline file,
-            except those found in fake branches.
-        """
-        fake_branches = [
-            'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names'
-        ]
-        return _get_keys(self._tree['Evt']['mc_trks'], fake_branches)
-
-    @cached_property
-    def valid_keys(self):
-        """constructs a list of all valid keys to be read from an offline event file.
-        Returns
-        -------
-        list of str
-            list of all valid keys.
-        """
-        return (self.events_keys + self.hits_keys + self.tracks_keys +
-                self.mc_tracks_keys + self.mc_hits_keys)
-
-    @cached_property
-    def fit_keys(self):
-        """constructs a list of fit parameters, not yet outsourced in an offline file.
-
-        Returns
-        -------
-        list of str
-            list of all "trks.fitinf" keys.
-        """
-        return sorted(km3io.definitions.fitparameters.data,
-                      key=km3io.definitions.fitparameters.data.get,
-                      reverse=False)
-
-    @cached_property
-    def cut_hits_keys(self):
-        """adapts hits keys for instance variables format in a Python class.
-
-        Returns
-        -------
-        list of str
-            list of adapted hits keys.
-        """
-        return [k.split('hits.')[1].replace('.', '_') for k in self.hits_keys]
-
-    @cached_property
-    def cut_tracks_keys(self):
-        """adapts tracks keys for instance variables format in a Python class.
-
-        Returns
-        -------
-        list of str
-            list of adapted tracks keys.
-        """
-        return [
-            k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys
-        ]
-
-    @cached_property
-    def cut_events_keys(self):
-        """adapts events keys for instance variables format in a Python class.
-
-        Returns
-        -------
-        list of str
-            list of adapted events keys.
-        """
-        return [k.replace('.', '_') for k in self.events_keys]
-
-
 class OfflineReader:
     """reader for offline ROOT files"""
-    def __init__(self, file_path=None, fobj=None, data=None):
+    def __init__(self, file_path=None, fobj=None, data=None, index=slice(-1)):
         """ OfflineReader class is an offline ROOT file wrapper
 
         Parameters
@@ -214,6 +51,7 @@ class OfflineReader:
             path-like object that points to the file.
 
         """
+        self._index = index
         if file_path is not None:
             self._fobj = uproot.open(file_path)
             self._tree = self._fobj[MAIN_TREE_NAME]
@@ -225,6 +63,9 @@ class OfflineReader:
             self._tree = self._fobj[MAIN_TREE_NAME]
             self._data = data
 
+        for mapper in BRANCH_MAPS:
+            setattr(self, mapper.name, BranchElement(self._tree, mapper=mapper, index=self._index))
+
     @classmethod
     def from_index(cls, source, index):
         """Create an instance with a subtree of a given index
@@ -232,18 +73,24 @@ class OfflineReader:
         Parameters
         ----------
         source: ROOTDirectory
-            The source file.
+            The source file object.
         index: index or slice
             The index or slice to create the subtree.
         """
-        instance = cls(fobj=source._fobj, data=source._data[index])
+        instance = cls(fobj=source._fobj, data=source._data[index], index=index)
         return instance
 
     def __getitem__(self, index):
         return OfflineReader.from_index(source=self, index=index)
 
     def __len__(self):
-        return len(self._data)
+        tree = self._fobj[MAIN_TREE_NAME]
+        if self._index == slice(-1):
+            return len(tree)
+        else:
+            return len(tree.lazyarrays(
+                basketcache=uproot.cache.ThreadSafeArrayCache(
+                    BASKET_CACHE_SIZE))[self.index])
 
     @cached_property
     def header(self):
@@ -256,76 +103,6 @@ class OfflineReader:
         else:
             warnings.warn("Your file header has an unsupported format")
 
-    @cached_property
-    def keys(self):
-        """wrapper for all keys in an offline file.
-
-        Returns
-        -------
-        Class
-            OfflineKeys.
-        """
-        return OfflineKeys(self._tree)
-
-    @cached_property
-    def events(self):
-        """wrapper for offline events.
-
-        Returns
-        -------
-        Class
-            OfflineEvents.
-        """
-        return OfflineEvents(
-            self.keys.cut_events_keys,
-            [self._data[key] for key in self.keys.events_keys])
-
-    @cached_property
-    def hits(self):
-        """wrapper for offline hits.
-
-        Returns
-        -------
-        Class
-            OfflineHits.
-        """
-        return OfflineHits(self.keys.cut_hits_keys,
-                           [self._data[key] for key in self.keys.hits_keys])
-
-    @cached_property
-    def tracks(self):
-        """wrapper for offline tracks.
-
-        Returns
-        -------
-        Class
-            OfflineTracks.
-        """
-        return OfflineTracks(self._tree['trks'])
-
-    @cached_property
-    def mc_hits(self):
-        """wrapper for offline mc hits.
-
-        Returns
-        -------
-        Class
-            OfflineHits.
-        """
-        return OfflineHits(self.keys.cut_hits_keys,
-                           [self._data[key] for key in self.keys.mc_hits_keys])
-
-    @cached_property
-    def mc_tracks(self):
-        """wrapper for offline mc tracks.
-
-        Returns
-        -------
-        Class
-            OfflineTracks.
-        """
-        return OfflineTracks(self._tree['mc_trks'])
-
     @cached_property
     def usr(self):
         return Usr(self._tree)
@@ -705,137 +482,23 @@ class Usr:
         return '\n'.join(entries)
 
 
-class OfflineEvents:
-    """wrapper for offline events"""
-    def __init__(self, keys, values):
-        """wrapper for offline events.
-
-        Parameters
-        ----------
-        keys : list of str
-            list of valid events keys.
-        values : list of arrays
-            list of arrays containting events data.
-        """
-        self._keys = keys
-        self._values = values
-        for k, v in zip(self._keys, self._values):
-            setattr(self, k, v)
-
-    def __getitem__(self, item):
-        return OfflineEvent(self._keys, [v[item] for v in self._values])
-
-    def __len__(self):
-        try:
-            return len(self._values[0])
-        except IndexError:
-            return 0
-
-    def __str__(self):
-        return "Number of events: {}".format(len(self))
-
-    def __repr__(self):
-        return "<{}: {} parsed events>".format(self.__class__.__name__,
-                                               len(self))
-
-
-class OfflineEvent:
-    """wrapper for an offline event"""
-    def __init__(self, keys, values):
-        """wrapper for one offline event.
-
-        Parameters
-        ----------
-        keys : list of str
-            list of valid events keys.
-        values : list of arrays
-            list of arrays containting event data.
-        """
-        self._keys = keys
-        self._values = values
-        for k, v in zip(self._keys, self._values):
-            setattr(self, k, v)
-
-    def __str__(self):
-        return "offline event:\n\t" + "\n\t".join([
-            "{:15} {:^10} {:>10}".format(k, ':', str(v))
-            for k, v in zip(self._keys, self._values)
-        ])
-
-
-class OfflineHits:
-    """wrapper for offline hits"""
-    def __init__(self, keys, values):
-        """wrapper for offline hits.
-
-        Parameters
-        ----------
-        keys : list of str
-            list of cropped hits keys.
-        values : list of arrays
-            list of arrays containting hits data.
-        """
-        self._keys = keys
-        self._values = values
-        for k, v in zip(self._keys, self._values):
-            setattr(self, k, v)
-
-    def __getitem__(self, item):
-        return OfflineHit(self._keys, [v[item] for v in self._values])
-
-    def __len__(self):
-        try:
-            return len(self._values[0])
-        except IndexError:
-            return 0
-
-    def __str__(self):
-        return "Number of hits: {}".format(len(self))
-
-    def __repr__(self):
-        return "<{}: {} parsed elements>".format(self.__class__.__name__,
-                                                 len(self))
-
-
-class OfflineHit:
-    """wrapper for an offline hit"""
-    def __init__(self, keys, values):
-        """wrapper for one offline hit.
-
-        Parameters
-        ----------
-        keys : list of str
-            list of cropped hits keys.
-        values : list of arrays
-            list of arrays containting hit data.
-        """
-        self._keys = keys
-        self._values = values
-        for k, v in zip(self._keys, self._values):
-            setattr(self, k, v)
-
-    def __str__(self):
-        return "offline hit:\n\t" + "\n\t".join([
-            "{:15} {:^10} {:>10}".format(k, ':', str(v))
-            for k, v in zip(self._keys, self._values)
-        ])
-
-    def __getitem__(self, item):
-        return self._values[item]
-
-
-class OfflineTracks:
+class BranchElement:
     """wrapper for offline tracks"""
-    def __init__(self, branch, index=slice(-1)):
-        keys = [k.decode('utf-8') for k in branch.keys()]
-        self._keymap = {k[5:].replace('.', '_'): k for k in keys}
-        self._branch = branch
-        self._keys = keys
+    def __init__(self, tree, mapper, index=slice(-1)):
+        self.mapper = mapper
+        self.name = mapper.name
+        self._tree = tree
+        self._branch = tree[mapper.key]
+        keys = [k.decode('utf-8') for k in self._branch.keys()]
+        self._keymap = {**{mapper.attrparser(k): k for k in keys}, **mapper.extra_keys}
         self._index = index
 
+        # for key in keys:
+        #     setattr(self, key, cached_property(self[key]))
+
     def __getitem__(self, item):
         if isinstance(item, slice):
-            return OfflineTracks(self._branch, index=item)
+            return self.__class__(self._tree, self.mapper, index=item)
         return self._branch[self._keymap[item]].lazyarray(
                 basketcache=uproot.cache.ThreadSafeArrayCache(
                     BASKET_CACHE_SIZE))[self._index]
@@ -846,41 +509,12 @@ class OfflineTracks:
         else:
             return len(self._branch[self._keymap['id']].lazyarray()[self._index])
 
+    def keys(self):
+        return self._keymap.keys()
+
     def __str__(self):
-        return "Number of tracks: {}".format(len(self._branch))
+        return "Number of elements: {}".format(len(self._branch))
 
     def __repr__(self):
-        return "<{}: {} parsed elements>".format(self.__class__.__name__,
+        return "<{}[{}]: {} parsed elements>".format(self.__class__.__name__, self.name,
                                                  len(self))
-
-
-class OfflineTrack:
-    """wrapper for an offline track"""
-    def __init__(self, keys, values):
-        """wrapper for one offline track.
-
-        Parameters
-        ----------
-        keys : list of str
-            list of cropped tracks keys.
-        values : list of arrays
-            list of arrays containting track data.
-        """
-        self._keys = keys
-        self._values = values
-        for k, v in zip(self._keys, self._values):
-            setattr(self, k, v)
-
-    def __str__(self):
-        return "offline track:\n\t" + "\n\t".join([
-            "{:30} {:^2} {:>26}".format(k, ':', str(v))
-            for k, v in zip(self._keys, self._values) if k not in ['fitinf']
-        ]) + "\n\t" + "\n\t".join([
-            "{:30} {:^2} {:>26}".format(k, ':', str(
-                getattr(self, 'fitinf')[v]))
-            for k, v in km3io.definitions.fitparameters.data.items()
-            if len(getattr(self, 'fitinf')) > v
-        ])  # I don't like 18 being explicit here
-
-    def __getitem__(self, item):
-        return self._values[item]
-- 
GitLab