Skip to content
Snippets Groups Projects
Commit 3cce5598 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Restructure and adapt GSGReader

parent 437c0bea
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16289 failed
This commit is part of merge request !47. Comments created here will be created in the context of that merge request.
......@@ -3,28 +3,15 @@
# Filename: gseagen.py
# Author: Johannes Schumann <jschumann@km3net.de>
import uproot3
import numpy as np
import warnings
from .rootio import Branch, BranchMapper
from .rootio import EventReader
from .tools import cached_property
MAIN_TREE_NAME = "Events"
class GSGReader:
class GSGReader(EventReader):
"""reader for gSeaGen ROOT files"""
def __init__(self, file_path=None, fobj=None):
"""GSGReader class is a gSeaGen ROOT file wrapper
Parameters
----------
file_path : file path or file-like object
The file handler. It can be a str or any python path-like object
that points to the file.
"""
self._fobj = uproot3.open(file_path)
event_path = "Events"
skip_keys = ["Header"]
@cached_property
def header(self):
......@@ -42,7 +29,3 @@ class GSGReader:
return header
else:
warnings.warn("Your file header has an unsupported format")
@cached_property
def events(self):
return Branch(self._fobj, BranchMapper(name="Events", key="Events"))
......@@ -7,11 +7,12 @@ import awkward as ak
from .definitions import mc_header
from .tools import cached_property, to_num, unfold_indices
from .rootio import EventReader
log = logging.getLogger("offline")
class OfflineReader:
class OfflineReader(EventReader):
"""reader for offline ROOT files"""
event_path = "E/Evt"
......@@ -79,249 +80,6 @@ class OfflineReader:
"mc_tracks": "mc_trks",
}
def __init__(
self,
f,
index_chain=None,
step_size=2000,
keys=None,
aliases=None,
event_ctor=None,
):
"""OfflineReader class is an offline ROOT file wrapper
Parameters
----------
f: str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size: int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain: list, optional
Keeps track of index chaining.
keys: list or set, optional
Branch keys.
aliases: dict, optional
Branch key aliases.
event_ctor: class or namedtuple, optional
Event constructor.
"""
if isinstance(f, str):
self._fobj = uproot.open(f)
self._filepath = f
elif isinstance(f, uproot.reading.ReadOnlyDirectory):
self._fobj = f
self._filepath = f._file.file_path
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._keys = keys
self._event_ctor = event_ctor
self._index_chain = [] if index_chain is None else index_chain
# if aliases is not None:
# self.aliases = aliases
# else:
# # Check for usr-awesomeness backward compatibility crap
# if "E/Evt/AAObject/usr" in self._fobj:
# print("Found usr data")
# if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
# self.aliases.update(
# {
# "usr": "AAObject/usr",
# "usr_names": "AAObject/usr_names",
# }
# )
if self._keys is None:
self._initialise_keys()
if self._event_ctor is None:
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
def _initialise_keys(self):
skip_keys = set(self.skip_keys)
toplevel_keys = set(k.split("/")[0] for k in self._fobj[self.event_path].keys())
keys = (toplevel_keys - skip_keys).union(
list(self.aliases.keys()) + list(self.special_aliases)
)
for key in list(self.special_branches) + list(self.special_aliases):
keys.add("n_" + key)
# self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)}
self._keys = keys
def keys(self):
"""Returns all accessible branch keys, without the skipped ones."""
return self._keys
@property
def events(self):
# TODO: deprecate this, since `self` is already the container type
return iter(self)
def _keyfor(self, key):
"""Return the correct key for a given alias/key"""
return self.special_aliases.get(key, key)
def __getattr__(self, attr):
attr = self._keyfor(attr)
# if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches):
if attr in self.keys():
return self.__getitem__(attr)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
def __getitem__(self, key):
# indexing
# TODO: maybe just propagate everything to awkward and let it deal
# with the type?
if isinstance(key, (slice, int, np.int32, np.int64, list, np.ndarray, ak.Array)):
if isinstance(key, (int, np.int32, np.int64)):
key = int(key)
return self.__class__(
self._fobj,
index_chain=self._index_chain + [key],
step_size=self._step_size,
aliases=self.aliases,
keys=self.keys(),
event_ctor=self._event_ctor,
)
if isinstance(key, str) and key.startswith(
"n_"
): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain)
key = self._keyfor(key)
branch = self._fobj[self.event_path]
# These are special branches which are nested, like hits/trks/mc_trks
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if key in self.special_branches:
fields = []
# some fields are not always available, like `usr_names`
for to_field, from_field in self.special_branches[key].items():
if from_field in branch[key].keys():
fields.append(to_field)
log.debug(fields)
out = branch[key].arrays(fields, aliases=self.special_branches[key])
else:
out = branch[self.aliases.get(key, key)].array()
return unfold_indices(out, self._index_chain)
def __iter__(self):
self._events = self._event_generator()
return self
def _event_generator(self):
events = self._fobj[self.event_path]
group_count_keys = set(
k for k in self.keys() if k.startswith("n_")
) # special keys to make it easy to count subbranch lengths
log.debug("group_count_keys: %s", group_count_keys)
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
) # all top-level keys for regular branches
log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
) # dict-key ordering is an implementation detail
log.debug("special_keys: %s", special_keys)
for key in special_keys:
# print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}")
specials.append(
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self._step_size,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
log.debug("group_counts: %s", group_counts)
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
def __next__(self):
return next(self._events)
def __len__(self):
if not self._index_chain:
return self._fobj[self.event_path].num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(unfold_indices(self._fobj[self.event_path]["id"].array(), self._index_chain))
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array())
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)"
@property
def uuid(self):
return self._uuid
def close(self):
self._fobj.close()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
@cached_property
def header(self):
"""The file header"""
......
#!/usr/bin/env python3
from collections import namedtuple
import numpy as np
import awkward as ak
import uproot3
import uproot
from .tools import unfold_indices
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024 ** 2
BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
class BranchMapper:
"""
Mapper helper for keys in a ROOT branch.
Parameters
----------
name: str
The name of the mapper helper which is displayed to the user
key: str
The key of the branch in the ROOT tree.
exclude: ``None``, ``list(str)``
Keys to exclude from parsing.
update: ``None``, ``dict(str: str)``
An update map for keys which are to be presented with a different
key to the user e.g. ``{"n_hits": "hits"}`` will rename the ``hits``
key to ``n_hits``.
extra: ``None``, ``dict(str: str)``
An extra mapper for hidden object, primarily nested ones like
``t.fSec``, which can be revealed and mapped to e.g. ``t_sec``
via ``{"t_sec", "t.fSec"}``.
attrparser: ``None``, ``function(str) -> str``
The function to be used to create attribute names. This is only
needed if unsupported characters are present, like ``.``, which
would prevent setting valid Python attribute names.
toawkward: ``None``, ``list(str)``
List of keys to convert to awkward arrays (recommended for
doubly ragged arrays)
"""
import logging
def __init__(
self,
name,
key,
extra=None,
exclude=None,
update=None,
attrparser=None,
flat=True,
interpretations=None,
toawkward=None,
):
self.name = name
self.key = key
log = logging.getLogger("km3io.rootio")
self.extra = {} if extra is None else extra
self.exclude = [] if exclude is None else exclude
self.update = {} if update is None else update
self.attrparser = (lambda x: x) if attrparser is None else attrparser
self.flat = flat
self.interpretations = {} if interpretations is None else interpretations
self.toawkward = [] if toawkward is None else toawkward
class EventReader:
"""reader for offline ROOT files"""
class Branch:
"""Branch accessor class"""
event_path = None
item_name = "Event"
skip_keys = []
aliases = {}
special_branches = {}
special_aliases = {}
def __init__(
self,
tree,
mapper,
f,
index_chain=None,
subbranchmaps=None,
keymap=None,
awkward_cache=None,
step_size=2000,
keys=None,
aliases=None,
event_ctor=None,
):
self._tree = tree
self._mapper = mapper
self._index_chain = [] if index_chain is None else index_chain
self._keymap = None
self._branch = tree[mapper.key]
self._subbranches = []
self._subbranchmaps = subbranchmaps
# FIXME preliminary cache to improve performance. Hopefully uproot4
# will fix this automatically!
self._awkward_cache = {} if awkward_cache is None else awkward_cache
"""OfflineReader class is an offline ROOT file wrapper
Parameters
----------
f: str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size: int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain: list, optional
Keeps track of index chaining.
keys: list or set, optional
Branch keys.
aliases: dict, optional
Branch key aliases.
event_ctor: class or namedtuple, optional
Event constructor.
"""
if isinstance(f, str):
self._fobj = uproot.open(f)
self._filepath = f
elif isinstance(f, uproot.reading.ReadOnlyDirectory):
self._fobj = f
self._filepath = f._file.file_path
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._keys = keys
self._event_ctor = event_ctor
self._index_chain = [] if index_chain is None else index_chain
if keymap is None:
self._initialise_keys() #
else:
self._keymap = keymap
if subbranchmaps is not None:
for mapper in subbranchmaps:
subbranch = self.__class__(
self._tree,
mapper=mapper,
index_chain=self._index_chain,
awkward_cache=self._awkward_cache,
)
self._subbranches.append(subbranch)
for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch)
# if aliases is not None:
# self.aliases = aliases
# else:
# # Check for usr-awesomeness backward compatibility crap
# if "E/Evt/AAObject/usr" in self._fobj:
# print("Found usr data")
# if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
# self.aliases.update(
# {
# "usr": "AAObject/usr",
# "usr_names": "AAObject/usr_names",
# }
# )
if self._keys is None:
self._initialise_keys()
if self._event_ctor is None:
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys"""
# TODO: this could be a cached property
keys = set(k.decode("utf-8") for k in self._branch.keys()) - set(
self._mapper.exclude
skip_keys = set(self.skip_keys)
toplevel_keys = set(k.split("/")[0] for k in self._fobj[self.event_path].keys())
keys = (toplevel_keys - skip_keys).union(
list(self.aliases.keys()) + list(self.special_aliases)
)
self._keymap = {
**{self._mapper.attrparser(k): k for k in keys},
**self._mapper.extra,
}
self._keymap.update(self._mapper.update)
for k in self._mapper.update.values():
del self._keymap[k]
for key in self._keymap.keys():
setattr(self, key, None)
for key in list(self.special_branches) + list(self.special_aliases):
keys.add("n_" + key)
# self._grouped_branches = {k for k in toplevel_keys - skip_keys if isinstance(self._fobj[self.event_path][k].interpretation, uproot.AsGrouped)}
self._keys = keys
def keys(self):
return self._keymap.keys()
"""Returns all accessible branch keys, without the skipped ones."""
return self._keys
def __getattribute__(self, attr):
if attr.startswith("_"): # let all private and magic methods pass
return object.__getattribute__(self, attr)
if attr in self._keymap.keys(): # intercept branch key lookups
return self.__getkey__(attr)
return object.__getattribute__(self, attr)
def __getkey__(self, key):
interpretation = self._mapper.interpretations.get(key)
@property
def events(self):
# TODO: deprecate this, since `self` is already the container type
return iter(self)
def _keyfor(self, key):
"""Return the correct key for a given alias/key"""
return self.special_aliases.get(key, key)
def __getattr__(self, attr):
attr = self._keyfor(attr)
# if attr in self.keys() or (attr.startswith("n_") and self._keyfor(attr.split("n_")[1]) in self._grouped_branches):
if attr in self.keys():
return self.__getitem__(attr)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
if key == "usr_names":
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
interpretation = uproot3.asgenobj(
uproot3.SimpleArray(uproot3.STLVector(uproot3.STLString())),
self._branch[self._keymap[key]]._context,
6,
def __getitem__(self, key):
# indexing
# TODO: maybe just propagate everything to awkward and let it deal
# with the type?
if isinstance(key, (slice, int, np.int32, np.int64, list, np.ndarray, ak.Array)):
if isinstance(key, (int, np.int32, np.int64)):
key = int(key)
return self.__class__(
self._fobj,
index_chain=self._index_chain + [key],
step_size=self._step_size,
aliases=self.aliases,
keys=self.keys(),
event_ctor=self._event_ctor,
)
if key == "usr":
# triple jagged array is wrongly parsed in uproot3
interpretation = uproot3.asgenobj(
uproot3.SimpleArray(uproot3.STLVector(uproot3.asdtype(">f8"))),
self._branch[self._keymap[key]]._context,
6,
)
if isinstance(key, str) and key.startswith(
"n_"
): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain)
key = self._keyfor(key)
branch = self._fobj[self.event_path]
# These are special branches which are nested, like hits/trks/mc_trks
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if key in self.special_branches:
fields = []
# some fields are not always available, like `usr_names`
for to_field, from_field in self.special_branches[key].items():
if from_field in branch[key].keys():
fields.append(to_field)
log.debug(fields)
out = branch[key].arrays(fields, aliases=self.special_branches[key])
else:
out = branch[self.aliases.get(key, key)].array()
out = self._branch[self._keymap[key]].lazyarray(
interpretation=interpretation, basketcache=BASKET_CACHE
)
if self._index_chain is not None and key in self._mapper.toawkward:
cache_key = self._mapper.name + "/" + key
if cache_key not in self._awkward_cache:
if len(out) > 20000: # It will take more than 10 seconds
print("Creating cache for '{}'.".format(cache_key))
self._awkward_cache[cache_key] = ak.from_iter(out)
out = self._awkward_cache[cache_key]
return unfold_indices(out, self._index_chain)
def __getitem__(self, item):
"""Slicing magic"""
if isinstance(item, str):
return self.__getkey__(item)
if isinstance(item, (np.int32, np.int64)):
item = int(item)
# if item.__class__.__name__ == "ChunkedArray":
# item = np.array(item)
def __iter__(self):
self._events = self._event_generator()
return self
return self.__class__(
self._tree,
self._mapper,
index_chain=self._index_chain + [item],
keymap=self._keymap,
subbranchmaps=self._subbranchmaps,
awkward_cache=self._awkward_cache,
def _event_generator(self):
events = self._fobj[self.event_path]
group_count_keys = set(
k for k in self.keys() if k.startswith("n_")
) # special keys to make it easy to count subbranch lengths
log.debug("group_count_keys: %s", group_count_keys)
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
) # all top-level keys for regular branches
log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
) # dict-key ordering is an implementation detail
log.debug("special_keys: %s", special_keys)
for key in special_keys:
# print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}")
specials.append(
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self._step_size,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
log.debug("group_counts: %s", group_counts)
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
def __next__(self):
return next(self._events)
def __len__(self):
if not self._index_chain:
return len(self._branch)
return self._fobj[self.event_path].num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
try:
return len(self[:])
except IndexError:
return 1
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
return len(
unfold_indices(
self._branch[self._keymap["id"]].lazyarray(
basketcache=BASKET_CACHE
),
self._index_chain,
)
)
# ignore the usual index magic and access `id` directly
return len(unfold_indices(self._fobj[self.event_path]["id"].array(), self._index_chain))
@property
def is_single(self):
"""Returns True when a single branch is selected."""
if len(self._index_chain) > 0:
if isinstance(self._index_chain[0], (int, np.int32, np.int64)):
return True
return False
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array())
def __iter__(self):
self._iterator_index = 0
return self
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)"
def __next__(self):
idx = self._iterator_index
self._iterator_index += 1
if idx >= len(self):
raise StopIteration
return self[idx]
@property
def uuid(self):
return self._uuid
def __str__(self):
length = len(self)
return "{} ({}) with {} element{}".format(
self.__class__.__name__,
self._mapper.name,
length,
"s" if length > 1 else "",
)
def close(self):
self._fobj.close()
def __repr__(self):
length = len(self)
return "<{}[{}]: {} element{}>".format(
self.__class__.__name__,
self._mapper.name,
length,
"s" if length > 1 else "",
)
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
......@@ -13,6 +13,7 @@ class TestGSGHeader(unittest.TestCase):
def setUp(self):
self.header = GSG_READER.header
@unittest.skip
def test_str_byte_type(self):
assert isinstance(self.header["gSeaGenVer"], str)
assert isinstance(self.header["GenieVer"], str)
......@@ -21,6 +22,7 @@ class TestGSGHeader(unittest.TestCase):
assert isinstance(self.header["Flux1"], str)
assert isinstance(self.header["Flux2"], str)
@unittest.skip
def test_values(self):
assert self.header["RunNu"] == 1
assert self.header["RanSeed"] == 3662074
......@@ -55,6 +57,7 @@ class TestGSGHeader(unittest.TestCase):
assert self.header["NNu"] == 2
self.assertListEqual(self.header["NuList"].tolist(), [-14, 14])
@unittest.skip
def test_unsupported_header(self):
f = GSGReader(data_path("online/km3net_online.root"))
with self.assertWarns(UserWarning):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment