Skip to content
Snippets Groups Projects
Commit d2c13d91 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Merge branch 'refactor-offline-prepare-for-merging' into 'master'

Refactor offline I/O

See merge request !27
parents 52a777d5 fcca3704
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
Pipeline #10195 passed with warnings
#!/usr/bin/env python3
from .mc_header import data as mc_header
from .trigger import data as trigger
from .fitparameters import data as fitparameters
from .reconstruction import data as reconstruction
#!/usr/bin/env python3
data = {
"DAQ": "livetime",
"seed": "program level iseed",
"PM1_type_area": "type area TTS",
"PDF": "i1 i2",
"model": "interaction muon scattering numberOfEnergyBins",
"can": "zmin zmax r",
"genvol": "zmin zmax r volume numberOfEvents",
"merge": "time gain",
"coord_origin": "x y z",
"translate": "x y z",
"genhencut": "gDir Emin",
"k40": "rate time",
"norma": "primaryFlux numberOfPrimaries",
"livetime": "numberOfSeconds errorOfSeconds",
"flux": "type key file_1 file_2",
"spectrum": "alpha",
"fixedcan": "xcenter ycenter zmin zmax radius",
"start_run": "run_id",
}
for key in "cut_primary cut_seamuon cut_in cut_nu".split():
data[key] = "Emin Emax cosTmin cosTmax"
for key in "generator physics simul".split():
data[key] = "program version date time"
for key in data.keys():
data[key] = data[key].split()
from collections import namedtuple
import uproot
import numpy as np
import warnings
import km3io.definitions.trigger
import km3io.definitions.fitparameters
import km3io.definitions.reconstruction
from .definitions import mc_header
from .tools import Branch, BranchMapper, cached_property, _to_num, _unfold_indices
MAIN_TREE_NAME = "E"
EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"]
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
class cached_property:
"""A simple cache decorator for properties."""
def __init__(self, function):
self.function = function
def __get__(self, obj, cls):
if obj is None:
return self
prop = obj.__dict__[self.function.__name__] = self.function(obj)
return prop
class OfflineKeys:
"""wrapper for offline keys"""
def __init__(self, tree):
"""OfflineKeys is a class that reads all the available keys in an offline
file and adapts the keys format to Python format.
Parameters
----------
tree : uproot.TTree
The main ROOT tree.
"""
self._tree = tree
def __str__(self):
return '\n'.join([
"Events keys are:\n\t" + "\n\t".join(self.events_keys),
"Hits keys are:\n\t" + '\n\t'.join(self.hits_keys),
"Tracks keys are:\n\t" + '\n\t'.join(self.tracks_keys),
"Mc hits keys are:\n\t" + '\n\t'.join(self.mc_hits_keys),
"Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys)
])
def __repr__(self):
return "<{}>".format(self.__class__.__name__)
def _get_keys(self, tree, fake_branches=None):
"""Get tree keys except those in fake_branches
Parameters
----------
tree : uproot.Tree
The tree to look for keys
fake_branches : list of str or None
The fake branches to ignore
Returns
-------
list of str
The keys of the tree.
"""
keys = []
for key in tree.keys():
key = key.decode('utf-8')
if fake_branches is not None and key in fake_branches:
continue
keys.append(key)
return keys
BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
def _nested_mapper(key):
"""Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)"""
return '_'.join(key.split('.')[1:])
EVENTS_MAP = BranchMapper(name="events",
key="Evt",
extra={
't_sec': 't.fSec',
't_ns': 't.fNanoSec'
},
exclude=EXCLUDE_KEYS,
update={
'n_hits': 'hits',
'n_mc_hits': 'mc_hits',
'n_tracks': 'trks',
'n_mc_tracks': 'mc_trks'
},
attrparser=lambda a: a,
flat=True)
SUBBRANCH_MAPS = [
BranchMapper(name="tracks",
key="trks",
extra={},
exclude=EXCLUDE_KEYS +
['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'],
update={},
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="mc_tracks",
key="mc_trks",
extra={},
exclude=EXCLUDE_KEYS + [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages',
'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits'
],
update={},
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="hits",
key="hits",
extra={},
exclude=EXCLUDE_KEYS + [
'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a',
'hits.pure_a', 'hits.fUniqueID', 'hits.fBits'
],
update={},
attrparser=_nested_mapper,
flat=False),
BranchMapper(name="mc_hits",
key="mc_hits",
extra={},
exclude=EXCLUDE_KEYS + [
'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id',
'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig',
'mc_hits.fUniqueID', 'mc_hits.fBits'
],
update={},
attrparser=_nested_mapper,
flat=False),
]
class OfflineBranch(Branch):
@cached_property
def events_keys(self):
"""reads events keys from an offline file.
Returns
-------
list of str
list of all events keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['Evt', 'AAObject', 'TObject', 't']
t_baskets = ['t.fSec', 't.fNanoSec']
tree = self._tree['Evt']
return self._get_keys(self._tree['Evt'], fake_branches) + t_baskets
def usr(self):
return Usr(self._mapper, self._branch, index_chain=self._index_chain)
@cached_property
def hits_keys(self):
"""reads hits keys from an offline file.
Returns
-------
list of str
list of all hits keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['hits.usr', 'hits.usr_names']
return self._get_keys(self._tree['hits'], fake_branches)
@cached_property
def tracks_keys(self):
"""reads tracks keys from an offline file.
Returns
-------
list of str
list of all tracks keys found in an offline file,
except those found in fake branches.
"""
# a solution can be tree['trks.usr_data'].array(
# uproot.asdtype(">i4"))
fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names']
return self._get_keys(self._tree['Evt']['trks'], fake_branches)
class Usr:
"""Helper class to access AAObject `usr` stuff"""
def __init__(self, mapper, branch, index_chain=None):
self._mapper = mapper
self._name = mapper.name
self._index_chain = [] if index_chain is None else index_chain
self._branch = branch
self._usr_names = []
self._usr_idx_lookup = {}
@cached_property
def mc_hits_keys(self):
"""reads mc hits keys from an offline file.
Returns
-------
list of str
list of all mc hits keys found in an offline file,
except those found in fake branches.
"""
fake_branches = ['mc_hits.usr', 'mc_hits.usr_names']
return self._get_keys(self._tree['Evt']['mc_hits'], fake_branches)
self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr'
@cached_property
def mc_tracks_keys(self):
"""reads mc tracks keys from an offline file.
Returns
-------
list of str
list of all mc tracks keys found in an offline file,
except those found in fake branches.
"""
fake_branches = [
'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names'
self._initialise()
def _initialise(self):
try:
self._branch[self._usr_key]
# This will raise a KeyError in old aanet files
# which has a different strucuter and key (usr_data)
# We do not support those (yet)
except (KeyError, IndexError):
print("The `usr` fields could not be parsed for the '{}' branch.".
format(self._name))
return
if self._mapper.flat:
self._initialise_flat()
def _initialise_flat(self):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if
# it's always the case.
self._usr_names = [
n.decode("utf-8")
for n in self._branch[self._usr_key + '_names'].lazyarray()[0]
]
return self._get_keys(self._tree['Evt']['mc_trks'], fake_branches)
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
@cached_property
def valid_keys(self):
"""constructs a list of all valid keys to be read from an offline event file.
Returns
-------
list of str
list of all valid keys.
"""
return (self.events_keys + self.hits_keys + self.tracks_keys +
self.mc_tracks_keys + self.mc_hits_keys)
data = self._branch[self._usr_key].lazyarray()
@cached_property
def fit_keys(self):
"""constructs a list of fit parameters, not yet outsourced in an offline file.
if self._index_chain:
data = _unfold_indices(data, self._index_chain)
Returns
-------
list of str
list of all "trks.fitinf" keys.
"""
return sorted(km3io.definitions.fitparameters.data,
key=km3io.definitions.fitparameters.data.get,
reverse=False)
self._usr_data = data
@cached_property
def cut_hits_keys(self):
"""adapts hits keys for instance variables format in a Python class.
for name in self._usr_names:
setattr(self, name, self[name])
Returns
-------
list of str
list of adapted hits keys.
"""
return [k.split('hits.')[1].replace('.', '_') for k in self.hits_keys]
# def _initialise_nested(self):
# self._usr_names = [
# n.decode("utf-8") for n in self.branch['usr_names'].lazyarray(
# # TODO this will be fixed soon in uproot,
# # see https://github.com/scikit-hep/uproot/issues/465
# uproot.asgenobj(
# uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
# self.branch['usr_names']._context, 6),
# basketcache=BASKET_CACHE)[0]
# ]
@cached_property
def cut_tracks_keys(self):
"""adapts tracks keys for instance variables format in a Python class.
def __getitem__(self, item):
if self._mapper.flat:
return self.__getitem_flat__(item)
return self.__getitem_nested__(item)
def __getitem_flat__(self, item):
if self._index_chain:
return _unfold_indices(
self._usr_data, self._index_chain)[:,
self._usr_idx_lookup[item]]
else:
return self._usr_data[:, self._usr_idx_lookup[item]]
def __getitem_nested__(self, item):
data = self._branch[self._usr_key + '_names'].lazyarray(
# TODO this will be fixed soon in uproot,
# see https://github.com/scikit-hep/uproot/issues/465
uproot.asgenobj(
uproot.SimpleArray(uproot.STLVector(uproot.STLString())),
self._branch[self._usr_key + '_names']._context, 6),
basketcache=BASKET_CACHE)
return _unfold_indices(data, self._index_chain)
Returns
-------
list of str
list of adapted tracks keys.
"""
return [
k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys
]
def keys(self):
return self._usr_names
@cached_property
def cut_events_keys(self):
"""adapts events keys for instance variables format in a Python class.
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return '\n'.join(entries)
Returns
-------
list of str
list of adapted events keys.
"""
return [k.replace('.', '_') for k in self.events_keys]
def __repr__(self):
return "<{}[{}]>".format(self.__class__.__name__, self._name)
class OfflineReader:
"""reader for offline ROOT files"""
def __init__(self, file_path=None, fobj=None, data=None):
def __init__(self, file_path=None):
""" OfflineReader class is an offline ROOT file wrapper
Parameters
......@@ -213,681 +192,67 @@ class OfflineReader:
path-like object that points to the file.
"""
if file_path is not None:
self._fobj = uproot.open(file_path)
self._tree = self._fobj[MAIN_TREE_NAME]
self._data = self._tree.lazyarrays(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
else:
self._fobj = fobj
self._tree = self._fobj[MAIN_TREE_NAME]
self._data = data
self._fobj = uproot.open(file_path)
self._tree = self._fobj[MAIN_TREE_NAME]
@classmethod
def from_index(cls, source, index):
"""Create an instance with a subtree of a given index
Parameters
----------
source: ROOTDirectory
The source file.
index: index or slice
The index or slice to create the subtree.
"""
instance = cls(fobj=source._fobj, data=source._data[index])
return instance
def __getitem__(self, index):
return OfflineReader.from_index(source=self, index=index)
def __len__(self):
return len(self._data)
@cached_property
def events(self):
"""The `E` branch, containing all offline events."""
return OfflineBranch(self._tree,
mapper=EVENTS_MAP,
subbranchmaps=SUBBRANCH_MAPS)
@cached_property
def header(self):
"""The file header"""
if 'Head' in self._fobj:
header = {}
for n, x in self._fobj['Head']._map_3c_string_2c_string_3e_.items(
):
header[n.decode("utf-8")] = x.decode("utf-8").strip()
return header
return Header(header)
else:
warnings.warn("Your file header has an unsupported format")
@cached_property
def keys(self):
"""wrapper for all keys in an offline file.
Returns
-------
Class
OfflineKeys.
"""
return OfflineKeys(self._tree)
@cached_property
def events(self):
"""wrapper for offline events.
Returns
-------
Class
OfflineEvents.
"""
return OfflineEvents(
self.keys.cut_events_keys,
[self._data[key] for key in self.keys.events_keys])
@cached_property
def hits(self):
"""wrapper for offline hits.
Returns
-------
Class
OfflineHits.
"""
return OfflineHits(self.keys.cut_hits_keys,
[self._data[key] for key in self.keys.hits_keys])
@cached_property
def tracks(self):
"""wrapper for offline tracks.
Returns
-------
Class
OfflineTracks.
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.tracks_keys])
@cached_property
def mc_hits(self):
"""wrapper for offline mc hits.
Returns
-------
Class
OfflineHits.
"""
return OfflineHits(self.keys.cut_hits_keys,
[self._data[key] for key in self.keys.mc_hits_keys])
@cached_property
def mc_tracks(self):
"""wrapper for offline mc tracks.
Returns
-------
Class
OfflineTracks.
"""
return OfflineTracks(
self.keys.cut_tracks_keys,
[self._data[key] for key in self.keys.mc_tracks_keys])
@cached_property
def usr(self):
return Usr(self._tree)
def get_best_reco(self):
"""returns the best reconstructed track fit data. The best fit is defined
as the track fit with the maximum reconstruction stages. When "nan" is
returned, it means that the reconstruction parameter of interest is not
found. for example, in the case of muon simulations: if [1, 2] are the
reconstruction stages, then only the fit parameters corresponding to the
stages [1, 2] are found in the Offline files, the remaining fit parameters
corresponding to the stages 3, 4, 5 are all filled with nan.
Returns
-------
numpy recarray
a recarray of the best track fit data (reconstruction data).
"""
keys = ", ".join(self.keys.fit_keys[:-1])
empty_fit_info = np.array(
[match for match in self._find_empty(self.tracks.fitinf)])
fit_info = [
i for i, j in zip(self.tracks.fitinf, empty_fit_info[:, 1])
if j is not None
]
stages = self._get_max_reco_stages(self.tracks.rec_stages)
fit_data = np.array([i[j] for i, j in zip(fit_info, stages[:, 2])])
rows_size = len(max(fit_data, key=len))
equal_size_data = np.vstack([
np.hstack([i, np.zeros(rows_size - len(i)) + np.nan])
for i in fit_data
])
return np.core.records.fromarrays(equal_size_data.transpose(),
names=keys)
def _get_max_reco_stages(self, reco_stages):
"""find the longest reconstructed track based on the maximum size of
reconstructed stages.
Parameters
----------
reco_stages : chunked array
chunked array of all the reconstruction stages of all tracks.
In km3io, it is accessed with
km3io.OfflineReader(my_file).tracks.rec_stages .
Returns
-------
numpy array
array with 3 columns: *list of the maximum reco_stages
*lentgh of the maximum reco_stages
*position of the maximum reco_stages
"""
empty_reco_stages = np.array(
[match for match in self._find_empty(reco_stages)])
max_reco_stages = np.array(
[[max(i, key=len),
len(max(i, key=len)),
i.index(max(i, key=len))]
for i, j in zip(reco_stages, empty_reco_stages[:, 1])
if j is not None])
return max_reco_stages
def get_reco_fit(self, stages, mc=False):
"""construct a numpy recarray of the fit information (reconstruction
data) of the tracks reconstructed following the reconstruction stages
of interest.
Parameters
----------
stages : list
list of reconstruction stages of interest. for example
[1, 2, 3, 4, 5].
mc : bool, optional
default is False to look for fit data in the tracks tree in offline files
(not the mc tracks tree). mc=True to look for fit data from the mc tracks
tree in offline files.
Returns
-------
numpy recarray
a recarray of the fit information (reconstruction data) of
the tracks of interest.
Raises
------
ValueError
ValueError raised when the reconstruction stages of interest
are not found in the file.
"""
keys = ", ".join(self.keys.fit_keys[:-1])
if mc is False:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=False)])
fitinf = self.tracks.fitinf
if mc is True:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=True)])
fitinf = self.mc_tracks.fitinf
mask = rec_stages[:, 1] != None
if np.all(rec_stages[:, 1] == None):
raise ValueError(
"The stages {} are not found in your file.".format(
str(stages)))
else:
fit_data = np.array(
[i[k] for i, k in zip(fitinf[mask], rec_stages[:, 1][mask])])
rec_array = np.core.records.fromarrays(fit_data.transpose(),
names=keys)
return rec_array
def get_reco_hits(self, stages, keys, mc=False):
"""construct a dictionary of hits class data based on the reconstruction
stages of interest. For example, if the reconstruction stages of interest
are [1, 2, 3, 4, 5], then get_reco_hits method will select the hits data
from the events that were reconstructed following these stages (i.e
[1, 2, 3, 4, 5]).
class Header:
"""The header"""
def __init__(self, header):
self._data = {}
Parameters
----------
stages : list
list of reconstruction stages of interest. for example
[1, 2, 3, 4, 5].
keys : list of str
list of the hits class attributes.
mc : bool, optional
default is False to look for hits data in the hits tree in offline files
(not the mc_hits tree). mc=True to look for mc hits data in the mc hits
tree in offline files.
Returns
-------
dict
dictionary of lazyarrays containing data for each hits attribute requested.
Raises
------
ValueError
ValueError raised when the reconstruction stages of interest
are not found in the file.
"""
lazy_d = {}
for attribute, fields in header.items():
values = fields.split()
fields = mc_header.get(attribute, [])
if mc is False:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=False)])
hits_data = self.hits
n_values = len(values)
n_fields = len(fields)
if mc is True:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=True)])
hits_data = self.mc_hits
mask = rec_stages[:, 1] != None
if np.all(rec_stages[:, 1] == None):
raise ValueError(
"The stages {} are not found in your file.".format(
str(stages)))
else:
for key in keys:
lazy_d[key] = getattr(hits_data, key)[mask]
return lazy_d
def get_reco_events(self, stages, keys, mc=False):
"""construct a dictionary of events class data based on the reconstruction
stages of interest. For example, if the reconstruction stages of interest
are [1, 2, 3, 4, 5], then get_reco_events method will select the events data
that were reconstructed following these stages (i.e [1, 2, 3, 4, 5]).
Parameters
----------
stages : list
list of reconstruction stages of interest. for example
[1, 2, 3, 4, 5].
keys : list of str
list of the events class attributes.
mc : bool, optional
default is False to look for the reconstruction stages in the tracks tree
in offline files (not the mc tracks tree). mc=True to look for the reconstruction
data in the mc tracks tree in offline files.
Returns
-------
dict
dictionary of lazyarrays containing data for each events attribute requested.
Raises
------
ValueError
ValueError raised when the reconstruction stages of interest
are not found in the file.
"""
lazy_d = {}
if mc is False:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=False)])
if mc is True:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=True)])
mask = rec_stages[:, 1] != None
if np.all(rec_stages[:, 1] == None):
raise ValueError(
"The stages {} are not found in your file.".format(
str(stages)))
else:
for key in keys:
lazy_d[key] = getattr(self.events, key)[mask]
return lazy_d
def get_reco_tracks(self, stages, keys, mc=False):
"""construct a dictionary of tracks class data based on the reconstruction
stages of interest. For example, if the reconstruction stages of interest
are [1, 2, 3, 4, 5], then get_reco_tracks method will select tracks data
from the events that were reconstructed following these stages (i.e
[1, 2, 3, 4, 5]).
Parameters
----------
stages : list
list of reconstruction stages of interest. for example
[1, 2, 3, 4, 5].
keys : list of str
list of the tracks class attributes.
mc : bool, optional
default is False to look for tracks data in the tracks tree in offline files
(not the mc tracks tree). mc=True to look for tracks data in the mc tracks
tree in offline files.
Returns
-------
dict
dictionary of lazyarrays containing data for each tracks attribute requested.
Raises
------
ValueError
ValueError raised when the reconstruction stages of interest
are not found in the file.
"""
lazy_d = {}
if mc is False:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=False)])
tracks_data = self.tracks
if mc is True:
rec_stages = np.array(
[match for match in self._find_rec_stages(stages, mc=True)])
tracks_data = self.mc_tracks
mask = rec_stages[:, 1] != None
if np.all(rec_stages[:, 1] == None):
raise ValueError(
"The stages {} are not found in your file.".format(
str(stages)))
else:
for key in keys:
lazy_d[key] = np.array([
i[k] for i, k in zip(
getattr(tracks_data, key)[mask], rec_stages[:,
1][mask])
])
return lazy_d
def _find_rec_stages(self, stages, mc=False):
"""find the index of reconstruction stages of interest in a
list of multiple reconstruction stages.
Parameters
----------
stages : list
list of reconstruction stages of interest. for example
[1, 2, 3, 4, 5].
mc : bool, optional
default is False to look for reconstruction stages in the tracks tree in
offline files (not the mc tracks tree). mc=True to look for reconstruction
stages in the mc tracks tree in offline files.
Yields
------
generator
the track id and the index of the reconstruction stages of
interest if found. If the reconstruction stages of interest
are not found, None is returned as the stages index.
"""
if mc is False:
stages_data = self.tracks.rec_stages
if mc is True:
stages_data = self.mc_tracks.rec_stages
for trk_index, rec_stages in enumerate(stages_data):
try:
stages_index = rec_stages.index(stages)
except ValueError:
stages_index = None
yield trk_index, stages_index
if n_values == 1 and n_fields == 0:
self._data[attribute] = _to_num(values[0])
continue
yield trk_index, stages_index
n_max = max(n_values, n_fields)
values += [None] * (n_max - n_values)
fields += ["field_{}".format(i) for i in range(n_fields, n_max)]
def _find_empty(self, array):
"""finds empty lists/arrays in an awkward array
Constructor = namedtuple(attribute, fields)
Parameters
----------
array : awkward array
Awkward array of data of interest. For example:
km3io.OfflineReader(my_file).tracks.fitinf .
Yields
------
generator
the empty list id and the index of the empty list. When
data structure (list) is simply empty, None is written in the
corresponding index. However, when data structure (list) is not
empty and does not contain an empty list, then False is written in the
corresponding index.
"""
for i, rs in enumerate(array):
try:
if len(rs) == 0:
j = None
if len(rs) != 0:
j = rs.index([])
except ValueError:
j = False # rs not empty but [] not found
yield i, j
if not values:
continue
yield i, j
class Usr:
"""Helper class to access AAObject usr stuff"""
def __init__(self, tree):
# Here, we assume that every event has the same names in the same order
# to massively increase the performance. This needs triple check if it's
# always the case; the usr-format is simply a very bad design.
try:
self._usr_names = [
n.decode("utf-8") for n in tree['Evt']['usr_names'].array()[0]
]
except (KeyError, IndexError): # e.g. old aanet files
self._usr_names = []
else:
self._usr_idx_lookup = {
name: index
for index, name in enumerate(self._usr_names)
}
self._usr_data = tree['Evt']['usr'].lazyarray(
basketcache=uproot.cache.ThreadSafeArrayCache(
BASKET_CACHE_SIZE))
for name in self._usr_names:
setattr(self, name, self[name])
def __getitem__(self, item):
return self._usr_data[:, self._usr_idx_lookup[item]]
def keys(self):
return self._usr_names
def __str__(self):
entries = []
for name in self.keys():
entries.append("{}: {}".format(name, self[name]))
return '\n'.join(entries)
class OfflineEvents:
"""wrapper for offline events"""
def __init__(self, keys, values):
"""wrapper for offline events.
Parameters
----------
keys : list of str
list of valid events keys.
values : list of arrays
list of arrays containting events data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __getitem__(self, item):
return OfflineEvent(self._keys, [v[item] for v in self._values])
def __len__(self):
try:
return len(self._values[0])
except IndexError:
return 0
def __str__(self):
return "Number of events: {}".format(len(self))
def __repr__(self):
return "<{}: {} parsed events>".format(self.__class__.__name__,
len(self))
class OfflineEvent:
"""wrapper for an offline event"""
def __init__(self, keys, values):
"""wrapper for one offline event.
Parameters
----------
keys : list of str
list of valid events keys.
values : list of arrays
list of arrays containting event data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __str__(self):
return "offline event:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values)
])
class OfflineHits:
"""wrapper for offline hits"""
def __init__(self, keys, values):
"""wrapper for offline hits.
Parameters
----------
keys : list of str
list of cropped hits keys.
values : list of arrays
list of arrays containting hits data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __getitem__(self, item):
return OfflineHit(self._keys, [v[item] for v in self._values])
def __len__(self):
try:
return len(self._values[0])
except IndexError:
return 0
def __str__(self):
return "Number of hits: {}".format(len(self))
def __repr__(self):
return "<{}: {} parsed elements>".format(self.__class__.__name__,
len(self))
class OfflineHit:
"""wrapper for an offline hit"""
def __init__(self, keys, values):
"""wrapper for one offline hit.
Parameters
----------
keys : list of str
list of cropped hits keys.
values : list of arrays
list of arrays containting hit data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __str__(self):
return "offline hit:\n\t" + "\n\t".join([
"{:15} {:^10} {:>10}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values)
])
def __getitem__(self, item):
return self._values[item]
class OfflineTracks:
"""wrapper for offline tracks"""
def __init__(self, keys, values):
"""wrapper for offline tracks
Parameters
----------
keys : list of str
list of cropped tracks keys.
values : list of arrays
list of arrays containting tracks data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
def __getitem__(self, item):
return OfflineTrack(self._keys, [v[item] for v in self._values])
def __len__(self):
try:
return len(self._values[0])
except IndexError:
return 0
def __str__(self):
return "Number of tracks: {}".format(len(self))
def __repr__(self):
return "<{}: {} parsed elements>".format(self.__class__.__name__,
len(self))
self._data[attribute] = Constructor(
**{f: _to_num(v)
for (f, v) in zip(fields, values)})
class OfflineTrack:
"""wrapper for an offline track"""
def __init__(self, keys, values):
"""wrapper for one offline track.
Parameters
----------
keys : list of str
list of cropped tracks keys.
values : list of arrays
list of arrays containting track data.
"""
self._keys = keys
self._values = values
for k, v in zip(self._keys, self._values):
setattr(self, k, v)
for attribute, value in self._data.items():
setattr(self, attribute, value)
def __str__(self):
return "offline track:\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(v))
for k, v in zip(self._keys, self._values) if k not in ['fitinf']
]) + "\n\t" + "\n\t".join([
"{:30} {:^2} {:>26}".format(k, ':', str(
getattr(self, 'fitinf')[v]))
for k, v in km3io.definitions.fitparameters.data.items()
if len(getattr(self, 'fitinf')) > v
]) # I don't like 18 being explicit here
def __getitem__(self, item):
return self._values[item]
lines = ["MC Header:"]
keys = set(mc_header.keys())
for key, value in self._data.items():
if key in keys:
lines.append(" {}".format(value))
else:
lines.append(" {}: {}".format(key, value))
return "\n".join(lines)
......@@ -19,6 +19,21 @@ class cached_property:
return prop
def _unfold_indices(obj, indices):
"""Unfolds an index chain and returns the corresponding item"""
original_obj = obj
for depth, idx in enumerate(indices):
try:
obj = obj[idx]
except IndexError:
print(
"IndexError while accessing an item from '{}' at depth {} ({}) "
"using the index chain {}".format(repr(original_obj), depth,
idx, indices))
raise
return obj
BranchMapper = namedtuple(
"BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
......@@ -29,17 +44,19 @@ class Branch:
def __init__(self,
tree,
mapper,
index=None,
index_chain=None,
subbranchmaps=None,
keymap=None):
self._tree = tree
self._mapper = mapper
self._index = index
self._index_chain = [] if index_chain is None else index_chain
self._keymap = None
self._branch = tree[mapper.key]
self._subbranches = []
self._subbranchmaps = subbranchmaps
self._iterator_index = 0
if keymap is None:
self._initialise_keys() #
else:
......@@ -49,7 +66,7 @@ class Branch:
for mapper in subbranchmaps:
subbranch = self.__class__(self._tree,
mapper=mapper,
index=self._index)
index_chain=self._index_chain)
self._subbranches.append(subbranch)
for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch)
......@@ -57,8 +74,8 @@ class Branch:
def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys"""
# TODO: this could be a cached property
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set(
self._mapper.exclude)
keys = set(k.decode('utf-8')
for k in self._branch.keys()) - set(self._mapper.exclude)
self._keymap = {
**{self._mapper.attrparser(k): k
for k in keys},
......@@ -86,42 +103,46 @@ class Branch:
def __getkey__(self, key):
out = self._branch[self._keymap[key]].lazyarray(
basketcache=BASKET_CACHE)
if self._index is not None:
out = out[self._index]
return out
return _unfold_indices(out, self._index_chain)
def __getitem__(self, item):
"""Slicing magic"""
if isinstance(item, (int, slice)):
return self.__class__(self._tree,
self._mapper,
index=item,
keymap=self._keymap,
subbranchmaps=self._subbranchmaps)
if isinstance(item, tuple):
return self[item[0]][item[1]]
if isinstance(item, str):
return self.__getkey__(item)
return self.__class__(self._tree,
self._mapper,
index=np.array(item),
index_chain=self._index_chain + [item],
keymap=self._keymap,
subbranchmaps=self._subbranchmaps)
def __len__(self):
if self._index is None:
if not self._index_chain:
return len(self._branch)
elif isinstance(self._index, int):
elif isinstance(self._index_chain[-1], int):
return 1
else:
return len(self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE)[self._index])
return len(
_unfold_indices(
self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE), self._index_chain))
def __iter__(self):
self._iterator_index = 0
return self
def __next__(self):
idx = self._iterator_index
self._iterator_index += 1
if idx >= len(self):
raise StopIteration
return self[idx]
def __str__(self):
return "Number of elements: {}".format(len(self._branch))
length = len(self)
return "{} ({}) with {} element{}".format(self.__class__.__name__,
self._mapper.name, length,
's' if length > 1 else '')
def __repr__(self):
length = len(self)
......
......@@ -5,12 +5,12 @@ import unittest
from km3io.daq import DAQReader, get_rate, has_udp_trailer, get_udp_max_sequence_number, get_channel_flags, get_number_udp_packets
SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples")
DAQ_FILE = DAQReader(os.path.join(SAMPLES_DIR, "daq_v1.0.0.root"))
class TestDAQEvents(unittest.TestCase):
def setUp(self):
self.events = DAQ_FILE.events
self.events = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events
def test_index_lookup(self):
assert 3 == len(self.events)
......@@ -24,7 +24,8 @@ class TestDAQEvents(unittest.TestCase):
class TestDAQEvent(unittest.TestCase):
def setUp(self):
self.event = DAQ_FILE.events[0]
self.event = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events[0]
def test_str(self):
assert re.match(".*event.*96.*snapshot.*18.*triggered",
......@@ -37,7 +38,8 @@ class TestDAQEvent(unittest.TestCase):
class TestDAQEventsSnapshotHits(unittest.TestCase):
def setUp(self):
self.events = DAQ_FILE.events
self.events = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events
self.lengths = {0: 96, 1: 124, -1: 78}
self.total_item_count = 298
......@@ -75,7 +77,8 @@ class TestDAQEventsSnapshotHits(unittest.TestCase):
class TestDAQEventsTriggeredHits(unittest.TestCase):
def setUp(self):
self.events = DAQ_FILE.events
self.events = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events
self.lengths = {0: 18, 1: 53, -1: 9}
self.total_item_count = 80
......@@ -115,7 +118,8 @@ class TestDAQEventsTriggeredHits(unittest.TestCase):
class TestDAQTimeslices(unittest.TestCase):
def setUp(self):
self.ts = DAQ_FILE.timeslices
self.ts = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).timeslices
def test_data_lengths(self):
assert 3 == len(self.ts._timeslices["L1"][0])
......@@ -140,7 +144,8 @@ class TestDAQTimeslices(unittest.TestCase):
class TestDAQTimeslice(unittest.TestCase):
def setUp(self):
self.ts = DAQ_FILE.timeslices
self.ts = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).timeslices
self.n_frames = {"L1": [69, 69, 69], "SN": [64, 66, 68]}
def test_str(self):
......@@ -153,7 +158,8 @@ class TestDAQTimeslice(unittest.TestCase):
class TestSummaryslices(unittest.TestCase):
def setUp(self):
self.ss = DAQ_FILE.summaryslices
self.ss = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).summaryslices
def test_headers(self):
assert 3 == len(self.ss.headers)
......
......@@ -2,389 +2,374 @@ import unittest
import numpy as np
from pathlib import Path
from km3io.offline import OfflineEvents, OfflineHits, OfflineTracks
from km3io import OfflineReader
from km3io.offline import _nested_mapper, Header
SAMPLES_DIR = Path(__file__).parent / 'samples'
OFFLINE_FILE = SAMPLES_DIR / 'aanet_v2.0.0.root'
OFFLINE_USR = SAMPLES_DIR / 'usr-sample.root'
OFFLINE_NUMUCC = SAMPLES_DIR / "numucc.root" # with mc data
class TestOfflineKeys(unittest.TestCase):
def setUp(self):
self.keys = OfflineReader(OFFLINE_FILE).keys
def test_events_keys(self):
# there are 22 "valid" events keys
self.assertEqual(len(self.keys.events_keys), 22)
self.assertEqual(len(self.keys.cut_events_keys), 22)
def test_hits_keys(self):
# there are 20 "valid" hits keys
self.assertEqual(len(self.keys.hits_keys), 20)
self.assertEqual(len(self.keys.mc_hits_keys), 20)
self.assertEqual(len(self.keys.cut_hits_keys), 20)
def test_tracks_keys(self):
# there are 22 "valid" tracks keys
self.assertEqual(len(self.keys.tracks_keys), 22)
self.assertEqual(len(self.keys.mc_tracks_keys), 22)
self.assertEqual(len(self.keys.cut_tracks_keys), 22)
def test_valid_keys(self):
# there are 106 valid keys: 22*2 + 22 + 20*2
# (fit keys are excluded)
self.assertEqual(len(self.keys.valid_keys), 106)
def test_fit_keys(self):
# there are 18 fit keys
self.assertEqual(len(self.keys.fit_keys), 18)
OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
OFFLINE_USR = OfflineReader(SAMPLES_DIR / 'usr-sample.root')
OFFLINE_NUMUCC = OfflineReader(SAMPLES_DIR / "numucc.root") # with mc data
class TestOfflineReader(unittest.TestCase):
def setUp(self):
self.r = OfflineReader(OFFLINE_FILE)
self.nu = OfflineReader(OFFLINE_NUMUCC)
self.Nevents = 10
self.r = OFFLINE_FILE
self.nu = OFFLINE_NUMUCC
self.n_events = 10
def test_number_events(self):
Nevents = len(self.r)
# check that there are 10 events
self.assertEqual(Nevents, self.Nevents)
def test_find_empty(self):
fitinf = self.nu.tracks.fitinf
rec_stages = self.nu.tracks.rec_stages
empty_fitinf = np.array(
[match for match in self.nu._find_empty(fitinf)])
empty_stages = np.array(
[match for match in self.nu._find_empty(rec_stages)])
self.assertListEqual(empty_fitinf[:5, 1].tolist(),
[23, 14, 14, 4, None])
self.assertListEqual(empty_stages[:5, 1].tolist(),
[False, False, False, False, None])
def test_find_rec_stages(self):
stages = np.array(
[match for match in self.nu._find_rec_stages([1, 2, 3, 4, 5])])
self.assertListEqual(stages[:5, 1].tolist(), [0, 0, 0, 0, None])
def test_get_reco_fit(self):
JGANDALF_BETA0_RAD = [
0.0020367251782607574, 0.003306725805622178, 0.0057877124222254885,
0.015581698352185896
]
reco_fit = self.nu.get_reco_fit([1, 2, 3, 4, 5])['JGANDALF_BETA0_RAD']
self.assertListEqual(JGANDALF_BETA0_RAD, reco_fit[:4].tolist())
with self.assertRaises(ValueError):
self.nu.get_reco_fit([1000, 4512, 5625], mc=True)
def test_get_reco_hits(self):
doms = self.nu.get_reco_hits([1, 2, 3, 4, 5], ["dom_id"])["dom_id"]
mc_doms = self.nu.get_reco_hits([], ["dom_id"], mc=True)["dom_id"]
self.assertEqual(doms.size, 9)
self.assertEqual(mc_doms.size, 10)
self.assertListEqual(doms[0][0:4].tolist(),
self.nu.hits[0].dom_id[0:4].tolist())
self.assertListEqual(mc_doms[0][0:4].tolist(),
self.nu.mc_hits[0].dom_id[0:4].tolist())
with self.assertRaises(ValueError):
self.nu.get_reco_hits([1000, 4512, 5625], ["dom_id"])
def test_get_reco_tracks(self):
assert self.n_events == len(self.r.events)
pos = self.nu.get_reco_tracks([1, 2, 3, 4, 5], ["pos_x"])["pos_x"]
mc_pos = self.nu.get_reco_tracks([], ["pos_x"], mc=True)["pos_x"]
self.assertEqual(pos.size, 9)
self.assertEqual(mc_pos.size, 10)
self.assertEqual(pos[0], self.nu.tracks[0].pos_x[0])
self.assertEqual(mc_pos[0], self.nu.mc_tracks[0].pos_x[0])
with self.assertRaises(ValueError):
self.nu.get_reco_tracks([1000, 4512, 5625], ["pos_x"])
def test_get_reco_events(self):
hits = self.nu.get_reco_events([1, 2, 3, 4, 5], ["hits"])["hits"]
mc_hits = self.nu.get_reco_events([], ["mc_hits"], mc=True)["mc_hits"]
self.assertEqual(hits.size, 9)
self.assertEqual(mc_hits.size, 10)
self.assertListEqual(hits[0:4].tolist(),
self.nu.events.hits[0:4].tolist())
self.assertListEqual(mc_hits[0:4].tolist(),
self.nu.events.mc_hits[0:4].tolist())
with self.assertRaises(ValueError):
self.nu.get_reco_events([1000, 4512, 5625], ["hits"])
def test_get_max_reco_stages(self):
rec_stages = self.nu.tracks.rec_stages
max_reco = self.nu._get_max_reco_stages(rec_stages)
self.assertEqual(len(max_reco.tolist()), 9)
self.assertListEqual(max_reco[0].tolist(), [[1, 2, 3, 4, 5], 5, 0])
def test_best_reco(self):
JGANDALF_BETA1_RAD = [
0.0014177681261476852, 0.002094094517471032, 0.003923368624980349,
0.009491461076780453
]
best = self.nu.get_best_reco()
self.assertEqual(best.size, 9)
self.assertEqual(best['JGANDALF_BETA1_RAD'][:4].tolist(),
JGANDALF_BETA1_RAD)
def test_reading_header(self):
# head is the supported format
head = OfflineReader(OFFLINE_NUMUCC).header
self.assertEqual(float(head['DAQ']), 394)
self.assertEqual(float(head['kcut']), 2)
class TestHeader(unittest.TestCase):
def test_str_header(self):
assert "MC Header" in str(OFFLINE_NUMUCC.header)
def test_warning_if_unsupported_header(self):
# test the warning for unsupported fheader format
with self.assertWarns(UserWarning):
self.r.header
OFFLINE_FILE.header
def test_missing_key_definitions(self):
head = {'a': '1 2 3', 'b': '4', 'c': 'd'}
header = Header(head)
assert 1 == header.a.field_0
assert 2 == header.a.field_1
assert 3 == header.a.field_2
assert 4 == header.b
assert 'd' == header.c
def test_missing_values(self):
head = {'can': '1'}
header = Header(head)
assert 1 == header.can.zmin
assert header.can.zmax is None
assert header.can.r is None
def test_additional_values_compared_to_definition(self):
head = {'can': '1 2 3 4'}
header = Header(head)
assert 1 == header.can.zmin
assert 2 == header.can.zmax
assert 3 == header.can.r
assert 4 == header.can.field_3
def test_header(self):
head = {
'DAQ': '394',
'PDF': '4',
'can': '0 1027 888.4',
'undefined': '1 2 test 3.4'
}
header = Header(head)
assert 394 == header.DAQ.livetime
assert 4 == header.PDF.i1
assert header.PDF.i2 is None
assert 0 == header.can.zmin
assert 1027 == header.can.zmax
assert 888.4 == header.can.r
assert 1 == header.undefined.field_0
assert 2 == header.undefined.field_1
assert "test" == header.undefined.field_2
assert 3.4 == header.undefined.field_3
def test_reading_header_from_sample_file(self):
head = OFFLINE_NUMUCC.header
assert 394 == head.DAQ.livetime
assert 4 == head.PDF.i1
assert 58 == head.PDF.i2
assert 0 == head.coord_origin.x
assert 0 == head.coord_origin.y
assert 0 == head.coord_origin.z
assert 100 == head.cut_nu.Emin
assert 100000000.0 == head.cut_nu.Emax
assert -1 == head.cut_nu.cosTmin
assert 1 == head.cut_nu.cosTmax
assert "diffuse" == head.sourcemode
assert 100000.0 == head.ngen
class TestOfflineEvents(unittest.TestCase):
def setUp(self):
self.events = OfflineReader(OFFLINE_FILE).events
self.hits = {0: 176, 1: 125, -1: 105}
self.Nevents = 10
self.events = OFFLINE_FILE.events
self.n_events = 10
self.det_id = [44] * self.n_events
self.n_hits = [176, 125, 318, 157, 83, 60, 71, 84, 255, 105]
self.n_tracks = [56, 55, 56, 56, 56, 56, 56, 56, 54, 56]
self.t_sec = [
1567036818, 1567036818, 1567036820, 1567036816, 1567036816,
1567036816, 1567036822, 1567036818, 1567036818, 1567036820
]
self.t_ns = [
200000000, 300000000, 200000000, 500000000, 500000000, 500000000,
200000000, 500000000, 500000000, 400000000
]
def test_reading_hits(self):
# test item selection
for event_id, hit in self.hits.items():
self.assertEqual(hit, self.events.hits[event_id])
def test_len(self):
assert self.n_events == len(self.events)
def reading_tracks(self):
self.assertListEqual(list(self.events.trks[:3]), [56, 55, 56])
def test_attributes_available(self):
for key in self.events._keymap.keys():
getattr(self.events, key)
def test_item_selection(self):
for event_id, hit in self.hits.items():
self.assertEqual(hit, self.events[event_id].hits)
def test_attributes(self):
assert self.n_events == len(self.events.det_id)
self.assertListEqual(self.det_id, list(self.events.det_id))
self.assertListEqual(self.n_hits, list(self.events.n_hits))
self.assertListEqual(self.n_tracks, list(self.events.n_tracks))
self.assertListEqual(self.t_sec, list(self.events.t_sec))
self.assertListEqual(self.t_ns, list(self.events.t_ns))
def test_len(self):
self.assertEqual(len(self.events), self.Nevents)
def test_keys(self):
assert np.allclose(self.n_hits, self.events['n_hits'])
assert np.allclose(self.n_tracks, self.events['n_tracks'])
assert np.allclose(self.t_sec, self.events['t_sec'])
assert np.allclose(self.t_ns, self.events['t_ns'])
def test_IndexError(self):
# test handling IndexError with empty lists/arrays
self.assertEqual(len(OfflineEvents(['whatever'], [])), 0)
def test_slicing(self):
s = slice(2, 8, 2)
s_events = self.events[s]
assert 3 == len(s_events)
self.assertListEqual(self.n_hits[s], list(s_events.n_hits))
self.assertListEqual(self.n_tracks[s], list(s_events.n_tracks))
self.assertListEqual(self.t_sec[s], list(s_events.t_sec))
self.assertListEqual(self.t_ns[s], list(s_events.t_ns))
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
assert np.allclose(self.events[s].n_hits, self.events.n_hits[s])
def test_index_consistency(self):
for i in [0, 2, 5]:
assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
def test_index_chaining(self):
assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5])
assert np.allclose(self.events[3:5][0].n_hits,
self.events.n_hits[3:5][0])
assert np.allclose(self.events[3:5].hits[1].dom_id[4],
self.events.hits[3:5][1][4].dom_id)
assert np.allclose(self.events.hits[3:5][1][4].dom_id,
self.events[3:5][1][4].hits.dom_id)
def test_iteration(self):
i = 0
for event in self.events:
i += 1
assert 10 == i
def test_iteration_2(self):
n_hits = [e.n_hits for e in self.events]
assert np.allclose(n_hits, self.events.n_hits)
def test_str(self):
self.assertEqual(str(self.events), 'Number of events: 10')
assert str(self.n_events) in str(self.events)
def test_repr(self):
self.assertEqual(repr(self.events),
'<OfflineEvents: 10 parsed events>')
class TestOfflineEvent(unittest.TestCase):
def test_event(self):
self.event = OfflineReader(OFFLINE_FILE).events[0]
assert str(self.n_events) in repr(self.events)
class TestOfflineHits(unittest.TestCase):
def setUp(self):
self.hits = OfflineReader(OFFLINE_FILE).hits
self.lengths = {0: 176, 1: 125, -1: 105}
self.total_item_count = 1434
self.r_mc = OfflineReader(OFFLINE_NUMUCC)
self.Nevents = 10
def test_item_selection(self):
self.assertListEqual(list(self.hits[0].dom_id[:3]),
[806451572, 806451572, 806451572])
def test_IndexError(self):
# test handling IndexError with empty lists/arrays
self.assertEqual(len(OfflineHits(['whatever'], [])), 0)
def test_repr(self):
self.assertEqual(repr(self.hits), '<OfflineHits: 10 parsed elements>')
self.hits = OFFLINE_FILE.events.hits
self.n_hits = 10
self.dom_id = {
0: [
806451572, 806451572, 806451572, 806451572, 806455814,
806455814, 806455814, 806483369, 806483369, 806483369
],
5: [
806455814, 806487219, 806487219, 806487219, 806487226,
808432835, 808432835, 808432835, 808432835, 808432835
]
}
self.t = {
0: [
70104010., 70104016., 70104192., 70104123., 70103096.,
70103797., 70103796., 70104191., 70104223., 70104181.
],
5: [
81861237., 81859608., 81860586., 81861062., 81860357.,
81860627., 81860628., 81860625., 81860627., 81860629.
]
}
def test_attributes_available(self):
for key in self.hits._keymap.keys():
getattr(self.hits, key)
def test_channel_ids(self):
self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min()))
self.assertTrue(all(c < 31 for c in self.hits.channel_id.max()))
def test_str(self):
self.assertEqual(str(self.hits), 'Number of hits: 10')
def test_reading_dom_id(self):
dom_ids = self.hits.dom_id
for event_id, length in self.lengths.items():
self.assertEqual(length, len(dom_ids[event_id]))
self.assertEqual(self.total_item_count, sum(dom_ids.count()))
self.assertListEqual([806451572, 806451572, 806451572],
list(dom_ids[0][:3]))
def test_reading_channel_id(self):
channel_ids = self.hits.channel_id
for event_id, length in self.lengths.items():
self.assertEqual(length, len(channel_ids[event_id]))
self.assertEqual(self.total_item_count, sum(channel_ids.count()))
assert str(self.n_hits) in str(self.hits)
self.assertListEqual([8, 9, 14], list(channel_ids[0][:3]))
# channel IDs are always between [0, 30]
self.assertTrue(all(c >= 0 for c in channel_ids.min()))
self.assertTrue(all(c < 31 for c in channel_ids.max()))
def test_reading_times(self):
ts = self.hits.t
for event_id, length in self.lengths.items():
self.assertEqual(length, len(ts[event_id]))
self.assertEqual(self.total_item_count, sum(ts.count()))
self.assertListEqual([70104010.0, 70104016.0, 70104192.0],
list(ts[0][:3]))
def test_reading_mc_pmt_id(self):
pmt_ids = self.r_mc.mc_hits.pmt_id
lengths = {0: 58, 2: 28, -1: 48}
def test_repr(self):
assert str(self.n_hits) in repr(self.hits)
for hit_id, length in lengths.items():
self.assertEqual(length, len(pmt_ids[hit_id]))
def test_attributes(self):
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id,
list(self.hits.dom_id[idx][:len(dom_id)]))
for idx, t in self.t.items():
assert np.allclose(t, self.hits.t[idx][:len(t)])
self.assertEqual(self.Nevents, len(pmt_ids))
def test_slicing(self):
s = slice(2, 8, 2)
s_hits = self.hits[s]
assert 3 == len(s_hits)
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id[s], list(self.hits.dom_id[idx][s]))
for idx, t in self.t.items():
self.assertListEqual(t[s], list(self.hits.t[idx][s]))
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
for idx in range(3):
assert np.allclose(self.hits.dom_id[idx][s],
self.hits[idx].dom_id[s])
assert np.allclose(OFFLINE_FILE.events[idx].hits.dom_id[s],
self.hits.dom_id[idx][s])
def test_index_consistency(self):
for idx, dom_ids in self.dom_id.items():
assert np.allclose(self.hits[idx].dom_id[:self.n_hits],
dom_ids[:self.n_hits])
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[:self.n_hits],
dom_ids[:self.n_hits])
for idx, ts in self.t.items():
assert np.allclose(self.hits[idx].t[:self.n_hits],
ts[:self.n_hits])
assert np.allclose(OFFLINE_FILE.events[idx].hits.t[:self.n_hits],
ts[:self.n_hits])
self.assertListEqual([677, 687, 689], list(pmt_ids[0][:3]))
def test_keys(self):
assert "dom_id" in self.hits.keys()
class TestOfflineHit(unittest.TestCase):
class TestOfflineTracks(unittest.TestCase):
def setUp(self):
self.hit = OfflineReader(OFFLINE_FILE)[0].hits[0]
self.f = OFFLINE_FILE
self.tracks = OFFLINE_FILE.events.tracks
self.tracks_numucc = OFFLINE_NUMUCC
self.n_events = 10
def test_item_selection(self):
self.assertEqual(self.hit[0], self.hit.id)
self.assertEqual(self.hit[1], self.hit.dom_id)
def test_attributes_available(self):
for key in self.tracks._keymap.keys():
getattr(self.tracks, key)
class TestOfflineTracks(unittest.TestCase):
def setUp(self):
self.tracks = OfflineReader(OFFLINE_FILE).tracks
self.r_mc = OfflineReader(OFFLINE_NUMUCC)
self.Nevents = 10
@unittest.skip
def test_attributes(self):
for idx, dom_id in self.dom_id.items():
self.assertListEqual(dom_id,
list(self.hits.dom_id[idx][:len(dom_id)]))
for idx, t in self.t.items():
assert np.allclose(t, self.hits.t[idx][:len(t)])
def test_item_selection(self):
self.assertListEqual(list(self.tracks[0].dir_z[:2]),
[-0.872885221293917, -0.872885221293917])
def test_IndexError(self):
# test handling IndexError with empty lists/arrays
self.assertEqual(len(OfflineTracks(['whatever'], [])), 0)
def test_repr(self):
self.assertEqual(repr(self.tracks),
'<OfflineTracks: 10 parsed elements>')
def test_str(self):
self.assertEqual(str(self.tracks), 'Number of tracks: 10')
def test_reading_tracks_dir_z(self):
dir_z = self.tracks.dir_z
tracks_dir_z = {0: 56, 1: 55, 8: 54}
for track_id, n_dir in tracks_dir_z.items():
self.assertEqual(n_dir, len(dir_z[track_id]))
# check that there are 10 arrays of tracks.dir_z info
self.assertEqual(len(dir_z), self.Nevents)
def test_reading_mc_tracks_dir_z(self):
dir_z = self.r_mc.mc_tracks.dir_z
tracks_dir_z = {0: 11, 1: 25, 8: 13}
for track_id, n_dir in tracks_dir_z.items():
self.assertEqual(n_dir, len(dir_z[track_id]))
# check that there are 10 arrays of tracks.dir_z info
self.assertEqual(len(dir_z), self.Nevents)
self.assertListEqual([0.230189, 0.230189, 0.218663],
list(dir_z[0][:3]))
assert " 10 " in repr(self.tracks)
def test_slicing(self):
tracks = self.tracks
assert 10 == len(tracks)
# track_selection = tracks[2:7]
# assert 5 == len(track_selection)
# track_selection_2 = tracks[1:3]
# assert 2 == len(track_selection_2)
# for _slice in [
# slice(0, 0),
# slice(0, 1),
# slice(0, 2),
# slice(1, 5),
# slice(3, -2)
# ]:
# self.assertListEqual(list(tracks.E[:, 0][_slice]),
# list(tracks[_slice].E[:, 0]))
class TestOfflineTrack(unittest.TestCase):
self.assertEqual(10, len(tracks))
self.assertEqual(1, len(tracks[0]))
track_selection = tracks[2:7]
assert 5 == len(track_selection)
track_selection_2 = tracks[1:3]
assert 2 == len(track_selection_2)
for _slice in [
slice(0, 0),
slice(0, 1),
slice(0, 2),
slice(1, 5),
slice(3, -2)
]:
self.assertListEqual(list(tracks.E[:, 0][_slice]),
list(tracks[_slice].E[:, 0]))
def test_nested_indexing(self):
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5].tracks[1].fitinf[9][2])
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5][1][9][2].tracks.fitinf)
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5][1].tracks[9][2].fitinf)
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
self.f.events[3:5][1].tracks[9].fitinf[2])
class TestBranchIndexingMagic(unittest.TestCase):
def setUp(self):
self.track = OfflineReader(OFFLINE_FILE)[0].tracks[0]
self.events = OFFLINE_FILE.events
def test_item_selection(self):
self.assertEqual(self.track[0], self.track.fUniqueID)
self.assertEqual(self.track[10], self.track.E)
def test_foo(self):
self.assertEqual(318, self.events[2:4].n_hits[0])
assert np.allclose(self.events[3].tracks.dir_z[10],
self.events.tracks.dir_z[3, 10])
assert np.allclose(self.events[3:6].tracks.pos_y[:, 0],
self.events.tracks.pos_y[3:6, 0])
def test_str(self):
self.assertEqual(str(self.track).split('\n\t')[0], 'offline track:')
# test selecting with a list
self.assertEqual(3, len(self.events[[0, 2, 3]]))
class TestUsr(unittest.TestCase):
def setUp(self):
self.f = OfflineReader(OFFLINE_USR)
def test_str(self):
print(self.f.usr)
self.f = OFFLINE_USR
def test_nonexistent_usr(self):
f = OfflineReader(SAMPLES_DIR / "daq_v1.0.0.root")
self.assertListEqual([], f.usr.keys())
def test_str_flat(self):
print(self.f.events.usr)
def test_keys(self):
def test_keys_flat(self):
self.assertListEqual([
'RecoQuality', 'RecoNDF', 'CoC', 'ToT', 'ChargeAbove',
'ChargeBelow', 'ChargeRatio', 'DeltaPosZ', 'FirstPartPosZ',
'LastPartPosZ', 'NSnapHits', 'NTrigHits', 'NTrigDOMs',
'NTrigLines', 'NSpeedVetoHits', 'NGeometryVetoHits',
'ClassficationScore'
], self.f.usr.keys())
], self.f.events.usr.keys())
def test_getitem(self):
def test_getitem_flat(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.usr['CoC'])
self.f.events.usr['CoC'])
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
self.f.usr['DeltaPosZ'])
self.f.events.usr['DeltaPosZ'])
def test_attributes(self):
@unittest.skip
def test_keys_nested(self):
self.assertListEqual(["a"], self.f.events.mc_tracks.usr.keys())
def test_attributes_flat(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
self.f.usr.CoC)
self.f.events.usr.CoC)
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
self.f.usr.DeltaPosZ)
self.f.events.usr.DeltaPosZ)
class TestNestedMapper(unittest.TestCase):
def test_nested_mapper(self):
self.assertEqual('pos_x', _nested_mapper("trks.pos.x"))
#!/usr/bin/env python3
import unittest
from km3io.tools import _to_num, cached_property
from km3io.tools import _to_num, cached_property, _unfold_indices
class TestToNum(unittest.TestCase):
def test_to_num(self):
......@@ -19,3 +20,21 @@ class TestCachedProperty(unittest.TestCase):
pass
self.assertTrue(isinstance(Test.prop, cached_property))
class TestUnfoldIndices(unittest.TestCase):
def test_unfold_indices(self):
data = range(10)
indices = [slice(2, 5), 0]
assert data[indices[0]][indices[1]] == _unfold_indices(data, indices)
indices = [slice(1, 9, 2), slice(1, 4), 2]
assert data[indices[0]][indices[1]][indices[2]] == _unfold_indices(
data, indices)
def test_unfold_indices_raises_index_error(self):
data = range(10)
indices = [slice(2, 5), 99]
with self.assertRaises(IndexError):
_unfold_indices(data, indices)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment