diff --git a/km3io/aanet.py b/km3io/aanet.py index 026b6e31e159e5c152631d16436c142e5ef57c30..3960d88f7fb49c053f66cc9803f68ee32a40d090 100644 --- a/km3io/aanet.py +++ b/km3io/aanet.py @@ -1,28 +1,251 @@ import uproot +# 110 MB based on the size of the largest basket found so far in km3net +BASKET_CACHE_SIZE = 110 * 1024**2 -class AanetReader: - """Reader for one Aanet ROOT file""" + +class AanetKeys: + """wrapper for aanet keys""" def __init__(self, file_path): - """ AanetReader class is a Aanet ROOT file wrapper + """AanetKeys is a class that reads all the available keys in an aanet + file and adapts the keys format to Python format. Parameters ---------- file_path : path-like object - Path to the file of interest. It can be a str or any python + Path to the aanet file of interest. It can be a str or any python path-like object that points to the file of ineterst. """ - self.file_path = file_path - self.data = uproot.open(self.file_path)['E'] - self.lazy_data = self.data.lazyarrays() + self._file_path = file_path self._events_keys = None self._hits_keys = None self._tracks_keys = None self._mc_hits_keys = None self._mc_tracks_keys = None + self._valid_keys = None + self._fit_keys = None + self._cut_hits_keys = None + self._cut_tracks_keys = None + self._cut_events_keys = None + + def __str__(self): + return '\n'.join([ + "Events keys are:\n\t" + "\n\t".join(self.events_keys), + "Hits keys are:\n\t" + '\n\t'.join(self.hits_keys), + "Tracks keys are:\n\t" + '\n\t'.join(self.tracks_keys), + "Mc hits keys are:\n\t" + '\n\t'.join(self.mc_hits_keys), + "Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys) + ]) + + def __repr__(self): + return str(self) + # return f'{self.__class__.__name__}("{self._file_path}")' + + @property + def events_keys(self): + """reads events keys from an aanet file. + + Returns + ------- + list of str + list of all events keys found in an aanet file, + except those found in fake branches. + """ + if self._events_keys is None: + fake_branches = ['Evt', 'AAObject', 'TObject', 't'] + t_baskets = ['t.fSec', 't.fNanoSec'] + tree = uproot.open(self._file_path)['E']['Evt'] + self._events_keys = [ + key.decode('utf-8') for key in tree.keys() + if key.decode('utf-8') not in fake_branches + ] + t_baskets + return self._events_keys + + @property + def hits_keys(self): + """reads hits keys from an aanet file. + + Returns + ------- + list of str + list of all hits keys found in an aanet file, + except those found in fake branches. + """ + if self._hits_keys is None: + fake_branches = [ + 'hits.usr', 'hits.usr_names' + ] # to be treated like trks.usr and trks.usr_names + tree = uproot.open(self._file_path)['E']['hits'] + self._hits_keys = [ + key.decode('utf8') for key in tree.keys() + if key.decode('utf8') not in fake_branches + ] + return self._hits_keys + + @property + def tracks_keys(self): + """reads tracks keys from an aanet file. + + Returns + ------- + list of str + list of all tracks keys found in an aanet file, + except those found in fake branches. + """ + if self._tracks_keys is None: + # a solution can be tree['trks.usr_data'].array( + # uproot.asdtype(">i4")) + fake_branches = [ + 'trks.usr_data', 'trks.usr', 'trks.usr_names' + ] # can be accessed using tree['trks.usr_names'].array() + tree = uproot.open(self._file_path)['E']['Evt']['trks'] + self._tracks_keys = [ + key.decode('utf8') for key in tree.keys() + if key.decode('utf8') not in fake_branches + ] + return self._tracks_keys + + @property + def mc_hits_keys(self): + """reads mc hits keys from an aanet file. + + Returns + ------- + list of str + list of all mc hits keys found in an aanet file, + except those found in fake branches. + """ + if self._mc_hits_keys is None: + fake_branches = ['mc_hits.usr', 'mc_hits.usr_names'] + tree = uproot.open(self._file_path)['E']['Evt']['mc_hits'] + self._mc_hits_keys = [ + key.decode('utf8') for key in tree.keys() + if key.decode('utf8') not in fake_branches + ] + return self._mc_hits_keys + + @property + def mc_tracks_keys(self): + """reads mc tracks keys from an aanet file. + + Returns + ------- + list of str + list of all mc tracks keys found in an aanet file, + except those found in fake branches. + """ + if self._mc_tracks_keys is None: + fake_branches = [ + 'mc_trks.usr_data', 'mc_trks.usr', 'mc_trks.usr_names' + ] # same solution as above can be used + tree = uproot.open(self._file_path)['E']['Evt']['mc_trks'] + self._mc_tracks_keys = [ + key.decode('utf8') for key in tree.keys() + if key.decode('utf8') not in fake_branches + ] + return self._mc_tracks_keys + + @property + def valid_keys(self): + """constructs a list of all valid keys to be read from an Aanet event file. + Returns + ------- + list of str + list of all valid keys. + """ + if self._valid_keys is None: + self._valid_keys = (self.events_keys + self.hits_keys + + self.tracks_keys + self.mc_tracks_keys + + self.mc_hits_keys) + return self._valid_keys + + @property + def fit_keys(self): + """constructs a list of fit parameters, not yet outsourced in an aanet file. + + Returns + ------- + list of str + list of all "trks.fitinf" keys. + """ + if self._fit_keys is None: + # these are hardcoded because they are not outsourced in aanet file + self._fit_keys = [ + 'JGANDALF_BETA0_RAD', 'JGANDALF_BETA1_RAD', 'JGANDALF_CHI2', + 'JGANDALF_NUMBER_OF_HITS', 'JENERGY_ENERGY', 'JENERGY_CHI2', + 'JGANDALF_LAMBDA', 'JGANDALF_NUMBER_OF_ITERATIONS', + 'JSTART_NPE_MIP', 'JSTART_NPE_MIP_TOTAL', + 'JSTART_LENGTH_METRES', 'JVETO_NPE', 'JVETO_NUMBER_OF_HITS', + 'JENERGY_MUON_RANGE_METRES', 'JENERGY_NOISE_LIKELIHOOD', + 'JENERGY_NDF', 'JENERGY_NUMBER_OF_HITS', 'JCOPY_Z_M' + ] + return self._fit_keys + + @property + def cut_hits_keys(self): + """adapts hits keys for instance variables format in a Python class. + + Returns + ------- + list of str + list of adapted hits keys. + """ + if self._cut_hits_keys is None: + self._cut_hits_keys = [ + k.split('hits.')[1].replace('.', '_') for k in self.hits_keys + ] + return self._cut_hits_keys + + @property + def cut_tracks_keys(self): + """adapts tracks keys for instance variables format in a Python class. + + Returns + ------- + list of str + list of adapted tracks keys. + """ + if self._cut_tracks_keys is None: + self._cut_tracks_keys = [ + k.split('trks.')[1].replace('.', '_') for k in self.tracks_keys + ] + return self._cut_tracks_keys + + @property + def cut_events_keys(self): + """adapts events keys for instance variables format in a Python class. + + Returns + ------- + list of str + list of adapted events keys. + """ + if self._cut_events_keys is None: + self._cut_events_keys = [ + k.replace('.', '_') for k in self.events_keys + ] + return self._cut_events_keys + + +class Reader: + """Reader for one Aanet ROOT file""" + def __init__(self, file_path): + """ AanetReader class is a Aanet ROOT file reader. This class is a + "very" low level I/O. + + Parameters + ---------- + file_path : path-like object + Path to the file of interest. It can be a str or any python + path-like object that points to the file of ineterst. + """ + self._file_path = file_path + self._data = uproot.open(self._file_path)['E'].lazyarrays( + basketcache=uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)) + self._keys = None def __getitem__(self, key): - """reads data stored in the branch of interest in an event tree. + """reads data stored in the branch of interest in an Evt tree. Parameters ---------- @@ -41,78 +264,360 @@ class AanetReader: Raises ------ - KeyEroor + KeyError Some branches in an Aanet file structure are "fake branches" and do not contain data. Therefore, the keys corresponding to these fake branches are not read. """ - if key not in self.keys() and not isinstance(key, int): + if key not in self.keys.valid_keys and not isinstance(key, int): raise KeyError( "'{}' is not a valid key or is a fake branch.".format(key)) - return self.lazy_data[key] + return self._data[key] def __len__(self): - return len(self.lazy_data) + return len(self._data) def __repr__(self): - return '\n'.join([ - "Number of events: {}".format(self.__len__()), - "Events keys are:\n\t" + '\n\t'.join(self.events_keys), - "Hits keys are:\n\t" + '\n\t'.join(self.hits_keys), - "Tracks keys are:\n\t" + '\n\t'.join(self.tracks_keys), - "Mc hits keys are:\n\t" + '\n\t'.join(self.mc_hits_keys), - "Mc tracks keys are:\n\t" + '\n\t'.join(self.mc_tracks_keys) - ]) + return "<{}: {} entries>".format(self.__class__.__name__, len(self)) + @property def keys(self): - """constructs a list of all valid keys to be read from an Aanet event file. + """wrapper for all keys in an aanet file. Returns ------- - list - list of all valid keys. + Class + AanetKeys. + """ + if self._keys is None: + self._keys = AanetKeys(self._file_path) + return self._keys + + +class AanetReader: + """reader for Aanet ROOT files""" + def __init__(self, file_path, data=None): + """ AanetReader class is an aanet ROOT file wrapper + + Parameters + ---------- + file_path : path-like object + Path to the file of interest. It can be a str or any python + path-like object that points to the file of ineterst. """ - return self.events_keys + self.hits_keys + self.tracks_keys + self.mc_tracks_keys + self.mc_hits_keys + self._file_path = file_path + if data is not None: + self._data = data + else: + self._data = uproot.open(self._file_path)['E'].lazyarrays( + basketcache=uproot.cache.ThreadSafeArrayCache( + BASKET_CACHE_SIZE)) + self._events = None + self._hits = None + self._tracks = None + self._mc_hits = None + self._mc_tracks = None + self._keys = None + + def __getitem__(self, item): + return AanetReader(file_path=self._file_path, data=self._data[item]) + + def __len__(self): + return len(self._data) @property - def events_keys(self): - if self._events_keys is None: - fake_branches = ['Evt', 'AAObject', 'TObject', 't'] - t_baskets = ['t.fSec', 't.fNanoSec'] - self._events_keys = [ - key.decode('utf-8') for key in self.data['Evt'].keys() - if key.decode('utf-8') not in fake_branches - ] + t_baskets - return self._events_keys + def keys(self): + """wrapper for all keys in an aanet file. + + Returns + ------- + Class + AanetKeys. + """ + if self._keys is None: + self._keys = AanetKeys(self._file_path) + return self._keys @property - def hits_keys(self): - if self._hits_keys is None: - hits_tree = self.data['Evt']['hits'] - self._hits_keys = [key.decode('utf8') for key in hits_tree.keys()] - return self._hits_keys + def events(self): + """wrapper for aanet events. + + Returns + ------- + Class + AanetEvents. + """ + if self._events is None: + self._events = AanetEvents( + self.keys.cut_events_keys, + [self._data[key] for key in self.keys.events_keys]) + return self._events @property - def tracks_keys(self): - if self._tracks_keys is None: - tracks_tree = self.data['Evt']['trks'] - self._tracks_keys = [ - key.decode('utf8') for key in tracks_tree.keys() - ] - return self._tracks_keys + def hits(self): + """wrapper for aanet hits. + + Returns + ------- + Class + AanetHits. + """ + if self._hits is None: + self._hits = AanetHits( + self.keys.cut_hits_keys, + [self._data[key] for key in self.keys.hits_keys]) + return self._hits @property - def mc_hits_keys(self): - if self._mc_hits_keys is None: - mc_hits_tree = self.data['Evt']['mc_hits'] - self._mc_hits_keys = [key.decode('utf8') for key in mc_hits_tree.keys()] - return self._mc_hits_keys + def tracks(self): + """wrapper for aanet tracks. + + Returns + ------- + Class + AanetTracks. + """ + if self._tracks is None: + self._tracks = AanetTracks( + self.keys.cut_tracks_keys, + [self._data[key] for key in self.keys.tracks_keys], + fit_keys=self.keys.fit_keys) + return self._tracks @property - def mc_tracks_keys(self): - if self._mc_tracks_keys is None: - mc_tracks_tree = self.data['Evt']['mc_trks'] - self._mc_tracks_keys = [ - key.decode('utf8') for key in mc_tracks_tree.keys() - ] - return self._mc_tracks_keys + def mc_hits(self): + """wrapper for aanet mc hits. + + Returns + ------- + Class + AanetHits. + """ + if self._mc_hits is None: + self._mc_hits = AanetHits( + self.keys.cut_hits_keys, + [self._data[key] for key in self.keys.mc_hits_keys]) + return self._mc_hits + + @property + def mc_tracks(self): + """wrapper for aanet mc tracks. + + Returns + ------- + Class + AanetTracks. + """ + if self._mc_tracks is None: + self._mc_tracks = AanetTracks( + self.keys.cut_tracks_keys, + [self._data[key] for key in self.keys.mc_tracks_keys]) + return self._mc_tracks + + +class AanetEvents: + """wrapper for Aanet events""" + def __init__(self, keys, values): + """wrapper for aanet events. + + Parameters + ---------- + keys : list of str + list of valid events keys. + values : list of arrays + list of arrays containting events data. + """ + self._keys = keys + self._values = values + for k, v in zip(self._keys, self._values): + setattr(self, k, v) + + def __getitem__(self, item): + return AanetEvent(self._keys, [v[item] for v in self._values]) + + def __len__(self): + try: + return len(self._values[0]) + except IndexError: + return 0 + + def __str__(self): + return "Number of events: {}".format(len(self)) + + def __repr__(self): + return "<{}: {} parsed events>".format(self.__class__.__name__, + len(self)) + + +class AanetEvent: + """wrapper for an Aanet event""" + def __init__(self, keys, values): + """wrapper for one aanet event. + + Parameters + ---------- + keys : list of str + list of valid events keys. + values : list of arrays + list of arrays containting event data. + """ + self._keys = keys + self._values = values + for k, v in zip(self._keys, self._values): + setattr(self, k, v) + + def __str__(self): + return "Aanet event:\n\t" + "\n\t".join([ + "{:15} {:^10} {:>10}".format(k, ':', str(v)) + for k, v in zip(self._keys, self._values) + ]) + + def __repr__(self): + return str(self) + + +class AanetHits: + """wrapper for Aanet hits, manages the display of all hits in one event""" + def __init__(self, keys, values): + """wrapper for aanet hits. + + Parameters + ---------- + keys : list of str + list of cropped hits keys. + values : list of arrays + list of arrays containting hits data. + """ + self._keys = keys + self._values = values + for k, v in zip(self._keys, self._values): + setattr(self, k, v) + + def __getitem__(self, item): + return AanetHit(self._keys, [v[item] for v in self._values]) + + def __len__(self): + try: + return len(self._values[0]) + except IndexError: + return 0 + + def __str__(self): + return "Number of hits: {}".format(len(self)) + + def __repr__(self): + return "<{}: {} parsed elements>".format(self.__class__.__name__, + len(self)) + + +class AanetHit: + """wrapper for an Aanet hit""" + def __init__(self, keys, values): + """wrapper for one aanet hit. + + Parameters + ---------- + keys : list of str + list of cropped hits keys. + values : list of arrays + list of arrays containting hit data. + """ + self._keys = keys + self._values = values + for k, v in zip(self._keys, self._values): + setattr(self, k, v) + + def __str__(self): + return "Aanet hit:\n\t" + "\n\t".join([ + "{:15} {:^10} {:>10}".format(k, ':', str(v)) + for k, v in zip(self._keys, self._values) + ]) + + def __getitem__(self, item): + return self._values[item] + + def __repr__(self): + return str(self) + + # def _is_empty(array): + # if array.size: + # return False + # else: + # return True + + +class AanetTracks: + """wrapper for Aanet tracks""" + def __init__(self, keys, values, fit_keys=None): + """Summary + + Parameters + ---------- + keys : TYPE + Description + values : TYPE + Description + fit_keys : None, optional + list of tracks fit information (not yet outsourced in aanet files) + """ + self._keys = keys + self._values = values + if fit_keys is not None: + self._fit_keys = fit_keys + for k, v in zip(self._keys, self._values): + setattr(self, k, v) + + def __getitem__(self, item): + return AanetTrack(self._keys, [v[item] for v in self._values], + fit_keys=self._fit_keys) + + def __len__(self): + try: + return len(self._values[0]) + except IndexError: + return 0 + + def __str__(self): + return "Number of tracks: {}".format(len(self)) + + def __repr__(self): + return "<{}: {} parsed elements>".format(self.__class__.__name__, + len(self)) + + +class AanetTrack: + """wrapper for an Aanet track""" + def __init__(self, keys, values, fit_keys=None): + """wrapper for one aanet track. + + Parameters + ---------- + keys : list of str + list of cropped tracks keys. + values : list of arrays + list of arrays containting track data. + fit_keys : None, optional + list of tracks fit information (not yet outsourced in aanet files). + """ + self._keys = keys + self._values = values + if fit_keys is not None: + self._fit_keys = fit_keys + for k, v in zip(self._keys, self._values): + setattr(self, k, v) + + def __str__(self): + return "Aanet track:\n\t" + "\n\t".join([ + "{:30} {:^2} {:>26}".format(k, ':', str(v)) + for k, v in zip(self._keys, self._values) if k not in ['fitinf'] + ]) + "\n\t" + "\n\t".join([ + "{:30} {:^2} {:>26}".format(k, ':', str(v)) + for k, v in zip(self._fit_keys, self._values[18] + ) # I don't like 18 being explicit here + ]) + + def __getitem__(self, item): + return self._values[item] + + def __repr__(self): + return str(self) diff --git a/notebooks/AanetReader_tutorial.ipynb b/notebooks/AanetReader_tutorial.ipynb index 8614d484fcbc0fc629af9ae491e029de8f5f57c6..3a225d24ba811a5fe7897e8d9ff8712218e37970 100644 --- a/notebooks/AanetReader_tutorial.ipynb +++ b/notebooks/AanetReader_tutorial.ipynb @@ -4,7 +4,15 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['/home/zineb/km3net/km3net/km3io/notebooks', '/home/zineb/miniconda3/envs/km3pipe/lib/python37.zip', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7/lib-dynload', '', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7/site-packages', '/home/zineb/km3net/km3net/km3io', '/home/zineb/miniconda3/envs/km3pipe/lib/python3.7/site-packages/IPython/extensions', '/home/zineb/.ipython', '/home/zineb/km3net/km3net/km3io']\n" + ] + } + ], "source": [ "# Add file to current python path\n", "from pathlib import Path\n", @@ -25,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -36,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 12, "metadata": { "scrolled": true }, @@ -44,7 +52,6 @@ { "data": { "text/plain": [ - "Number of events: 10\n", "Events keys are:\n", "\tid\n", "\tdet_id\n", @@ -113,10 +120,56 @@ "\ttrks.fitinf\n", "\ttrks.hit_ids\n", "\ttrks.error_matrix\n", - "\ttrks.comment" + "\ttrks.comment\n", + "Mc hits keys are:\n", + "\tmc_hits.id\n", + "\tmc_hits.dom_id\n", + "\tmc_hits.channel_id\n", + "\tmc_hits.tdc\n", + "\tmc_hits.tot\n", + "\tmc_hits.trig\n", + "\tmc_hits.pmt_id\n", + "\tmc_hits.t\n", + "\tmc_hits.a\n", + "\tmc_hits.pos.x\n", + "\tmc_hits.pos.y\n", + "\tmc_hits.pos.z\n", + "\tmc_hits.dir.x\n", + "\tmc_hits.dir.y\n", + "\tmc_hits.dir.z\n", + "\tmc_hits.pure_t\n", + "\tmc_hits.pure_a\n", + "\tmc_hits.type\n", + "\tmc_hits.origin\n", + "\tmc_hits.pattern_flags\n", + "Mc tracks keys are:\n", + "\tmc_trks.fUniqueID\n", + "\tmc_trks.fBits\n", + "\tmc_trks.usr_data\n", + "\tmc_trks.usr_names\n", + "\tmc_trks.id\n", + "\tmc_trks.pos.x\n", + "\tmc_trks.pos.y\n", + "\tmc_trks.pos.z\n", + "\tmc_trks.dir.x\n", + "\tmc_trks.dir.y\n", + "\tmc_trks.dir.z\n", + "\tmc_trks.t\n", + "\tmc_trks.E\n", + "\tmc_trks.len\n", + "\tmc_trks.lik\n", + "\tmc_trks.type\n", + "\tmc_trks.rec_type\n", + "\tmc_trks.rec_stages\n", + "\tmc_trks.status\n", + "\tmc_trks.mother_id\n", + "\tmc_trks.fitinf\n", + "\tmc_trks.hit_ids\n", + "\tmc_trks.error_matrix\n", + "\tmc_trks.comment" ] }, - "execution_count": 25, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -134,7 +187,30 @@ { "data": { "text/plain": [ - "<Table [<Row 0> <Row 1> <Row 2> ... <Row 7> <Row 8> <Row 9>] at 0x7fb2342f1f10>" + "['trks.fUniqueID',\n", + " 'trks.fBits',\n", + " 'trks.usr_data',\n", + " 'trks.usr_names',\n", + " 'trks.id',\n", + " 'trks.pos.x',\n", + " 'trks.pos.y',\n", + " 'trks.pos.z',\n", + " 'trks.dir.x',\n", + " 'trks.dir.y',\n", + " 'trks.dir.z',\n", + " 'trks.t',\n", + " 'trks.E',\n", + " 'trks.len',\n", + " 'trks.lik',\n", + " 'trks.type',\n", + " 'trks.rec_type',\n", + " 'trks.rec_stages',\n", + " 'trks.status',\n", + " 'trks.mother_id',\n", + " 'trks.fitinf',\n", + " 'trks.hit_ids',\n", + " 'trks.error_matrix',\n", + " 'trks.comment']" ] }, "execution_count": 5, @@ -143,9 +219,7 @@ } ], "source": [ - "# big lazyarray with ALL file data!\n", - "lazy_data = reader.lazy_data\n", - "lazy_data" + "reader.tracks_keys" ] }, { @@ -156,7 +230,7 @@ { "data": { "text/plain": [ - "5971" + "<Table [<Row 0> <Row 1> <Row 2> ... <Row 7> <Row 8> <Row 9>] at 0x7f83f9d62c10>" ] }, "execution_count": 6, @@ -165,8 +239,9 @@ } ], "source": [ - "# getting the run_id for a specific event (event 5 for example)\n", - "reader[5]['run_id']" + "# big lazyarray with ALL file data!\n", + "lazy_data = reader._lazy_data\n", + "lazy_data" ] }, { @@ -177,7 +252,7 @@ { "data": { "text/plain": [ - "60" + "<ChunkedArray [5971 5971 5971 ... 5971 5971 5971] at 0x7f83f9d2cdd0>" ] }, "execution_count": 7, @@ -186,8 +261,8 @@ } ], "source": [ - "# one can check how many hits are in event 5\n", - "reader[5]['hits']" + "# getting the run_id for a specific event (event 5 for example)\n", + "reader['run_id']" ] }, { @@ -198,7 +273,7 @@ { "data": { "text/plain": [ - "56" + "60" ] }, "execution_count": 8, @@ -207,8 +282,8 @@ } ], "source": [ - "# one can also check how many tracks are in event 5\n", - "reader[5]['trks']" + "# one can check how many hits are in event 5\n", + "reader[5]['hits']" ] }, { @@ -217,20 +292,33 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "\"'whatever' is not a valid key or is a fake branch.\"\n" - ] + "data": { + "text/plain": [ + "56" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "# the user is reminded to always specify the \"correct\" event/hits/tracks \n", - "# key in the Aanet event file\n", - "try:\n", - " reader['whatever']\n", - "except KeyError as e:\n", - " print(e)" + "# one can also check how many tracks are in event 5\n", + "reader[5]['trks']" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# # the user is reminded to always specify the \"correct\" event/hits/tracks \n", + "# # key in the Aanet event file\n", + "# try:\n", + "# reader['whatever']\n", + "# except KeyError as e:\n", + "# print(e)" ] }, { @@ -242,16 +330,16 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "<ChunkedArray [[806451572 806451572 806451572 ... 809544061 809544061 809544061] [806451572 806451572 806451572 ... 809524432 809526097 809544061] [806451572 806451572 806451572 ... 809544061 809544061 809544061] ... [806451572 806455814 806465101 ... 809526097 809544058 809544061] [806455814 806455814 806455814 ... 809544061 809544061 809544061] [806455814 806455814 806455814 ... 809544058 809544058 809544061]] at 0x7fb23426e190>" + "<ChunkedArray [[806451572 806451572 806451572 ... 809544061 809544061 809544061] [806451572 806451572 806451572 ... 809524432 809526097 809544061] [806451572 806451572 806451572 ... 809544061 809544061 809544061] ... [806451572 806455814 806465101 ... 809526097 809544058 809544061] [806455814 806455814 806455814 ... 809544061 809544061 809544061] [806455814 806455814 806455814 ... 809544058 809544058 809544061]] at 0x7f83f9a13890>" ] }, - "execution_count": 10, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -264,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -283,7 +371,7 @@ " dtype=int32)" ] }, - "execution_count": 11, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -296,7 +384,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -305,7 +393,7 @@ "60" ] }, - "execution_count": 12, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -319,7 +407,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -328,7 +416,7 @@ "806455814" ] }, - "execution_count": 13, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -347,16 +435,16 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "<ChunkedArray [[-0.872885221293917 -0.872885221293917 -0.872885221293917 ... -0.6631226836266504 -0.5680647731737454 -0.5680647731737454] [-0.8351996698137462 -0.8351996698137462 -0.8351996698137462 ... -0.7485107718446855 -0.8229838871876581 -0.239315690284641] [-0.989148723802379 -0.989148723802379 -0.989148723802379 ... -0.9350162572437829 -0.88545604390297 -0.88545604390297] ... [-0.5704611045902105 -0.5704611045902105 -0.5704611045902105 ... -0.9350162572437829 -0.4647231989130516 -0.4647231989130516] [-0.9779941383490359 -0.9779941383490359 -0.9779941383490359 ... -0.88545604390297 -0.88545604390297 -0.8229838871876581] [-0.7396916780974963 -0.7396916780974963 -0.7396916780974963 ... -0.6631226836266504 -0.7485107718446855 -0.7485107718446855]] at 0x7fb234277590>" + "<ChunkedArray [[-0.872885221293917 -0.872885221293917 -0.872885221293917 ... -0.6631226836266504 -0.5680647731737454 -0.5680647731737454] [-0.8351996698137462 -0.8351996698137462 -0.8351996698137462 ... -0.7485107718446855 -0.8229838871876581 -0.239315690284641] [-0.989148723802379 -0.989148723802379 -0.989148723802379 ... -0.9350162572437829 -0.88545604390297 -0.88545604390297] ... [-0.5704611045902105 -0.5704611045902105 -0.5704611045902105 ... -0.9350162572437829 -0.4647231989130516 -0.4647231989130516] [-0.9779941383490359 -0.9779941383490359 -0.9779941383490359 ... -0.88545604390297 -0.88545604390297 -0.8229838871876581] [-0.7396916780974963 -0.7396916780974963 -0.7396916780974963 ... -0.6631226836266504 -0.7485107718446855 -0.7485107718446855]] at 0x7f83f9b65f90>" ] }, - "execution_count": 14, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -368,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -388,7 +476,7 @@ " -0.97094183])" ] }, - "execution_count": 15, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -401,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -410,7 +498,7 @@ "56" ] }, - "execution_count": 16, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -424,7 +512,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -433,7 +521,7 @@ "-0.6024604933159441" ] }, - "execution_count": 17, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } diff --git a/tests/samples/numucc.root b/tests/samples/numucc.root new file mode 100644 index 0000000000000000000000000000000000000000..dc3071ceffea0f13767eaf5169d438c7a7db624e Binary files /dev/null and b/tests/samples/numucc.root differ diff --git a/tests/test_aanet.py b/tests/test_aanet.py index 2b307df1b4404680d212eb25c23c86d19f103ff9..c4c4895aa42e74b3e8832d25d9d42623219632ff 100644 --- a/tests/test_aanet.py +++ b/tests/test_aanet.py @@ -1,15 +1,54 @@ import unittest from pathlib import Path +from km3io.aanet import Reader, AanetEvents, AanetHits, AanetTracks from km3io import AanetReader SAMPLES_DIR = Path(__file__).parent / 'samples' AANET_FILE = SAMPLES_DIR / 'aanet_v2.0.0.root' +AANET_NUMUCC = SAMPLES_DIR / "numucc.root" # with mc data -class TestAanetReader(unittest.TestCase): +class TestAanetKeys(unittest.TestCase): def setUp(self): - self.r = AanetReader(AANET_FILE) + self.keys = AanetReader(AANET_FILE).keys + + def test_repr(self): + reader_repr = repr(self.keys) + + # check that there are 106 keys + 5 extra str + self.assertEqual(len(reader_repr.split('\n')), 111) + + def test_events_keys(self): + # there are 22 "valid" events keys + self.assertEqual(len(self.keys.events_keys), 22) + self.assertEqual(len(self.keys.cut_events_keys), 22) + + def test_hits_keys(self): + # there are 20 "valid" hits keys + self.assertEqual(len(self.keys.hits_keys), 20) + self.assertEqual(len(self.keys.mc_hits_keys), 20) + self.assertEqual(len(self.keys.cut_hits_keys), 20) + + def test_tracks_keys(self): + # there are 22 "valid" tracks keys + self.assertEqual(len(self.keys.tracks_keys), 22) + self.assertEqual(len(self.keys.mc_tracks_keys), 22) + self.assertEqual(len(self.keys.cut_tracks_keys), 22) + + def test_valid_keys(self): + # there are 106 valid keys: 22*2 + 22 + 20*2 + # (fit keys are excluded) + self.assertEqual(len(self.keys.valid_keys), 106) + + def test_fit_keys(self): + # there are 18 fit keys + self.assertEqual(len(self.keys.fit_keys), 18) + + +class TestReader(unittest.TestCase): + def setUp(self): + self.r = Reader(AANET_FILE) self.lengths = {0: 176, 1: 125, -1: 105} self.total_item_count = 1434 @@ -48,27 +87,18 @@ class TestAanetReader(unittest.TestCase): self.assertListEqual([70104010.0, 70104016.0, 70104192.0], list(ts[0][:3])) - def test_reading_hits_keys(self): - keys = self.r.hits_keys - mc_keys = self.r.mc_hits_keys - - # there are 20 hits keys - self.assertEqual(len(keys), 20) - self.assertEqual(len(mc_keys), 20) - - def test_reading_tracks_keys(self): - keys = self.r.tracks_keys - mc_keys = self.r.mc_tracks_keys - - # there are 24 tracks keys - self.assertEqual(len(keys), 24) - self.assertEqual(len(mc_keys), 24) def test_reading_keys(self): - all_keys = self.r.keys() + # there are 106 "valid" keys in Aanet file + self.assertEqual(len(self.r.keys.valid_keys), 106) + + # there are 20 hits keys + self.assertEqual(len(self.r.keys.hits_keys), 20) + self.assertEqual(len(self.r.keys.mc_hits_keys), 20) - # there are 110 valid keys in Aanet file - self.assertEqual(len(all_keys), 110) + # there are 22 tracks keys + self.assertEqual(len(self.r.keys.tracks_keys), 22) + self.assertEqual(len(self.r.keys.mc_tracks_keys), 22) def test_raising_KeyError(self): # non valid keys must raise a KeyError @@ -77,9 +107,215 @@ class TestAanetReader(unittest.TestCase): def test_number_events(self): Nevents = len(self.r) - reader_repr = repr(self.r) # check that there are 10 events self.assertEqual(Nevents, 10) - # check that there are 110 keys + 6 extra str - self.assertEqual(len(reader_repr.split('\n')), 116) + + +class TestAanetReader(unittest.TestCase): + def setUp(self): + self.r = AanetReader(AANET_FILE) + self.Nevents = 10 + self.selected_data = AanetReader(AANET_FILE, + data=self.r._data[0])._data + + def test_item_selection(self): + # test class instance with data=None option + self.assertEqual(len(self.selected_data), len(self.r._data[0])) + + # test item selection (here we test with hits=176) + self.assertEqual(self.r[0].events.hits, self.selected_data['hits']) + + def test_number_events(self): + Nevents = len(self.r) + + # check that there are 10 events + self.assertEqual(Nevents, self.Nevents) + + +class TestAanetEvents(unittest.TestCase): + def setUp(self): + self.events = AanetReader(AANET_FILE).events + self.hits = {0: 176, 1: 125, -1: 105} + self.Nevents = 10 + + def test_reading_hits(self): + # test item selection + for event_id, hit in self.hits.items(): + self.assertEqual(hit, self.events.hits[event_id]) + + def reading_tracks(self): + self.assertListEqual(list(self.events.trks[:3]), [56, 55, 56]) + + def test_item_selection(self): + for event_id, hit in self.hits.items(): + self.assertEqual(hit, self.events[event_id].hits) + + def test_len(self): + self.assertEqual(len(self.events), self.Nevents) + + def test_IndexError(self): + # test handling IndexError with empty lists/arrays + self.assertEqual(len(AanetEvents(['whatever'], [])), 0) + + def test_str(self): + self.assertEqual(str(self.events), 'Number of events: 10') + + def test_repr(self): + self.assertEqual(repr(self.events), '<AanetEvents: 10 parsed events>') + + +class TestAanetEvent(unittest.TestCase): + def setUp(self): + self.event = AanetReader(AANET_FILE).events[0] + + def test_str(self): + self.assertEqual(repr(self.event).split('\n\t')[0], 'Aanet event:') + self.assertEqual( + repr(self.event).split('\n\t')[2], + 'det_id : 44') + + +class TestAanetHits(unittest.TestCase): + def setUp(self): + self.hits = AanetReader(AANET_FILE).hits + self.lengths = {0: 176, 1: 125, -1: 105} + self.total_item_count = 1434 + self.r_mc = AanetReader(AANET_NUMUCC) + self.Nevents = 10 + + def test_item_selection(self): + self.assertListEqual(list(self.hits[0].dom_id[:3]), + [806451572, 806451572, 806451572]) + + def test_IndexError(self): + # test handling IndexError with empty lists/arrays + self.assertEqual(len(AanetHits(['whatever'], [])), 0) + + def test_repr(self): + self.assertEqual(repr(self.hits), '<AanetHits: 10 parsed elements>') + + def test_str(self): + self.assertEqual(str(self.hits), 'Number of hits: 10') + + def test_reading_dom_id(self): + dom_ids = self.hits.dom_id + + for event_id, length in self.lengths.items(): + self.assertEqual(length, len(dom_ids[event_id])) + + self.assertEqual(self.total_item_count, sum(dom_ids.count())) + + self.assertListEqual([806451572, 806451572, 806451572], + list(dom_ids[0][:3])) + + def test_reading_channel_id(self): + channel_ids = self.hits.channel_id + + for event_id, length in self.lengths.items(): + self.assertEqual(length, len(channel_ids[event_id])) + + self.assertEqual(self.total_item_count, sum(channel_ids.count())) + + self.assertListEqual([8, 9, 14], list(channel_ids[0][:3])) + + # channel IDs are always between [0, 30] + self.assertTrue(all(c >= 0 for c in channel_ids.min())) + self.assertTrue(all(c < 31 for c in channel_ids.max())) + + def test_reading_times(self): + ts = self.hits.t + + for event_id, length in self.lengths.items(): + self.assertEqual(length, len(ts[event_id])) + + self.assertEqual(self.total_item_count, sum(ts.count())) + + self.assertListEqual([70104010.0, 70104016.0, 70104192.0], + list(ts[0][:3])) + + def test_reading_mc_pmt_id(self): + pmt_ids = self.r_mc.mc_hits.pmt_id + lengths = {0: 58, 2: 28, -1: 48} + + for hit_id, length in lengths.items(): + self.assertEqual(length, len(pmt_ids[hit_id])) + + self.assertEqual(self.Nevents, len(pmt_ids)) + + self.assertListEqual([677, 687, 689], list(pmt_ids[0][:3])) + + +class TestAanetHit(unittest.TestCase): + def setUp(self): + self.hit = AanetReader(AANET_FILE)[0].hits[0] + + def test_item_selection(self): + self.assertEqual(self.hit[0], self.hit.id) + self.assertEqual(self.hit[1], self.hit.dom_id) + + def test_str(self): + self.assertEqual(repr(self.hit).split('\n\t')[0], 'Aanet hit:') + self.assertEqual( + repr(self.hit).split('\n\t')[2], + 'dom_id : 806451572') + + +class TestAanetTracks(unittest.TestCase): + def setUp(self): + self.tracks = AanetReader(AANET_FILE).tracks + self.r_mc = AanetReader(AANET_NUMUCC) + self.Nevents = 10 + + def test_item_selection(self): + self.assertListEqual(list(self.tracks[0].dir_z[:2]), + [-0.872885221293917, -0.872885221293917]) + + def test_IndexError(self): + # test handling IndexError with empty lists/arrays + self.assertEqual(len(AanetTracks(['whatever'], [])), 0) + + def test_repr(self): + self.assertEqual(repr(self.tracks), + '<AanetTracks: 10 parsed elements>') + + def test_str(self): + self.assertEqual(str(self.tracks), 'Number of tracks: 10') + + def test_reading_tracks_dir_z(self): + dir_z = self.tracks.dir_z + tracks_dir_z = {0: 56, 1: 55, 8: 54} + + for track_id, n_dir in tracks_dir_z.items(): + self.assertEqual(n_dir, len(dir_z[track_id])) + + # check that there are 10 arrays of tracks.dir_z info + self.assertEqual(len(dir_z), self.Nevents) + + def test_reading_mc_tracks_dir_z(self): + dir_z = self.r_mc.mc_tracks.dir_z + tracks_dir_z = {0: 11, 1: 25, 8: 13} + + for track_id, n_dir in tracks_dir_z.items(): + self.assertEqual(n_dir, len(dir_z[track_id])) + + # check that there are 10 arrays of tracks.dir_z info + self.assertEqual(len(dir_z), self.Nevents) + + self.assertListEqual([0.230189, 0.230189, 0.218663], + list(dir_z[0][:3])) + + +class TestAanetTrack(unittest.TestCase): + def setUp(self): + self.track = AanetReader(AANET_FILE)[0].tracks[0] + + def test_item_selection(self): + self.assertEqual(self.track[0], self.track.fUniqueID) + self.assertEqual(self.track[10], self.track.E) + + def test_str(self): + self.assertEqual(repr(self.track).split('\n\t')[0], 'Aanet track:') + self.assertEqual( + repr(self.track).split('\n\t')[28], + 'JGANDALF_LAMBDA : 4.2409761837248484e-12')