diff --git a/km3io/definitions/__init__.py b/km3io/definitions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..430a8773f2d943cbbfe1e8a991f33a6cbfcf7a07 --- /dev/null +++ b/km3io/definitions/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from .mc_header import data as mc_header +from .trigger import data as trigger +from .fitparameters import data as fitparameters +from .reconstruction import data as reconstruction diff --git a/km3io/definitions/mc_header.py b/km3io/definitions/mc_header.py new file mode 100644 index 0000000000000000000000000000000000000000..28e2d1fb939c75319f9941bbde562d5df3b8aa36 --- /dev/null +++ b/km3io/definitions/mc_header.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +data = { + "DAQ": "livetime", + "seed": "program level iseed", + "PM1_type_area": "type area TTS", + "PDF": "i1 i2", + "model": "interaction muon scattering numberOfEnergyBins", + "can": "zmin zmax r", + "genvol": "zmin zmax r volume numberOfEvents", + "merge": "time gain", + "coord_origin": "x y z", + "translate": "x y z", + "genhencut": "gDir Emin", + "k40": "rate time", + "norma": "primaryFlux numberOfPrimaries", + "livetime": "numberOfSeconds errorOfSeconds", + "flux": "type key file_1 file_2", + "spectrum": "alpha", + "fixedcan": "xcenter ycenter zmin zmax radius", + "start_run": "run_id", +} + +for key in "cut_primary cut_seamuon cut_in cut_nu".split(): + data[key] = "Emin Emax cosTmin cosTmax" + +for key in "generator physics simul".split(): + data[key] = "program version date time" + +for key in data.keys(): + data[key] = data[key].split() diff --git a/km3io/offline.py b/km3io/offline.py index 82f48c0eccdd9aa179dab4841def3964bd2fd052..14043bc88432082b4e7dcf3dfe741ae71e617f96 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -1,209 +1,188 @@ +from collections import namedtuple import uproot -import numpy as np import warnings -import km3io.definitions.trigger -import km3io.definitions.fitparameters -import km3io.definitions.reconstruction +from .definitions import mc_header +from .tools import Branch, BranchMapper, cached_property, _to_num, _unfold_indices MAIN_TREE_NAME = "E" +EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"] + # 110 MB based on the size of the largest basket found so far in km3net BASKET_CACHE_SIZE = 110 * 1024**2 - - -class cached_property: - """A simple cache decorator for properties.""" - def __init__(self, function): - self.function = function - - def __get__(self, obj, cls): - if obj is None: - return self - prop = obj.__dict__[self.function.__name__] = self.function(obj) - return prop - - -class OfflineKeys: - """wrapper for offline keys""" - def __init__(self, tree): - """OfflineKeys is a class that reads all the available keys in an offline - file and adapts the keys format to Python format. - - Parameters - ---------- - tree : uproot.TTree - The main ROOT tree. - """ - self._tree = tree - - def __str__(self): - return '\n'.join([ - "Events keys are:\n\t" + "\n\t".join(self.events_keys), - "Hits keys are:\n\t" + '\n\t'.join(self.hits_keys), - "Tracks keys are:\n\t" + '\n\t'.join(self.tracks_keys), - "Mc hits keys are:\n\t" + '\n\t'.join(self.mc_hits_keys), - "Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys) - ]) - - def __repr__(self): - return "<{}>".format(self.__class__.__name__) - - def _get_keys(self, tree, fake_branches=None): - """Get tree keys except those in fake_branches - - Parameters - ---------- - tree : uproot.Tree - The tree to look for keys - fake_branches : list of str or None - The fake branches to ignore - - Returns - ------- - list of str - The keys of the tree. - """ - keys = [] - for key in tree.keys(): - key = key.decode('utf-8') - if fake_branches is not None and key in fake_branches: - continue - keys.append(key) - return keys - +BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) + + +def _nested_mapper(key): + """Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)""" + return '_'.join(key.split('.')[1:]) + + +EVENTS_MAP = BranchMapper(name="events", + key="Evt", + extra={ + 't_sec': 't.fSec', + 't_ns': 't.fNanoSec' + }, + exclude=EXCLUDE_KEYS, + update={ + 'n_hits': 'hits', + 'n_mc_hits': 'mc_hits', + 'n_tracks': 'trks', + 'n_mc_tracks': 'mc_trks' + }, + attrparser=lambda a: a, + flat=True) + +SUBBRANCH_MAPS = [ + BranchMapper(name="tracks", + key="trks", + extra={}, + exclude=EXCLUDE_KEYS + + ['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'], + update={}, + attrparser=_nested_mapper, + flat=False), + BranchMapper(name="mc_tracks", + key="mc_trks", + extra={}, + exclude=EXCLUDE_KEYS + [ + 'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.rec_stages', + 'mc_trks.fitinf', 'mc_trks.fUniqueID', 'mc_trks.fBits' + ], + update={}, + attrparser=_nested_mapper, + flat=False), + BranchMapper(name="hits", + key="hits", + extra={}, + exclude=EXCLUDE_KEYS + [ + 'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a', + 'hits.pure_a', 'hits.fUniqueID', 'hits.fBits' + ], + update={}, + attrparser=_nested_mapper, + flat=False), + BranchMapper(name="mc_hits", + key="mc_hits", + extra={}, + exclude=EXCLUDE_KEYS + [ + 'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id', + 'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig', + 'mc_hits.fUniqueID', 'mc_hits.fBits' + ], + update={}, + attrparser=_nested_mapper, + flat=False), +] + + +class OfflineBranch(Branch): @cached_property - def events_keys(self): - """reads events keys from an offline file. - - Returns - ------- - list of str - list of all events keys found in an offline file, - except those found in fake branches. - """ - fake_branches = ['Evt', 'AAObject', 'TObject', 't'] - t_baskets = ['t.fSec', 't.fNanoSec'] - tree = self._tree['Evt'] - return self._get_keys(self._tree['Evt'], fake_branches) + t_baskets + def usr(self): + return Usr(self._mapper, self._branch, index_chain=self._index_chain) - @cached_property - def hits_keys(self): - """reads hits keys from an offline file. - - Returns - ------- - list of str - list of all hits keys found in an offline file, - except those found in fake branches. - """ - fake_branches = ['hits.usr', 'hits.usr_names'] - return self._get_keys(self._tree['hits'], fake_branches) - @cached_property - def tracks_keys(self): - """reads tracks keys from an offline file. - - Returns - ------- - list of str - list of all tracks keys found in an offline file, - except those found in fake branches. - """ - # a solution can be tree['trks.usr_data'].array( - # uproot.asdtype(">i4")) - fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names'] - return self._get_keys(self._tree['Evt']['trks'], fake_branches) +class Usr: + """Helper class to access AAObject `usr` stuff""" + def __init__(self, mapper, branch, index_chain=None): + self._mapper = mapper + self._name = mapper.name + self._index_chain = [] if index_chain is None else index_chain + self._branch = branch + self._usr_names = [] + self._usr_idx_lookup = {} - @cached_property - def mc_hits_keys(self): - """reads mc hits keys from an offline file. - - Returns - ------- - list of str - list of all mc hits keys found in an offline file, - except those found in fake branches. - """ - fake_branches = ['mc_hits.usr', 'mc_hits.usr_names'] - return self._get_keys(self._tree['Evt']['mc_hits'], fake_branches) + self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr' - @cached_property - def mc_tracks_keys(self): - """reads mc tracks keys from an offline file. - - Returns - ------- - list of str - list of all mc tracks keys found in an offline file, - except those found in fake branches. - """ - fake_branches = [ - 'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names' + self._initialise() + + def _initialise(self): + try: + self._branch[self._usr_key] + # This will raise a KeyError in old aanet files + # which has a different strucuter and key (usr_data) + # We do not support those (yet) + except (KeyError, IndexError): + print("The `usr` fields could not be parsed for the '{}' branch.". + format(self._name)) + return + + if self._mapper.flat: + self._initialise_flat() + + def _initialise_flat(self): + # Here, we assume that every event has the same names in the same order + # to massively increase the performance. This needs triple check if + # it's always the case. + self._usr_names = [ + n.decode("utf-8") + for n in self._branch[self._usr_key + '_names'].lazyarray()[0] ] - return self._get_keys(self._tree['Evt']['mc_trks'], fake_branches) + self._usr_idx_lookup = { + name: index + for index, name in enumerate(self._usr_names) + } - @cached_property - def valid_keys(self): - """constructs a list of all valid keys to be read from an offline event file. - Returns - ------- - list of str - list of all valid keys. - """ - return (self.events_keys + self.hits_keys + self.tracks_keys + - self.mc_tracks_keys + self.mc_hits_keys) + data = self._branch[self._usr_key].lazyarray() - @cached_property - def fit_keys(self): - """constructs a list of fit parameters, not yet outsourced in an offline file. + if self._index_chain: + data = _unfold_indices(data, self._index_chain) - Returns - ------- - list of str - list of all "trks.fitinf" keys. - """ - return sorted(km3io.definitions.fitparameters.data, - key=km3io.definitions.fitparameters.data.get, - reverse=False) + self._usr_data = data - @cached_property - def cut_hits_keys(self): - """adapts hits keys for instance variables format in a Python class. + for name in self._usr_names: + setattr(self, name, self[name]) - Returns - ------- - list of str - list of adapted hits keys. - """ - return [k.split('hits.')[1].replace('.', '_') for k in self.hits_keys] + # def _initialise_nested(self): + # self._usr_names = [ + # n.decode("utf-8") for n in self.branch['usr_names'].lazyarray( + # # TODO this will be fixed soon in uproot, + # # see https://github.com/scikit-hep/uproot/issues/465 + # uproot.asgenobj( + # uproot.SimpleArray(uproot.STLVector(uproot.STLString())), + # self.branch['usr_names']._context, 6), + # basketcache=BASKET_CACHE)[0] + # ] - @cached_property - def cut_tracks_keys(self): - """adapts tracks keys for instance variables format in a Python class. + def __getitem__(self, item): + if self._mapper.flat: + return self.__getitem_flat__(item) + return self.__getitem_nested__(item) + + def __getitem_flat__(self, item): + if self._index_chain: + return _unfold_indices( + self._usr_data, self._index_chain)[:, + self._usr_idx_lookup[item]] + else: + return self._usr_data[:, self._usr_idx_lookup[item]] + + def __getitem_nested__(self, item): + data = self._branch[self._usr_key + '_names'].lazyarray( + # TODO this will be fixed soon in uproot, + # see https://github.com/scikit-hep/uproot/issues/465 + uproot.asgenobj( + uproot.SimpleArray(uproot.STLVector(uproot.STLString())), + self._branch[self._usr_key + '_names']._context, 6), + basketcache=BASKET_CACHE) + return _unfold_indices(data, self._index_chain) - Returns - ------- - list of str - list of adapted tracks keys. - """ - return [ - k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys - ] + def keys(self): + return self._usr_names - @cached_property - def cut_events_keys(self): - """adapts events keys for instance variables format in a Python class. + def __str__(self): + entries = [] + for name in self.keys(): + entries.append("{}: {}".format(name, self[name])) + return '\n'.join(entries) - Returns - ------- - list of str - list of adapted events keys. - """ - return [k.replace('.', '_') for k in self.events_keys] + def __repr__(self): + return "<{}[{}]>".format(self.__class__.__name__, self._name) class OfflineReader: """reader for offline ROOT files""" - def __init__(self, file_path=None, fobj=None, data=None): + def __init__(self, file_path=None): """ OfflineReader class is an offline ROOT file wrapper Parameters @@ -213,681 +192,67 @@ class OfflineReader: path-like object that points to the file. """ - if file_path is not None: - self._fobj = uproot.open(file_path) - self._tree = self._fobj[MAIN_TREE_NAME] - self._data = self._tree.lazyarrays( - basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE)) - else: - self._fobj = fobj - self._tree = self._fobj[MAIN_TREE_NAME] - self._data = data + self._fobj = uproot.open(file_path) + self._tree = self._fobj[MAIN_TREE_NAME] - @classmethod - def from_index(cls, source, index): - """Create an instance with a subtree of a given index - - Parameters - ---------- - source: ROOTDirectory - The source file. - index: index or slice - The index or slice to create the subtree. - """ - instance = cls(fobj=source._fobj, data=source._data[index]) - return instance - - def __getitem__(self, index): - return OfflineReader.from_index(source=self, index=index) - - def __len__(self): - return len(self._data) + @cached_property + def events(self): + """The `E` branch, containing all offline events.""" + return OfflineBranch(self._tree, + mapper=EVENTS_MAP, + subbranchmaps=SUBBRANCH_MAPS) @cached_property def header(self): + """The file header""" if 'Head' in self._fobj: header = {} for n, x in self._fobj['Head']._map_3c_string_2c_string_3e_.items( ): header[n.decode("utf-8")] = x.decode("utf-8").strip() - return header + return Header(header) else: warnings.warn("Your file header has an unsupported format") - @cached_property - def keys(self): - """wrapper for all keys in an offline file. - - Returns - ------- - Class - OfflineKeys. - """ - return OfflineKeys(self._tree) - - @cached_property - def events(self): - """wrapper for offline events. - - Returns - ------- - Class - OfflineEvents. - """ - return OfflineEvents( - self.keys.cut_events_keys, - [self._data[key] for key in self.keys.events_keys]) - - @cached_property - def hits(self): - """wrapper for offline hits. - - Returns - ------- - Class - OfflineHits. - """ - return OfflineHits(self.keys.cut_hits_keys, - [self._data[key] for key in self.keys.hits_keys]) - - @cached_property - def tracks(self): - """wrapper for offline tracks. - - Returns - ------- - Class - OfflineTracks. - """ - return OfflineTracks( - self.keys.cut_tracks_keys, - [self._data[key] for key in self.keys.tracks_keys]) - - @cached_property - def mc_hits(self): - """wrapper for offline mc hits. - - Returns - ------- - Class - OfflineHits. - """ - return OfflineHits(self.keys.cut_hits_keys, - [self._data[key] for key in self.keys.mc_hits_keys]) - - @cached_property - def mc_tracks(self): - """wrapper for offline mc tracks. - - Returns - ------- - Class - OfflineTracks. - """ - return OfflineTracks( - self.keys.cut_tracks_keys, - [self._data[key] for key in self.keys.mc_tracks_keys]) - - @cached_property - def usr(self): - return Usr(self._tree) - - def get_best_reco(self): - """returns the best reconstructed track fit data. The best fit is defined - as the track fit with the maximum reconstruction stages. When "nan" is - returned, it means that the reconstruction parameter of interest is not - found. for example, in the case of muon simulations: if [1, 2] are the - reconstruction stages, then only the fit parameters corresponding to the - stages [1, 2] are found in the Offline files, the remaining fit parameters - corresponding to the stages 3, 4, 5 are all filled with nan. - - Returns - ------- - numpy recarray - a recarray of the best track fit data (reconstruction data). - """ - keys = ", ".join(self.keys.fit_keys[:-1]) - empty_fit_info = np.array( - [match for match in self._find_empty(self.tracks.fitinf)]) - fit_info = [ - i for i, j in zip(self.tracks.fitinf, empty_fit_info[:, 1]) - if j is not None - ] - stages = self._get_max_reco_stages(self.tracks.rec_stages) - fit_data = np.array([i[j] for i, j in zip(fit_info, stages[:, 2])]) - rows_size = len(max(fit_data, key=len)) - equal_size_data = np.vstack([ - np.hstack([i, np.zeros(rows_size - len(i)) + np.nan]) - for i in fit_data - ]) - return np.core.records.fromarrays(equal_size_data.transpose(), - names=keys) - - def _get_max_reco_stages(self, reco_stages): - """find the longest reconstructed track based on the maximum size of - reconstructed stages. - - Parameters - ---------- - reco_stages : chunked array - chunked array of all the reconstruction stages of all tracks. - In km3io, it is accessed with - km3io.OfflineReader(my_file).tracks.rec_stages . - - Returns - ------- - numpy array - array with 3 columns: *list of the maximum reco_stages - *lentgh of the maximum reco_stages - *position of the maximum reco_stages - """ - empty_reco_stages = np.array( - [match for match in self._find_empty(reco_stages)]) - max_reco_stages = np.array( - [[max(i, key=len), - len(max(i, key=len)), - i.index(max(i, key=len))] - for i, j in zip(reco_stages, empty_reco_stages[:, 1]) - if j is not None]) - return max_reco_stages - - def get_reco_fit(self, stages, mc=False): - """construct a numpy recarray of the fit information (reconstruction - data) of the tracks reconstructed following the reconstruction stages - of interest. - Parameters - ---------- - stages : list - list of reconstruction stages of interest. for example - [1, 2, 3, 4, 5]. - mc : bool, optional - default is False to look for fit data in the tracks tree in offline files - (not the mc tracks tree). mc=True to look for fit data from the mc tracks - tree in offline files. - - Returns - ------- - numpy recarray - a recarray of the fit information (reconstruction data) of - the tracks of interest. - - Raises - ------ - ValueError - ValueError raised when the reconstruction stages of interest - are not found in the file. - """ - keys = ", ".join(self.keys.fit_keys[:-1]) - - if mc is False: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=False)]) - fitinf = self.tracks.fitinf - - if mc is True: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=True)]) - fitinf = self.mc_tracks.fitinf - - mask = rec_stages[:, 1] != None - - if np.all(rec_stages[:, 1] == None): - raise ValueError( - "The stages {} are not found in your file.".format( - str(stages))) - else: - fit_data = np.array( - [i[k] for i, k in zip(fitinf[mask], rec_stages[:, 1][mask])]) - rec_array = np.core.records.fromarrays(fit_data.transpose(), - names=keys) - return rec_array - - def get_reco_hits(self, stages, keys, mc=False): - """construct a dictionary of hits class data based on the reconstruction - stages of interest. For example, if the reconstruction stages of interest - are [1, 2, 3, 4, 5], then get_reco_hits method will select the hits data - from the events that were reconstructed following these stages (i.e - [1, 2, 3, 4, 5]). +class Header: + """The header""" + def __init__(self, header): + self._data = {} - Parameters - ---------- - stages : list - list of reconstruction stages of interest. for example - [1, 2, 3, 4, 5]. - keys : list of str - list of the hits class attributes. - mc : bool, optional - default is False to look for hits data in the hits tree in offline files - (not the mc_hits tree). mc=True to look for mc hits data in the mc hits - tree in offline files. - - Returns - ------- - dict - dictionary of lazyarrays containing data for each hits attribute requested. - - Raises - ------ - ValueError - ValueError raised when the reconstruction stages of interest - are not found in the file. - """ - lazy_d = {} + for attribute, fields in header.items(): + values = fields.split() + fields = mc_header.get(attribute, []) - if mc is False: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=False)]) - hits_data = self.hits + n_values = len(values) + n_fields = len(fields) - if mc is True: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=True)]) - hits_data = self.mc_hits - - mask = rec_stages[:, 1] != None - - if np.all(rec_stages[:, 1] == None): - raise ValueError( - "The stages {} are not found in your file.".format( - str(stages))) - else: - for key in keys: - lazy_d[key] = getattr(hits_data, key)[mask] - return lazy_d - - def get_reco_events(self, stages, keys, mc=False): - """construct a dictionary of events class data based on the reconstruction - stages of interest. For example, if the reconstruction stages of interest - are [1, 2, 3, 4, 5], then get_reco_events method will select the events data - that were reconstructed following these stages (i.e [1, 2, 3, 4, 5]). - - Parameters - ---------- - stages : list - list of reconstruction stages of interest. for example - [1, 2, 3, 4, 5]. - keys : list of str - list of the events class attributes. - mc : bool, optional - default is False to look for the reconstruction stages in the tracks tree - in offline files (not the mc tracks tree). mc=True to look for the reconstruction - data in the mc tracks tree in offline files. - - Returns - ------- - dict - dictionary of lazyarrays containing data for each events attribute requested. - - Raises - ------ - ValueError - ValueError raised when the reconstruction stages of interest - are not found in the file. - """ - lazy_d = {} - - if mc is False: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=False)]) - - if mc is True: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=True)]) - - mask = rec_stages[:, 1] != None - - if np.all(rec_stages[:, 1] == None): - raise ValueError( - "The stages {} are not found in your file.".format( - str(stages))) - else: - for key in keys: - lazy_d[key] = getattr(self.events, key)[mask] - return lazy_d - - def get_reco_tracks(self, stages, keys, mc=False): - """construct a dictionary of tracks class data based on the reconstruction - stages of interest. For example, if the reconstruction stages of interest - are [1, 2, 3, 4, 5], then get_reco_tracks method will select tracks data - from the events that were reconstructed following these stages (i.e - [1, 2, 3, 4, 5]). - - Parameters - ---------- - stages : list - list of reconstruction stages of interest. for example - [1, 2, 3, 4, 5]. - keys : list of str - list of the tracks class attributes. - mc : bool, optional - default is False to look for tracks data in the tracks tree in offline files - (not the mc tracks tree). mc=True to look for tracks data in the mc tracks - tree in offline files. - - Returns - ------- - dict - dictionary of lazyarrays containing data for each tracks attribute requested. - - Raises - ------ - ValueError - ValueError raised when the reconstruction stages of interest - are not found in the file. - """ - lazy_d = {} - - if mc is False: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=False)]) - tracks_data = self.tracks - - if mc is True: - rec_stages = np.array( - [match for match in self._find_rec_stages(stages, mc=True)]) - tracks_data = self.mc_tracks - - mask = rec_stages[:, 1] != None - - if np.all(rec_stages[:, 1] == None): - raise ValueError( - "The stages {} are not found in your file.".format( - str(stages))) - else: - for key in keys: - lazy_d[key] = np.array([ - i[k] for i, k in zip( - getattr(tracks_data, key)[mask], rec_stages[:, - 1][mask]) - ]) - - return lazy_d - - def _find_rec_stages(self, stages, mc=False): - """find the index of reconstruction stages of interest in a - list of multiple reconstruction stages. - - Parameters - ---------- - stages : list - list of reconstruction stages of interest. for example - [1, 2, 3, 4, 5]. - mc : bool, optional - default is False to look for reconstruction stages in the tracks tree in - offline files (not the mc tracks tree). mc=True to look for reconstruction - stages in the mc tracks tree in offline files. - Yields - ------ - generator - the track id and the index of the reconstruction stages of - interest if found. If the reconstruction stages of interest - are not found, None is returned as the stages index. - """ - if mc is False: - stages_data = self.tracks.rec_stages - - if mc is True: - stages_data = self.mc_tracks.rec_stages - - for trk_index, rec_stages in enumerate(stages_data): - try: - stages_index = rec_stages.index(stages) - except ValueError: - stages_index = None - yield trk_index, stages_index + if n_values == 1 and n_fields == 0: + self._data[attribute] = _to_num(values[0]) continue - yield trk_index, stages_index + n_max = max(n_values, n_fields) + values += [None] * (n_max - n_values) + fields += ["field_{}".format(i) for i in range(n_fields, n_max)] - def _find_empty(self, array): - """finds empty lists/arrays in an awkward array + Constructor = namedtuple(attribute, fields) - Parameters - ---------- - array : awkward array - Awkward array of data of interest. For example: - km3io.OfflineReader(my_file).tracks.fitinf . - - Yields - ------ - generator - the empty list id and the index of the empty list. When - data structure (list) is simply empty, None is written in the - corresponding index. However, when data structure (list) is not - empty and does not contain an empty list, then False is written in the - corresponding index. - """ - for i, rs in enumerate(array): - try: - if len(rs) == 0: - j = None - if len(rs) != 0: - j = rs.index([]) - except ValueError: - j = False # rs not empty but [] not found - yield i, j + if not values: continue - yield i, j - - -class Usr: - """Helper class to access AAObject usr stuff""" - def __init__(self, tree): - # Here, we assume that every event has the same names in the same order - # to massively increase the performance. This needs triple check if it's - # always the case; the usr-format is simply a very bad design. - try: - self._usr_names = [ - n.decode("utf-8") for n in tree['Evt']['usr_names'].array()[0] - ] - except (KeyError, IndexError): # e.g. old aanet files - self._usr_names = [] - else: - self._usr_idx_lookup = { - name: index - for index, name in enumerate(self._usr_names) - } - self._usr_data = tree['Evt']['usr'].lazyarray( - basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE)) - for name in self._usr_names: - setattr(self, name, self[name]) - - def __getitem__(self, item): - return self._usr_data[:, self._usr_idx_lookup[item]] - - def keys(self): - return self._usr_names - - def __str__(self): - entries = [] - for name in self.keys(): - entries.append("{}: {}".format(name, self[name])) - return '\n'.join(entries) - - -class OfflineEvents: - """wrapper for offline events""" - def __init__(self, keys, values): - """wrapper for offline events. - - Parameters - ---------- - keys : list of str - list of valid events keys. - values : list of arrays - list of arrays containting events data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __getitem__(self, item): - return OfflineEvent(self._keys, [v[item] for v in self._values]) - - def __len__(self): - try: - return len(self._values[0]) - except IndexError: - return 0 - def __str__(self): - return "Number of events: {}".format(len(self)) - - def __repr__(self): - return "<{}: {} parsed events>".format(self.__class__.__name__, - len(self)) - - -class OfflineEvent: - """wrapper for an offline event""" - def __init__(self, keys, values): - """wrapper for one offline event. - - Parameters - ---------- - keys : list of str - list of valid events keys. - values : list of arrays - list of arrays containting event data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __str__(self): - return "offline event:\n\t" + "\n\t".join([ - "{:15} {:^10} {:>10}".format(k, ':', str(v)) - for k, v in zip(self._keys, self._values) - ]) - - -class OfflineHits: - """wrapper for offline hits""" - def __init__(self, keys, values): - """wrapper for offline hits. - - Parameters - ---------- - keys : list of str - list of cropped hits keys. - values : list of arrays - list of arrays containting hits data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __getitem__(self, item): - return OfflineHit(self._keys, [v[item] for v in self._values]) - - def __len__(self): - try: - return len(self._values[0]) - except IndexError: - return 0 - - def __str__(self): - return "Number of hits: {}".format(len(self)) - - def __repr__(self): - return "<{}: {} parsed elements>".format(self.__class__.__name__, - len(self)) - - -class OfflineHit: - """wrapper for an offline hit""" - def __init__(self, keys, values): - """wrapper for one offline hit. - - Parameters - ---------- - keys : list of str - list of cropped hits keys. - values : list of arrays - list of arrays containting hit data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __str__(self): - return "offline hit:\n\t" + "\n\t".join([ - "{:15} {:^10} {:>10}".format(k, ':', str(v)) - for k, v in zip(self._keys, self._values) - ]) - - def __getitem__(self, item): - return self._values[item] - - -class OfflineTracks: - """wrapper for offline tracks""" - def __init__(self, keys, values): - """wrapper for offline tracks - - Parameters - ---------- - keys : list of str - list of cropped tracks keys. - values : list of arrays - list of arrays containting tracks data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) - - def __getitem__(self, item): - return OfflineTrack(self._keys, [v[item] for v in self._values]) - - def __len__(self): - try: - return len(self._values[0]) - except IndexError: - return 0 - - def __str__(self): - return "Number of tracks: {}".format(len(self)) - - def __repr__(self): - return "<{}: {} parsed elements>".format(self.__class__.__name__, - len(self)) + self._data[attribute] = Constructor( + **{f: _to_num(v) + for (f, v) in zip(fields, values)}) - -class OfflineTrack: - """wrapper for an offline track""" - def __init__(self, keys, values): - """wrapper for one offline track. - - Parameters - ---------- - keys : list of str - list of cropped tracks keys. - values : list of arrays - list of arrays containting track data. - """ - self._keys = keys - self._values = values - for k, v in zip(self._keys, self._values): - setattr(self, k, v) + for attribute, value in self._data.items(): + setattr(self, attribute, value) def __str__(self): - return "offline track:\n\t" + "\n\t".join([ - "{:30} {:^2} {:>26}".format(k, ':', str(v)) - for k, v in zip(self._keys, self._values) if k not in ['fitinf'] - ]) + "\n\t" + "\n\t".join([ - "{:30} {:^2} {:>26}".format(k, ':', str( - getattr(self, 'fitinf')[v])) - for k, v in km3io.definitions.fitparameters.data.items() - if len(getattr(self, 'fitinf')) > v - ]) # I don't like 18 being explicit here - - def __getitem__(self, item): - return self._values[item] + lines = ["MC Header:"] + keys = set(mc_header.keys()) + for key, value in self._data.items(): + if key in keys: + lines.append(" {}".format(value)) + else: + lines.append(" {}: {}".format(key, value)) + return "\n".join(lines) diff --git a/km3io/tools.py b/km3io/tools.py index 8def8ee42a080fb0b7669d6b45ce4f1e75a23fa4..e443357aa1ec72278b7e44bc9f14ccc4175d0728 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -19,6 +19,21 @@ class cached_property: return prop +def _unfold_indices(obj, indices): + """Unfolds an index chain and returns the corresponding item""" + original_obj = obj + for depth, idx in enumerate(indices): + try: + obj = obj[idx] + except IndexError: + print( + "IndexError while accessing an item from '{}' at depth {} ({}) " + "using the index chain {}".format(repr(original_obj), depth, + idx, indices)) + raise + return obj + + BranchMapper = namedtuple( "BranchMapper", ['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat']) @@ -29,17 +44,19 @@ class Branch: def __init__(self, tree, mapper, - index=None, + index_chain=None, subbranchmaps=None, keymap=None): self._tree = tree self._mapper = mapper - self._index = index + self._index_chain = [] if index_chain is None else index_chain self._keymap = None self._branch = tree[mapper.key] self._subbranches = [] self._subbranchmaps = subbranchmaps + self._iterator_index = 0 + if keymap is None: self._initialise_keys() # else: @@ -49,7 +66,7 @@ class Branch: for mapper in subbranchmaps: subbranch = self.__class__(self._tree, mapper=mapper, - index=self._index) + index_chain=self._index_chain) self._subbranches.append(subbranch) for subbranch in self._subbranches: setattr(self, subbranch._mapper.name, subbranch) @@ -57,8 +74,8 @@ class Branch: def _initialise_keys(self): """Create the keymap and instance attributes for branch keys""" # TODO: this could be a cached property - keys = set(k.decode('utf-8') for k in self._branch.keys()) - set( - self._mapper.exclude) + keys = set(k.decode('utf-8') + for k in self._branch.keys()) - set(self._mapper.exclude) self._keymap = { **{self._mapper.attrparser(k): k for k in keys}, @@ -86,42 +103,46 @@ class Branch: def __getkey__(self, key): out = self._branch[self._keymap[key]].lazyarray( basketcache=BASKET_CACHE) - if self._index is not None: - out = out[self._index] - return out + return _unfold_indices(out, self._index_chain) def __getitem__(self, item): """Slicing magic""" - if isinstance(item, (int, slice)): - return self.__class__(self._tree, - self._mapper, - index=item, - keymap=self._keymap, - subbranchmaps=self._subbranchmaps) - - if isinstance(item, tuple): - return self[item[0]][item[1]] - if isinstance(item, str): return self.__getkey__(item) return self.__class__(self._tree, self._mapper, - index=np.array(item), + index_chain=self._index_chain + [item], keymap=self._keymap, subbranchmaps=self._subbranchmaps) def __len__(self): - if self._index is None: + if not self._index_chain: return len(self._branch) - elif isinstance(self._index, int): + elif isinstance(self._index_chain[-1], int): return 1 else: - return len(self._branch[self._keymap['id']].lazyarray( - basketcache=BASKET_CACHE)[self._index]) + return len( + _unfold_indices( + self._branch[self._keymap['id']].lazyarray( + basketcache=BASKET_CACHE), self._index_chain)) + + def __iter__(self): + self._iterator_index = 0 + return self + + def __next__(self): + idx = self._iterator_index + self._iterator_index += 1 + if idx >= len(self): + raise StopIteration + return self[idx] def __str__(self): - return "Number of elements: {}".format(len(self._branch)) + length = len(self) + return "{} ({}) with {} element{}".format(self.__class__.__name__, + self._mapper.name, length, + 's' if length > 1 else '') def __repr__(self): length = len(self) diff --git a/requirements.txt b/requirements.txt index 96ec8f79ccd2c06a830005d59baaf84d601f67f1..d11bcbe8cfbda5719cfee236238e2060d616d81b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ docopt numba +awkward1 uproot>=3.11.1 setuptools_scm diff --git a/tests/samples/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root b/tests/samples/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root new file mode 100644 index 0000000000000000000000000000000000000000..e20ee07444cb94ba1140e7e3dff202ed52761d1f Binary files /dev/null and b/tests/samples/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root differ diff --git a/tests/test_daq.py b/tests/test_daq.py index cff1fbfd5009b8080cac2c8ff8721800dab7e30c..7482e09d51ec170073a6738d0314ca19e4d05986 100644 --- a/tests/test_daq.py +++ b/tests/test_daq.py @@ -5,12 +5,12 @@ import unittest from km3io.daq import DAQReader, get_rate, has_udp_trailer, get_udp_max_sequence_number, get_channel_flags, get_number_udp_packets SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples") -DAQ_FILE = DAQReader(os.path.join(SAMPLES_DIR, "daq_v1.0.0.root")) class TestDAQEvents(unittest.TestCase): def setUp(self): - self.events = DAQ_FILE.events + self.events = DAQReader(os.path.join(SAMPLES_DIR, + "daq_v1.0.0.root")).events def test_index_lookup(self): assert 3 == len(self.events) @@ -24,7 +24,8 @@ class TestDAQEvents(unittest.TestCase): class TestDAQEvent(unittest.TestCase): def setUp(self): - self.event = DAQ_FILE.events[0] + self.event = DAQReader(os.path.join(SAMPLES_DIR, + "daq_v1.0.0.root")).events[0] def test_str(self): assert re.match(".*event.*96.*snapshot.*18.*triggered", @@ -37,7 +38,8 @@ class TestDAQEvent(unittest.TestCase): class TestDAQEventsSnapshotHits(unittest.TestCase): def setUp(self): - self.events = DAQ_FILE.events + self.events = DAQReader(os.path.join(SAMPLES_DIR, + "daq_v1.0.0.root")).events self.lengths = {0: 96, 1: 124, -1: 78} self.total_item_count = 298 @@ -75,7 +77,8 @@ class TestDAQEventsSnapshotHits(unittest.TestCase): class TestDAQEventsTriggeredHits(unittest.TestCase): def setUp(self): - self.events = DAQ_FILE.events + self.events = DAQReader(os.path.join(SAMPLES_DIR, + "daq_v1.0.0.root")).events self.lengths = {0: 18, 1: 53, -1: 9} self.total_item_count = 80 @@ -115,7 +118,8 @@ class TestDAQEventsTriggeredHits(unittest.TestCase): class TestDAQTimeslices(unittest.TestCase): def setUp(self): - self.ts = DAQ_FILE.timeslices + self.ts = DAQReader(os.path.join(SAMPLES_DIR, + "daq_v1.0.0.root")).timeslices def test_data_lengths(self): assert 3 == len(self.ts._timeslices["L1"][0]) @@ -140,7 +144,8 @@ class TestDAQTimeslices(unittest.TestCase): class TestDAQTimeslice(unittest.TestCase): def setUp(self): - self.ts = DAQ_FILE.timeslices + self.ts = DAQReader(os.path.join(SAMPLES_DIR, + "daq_v1.0.0.root")).timeslices self.n_frames = {"L1": [69, 69, 69], "SN": [64, 66, 68]} def test_str(self): @@ -153,7 +158,8 @@ class TestDAQTimeslice(unittest.TestCase): class TestSummaryslices(unittest.TestCase): def setUp(self): - self.ss = DAQ_FILE.summaryslices + self.ss = DAQReader(os.path.join(SAMPLES_DIR, + "daq_v1.0.0.root")).summaryslices def test_headers(self): assert 3 == len(self.ss.headers) diff --git a/tests/test_offline.py b/tests/test_offline.py index 60ae196939242ee9b3a82f84f35733147ba53910..2dd1ee7fe2f4e0a07fb30ae94919ea4fcb53e8e7 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -2,389 +2,374 @@ import unittest import numpy as np from pathlib import Path -from km3io.offline import OfflineEvents, OfflineHits, OfflineTracks from km3io import OfflineReader +from km3io.offline import _nested_mapper, Header SAMPLES_DIR = Path(__file__).parent / 'samples' -OFFLINE_FILE = SAMPLES_DIR / 'aanet_v2.0.0.root' -OFFLINE_USR = SAMPLES_DIR / 'usr-sample.root' -OFFLINE_NUMUCC = SAMPLES_DIR / "numucc.root" # with mc data - - -class TestOfflineKeys(unittest.TestCase): - def setUp(self): - self.keys = OfflineReader(OFFLINE_FILE).keys - - def test_events_keys(self): - # there are 22 "valid" events keys - self.assertEqual(len(self.keys.events_keys), 22) - self.assertEqual(len(self.keys.cut_events_keys), 22) - - def test_hits_keys(self): - # there are 20 "valid" hits keys - self.assertEqual(len(self.keys.hits_keys), 20) - self.assertEqual(len(self.keys.mc_hits_keys), 20) - self.assertEqual(len(self.keys.cut_hits_keys), 20) - - def test_tracks_keys(self): - # there are 22 "valid" tracks keys - self.assertEqual(len(self.keys.tracks_keys), 22) - self.assertEqual(len(self.keys.mc_tracks_keys), 22) - self.assertEqual(len(self.keys.cut_tracks_keys), 22) - - def test_valid_keys(self): - # there are 106 valid keys: 22*2 + 22 + 20*2 - # (fit keys are excluded) - self.assertEqual(len(self.keys.valid_keys), 106) - - def test_fit_keys(self): - # there are 18 fit keys - self.assertEqual(len(self.keys.fit_keys), 18) +OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root') +OFFLINE_USR = OfflineReader(SAMPLES_DIR / 'usr-sample.root') +OFFLINE_NUMUCC = OfflineReader(SAMPLES_DIR / "numucc.root") # with mc data class TestOfflineReader(unittest.TestCase): def setUp(self): - self.r = OfflineReader(OFFLINE_FILE) - self.nu = OfflineReader(OFFLINE_NUMUCC) - self.Nevents = 10 + self.r = OFFLINE_FILE + self.nu = OFFLINE_NUMUCC + self.n_events = 10 def test_number_events(self): - Nevents = len(self.r) - - # check that there are 10 events - self.assertEqual(Nevents, self.Nevents) - - def test_find_empty(self): - fitinf = self.nu.tracks.fitinf - rec_stages = self.nu.tracks.rec_stages - - empty_fitinf = np.array( - [match for match in self.nu._find_empty(fitinf)]) - empty_stages = np.array( - [match for match in self.nu._find_empty(rec_stages)]) - - self.assertListEqual(empty_fitinf[:5, 1].tolist(), - [23, 14, 14, 4, None]) - self.assertListEqual(empty_stages[:5, 1].tolist(), - [False, False, False, False, None]) - - def test_find_rec_stages(self): - stages = np.array( - [match for match in self.nu._find_rec_stages([1, 2, 3, 4, 5])]) - - self.assertListEqual(stages[:5, 1].tolist(), [0, 0, 0, 0, None]) - - def test_get_reco_fit(self): - JGANDALF_BETA0_RAD = [ - 0.0020367251782607574, 0.003306725805622178, 0.0057877124222254885, - 0.015581698352185896 - ] - reco_fit = self.nu.get_reco_fit([1, 2, 3, 4, 5])['JGANDALF_BETA0_RAD'] - - self.assertListEqual(JGANDALF_BETA0_RAD, reco_fit[:4].tolist()) - with self.assertRaises(ValueError): - self.nu.get_reco_fit([1000, 4512, 5625], mc=True) - - def test_get_reco_hits(self): - - doms = self.nu.get_reco_hits([1, 2, 3, 4, 5], ["dom_id"])["dom_id"] - - mc_doms = self.nu.get_reco_hits([], ["dom_id"], mc=True)["dom_id"] - - self.assertEqual(doms.size, 9) - self.assertEqual(mc_doms.size, 10) - - self.assertListEqual(doms[0][0:4].tolist(), - self.nu.hits[0].dom_id[0:4].tolist()) - self.assertListEqual(mc_doms[0][0:4].tolist(), - self.nu.mc_hits[0].dom_id[0:4].tolist()) - - with self.assertRaises(ValueError): - self.nu.get_reco_hits([1000, 4512, 5625], ["dom_id"]) - - def test_get_reco_tracks(self): + assert self.n_events == len(self.r.events) - pos = self.nu.get_reco_tracks([1, 2, 3, 4, 5], ["pos_x"])["pos_x"] - mc_pos = self.nu.get_reco_tracks([], ["pos_x"], mc=True)["pos_x"] - self.assertEqual(pos.size, 9) - self.assertEqual(mc_pos.size, 10) - - self.assertEqual(pos[0], self.nu.tracks[0].pos_x[0]) - self.assertEqual(mc_pos[0], self.nu.mc_tracks[0].pos_x[0]) - - with self.assertRaises(ValueError): - self.nu.get_reco_tracks([1000, 4512, 5625], ["pos_x"]) - - def test_get_reco_events(self): - - hits = self.nu.get_reco_events([1, 2, 3, 4, 5], ["hits"])["hits"] - mc_hits = self.nu.get_reco_events([], ["mc_hits"], mc=True)["mc_hits"] - - self.assertEqual(hits.size, 9) - self.assertEqual(mc_hits.size, 10) - - self.assertListEqual(hits[0:4].tolist(), - self.nu.events.hits[0:4].tolist()) - self.assertListEqual(mc_hits[0:4].tolist(), - self.nu.events.mc_hits[0:4].tolist()) - - with self.assertRaises(ValueError): - self.nu.get_reco_events([1000, 4512, 5625], ["hits"]) - - def test_get_max_reco_stages(self): - rec_stages = self.nu.tracks.rec_stages - max_reco = self.nu._get_max_reco_stages(rec_stages) - - self.assertEqual(len(max_reco.tolist()), 9) - self.assertListEqual(max_reco[0].tolist(), [[1, 2, 3, 4, 5], 5, 0]) - - def test_best_reco(self): - JGANDALF_BETA1_RAD = [ - 0.0014177681261476852, 0.002094094517471032, 0.003923368624980349, - 0.009491461076780453 - ] - best = self.nu.get_best_reco() - - self.assertEqual(best.size, 9) - self.assertEqual(best['JGANDALF_BETA1_RAD'][:4].tolist(), - JGANDALF_BETA1_RAD) - - def test_reading_header(self): - # head is the supported format - head = OfflineReader(OFFLINE_NUMUCC).header - - self.assertEqual(float(head['DAQ']), 394) - self.assertEqual(float(head['kcut']), 2) +class TestHeader(unittest.TestCase): + def test_str_header(self): + assert "MC Header" in str(OFFLINE_NUMUCC.header) + def test_warning_if_unsupported_header(self): # test the warning for unsupported fheader format with self.assertWarns(UserWarning): - self.r.header + OFFLINE_FILE.header + + def test_missing_key_definitions(self): + head = {'a': '1 2 3', 'b': '4', 'c': 'd'} + + header = Header(head) + + assert 1 == header.a.field_0 + assert 2 == header.a.field_1 + assert 3 == header.a.field_2 + assert 4 == header.b + assert 'd' == header.c + + def test_missing_values(self): + head = {'can': '1'} + + header = Header(head) + + assert 1 == header.can.zmin + assert header.can.zmax is None + assert header.can.r is None + + def test_additional_values_compared_to_definition(self): + head = {'can': '1 2 3 4'} + + header = Header(head) + + assert 1 == header.can.zmin + assert 2 == header.can.zmax + assert 3 == header.can.r + assert 4 == header.can.field_3 + + def test_header(self): + head = { + 'DAQ': '394', + 'PDF': '4', + 'can': '0 1027 888.4', + 'undefined': '1 2 test 3.4' + } + + header = Header(head) + + assert 394 == header.DAQ.livetime + assert 4 == header.PDF.i1 + assert header.PDF.i2 is None + assert 0 == header.can.zmin + assert 1027 == header.can.zmax + assert 888.4 == header.can.r + assert 1 == header.undefined.field_0 + assert 2 == header.undefined.field_1 + assert "test" == header.undefined.field_2 + assert 3.4 == header.undefined.field_3 + + def test_reading_header_from_sample_file(self): + head = OFFLINE_NUMUCC.header + + assert 394 == head.DAQ.livetime + assert 4 == head.PDF.i1 + assert 58 == head.PDF.i2 + assert 0 == head.coord_origin.x + assert 0 == head.coord_origin.y + assert 0 == head.coord_origin.z + assert 100 == head.cut_nu.Emin + assert 100000000.0 == head.cut_nu.Emax + assert -1 == head.cut_nu.cosTmin + assert 1 == head.cut_nu.cosTmax + assert "diffuse" == head.sourcemode + assert 100000.0 == head.ngen class TestOfflineEvents(unittest.TestCase): def setUp(self): - self.events = OfflineReader(OFFLINE_FILE).events - self.hits = {0: 176, 1: 125, -1: 105} - self.Nevents = 10 + self.events = OFFLINE_FILE.events + self.n_events = 10 + self.det_id = [44] * self.n_events + self.n_hits = [176, 125, 318, 157, 83, 60, 71, 84, 255, 105] + self.n_tracks = [56, 55, 56, 56, 56, 56, 56, 56, 54, 56] + self.t_sec = [ + 1567036818, 1567036818, 1567036820, 1567036816, 1567036816, + 1567036816, 1567036822, 1567036818, 1567036818, 1567036820 + ] + self.t_ns = [ + 200000000, 300000000, 200000000, 500000000, 500000000, 500000000, + 200000000, 500000000, 500000000, 400000000 + ] - def test_reading_hits(self): - # test item selection - for event_id, hit in self.hits.items(): - self.assertEqual(hit, self.events.hits[event_id]) + def test_len(self): + assert self.n_events == len(self.events) - def reading_tracks(self): - self.assertListEqual(list(self.events.trks[:3]), [56, 55, 56]) + def test_attributes_available(self): + for key in self.events._keymap.keys(): + getattr(self.events, key) - def test_item_selection(self): - for event_id, hit in self.hits.items(): - self.assertEqual(hit, self.events[event_id].hits) + def test_attributes(self): + assert self.n_events == len(self.events.det_id) + self.assertListEqual(self.det_id, list(self.events.det_id)) + self.assertListEqual(self.n_hits, list(self.events.n_hits)) + self.assertListEqual(self.n_tracks, list(self.events.n_tracks)) + self.assertListEqual(self.t_sec, list(self.events.t_sec)) + self.assertListEqual(self.t_ns, list(self.events.t_ns)) - def test_len(self): - self.assertEqual(len(self.events), self.Nevents) + def test_keys(self): + assert np.allclose(self.n_hits, self.events['n_hits']) + assert np.allclose(self.n_tracks, self.events['n_tracks']) + assert np.allclose(self.t_sec, self.events['t_sec']) + assert np.allclose(self.t_ns, self.events['t_ns']) - def test_IndexError(self): - # test handling IndexError with empty lists/arrays - self.assertEqual(len(OfflineEvents(['whatever'], [])), 0) + def test_slicing(self): + s = slice(2, 8, 2) + s_events = self.events[s] + assert 3 == len(s_events) + self.assertListEqual(self.n_hits[s], list(s_events.n_hits)) + self.assertListEqual(self.n_tracks[s], list(s_events.n_tracks)) + self.assertListEqual(self.t_sec[s], list(s_events.t_sec)) + self.assertListEqual(self.t_ns[s], list(s_events.t_ns)) + + def test_slicing_consistency(self): + for s in [slice(1, 3), slice(2, 7, 3)]: + assert np.allclose(self.events[s].n_hits, self.events.n_hits[s]) + + def test_index_consistency(self): + for i in [0, 2, 5]: + assert np.allclose(self.events[i].n_hits, self.events.n_hits[i]) + + def test_index_chaining(self): + assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5]) + assert np.allclose(self.events[3:5][0].n_hits, + self.events.n_hits[3:5][0]) + assert np.allclose(self.events[3:5].hits[1].dom_id[4], + self.events.hits[3:5][1][4].dom_id) + assert np.allclose(self.events.hits[3:5][1][4].dom_id, + self.events[3:5][1][4].hits.dom_id) + + def test_iteration(self): + i = 0 + for event in self.events: + i += 1 + assert 10 == i + + def test_iteration_2(self): + n_hits = [e.n_hits for e in self.events] + assert np.allclose(n_hits, self.events.n_hits) def test_str(self): - self.assertEqual(str(self.events), 'Number of events: 10') + assert str(self.n_events) in str(self.events) def test_repr(self): - self.assertEqual(repr(self.events), - '<OfflineEvents: 10 parsed events>') - - -class TestOfflineEvent(unittest.TestCase): - def test_event(self): - self.event = OfflineReader(OFFLINE_FILE).events[0] + assert str(self.n_events) in repr(self.events) class TestOfflineHits(unittest.TestCase): def setUp(self): - self.hits = OfflineReader(OFFLINE_FILE).hits - self.lengths = {0: 176, 1: 125, -1: 105} - self.total_item_count = 1434 - self.r_mc = OfflineReader(OFFLINE_NUMUCC) - self.Nevents = 10 - - def test_item_selection(self): - self.assertListEqual(list(self.hits[0].dom_id[:3]), - [806451572, 806451572, 806451572]) - - def test_IndexError(self): - # test handling IndexError with empty lists/arrays - self.assertEqual(len(OfflineHits(['whatever'], [])), 0) - - def test_repr(self): - self.assertEqual(repr(self.hits), '<OfflineHits: 10 parsed elements>') + self.hits = OFFLINE_FILE.events.hits + self.n_hits = 10 + self.dom_id = { + 0: [ + 806451572, 806451572, 806451572, 806451572, 806455814, + 806455814, 806455814, 806483369, 806483369, 806483369 + ], + 5: [ + 806455814, 806487219, 806487219, 806487219, 806487226, + 808432835, 808432835, 808432835, 808432835, 808432835 + ] + } + self.t = { + 0: [ + 70104010., 70104016., 70104192., 70104123., 70103096., + 70103797., 70103796., 70104191., 70104223., 70104181. + ], + 5: [ + 81861237., 81859608., 81860586., 81861062., 81860357., + 81860627., 81860628., 81860625., 81860627., 81860629. + ] + } + + def test_attributes_available(self): + for key in self.hits._keymap.keys(): + getattr(self.hits, key) + + def test_channel_ids(self): + self.assertTrue(all(c >= 0 for c in self.hits.channel_id.min())) + self.assertTrue(all(c < 31 for c in self.hits.channel_id.max())) def test_str(self): - self.assertEqual(str(self.hits), 'Number of hits: 10') - - def test_reading_dom_id(self): - dom_ids = self.hits.dom_id - - for event_id, length in self.lengths.items(): - self.assertEqual(length, len(dom_ids[event_id])) - - self.assertEqual(self.total_item_count, sum(dom_ids.count())) - - self.assertListEqual([806451572, 806451572, 806451572], - list(dom_ids[0][:3])) - - def test_reading_channel_id(self): - channel_ids = self.hits.channel_id - - for event_id, length in self.lengths.items(): - self.assertEqual(length, len(channel_ids[event_id])) - - self.assertEqual(self.total_item_count, sum(channel_ids.count())) + assert str(self.n_hits) in str(self.hits) - self.assertListEqual([8, 9, 14], list(channel_ids[0][:3])) - - # channel IDs are always between [0, 30] - self.assertTrue(all(c >= 0 for c in channel_ids.min())) - self.assertTrue(all(c < 31 for c in channel_ids.max())) - - def test_reading_times(self): - ts = self.hits.t - - for event_id, length in self.lengths.items(): - self.assertEqual(length, len(ts[event_id])) - - self.assertEqual(self.total_item_count, sum(ts.count())) - - self.assertListEqual([70104010.0, 70104016.0, 70104192.0], - list(ts[0][:3])) - - def test_reading_mc_pmt_id(self): - pmt_ids = self.r_mc.mc_hits.pmt_id - lengths = {0: 58, 2: 28, -1: 48} + def test_repr(self): + assert str(self.n_hits) in repr(self.hits) - for hit_id, length in lengths.items(): - self.assertEqual(length, len(pmt_ids[hit_id])) + def test_attributes(self): + for idx, dom_id in self.dom_id.items(): + self.assertListEqual(dom_id, + list(self.hits.dom_id[idx][:len(dom_id)])) + for idx, t in self.t.items(): + assert np.allclose(t, self.hits.t[idx][:len(t)]) - self.assertEqual(self.Nevents, len(pmt_ids)) + def test_slicing(self): + s = slice(2, 8, 2) + s_hits = self.hits[s] + assert 3 == len(s_hits) + for idx, dom_id in self.dom_id.items(): + self.assertListEqual(dom_id[s], list(self.hits.dom_id[idx][s])) + for idx, t in self.t.items(): + self.assertListEqual(t[s], list(self.hits.t[idx][s])) + + def test_slicing_consistency(self): + for s in [slice(1, 3), slice(2, 7, 3)]: + for idx in range(3): + assert np.allclose(self.hits.dom_id[idx][s], + self.hits[idx].dom_id[s]) + assert np.allclose(OFFLINE_FILE.events[idx].hits.dom_id[s], + self.hits.dom_id[idx][s]) + + def test_index_consistency(self): + for idx, dom_ids in self.dom_id.items(): + assert np.allclose(self.hits[idx].dom_id[:self.n_hits], + dom_ids[:self.n_hits]) + assert np.allclose( + OFFLINE_FILE.events[idx].hits.dom_id[:self.n_hits], + dom_ids[:self.n_hits]) + for idx, ts in self.t.items(): + assert np.allclose(self.hits[idx].t[:self.n_hits], + ts[:self.n_hits]) + assert np.allclose(OFFLINE_FILE.events[idx].hits.t[:self.n_hits], + ts[:self.n_hits]) - self.assertListEqual([677, 687, 689], list(pmt_ids[0][:3])) + def test_keys(self): + assert "dom_id" in self.hits.keys() -class TestOfflineHit(unittest.TestCase): +class TestOfflineTracks(unittest.TestCase): def setUp(self): - self.hit = OfflineReader(OFFLINE_FILE)[0].hits[0] + self.f = OFFLINE_FILE + self.tracks = OFFLINE_FILE.events.tracks + self.tracks_numucc = OFFLINE_NUMUCC + self.n_events = 10 - def test_item_selection(self): - self.assertEqual(self.hit[0], self.hit.id) - self.assertEqual(self.hit[1], self.hit.dom_id) + def test_attributes_available(self): + for key in self.tracks._keymap.keys(): + getattr(self.tracks, key) - -class TestOfflineTracks(unittest.TestCase): - def setUp(self): - self.tracks = OfflineReader(OFFLINE_FILE).tracks - self.r_mc = OfflineReader(OFFLINE_NUMUCC) - self.Nevents = 10 + @unittest.skip + def test_attributes(self): + for idx, dom_id in self.dom_id.items(): + self.assertListEqual(dom_id, + list(self.hits.dom_id[idx][:len(dom_id)])) + for idx, t in self.t.items(): + assert np.allclose(t, self.hits.t[idx][:len(t)]) def test_item_selection(self): self.assertListEqual(list(self.tracks[0].dir_z[:2]), [-0.872885221293917, -0.872885221293917]) - def test_IndexError(self): - # test handling IndexError with empty lists/arrays - self.assertEqual(len(OfflineTracks(['whatever'], [])), 0) - def test_repr(self): - self.assertEqual(repr(self.tracks), - '<OfflineTracks: 10 parsed elements>') - - def test_str(self): - self.assertEqual(str(self.tracks), 'Number of tracks: 10') - - def test_reading_tracks_dir_z(self): - dir_z = self.tracks.dir_z - tracks_dir_z = {0: 56, 1: 55, 8: 54} - - for track_id, n_dir in tracks_dir_z.items(): - self.assertEqual(n_dir, len(dir_z[track_id])) - - # check that there are 10 arrays of tracks.dir_z info - self.assertEqual(len(dir_z), self.Nevents) - - def test_reading_mc_tracks_dir_z(self): - dir_z = self.r_mc.mc_tracks.dir_z - tracks_dir_z = {0: 11, 1: 25, 8: 13} - - for track_id, n_dir in tracks_dir_z.items(): - self.assertEqual(n_dir, len(dir_z[track_id])) - - # check that there are 10 arrays of tracks.dir_z info - self.assertEqual(len(dir_z), self.Nevents) - - self.assertListEqual([0.230189, 0.230189, 0.218663], - list(dir_z[0][:3])) + assert " 10 " in repr(self.tracks) def test_slicing(self): tracks = self.tracks - assert 10 == len(tracks) - # track_selection = tracks[2:7] - # assert 5 == len(track_selection) - # track_selection_2 = tracks[1:3] - # assert 2 == len(track_selection_2) - # for _slice in [ - # slice(0, 0), - # slice(0, 1), - # slice(0, 2), - # slice(1, 5), - # slice(3, -2) - # ]: - # self.assertListEqual(list(tracks.E[:, 0][_slice]), - # list(tracks[_slice].E[:, 0])) - - -class TestOfflineTrack(unittest.TestCase): + self.assertEqual(10, len(tracks)) + self.assertEqual(1, len(tracks[0])) + track_selection = tracks[2:7] + assert 5 == len(track_selection) + track_selection_2 = tracks[1:3] + assert 2 == len(track_selection_2) + for _slice in [ + slice(0, 0), + slice(0, 1), + slice(0, 2), + slice(1, 5), + slice(3, -2) + ]: + self.assertListEqual(list(tracks.E[:, 0][_slice]), + list(tracks[_slice].E[:, 0])) + + def test_nested_indexing(self): + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5].tracks[1].fitinf[9][2]) + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5][1][9][2].tracks.fitinf) + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5][1].tracks[9][2].fitinf) + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5][1].tracks[9].fitinf[2]) + + + +class TestBranchIndexingMagic(unittest.TestCase): def setUp(self): - self.track = OfflineReader(OFFLINE_FILE)[0].tracks[0] + self.events = OFFLINE_FILE.events - def test_item_selection(self): - self.assertEqual(self.track[0], self.track.fUniqueID) - self.assertEqual(self.track[10], self.track.E) + def test_foo(self): + self.assertEqual(318, self.events[2:4].n_hits[0]) + assert np.allclose(self.events[3].tracks.dir_z[10], + self.events.tracks.dir_z[3, 10]) + assert np.allclose(self.events[3:6].tracks.pos_y[:, 0], + self.events.tracks.pos_y[3:6, 0]) - def test_str(self): - self.assertEqual(str(self.track).split('\n\t')[0], 'offline track:') + # test selecting with a list + self.assertEqual(3, len(self.events[[0, 2, 3]])) class TestUsr(unittest.TestCase): def setUp(self): - self.f = OfflineReader(OFFLINE_USR) - - def test_str(self): - print(self.f.usr) + self.f = OFFLINE_USR - def test_nonexistent_usr(self): - f = OfflineReader(SAMPLES_DIR / "daq_v1.0.0.root") - self.assertListEqual([], f.usr.keys()) + def test_str_flat(self): + print(self.f.events.usr) - def test_keys(self): + def test_keys_flat(self): self.assertListEqual([ 'RecoQuality', 'RecoNDF', 'CoC', 'ToT', 'ChargeAbove', 'ChargeBelow', 'ChargeRatio', 'DeltaPosZ', 'FirstPartPosZ', 'LastPartPosZ', 'NSnapHits', 'NTrigHits', 'NTrigDOMs', 'NTrigLines', 'NSpeedVetoHits', 'NGeometryVetoHits', 'ClassficationScore' - ], self.f.usr.keys()) + ], self.f.events.usr.keys()) - def test_getitem(self): + def test_getitem_flat(self): assert np.allclose( [118.6302815337638, 44.33580521344907, 99.93916717621543], - self.f.usr['CoC']) + self.f.events.usr['CoC']) assert np.allclose( [37.51967774166617, -10.280346193553832, 13.67595659707355], - self.f.usr['DeltaPosZ']) + self.f.events.usr['DeltaPosZ']) - def test_attributes(self): + @unittest.skip + def test_keys_nested(self): + self.assertListEqual(["a"], self.f.events.mc_tracks.usr.keys()) + + def test_attributes_flat(self): assert np.allclose( [118.6302815337638, 44.33580521344907, 99.93916717621543], - self.f.usr.CoC) + self.f.events.usr.CoC) assert np.allclose( [37.51967774166617, -10.280346193553832, 13.67595659707355], - self.f.usr.DeltaPosZ) + self.f.events.usr.DeltaPosZ) + + +class TestNestedMapper(unittest.TestCase): + def test_nested_mapper(self): + self.assertEqual('pos_x', _nested_mapper("trks.pos.x")) diff --git a/tests/test_tools.py b/tests/test_tools.py index 4506869935a5a7f1331560744a6fb79c06e6e37f..857f9c84a7939f0676560afc29cb74bc41d0c85a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 import unittest -from km3io.tools import _to_num, cached_property +from km3io.tools import _to_num, cached_property, _unfold_indices + class TestToNum(unittest.TestCase): def test_to_num(self): @@ -19,3 +20,21 @@ class TestCachedProperty(unittest.TestCase): pass self.assertTrue(isinstance(Test.prop, cached_property)) + + +class TestUnfoldIndices(unittest.TestCase): + def test_unfold_indices(self): + data = range(10) + + indices = [slice(2, 5), 0] + assert data[indices[0]][indices[1]] == _unfold_indices(data, indices) + + indices = [slice(1, 9, 2), slice(1, 4), 2] + assert data[indices[0]][indices[1]][indices[2]] == _unfold_indices( + data, indices) + + def test_unfold_indices_raises_index_error(self): + data = range(10) + indices = [slice(2, 5), 99] + with self.assertRaises(IndexError): + _unfold_indices(data, indices)