diff --git a/km3io/offline.py b/km3io/offline.py index 53c1a1a254e9afaa591fd9b682ba20079684c535..82f48c0eccdd9aa179dab4841def3964bd2fd052 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -5,36 +5,35 @@ import km3io.definitions.trigger import km3io.definitions.fitparameters import km3io.definitions.reconstruction +MAIN_TREE_NAME = "E" # 110 MB based on the size of the largest basket found so far in km3net BASKET_CACHE_SIZE = 110 * 1024**2 +class cached_property: + """A simple cache decorator for properties.""" + def __init__(self, function): + self.function = function + + def __get__(self, obj, cls): + if obj is None: + return self + prop = obj.__dict__[self.function.__name__] = self.function(obj) + return prop + + class OfflineKeys: """wrapper for offline keys""" - def __init__(self, file_path): + def __init__(self, tree): """OfflineKeys is a class that reads all the available keys in an offline file and adapts the keys format to Python format. Parameters ---------- - file_path : path-like object - Path to the offline file of interest. It can be a str or any python - path-like object that points to the file of ineterst. + tree : uproot.TTree + The main ROOT tree. """ - self._file_path = file_path - self._events_keys = None - self._hits_keys = None - self._tracks_keys = None - self._mc_hits_keys = None - self._mc_tracks_keys = None - self._valid_keys = None - self._fit_keys = None - self._cut_hits_keys = None - self._cut_tracks_keys = None - self._cut_events_keys = None - self._trigger = None - self._fitparameters = None - self._reconstruction = None + self._tree = tree def __str__(self): return '\n'.join([ @@ -46,10 +45,32 @@ class OfflineKeys: ]) def __repr__(self): - return str(self) - # return f'{self.__class__.__name__}("{self._file_path}")' + return "<{}>".format(self.__class__.__name__) + + def _get_keys(self, tree, fake_branches=None): + """Get tree keys except those in fake_branches + + Parameters + ---------- + tree : uproot.Tree + The tree to look for keys + fake_branches : list of str or None + The fake branches to ignore + + Returns + ------- + list of str + The keys of the tree. + """ + keys = [] + for key in tree.keys(): + key = key.decode('utf-8') + if fake_branches is not None and key in fake_branches: + continue + keys.append(key) + return keys - @property + @cached_property def events_keys(self): """reads events keys from an offline file. @@ -59,17 +80,12 @@ class OfflineKeys: list of all events keys found in an offline file, except those found in fake branches. """ - if self._events_keys is None: - fake_branches = ['Evt', 'AAObject', 'TObject', 't'] - t_baskets = ['t.fSec', 't.fNanoSec'] - tree = uproot.open(self._file_path)['E']['Evt'] - self._events_keys = [ - key.decode('utf-8') for key in tree.keys() - if key.decode('utf-8') not in fake_branches - ] + t_baskets - return self._events_keys - - @property + fake_branches = ['Evt', 'AAObject', 'TObject', 't'] + t_baskets = ['t.fSec', 't.fNanoSec'] + tree = self._tree['Evt'] + return self._get_keys(self._tree['Evt'], fake_branches) + t_baskets + + @cached_property def hits_keys(self): """reads hits keys from an offline file. @@ -79,18 +95,10 @@ class OfflineKeys: list of all hits keys found in an offline file, except those found in fake branches. """ - if self._hits_keys is None: - fake_branches = [ - 'hits.usr', 'hits.usr_names' - ] # to be treated like trks.usr and trks.usr_names - tree = uproot.open(self._file_path)['E']['hits'] - self._hits_keys = [ - key.decode('utf8') for key in tree.keys() - if key.decode('utf8') not in fake_branches - ] - return self._hits_keys + fake_branches = ['hits.usr', 'hits.usr_names'] + return self._get_keys(self._tree['hits'], fake_branches) - @property + @cached_property def tracks_keys(self): """reads tracks keys from an offline file. @@ -100,20 +108,12 @@ class OfflineKeys: list of all tracks keys found in an offline file, except those found in fake branches. """ - if self._tracks_keys is None: - # a solution can be tree['trks.usr_data'].array( - # uproot.asdtype(">i4")) - fake_branches = [ - 'trks.usr_data', 'trks.usr', 'trks.usr_names' - ] # can be accessed using tree['trks.usr_names'].array() - tree = uproot.open(self._file_path)['E']['Evt']['trks'] - self._tracks_keys = [ - key.decode('utf8') for key in tree.keys() - if key.decode('utf8') not in fake_branches - ] - return self._tracks_keys + # a solution can be tree['trks.usr_data'].array( + # uproot.asdtype(">i4")) + fake_branches = ['trks.usr_data', 'trks.usr', 'trks.usr_names'] + return self._get_keys(self._tree['Evt']['trks'], fake_branches) - @property + @cached_property def mc_hits_keys(self): """reads mc hits keys from an offline file. @@ -123,16 +123,10 @@ class OfflineKeys: list of all mc hits keys found in an offline file, except those found in fake branches. """ - if self._mc_hits_keys is None: - fake_branches = ['mc_hits.usr', 'mc_hits.usr_names'] - tree = uproot.open(self._file_path)['E']['Evt']['mc_hits'] - self._mc_hits_keys = [ - key.decode('utf8') for key in tree.keys() - if key.decode('utf8') not in fake_branches - ] - return self._mc_hits_keys + fake_branches = ['mc_hits.usr', 'mc_hits.usr_names'] + return self._get_keys(self._tree['Evt']['mc_hits'], fake_branches) - @property + @cached_property def mc_tracks_keys(self): """reads mc tracks keys from an offline file. @@ -142,32 +136,23 @@ class OfflineKeys: list of all mc tracks keys found in an offline file, except those found in fake branches. """ - if self._mc_tracks_keys is None: - fake_branches = [ - 'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names' - ] # same solution as above can be used - tree = uproot.open(self._file_path)['E']['Evt']['mc_trks'] - self._mc_tracks_keys = [ - key.decode('utf8') for key in tree.keys() - if key.decode('utf8') not in fake_branches - ] - return self._mc_tracks_keys + fake_branches = [ + 'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names' + ] + return self._get_keys(self._tree['Evt']['mc_trks'], fake_branches) - @property + @cached_property def valid_keys(self): """constructs a list of all valid keys to be read from an offline event file. Returns ------- list of str list of all valid keys. - """ - if self._valid_keys is None: - self._valid_keys = (self.events_keys + self.hits_keys + - self.tracks_keys + self.mc_tracks_keys + - self.mc_hits_keys) - return self._valid_keys - - @property + """ + return (self.events_keys + self.hits_keys + self.tracks_keys + + self.mc_tracks_keys + self.mc_hits_keys) + + @cached_property def fit_keys(self): """constructs a list of fit parameters, not yet outsourced in an offline file. @@ -176,14 +161,11 @@ class OfflineKeys: list of str list of all "trks.fitinf" keys. """ - if self._fit_keys is None: - self._fit_keys = sorted(self.fitparameters, - key=self.fitparameters.get, - reverse=False) - # self._fit_keys = [*fit.keys()] - return self._fit_keys - - @property + return sorted(km3io.definitions.fitparameters.data, + key=km3io.definitions.fitparameters.data.get, + reverse=False) + + @cached_property def cut_hits_keys(self): """adapts hits keys for instance variables format in a Python class. @@ -192,13 +174,9 @@ class OfflineKeys: list of str list of adapted hits keys. """ - if self._cut_hits_keys is None: - self._cut_hits_keys = [ - k.split('hits.')[1].replace('.', '_') for k in self.hits_keys - ] - return self._cut_hits_keys + return [k.split('hits.')[1].replace('.', '_') for k in self.hits_keys] - @property + @cached_property def cut_tracks_keys(self): """adapts tracks keys for instance variables format in a Python class. @@ -207,13 +185,11 @@ class OfflineKeys: list of str list of adapted tracks keys. """ - if self._cut_tracks_keys is None: - self._cut_tracks_keys = [ - k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys - ] - return self._cut_tracks_keys + return [ + k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys + ] - @property + @cached_property def cut_events_keys(self): """adapts events keys for instance variables format in a Python class. @@ -222,170 +198,64 @@ class OfflineKeys: list of str list of adapted events keys. """ - if self._cut_events_keys is None: - self._cut_events_keys = [ - k.replace('.', '_') for k in self.events_keys - ] - return self._cut_events_keys + return [k.replace('.', '_') for k in self.events_keys] - @property - def trigger(self): - """trigger parameters and their index from km3net-Dataformat. - Returns - ------- - dict - dictionary of trigger parameters and their index in an Offline - file. - """ - if self._trigger is None: - self._trigger = km3io.definitions.trigger.data - return self._trigger - - @property - def reconstruction(self): - """reconstruction parameters and their index from km3net-Dataformat. - - Returns - ------- - dict - dictionary of reconstruction parameters and their index in an - Offline file. - """ - if self._reconstruction is None: - self._reconstruction = km3io.definitions.reconstruction.data - return self._reconstruction - - @property - def fitparameters(self): - """fit parameters parameters and their index from km3net-Dataformat. - - Returns - ------- - dict - dictionary of fit parameters and their index in an Offline - file. - """ - if self._fitparameters is None: - self._fitparameters = km3io.definitions.fitparameters.data - return self._fitparameters - - -class Reader: - """Reader for one offline ROOT file""" - def __init__(self, file_path): - """ Reader class is an offline ROOT file reader. This class is a - "very" low level I/O. +class OfflineReader: + """reader for offline ROOT files""" + def __init__(self, file_path=None, fobj=None, data=None): + """ OfflineReader class is an offline ROOT file wrapper Parameters ---------- file_path : path-like object Path to the file of interest. It can be a str or any python - path-like object that points to the file of ineterst. - """ - self._file_path = file_path - self._data = uproot.open(self._file_path)['E'].lazyarrays( - basketcache=uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)) - self._keys = None - - def __getitem__(self, key): - """reads data stored in the branch of interest in an Evt tree. - - Parameters - ---------- - key : str - name of the branch of interest in event data. + path-like object that points to the file. - Returns - ------- - lazyarray - Lazyarray of all data stored in the branch of interest. A lazyarray - is an array-like object that reads data on demand. Here, only the - first and last chunks of data are read in memory, and not all data - in the array. The output can be used with all `Numpy's universal - functions <https://docs.scipy.org/doc/numpy/reference/ufuncs.html>` - . - - Raises - ------ - KeyError - Some branches in an offline file structure are "fake branches" and - do not contain data. Therefore, the keys corresponding to these - fake branches are not read. """ - keys = self.keys.valid_keys - if key not in keys and not isinstance(key, int): - raise KeyError( - "'{}' is not a valid key or is a fake branch.".format(key)) - return self._data[key] - - def __len__(self): - return len(self._data) - - def __repr__(self): - return "<{}: {} entries>".format(self.__class__.__name__, len(self)) - - @property - def keys(self): - """wrapper for all keys in an offline file. - - Returns - ------- - Class - OfflineKeys. - """ - if self._keys is None: - self._keys = OfflineKeys(self._file_path) - return self._keys - + if file_path is not None: + self._fobj = uproot.open(file_path) + self._tree = self._fobj[MAIN_TREE_NAME] + self._data = self._tree.lazyarrays( + basketcache=uproot.cache.ThreadSafeArrayCache( + BASKET_CACHE_SIZE)) + else: + self._fobj = fobj + self._tree = self._fobj[MAIN_TREE_NAME] + self._data = data -class OfflineReader: - """reader for offline ROOT files""" - def __init__(self, file_path, data=None): - """ OfflineReader class is an offline ROOT file wrapper + @classmethod + def from_index(cls, source, index): + """Create an instance with a subtree of a given index Parameters ---------- - file_path : path-like object - Path to the file of interest. It can be a str or any python - path-like object that points to the file of ineterst. + source: ROOTDirectory + The source file. + index: index or slice + The index or slice to create the subtree. """ - self._file_path = file_path - if data is not None: - self._data = data - else: - self._data = uproot.open(self._file_path)['E'].lazyarrays( - basketcache=uproot.cache.ThreadSafeArrayCache( - BASKET_CACHE_SIZE)) - self._events = None - self._hits = None - self._tracks = None - self._mc_hits = None - self._mc_tracks = None - self._keys = None - self._best_reco = None - self._header = None - self._usr = None + instance = cls(fobj=source._fobj, data=source._data[index]) + return instance - def __getitem__(self, item): - return OfflineReader(file_path=self._file_path, data=self._data[item]) + def __getitem__(self, index): + return OfflineReader.from_index(source=self, index=index) def __len__(self): return len(self._data) - @property + @cached_property def header(self): - if self._header is None: - fobj = uproot.open(self._file_path) - if 'Head' in fobj: - self._header = {} - for n, x in fobj['Head']._map_3c_string_2c_string_3e_.items(): - self._header[n.decode("utf-8")] = x.decode("utf-8").strip() - else: - warnings.warn("Your file header has an unsupported format") - return self._header - - @property + if 'Head' in self._fobj: + header = {} + for n, x in self._fobj['Head']._map_3c_string_2c_string_3e_.items( + ): + header[n.decode("utf-8")] = x.decode("utf-8").strip() + return header + else: + warnings.warn("Your file header has an unsupported format") + + @cached_property def keys(self): """wrapper for all keys in an offline file. @@ -394,11 +264,9 @@ class OfflineReader: Class OfflineKeys. """ - if self._keys is None: - self._keys = OfflineKeys(self._file_path) - return self._keys + return OfflineKeys(self._tree) - @property + @cached_property def events(self): """wrapper for offline events. @@ -407,13 +275,11 @@ class OfflineReader: Class OfflineEvents. """ - if self._events is None: - self._events = OfflineEvents( - self.keys.cut_events_keys, - [self._data[key] for key in self.keys.events_keys]) - return self._events + return OfflineEvents( + self.keys.cut_events_keys, + [self._data[key] for key in self.keys.events_keys]) - @property + @cached_property def hits(self): """wrapper for offline hits. @@ -422,13 +288,10 @@ class OfflineReader: Class OfflineHits. """ - if self._hits is None: - self._hits = OfflineHits( - self.keys.cut_hits_keys, - [self._data[key] for key in self.keys.hits_keys]) - return self._hits + return OfflineHits(self.keys.cut_hits_keys, + [self._data[key] for key in self.keys.hits_keys]) - @property + @cached_property def tracks(self): """wrapper for offline tracks. @@ -437,14 +300,11 @@ class OfflineReader: Class OfflineTracks. """ - if self._tracks is None: - self._tracks = OfflineTracks( - self.keys.cut_tracks_keys, - [self._data[key] for key in self.keys.tracks_keys], - fitparameters=self.keys.fitparameters) - return self._tracks - - @property + return OfflineTracks( + self.keys.cut_tracks_keys, + [self._data[key] for key in self.keys.tracks_keys]) + + @cached_property def mc_hits(self): """wrapper for offline mc hits. @@ -453,13 +313,10 @@ class OfflineReader: Class OfflineHits. """ - if self._mc_hits is None: - self._mc_hits = OfflineHits( - self.keys.cut_hits_keys, - [self._data[key] for key in self.keys.mc_hits_keys]) - return self._mc_hits + return OfflineHits(self.keys.cut_hits_keys, + [self._data[key] for key in self.keys.mc_hits_keys]) - @property + @cached_property def mc_tracks(self): """wrapper for offline mc tracks. @@ -468,21 +325,15 @@ class OfflineReader: Class OfflineTracks. """ - if self._mc_tracks is None: - self._mc_tracks = OfflineTracks( - self.keys.cut_tracks_keys, - [self._data[key] for key in self.keys.mc_tracks_keys], - fitparameters=self.keys.fitparameters) - return self._mc_tracks - - @property + return OfflineTracks( + self.keys.cut_tracks_keys, + [self._data[key] for key in self.keys.mc_tracks_keys]) + + @cached_property def usr(self): - if self._usr is None: - self._usr = Usr(self._file_path) - return self._usr + return Usr(self._tree) - @property - def best_reco(self): + def get_best_reco(self): """returns the best reconstructed track fit data. The best fit is defined as the track fit with the maximum reconstruction stages. When "nan" is returned, it means that the reconstruction parameter of interest is not @@ -496,24 +347,22 @@ class OfflineReader: numpy recarray a recarray of the best track fit data (reconstruction data). """ - if self._best_reco is None: - keys = ", ".join(self.keys.fit_keys[:-1]) - empty_fit_info = np.array( - [match for match in self._find_empty(self.tracks.fitinf)]) - fit_info = [ - i for i, j in zip(self.tracks.fitinf, empty_fit_info[:, 1]) - if j is not None - ] - stages = self._get_max_reco_stages(self.tracks.rec_stages) - fit_data = np.array([i[j] for i, j in zip(fit_info, stages[:, 2])]) - rows_size = len(max(fit_data, key=len)) - equal_size_data = np.vstack([ - np.hstack([i, np.zeros(rows_size - len(i)) + np.nan]) - for i in fit_data - ]) - self._best_reco = np.core.records.fromarrays( - equal_size_data.transpose(), names=keys) - return self._best_reco + keys = ", ".join(self.keys.fit_keys[:-1]) + empty_fit_info = np.array( + [match for match in self._find_empty(self.tracks.fitinf)]) + fit_info = [ + i for i, j in zip(self.tracks.fitinf, empty_fit_info[:, 1]) + if j is not None + ] + stages = self._get_max_reco_stages(self.tracks.rec_stages) + fit_data = np.array([i[j] for i, j in zip(fit_info, stages[:, 2])]) + rows_size = len(max(fit_data, key=len)) + equal_size_data = np.vstack([ + np.hstack([i, np.zeros(rows_size - len(i)) + np.nan]) + for i in fit_data + ]) + return np.core.records.fromarrays(equal_size_data.transpose(), + names=keys) def _get_max_reco_stages(self, reco_stages): """find the longest reconstructed track based on the maximum size of @@ -825,15 +674,13 @@ class OfflineReader: class Usr: """Helper class to access AAObject usr stuff""" - def __init__(self, filepath): - self._f = uproot.open(filepath) + def __init__(self, tree): # Here, we assume that every event has the same names in the same order # to massively increase the performance. This needs triple check if it's # always the case; the usr-format is simply a very bad design. try: self._usr_names = [ - n.decode("utf-8") - for n in self._f['E']['Evt']['usr_names'].array()[0] + n.decode("utf-8") for n in tree['Evt']['usr_names'].array()[0] ] except (KeyError, IndexError): # e.g. old aanet files self._usr_names = [] @@ -842,7 +689,7 @@ class Usr: name: index for index, name in enumerate(self._usr_names) } - self._usr_data = self._f['E']['Evt']['usr'].lazyarray( + self._usr_data = tree['Evt']['usr'].lazyarray( basketcache=uproot.cache.ThreadSafeArrayCache( BASKET_CACHE_SIZE)) for name in self._usr_names: @@ -918,9 +765,6 @@ class OfflineEvent: for k, v in zip(self._keys, self._values) ]) - def __repr__(self): - return str(self) - class OfflineHits: """wrapper for offline hits""" @@ -982,19 +826,10 @@ class OfflineHit: def __getitem__(self, item): return self._values[item] - def __repr__(self): - return str(self) - - # def _is_empty(array): - # if array.size: - # return False - # else: - # return True - class OfflineTracks: """wrapper for offline tracks""" - def __init__(self, keys, values, fitparameters=None): + def __init__(self, keys, values): """wrapper for offline tracks Parameters @@ -1003,20 +838,14 @@ class OfflineTracks: list of cropped tracks keys. values : list of arrays list of arrays containting tracks data. - fitparameters : None, optional - dictionary of tracks fit information (not yet outsourced in offline - files). """ self._keys = keys self._values = values - if fitparameters is not None: - self._fitparameters = fitparameters for k, v in zip(self._keys, self._values): setattr(self, k, v) def __getitem__(self, item): - return OfflineTrack(self._keys, [v[item] for v in self._values], - fitparameters=self._fitparameters) + return OfflineTrack(self._keys, [v[item] for v in self._values]) def __len__(self): try: @@ -1034,7 +863,7 @@ class OfflineTracks: class OfflineTrack: """wrapper for an offline track""" - def __init__(self, keys, values, fitparameters=None): + def __init__(self, keys, values): """wrapper for one offline track. Parameters @@ -1043,14 +872,9 @@ class OfflineTrack: list of cropped tracks keys. values : list of arrays list of arrays containting track data. - fitparameters : None, optional - dictionary of tracks fit information (not yet outsourced in offline - files). """ self._keys = keys self._values = values - if fitparameters is not None: - self._fitparameters = fitparameters for k, v in zip(self._keys, self._values): setattr(self, k, v) @@ -1061,12 +885,9 @@ class OfflineTrack: ]) + "\n\t" + "\n\t".join([ "{:30} {:^2} {:>26}".format(k, ':', str( getattr(self, 'fitinf')[v])) - for k, v in self._fitparameters.items() + for k, v in km3io.definitions.fitparameters.data.items() if len(getattr(self, 'fitinf')) > v ]) # I don't like 18 being explicit here def __getitem__(self, item): return self._values[item] - - def __repr__(self): - return str(self) diff --git a/notebooks/Reader_tutorial.ipynb b/notebooks/Reader_tutorial.ipynb deleted file mode 100644 index 808e5452d4a8f70f75f245b175ec6aea5147ba10..0000000000000000000000000000000000000000 --- a/notebooks/Reader_tutorial.ipynb +++ /dev/null @@ -1,440 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['/home/zineb/km3net/km3net/km3io/notebooks', '/home/zineb/miniconda3/envs/km3pipe/lib/python37.zip', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7/lib-dynload', '', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7/site-packages', '/home/zineb/km3net/km3net/km3io', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7/site-packages/IPython/extensions', '/home/zineb/.ipython', '/home/zineb/km3net/km3net/km3io']\n" - ] - } - ], - "source": [ - "# Add file to current python path\n", - "from pathlib import Path\n", - "import sys\n", - "sys.path.append(str(Path.cwd().parent))\n", - "Path.cwd()\n", - "print(sys.path)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from km3io.aanet import Reader" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# test samples directory - aanet test file\n", - "files_path = Path.cwd().parent / 'tests/samples' \n", - "aanet_file = files_path / 'aanet_v2.0.0.root'" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "<Reader: 10 entries>" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "reader = Reader(aanet_file)\n", - "reader" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['trks.fUniqueID',\n", - " 'trks.fBits',\n", - " 'trks.id',\n", - " 'trks.pos.x',\n", - " 'trks.pos.y',\n", - " 'trks.pos.z',\n", - " 'trks.dir.x',\n", - " 'trks.dir.y',\n", - " 'trks.dir.z',\n", - " 'trks.t',\n", - " 'trks.E',\n", - " 'trks.len',\n", - " 'trks.lik',\n", - " 'trks.type',\n", - " 'trks.rec_type',\n", - " 'trks.rec_stages',\n", - " 'trks.status',\n", - " 'trks.mother_id',\n", - " 'trks.fitinf',\n", - " 'trks.hit_ids',\n", - " 'trks.error_matrix',\n", - " 'trks.comment']" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "reader.keys.tracks_keys" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "<Table [<Row 0> <Row 1> <Row 2> ... <Row 7> <Row 8> <Row 9>] at 0x7fed2d98a750>" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# big lazyarray with ALL file data!\n", - "lazy_data = reader._data\n", - "lazy_data" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "<ChunkedArray [5971 5971 5971 ... 5971 5971 5971] at 0x7fed2d91f9d0>" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# getting the run_id for a specific event (event 5 for example)\n", - "reader['run_id']" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "60" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# one can check how many hits are in event 5\n", - "reader[5]['hits']" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "56" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# one can also check how many tracks are in event 5\n", - "reader[5]['trks']" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# # the user is reminded to always specify the \"correct\" event/hits/tracks \n", - "# # key in the Aanet event file\n", - "# try:\n", - "# reader['whatever']\n", - "# except KeyError as e:\n", - "# print(e)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Now let's explore in more details the hits:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "<ChunkedArray [[806451572 806451572 806451572 ... 809544061 809544061 809544061] [806451572 806451572 806451572 ... 809524432 809526097 809544061] [806451572 806451572 806451572 ... 809544061 809544061 809544061] ... [806451572 806455814 806465101 ... 809526097 809544058 809544061] [806455814 806455814 806455814 ... 809544061 809544061 809544061] [806455814 806455814 806455814 ... 809544058 809544058 809544061]] at 0x7fed2d8af450>" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# reading all data from a specific branch in hits data: for example \n", - "# 'hits.dom_id'\n", - "reader['hits.dom_id']" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([806455814, 806487219, 806487219, 806487219, 806487226, 808432835,\n", - " 808432835, 808432835, 808432835, 808432835, 808432835, 808432835,\n", - " 808451904, 808451904, 808451907, 808451907, 808469129, 808469129,\n", - " 808469129, 808493910, 808949744, 808949744, 808951460, 808951460,\n", - " 808956908, 808961655, 808964908, 808969848, 808969857, 808972593,\n", - " 808972593, 808972598, 808972598, 808972698, 808972698, 808974758,\n", - " 808974811, 808976377, 808981510, 808981523, 808981812, 808982005,\n", - " 808982005, 808982018, 808982077, 808982077, 808982547, 809007627,\n", - " 809521500, 809521500, 809521500, 809524432, 809526097, 809526097,\n", - " 809526097, 809526097, 809526097, 809526097, 809526097, 809544058],\n", - " dtype=int32)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# the user can access hits.dom_id data for a specific event (for example \n", - "# event 5)\n", - "reader['hits.dom_id'][5]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "60" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# We previsouly checked (using reader[5]['hits']) that event\n", - "# 5 has 60 hits, now we can see that reader['hits.dom_id'][5]\n", - "# has exaclty 60 dom ids as well! \n", - "len(reader['hits.dom_id'][5])" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "806455814" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# one can access a dom id of the first hit \n", - "reader['hits.dom_id'][5][0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Now let's explore in more details the tracks:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "<ChunkedArray [[-0.872885221293917 -0.872885221293917 -0.872885221293917 ... -0.6631226836266504 -0.5680647731737454 -0.5680647731737454] [-0.8351996698137462 -0.8351996698137462 -0.8351996698137462 ... -0.7485107718446855 -0.8229838871876581 -0.239315690284641] [-0.989148723802379 -0.989148723802379 -0.989148723802379 ... -0.9350162572437829 -0.88545604390297 -0.88545604390297] ... [-0.5704611045902105 -0.5704611045902105 -0.5704611045902105 ... -0.9350162572437829 -0.4647231989130516 -0.4647231989130516] [-0.9779941383490359 -0.9779941383490359 -0.9779941383490359 ... -0.88545604390297 -0.88545604390297 -0.8229838871876581] [-0.7396916780974963 -0.7396916780974963 -0.7396916780974963 ... -0.6631226836266504 -0.7485107718446855 -0.7485107718446855]] at 0x7fed2d8b8850>" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# reading all data from a specific branch in tracks data:\n", - "reader['trks.dir.z']" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([-0.60246049, -0.60246049, -0.60246049, -0.51420541, -0.5475772 ,\n", - " -0.5772408 , -0.56068238, -0.64907684, -0.67781799, -0.66565114,\n", - " -0.63014839, -0.64566464, -0.62691012, -0.58465493, -0.59287533,\n", - " -0.63655091, -0.63771247, -0.73446841, -0.7456636 , -0.70941246,\n", - " -0.66312268, -0.66312268, -0.56806477, -0.56806477, -0.66312268,\n", - " -0.66312268, -0.74851077, -0.74851077, -0.66312268, -0.74851077,\n", - " -0.56806477, -0.74851077, -0.66312268, -0.74851077, -0.56806477,\n", - " -0.66312268, -0.56806477, -0.66312268, -0.56806477, -0.56806477,\n", - " -0.66312268, -0.74851077, -0.66312268, -0.93501626, -0.56806477,\n", - " -0.74851077, -0.66312268, -0.56806477, -0.82298389, -0.74851077,\n", - " -0.66312268, -0.56806477, -0.82298389, -0.56806477, -0.66312268,\n", - " -0.97094183])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# the user can access trks.dir.z data for a specific event (for example \n", - "# event 5)\n", - "reader['trks.dir.z'][5]" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "56" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# We previsouly checked (using reader[5]['trks']) that event\n", - "# 5 has 56 tracks, now we can see that reader['trks.dir.z'][5]\n", - "# has exaclty 56 values of trks.dir.z as well! \n", - "len(reader['trks.dir.z'][5])" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "-0.6024604933159441" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# one can access the first trks.dir.z from event 5 using \n", - "reader['trks.dir.z'][5][0]" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/tests/test_offline.py b/tests/test_offline.py index f0ac1ba03b190547a063468d0bd3d53c0ec6d431..60ae196939242ee9b3a82f84f35733147ba53910 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -2,7 +2,7 @@ import unittest import numpy as np from pathlib import Path -from km3io.offline import Reader, OfflineEvents, OfflineHits, OfflineTracks +from km3io.offline import OfflineEvents, OfflineHits, OfflineTracks from km3io import OfflineReader SAMPLES_DIR = Path(__file__).parent / 'samples' @@ -15,12 +15,6 @@ class TestOfflineKeys(unittest.TestCase): def setUp(self): self.keys = OfflineReader(OFFLINE_FILE).keys - def test_repr(self): - reader_repr = repr(self.keys) - - # check that there are 106 keys + 5 extra str - self.assertEqual(len(reader_repr.split('\n')), 111) - def test_events_keys(self): # there are 22 "valid" events keys self.assertEqual(len(self.keys.events_keys), 22) @@ -47,121 +41,12 @@ class TestOfflineKeys(unittest.TestCase): # there are 18 fit keys self.assertEqual(len(self.keys.fit_keys), 18) - def test_trigger(self): - # there are 4 trigger keys in v1.1.2 of km3net-Dataformat - trigger = self.keys.trigger - keys = [ - 'JTRIGGER3DSHOWER', 'JTRIGGERMXSHOWER', 'JTRIGGER3DMUON', - 'JTRIGGERNB' - ] - values = [1, 2, 4, 5] - - for k, v in zip(keys, values): - self.assertEqual(v, trigger[k]) - - def test_reconstruction(self): - # there are 34 parameters in v1.1.2 of km3net-Dataformat - reco = self.keys.reconstruction - keys = [ - 'JPP_RECONSTRUCTION_TYPE', 'JMUONFIT', 'JMUONBEGIN', 'JMUONPREFIT', - 'JMUONSIMPLEX', 'JMUONGANDALF', 'JMUONENERGY', 'JMUONSTART' - ] - values = [4000, 0, 0, 1, 2, 3, 4, 5] - - self.assertEqual(34, len([*reco.keys()])) - for k, v in zip(keys, values): - self.assertEqual(v, reco[k]) - - def test_fitparameters(self): - # there are 18 parameters in v1.1.2 of km3net-Dataformat - fit = self.keys.fitparameters - values = [i for i in range(18)] - - self.assertEqual(18, len([*fit.keys()])) - for k, v in fit.items(): - self.assertEqual(values[v], fit[k]) - - -class TestReader(unittest.TestCase): - def setUp(self): - self.r = Reader(OFFLINE_FILE) - self.lengths = {0: 176, 1: 125, -1: 105} - self.total_item_count = 1434 - - def test_reading_dom_id(self): - dom_ids = self.r["hits.dom_id"] - - for event_id, length in self.lengths.items(): - self.assertEqual(length, len(dom_ids[event_id])) - - self.assertEqual(self.total_item_count, sum(dom_ids.count())) - - self.assertListEqual([806451572, 806451572, 806451572], - list(dom_ids[0][:3])) - - def test_reading_channel_id(self): - channel_ids = self.r["hits.channel_id"] - - for event_id, length in self.lengths.items(): - self.assertEqual(length, len(channel_ids[event_id])) - - self.assertEqual(self.total_item_count, sum(channel_ids.count())) - - self.assertListEqual([8, 9, 14], list(channel_ids[0][:3])) - - # channel IDs are always between [0, 30] - self.assertTrue(all(c >= 0 for c in channel_ids.min())) - self.assertTrue(all(c < 31 for c in channel_ids.max())) - - def test_reading_times(self): - ts = self.r["hits.t"] - - for event_id, length in self.lengths.items(): - self.assertEqual(length, len(ts[event_id])) - - self.assertEqual(self.total_item_count, sum(ts.count())) - - self.assertListEqual([70104010.0, 70104016.0, 70104192.0], - list(ts[0][:3])) - - def test_reading_keys(self): - # there are 106 "valid" keys in an offline file - self.assertEqual(len(self.r.keys.valid_keys), 106) - - # there are 20 hits keys - self.assertEqual(len(self.r.keys.hits_keys), 20) - self.assertEqual(len(self.r.keys.mc_hits_keys), 20) - - # there are 22 tracks keys - self.assertEqual(len(self.r.keys.tracks_keys), 22) - self.assertEqual(len(self.r.keys.mc_tracks_keys), 22) - - def test_raising_KeyError(self): - # non valid keys must raise a KeyError - with self.assertRaises(KeyError): - self.r['whatever'] - - def test_number_events(self): - Nevents = len(self.r) - - # check that there are 10 events - self.assertEqual(Nevents, 10) - class TestOfflineReader(unittest.TestCase): def setUp(self): self.r = OfflineReader(OFFLINE_FILE) self.nu = OfflineReader(OFFLINE_NUMUCC) self.Nevents = 10 - self.selected_data = OfflineReader(OFFLINE_FILE, - data=self.r._data[0])._data - - def test_item_selection(self): - # test class instance with data=None option - self.assertEqual(len(self.selected_data), len(self.r._data[0])) - - # test item selection (here we test with hits=176) - self.assertEqual(self.r[0].events.hits, self.selected_data['hits']) def test_number_events(self): Nevents = len(self.r) @@ -259,7 +144,7 @@ class TestOfflineReader(unittest.TestCase): 0.0014177681261476852, 0.002094094517471032, 0.003923368624980349, 0.009491461076780453 ] - best = self.nu.best_reco + best = self.nu.get_best_reco() self.assertEqual(best.size, 9) self.assertEqual(best['JGANDALF_BETA1_RAD'][:4].tolist(), @@ -311,15 +196,9 @@ class TestOfflineEvents(unittest.TestCase): class TestOfflineEvent(unittest.TestCase): - def setUp(self): + def test_event(self): self.event = OfflineReader(OFFLINE_FILE).events[0] - def test_str(self): - self.assertEqual(repr(self.event).split('\n\t')[0], 'offline event:') - self.assertEqual( - repr(self.event).split('\n\t')[2], - 'det_id : 44') - class TestOfflineHits(unittest.TestCase): def setUp(self): @@ -399,12 +278,6 @@ class TestOfflineHit(unittest.TestCase): self.assertEqual(self.hit[0], self.hit.id) self.assertEqual(self.hit[1], self.hit.dom_id) - def test_str(self): - self.assertEqual(repr(self.hit).split('\n\t')[0], 'offline hit:') - self.assertEqual( - repr(self.hit).split('\n\t')[2], - 'dom_id : 806451572') - class TestOfflineTracks(unittest.TestCase): def setUp(self): @@ -477,8 +350,7 @@ class TestOfflineTrack(unittest.TestCase): self.assertEqual(self.track[10], self.track.E) def test_str(self): - self.assertEqual(repr(self.track).split('\n\t')[0], 'offline track:') - self.assertTrue("JGANDALF_LAMBDA" in repr(self.track)) + self.assertEqual(str(self.track).split('\n\t')[0], 'offline track:') class TestUsr(unittest.TestCase):