From 8b50fcc4e09a17061943d0e3cae2c1df818eddf7 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Mon, 19 Oct 2020 14:58:10 +0200 Subject: [PATCH] Make black --- doc/conf.py | 43 ++-- examples/plot_offline_events.py | 4 +- examples/plot_offline_hits.py | 2 +- examples/plot_offline_tracks.py | 2 +- examples/plot_offline_usr.py | 2 +- km3io/definitions.py | 1 + km3io/gseagen.py | 8 +- km3io/offline.py | 166 +++++++------ km3io/online.py | 211 ++++++++++------ km3io/patches.py | 4 +- km3io/rootio.py | 121 +++++---- km3io/tools.py | 65 +++-- km3io/utils/kprinttree.py | 5 +- setup.py | 32 ++- tests/test_gseagen.py | 93 ++++--- tests/test_offline.py | 278 +++++++++++++-------- tests/test_online.py | 428 ++++++++++++++++++++++++++------ tests/test_tools.py | 57 +++-- 18 files changed, 1007 insertions(+), 515 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index ec4c757..cd79de6 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -21,11 +21,11 @@ from pkg_resources import get_distribution # -- Project information ----------------------------------------------------- -version = get_distribution('km3io').version -short_version = '.'.join(version.split('.')[:2]) -project = 'km3io {}'.format(short_version) -copyright = '{0}, Zineb Aly and Tamas Gal'.format(date.today().year) -author = 'Zineb Aly, Tamas Gal, Johannes Schumann' +version = get_distribution("km3io").version +short_version = ".".join(version.split(".")[:2]) +project = "km3io {}".format(short_version) +copyright = "{0}, Zineb Aly and Tamas Gal".format(date.today().year) +author = "Zineb Aly, Tamas Gal, Johannes Schumann" # -- General configuration --------------------------------------------------- @@ -33,32 +33,35 @@ author = 'Zineb Aly, Tamas Gal, Johannes Schumann' # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.viewcode', - 'autoapi.extension', 'numpydoc', 'sphinx_gallery.gen_gallery' + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.viewcode", + "autoapi.extension", + "numpydoc", + "sphinx_gallery.gen_gallery", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # AutoAPI -autoapi_type = 'python' -autoapi_dirs = ['../km3io'] -autoapi_options = ['members', 'undoc-members', 'show-module-summary'] +autoapi_type = "python" +autoapi_dirs = ["../km3io"] +autoapi_options = ["members", "undoc-members", "show-module-summary"] autoapi_include_summaries = True # Gallery sphinx_gallery_conf = { - 'backreferences_dir': 'modules/generated', - 'default_thumb_file': '_static/default_gallery_thumbnail.png', - 'examples_dirs': '../examples', # path to your example scripts - 'gallery_dirs': - 'auto_examples', # path to where to save gallery generated output - 'show_memory': True, + "backreferences_dir": "modules/generated", + "default_thumb_file": "_static/default_gallery_thumbnail.png", + "examples_dirs": "../examples", # path to your example scripts + "gallery_dirs": "auto_examples", # path to where to save gallery generated output + "show_memory": True, } # -- Options for HTML output ------------------------------------------------- @@ -66,11 +69,11 @@ sphinx_gallery_conf = { # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_title = "km3io {}".format(version) diff --git a/examples/plot_offline_events.py b/examples/plot_offline_events.py index 35ac85d..26d7c73 100644 --- a/examples/plot_offline_events.py +++ b/examples/plot_offline_events.py @@ -21,7 +21,7 @@ r = ki.OfflineReader(data_path("offline/numucc.root")) ##################################################### # Accessing the file header # ------------------------- -# Note that not all file headers are supported, so don't be surprised if +# Note that not all file headers are supported, so don't be surprised if # nothing is returned when the file header is called (this can happen if your file # was produced with old versions of aanet). @@ -119,5 +119,3 @@ print(r.events.n_mc_hits[mask]) # or: print(r.events.n_mc_tracks[mask]) - - diff --git a/examples/plot_offline_hits.py b/examples/plot_offline_hits.py index fab625d..05972cd 100644 --- a/examples/plot_offline_hits.py +++ b/examples/plot_offline_hits.py @@ -65,7 +65,7 @@ print(channel_ids) ##################################################### # Accessing the mc_hits data # -------------------------- -# similarly, you can access mc_hits data in any key of interest by +# similarly, you can access mc_hits data in any key of interest by # following the same procedure as for hits: mc_pmt_ids = mc_hits.pmt_id diff --git a/examples/plot_offline_tracks.py b/examples/plot_offline_tracks.py index 42cf746..8f0c4b0 100644 --- a/examples/plot_offline_tracks.py +++ b/examples/plot_offline_tracks.py @@ -64,7 +64,7 @@ print(likelihood) ##################################################### # Accessing the mc_tracks data # ---------------------------- -# similarly, you can access mc_tracks data in any key of interest by +# similarly, you can access mc_tracks data in any key of interest by # following the same procedure as for tracks: cos_zenith = mc_tracks.dir_z diff --git a/examples/plot_offline_usr.py b/examples/plot_offline_usr.py index c335f49..9d7959b 100644 --- a/examples/plot_offline_usr.py +++ b/examples/plot_offline_usr.py @@ -33,4 +33,4 @@ print(usr.DeltaPosZ) ##################################################### # or -print(usr['RecoQuality']) +print(usr["RecoQuality"]) diff --git a/km3io/definitions.py b/km3io/definitions.py index 358a668..9aaf4dc 100644 --- a/km3io/definitions.py +++ b/km3io/definitions.py @@ -11,6 +11,7 @@ from km3io._definitions.w2list_gseagen import data as w2list_gseagen class AttrDict(dict): """A dictionary which allows access to its key through attributes.""" + def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self diff --git a/km3io/gseagen.py b/km3io/gseagen.py index dddfaad..21772ab 100644 --- a/km3io/gseagen.py +++ b/km3io/gseagen.py @@ -8,13 +8,15 @@ import numpy as np import warnings from .rootio import Branch, BranchMapper from .tools import cached_property + MAIN_TREE_NAME = "Events" class GSGReader: """reader for gSeaGen ROOT files""" + def __init__(self, file_path=None, fobj=None): - """ GSGReader class is a gSeaGen ROOT file wrapper + """GSGReader class is a gSeaGen ROOT file wrapper Parameters ---------- @@ -26,14 +28,14 @@ class GSGReader: @cached_property def header(self): - header_key = 'Header' + header_key = "Header" if header_key in self._fobj: header = {} for k, v in self._fobj[header_key].items(): v = v.array()[0] if isinstance(v, bytes): try: - v = v.decode('utf-8') + v = v.decode("utf-8") except UnicodeDecodeError: pass header[k.decode("utf-8")] = v diff --git a/km3io/offline.py b/km3io/offline.py index 9f193d9..37241f6 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -13,64 +13,86 @@ MAIN_TREE_NAME = "E" EXCLUDE_KEYS = ["AAObject", "t", "fBits", "fUniqueID"] # 110 MB based on the size of the largest basket found so far in km3net -BASKET_CACHE_SIZE = 110 * 1024**2 +BASKET_CACHE_SIZE = 110 * 1024 ** 2 BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) def _nested_mapper(key): """Maps a key in the ROOT file to another key (e.g. trks.pos.x -> pos_x)""" - return '_'.join(key.split('.')[1:]) - - -EVENTS_MAP = BranchMapper(name="events", - key="Evt", - extra={ - 't_sec': 't.fSec', - 't_ns': 't.fNanoSec' - }, - exclude=EXCLUDE_KEYS, - update={ - 'n_hits': 'hits', - 'n_mc_hits': 'mc_hits', - 'n_tracks': 'trks', - 'n_mc_tracks': 'mc_trks' - }) + return "_".join(key.split(".")[1:]) + + +EVENTS_MAP = BranchMapper( + name="events", + key="Evt", + extra={"t_sec": "t.fSec", "t_ns": "t.fNanoSec"}, + exclude=EXCLUDE_KEYS, + update={ + "n_hits": "hits", + "n_mc_hits": "mc_hits", + "n_tracks": "trks", + "n_mc_tracks": "mc_trks", + }, +) SUBBRANCH_MAPS = [ - BranchMapper(name="tracks", - key="trks", - extra={}, - exclude=EXCLUDE_KEYS + - ['trks.usr_data', 'trks.usr', 'trks.fUniqueID', 'trks.fBits'], - attrparser=_nested_mapper, - flat=False, - toawkward=['fitinf', 'rec_stages']), - BranchMapper(name="mc_tracks", - key="mc_trks", - exclude=EXCLUDE_KEYS + [ - 'mc_trks.rec_stages', 'mc_trks.fitinf', - 'mc_trks.fUniqueID', 'mc_trks.fBits' - ], - attrparser=_nested_mapper, - toawkward=['usr', 'usr_names'], - flat=False), - BranchMapper(name="hits", - key="hits", - exclude=EXCLUDE_KEYS + [ - 'hits.usr', 'hits.pmt_id', 'hits.origin', 'hits.a', - 'hits.pure_a', 'hits.fUniqueID', 'hits.fBits' - ], - attrparser=_nested_mapper, - flat=False), - BranchMapper(name="mc_hits", - key="mc_hits", - exclude=EXCLUDE_KEYS + [ - 'mc_hits.usr', 'mc_hits.dom_id', 'mc_hits.channel_id', - 'mc_hits.tdc', 'mc_hits.tot', 'mc_hits.trig', - 'mc_hits.fUniqueID', 'mc_hits.fBits' - ], - attrparser=_nested_mapper, - flat=False), + BranchMapper( + name="tracks", + key="trks", + extra={}, + exclude=EXCLUDE_KEYS + + ["trks.usr_data", "trks.usr", "trks.fUniqueID", "trks.fBits"], + attrparser=_nested_mapper, + flat=False, + toawkward=["fitinf", "rec_stages"], + ), + BranchMapper( + name="mc_tracks", + key="mc_trks", + exclude=EXCLUDE_KEYS + + [ + "mc_trks.rec_stages", + "mc_trks.fitinf", + "mc_trks.fUniqueID", + "mc_trks.fBits", + ], + attrparser=_nested_mapper, + toawkward=["usr", "usr_names"], + flat=False, + ), + BranchMapper( + name="hits", + key="hits", + exclude=EXCLUDE_KEYS + + [ + "hits.usr", + "hits.pmt_id", + "hits.origin", + "hits.a", + "hits.pure_a", + "hits.fUniqueID", + "hits.fBits", + ], + attrparser=_nested_mapper, + flat=False, + ), + BranchMapper( + name="mc_hits", + key="mc_hits", + exclude=EXCLUDE_KEYS + + [ + "mc_hits.usr", + "mc_hits.dom_id", + "mc_hits.channel_id", + "mc_hits.tdc", + "mc_hits.tot", + "mc_hits.trig", + "mc_hits.fUniqueID", + "mc_hits.fBits", + ], + attrparser=_nested_mapper, + flat=False, + ), ] @@ -82,6 +104,7 @@ class OfflineBranch(Branch): class Usr: """Helper class to access AAObject `usr` stuff (only for events.usr)""" + def __init__(self, mapper, branch, index_chain=None): self._mapper = mapper self._name = mapper.name @@ -90,7 +113,7 @@ class Usr: self._usr_names = [] self._usr_idx_lookup = {} - self._usr_key = 'usr' if mapper.flat else mapper.key + '.usr' + self._usr_key = "usr" if mapper.flat else mapper.key + ".usr" self._initialise() @@ -101,17 +124,19 @@ class Usr: # which has a different strucuter and key (usr_data) # We do not support those (yet) except (KeyError, IndexError): - print("The `usr` fields could not be parsed for the '{}' branch.". - format(self._name)) + print( + "The `usr` fields could not be parsed for the '{}' branch.".format( + self._name + ) + ) return self._usr_names = [ n.decode("utf-8") - for n in self._branch[self._usr_key + '_names'].lazyarray()[0] + for n in self._branch[self._usr_key + "_names"].lazyarray()[0] ] self._usr_idx_lookup = { - name: index - for index, name in enumerate(self._usr_names) + name: index for index, name in enumerate(self._usr_names) } data = self._branch[self._usr_key].lazyarray() @@ -126,9 +151,9 @@ class Usr: def __getitem__(self, item): if self._index_chain: - return unfold_indices( - self._usr_data, self._index_chain)[:, - self._usr_idx_lookup[item]] + return unfold_indices(self._usr_data, self._index_chain)[ + :, self._usr_idx_lookup[item] + ] else: return self._usr_data[:, self._usr_idx_lookup[item]] @@ -139,7 +164,7 @@ class Usr: entries = [] for name in self.keys(): entries.append("{}: {}".format(name, self[name])) - return '\n'.join(entries) + return "\n".join(entries) def __repr__(self): return "<{}[{}]>".format(self.__class__.__name__, self._name) @@ -147,8 +172,9 @@ class Usr: class OfflineReader: """reader for offline ROOT files""" + def __init__(self, file_path=None): - """ OfflineReader class is an offline ROOT file wrapper + """OfflineReader class is an offline ROOT file wrapper Parameters ---------- @@ -178,17 +204,16 @@ class OfflineReader: @cached_property def events(self): """The `E` branch, containing all offline events.""" - return OfflineBranch(self._tree, - mapper=EVENTS_MAP, - subbranchmaps=SUBBRANCH_MAPS) + return OfflineBranch( + self._tree, mapper=EVENTS_MAP, subbranchmaps=SUBBRANCH_MAPS + ) @cached_property def header(self): """The file header""" - if 'Head' in self._fobj: + if "Head" in self._fobj: header = {} - for n, x in self._fobj['Head']._map_3c_string_2c_string_3e_.items( - ): + for n, x in self._fobj["Head"]._map_3c_string_2c_string_3e_.items(): header[n.decode("utf-8")] = x.decode("utf-8").strip() return Header(header) else: @@ -197,6 +222,7 @@ class OfflineReader: class Header: """The header""" + def __init__(self, header): self._data = {} @@ -221,8 +247,8 @@ class Header: continue self._data[attribute] = Constructor( - **{f: to_num(v) - for (f, v) in zip(fields, values)}) + **{f: to_num(v) for (f, v) in zip(fields, values)} + ) for attribute, value in self._data.items(): setattr(self, attribute, value) diff --git a/km3io/online.py b/km3io/online.py index ba1d9e5..3440a40 100644 --- a/km3io/online.py +++ b/km3io/online.py @@ -5,9 +5,9 @@ import numpy as np import numba as nb -TIMESLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte] -SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte] -BASKET_CACHE_SIZE = 110 * 1024**2 +TIMESLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024 ** 2 # [byte] +SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024 ** 2 # [byte] +BASKET_CACHE_SIZE = 110 * 1024 ** 2 BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) # Parameters for PMT rate conversions, since the rates in summary slices are @@ -21,13 +21,10 @@ RATE_FACTOR = np.log(MAXIMAL_RATE_HZ / MINIMAL_RATE_HZ) / 255 CHANNEL_BITS_TEMPLATE = np.zeros(31, dtype=bool) -@nb.vectorize([ - nb.int32(nb.int8), - nb.int32(nb.int16), - nb.int32(nb.int32), - nb.int32(nb.int64) -]) -def get_rate(value): #pragma: no cover +@nb.vectorize( + [nb.int32(nb.int8), nb.int32(nb.int16), nb.int32(nb.int32), nb.int32(nb.int64)] +) +def get_rate(value): # pragma: no cover """Return the rate in Hz from the short int value""" if value == 0: return 0 @@ -35,11 +32,10 @@ def get_rate(value): #pragma: no cover return MINIMAL_RATE_HZ * np.exp(value * RATE_FACTOR) -@nb.guvectorize("void(i8, b1[:], b1[:])", - "(), (n) -> (n)", - target="parallel", - nopython=True) -def unpack_bits(value, bits_template, out): #pragma: no cover +@nb.guvectorize( + "void(i8, b1[:], b1[:])", "(), (n) -> (n)", target="parallel", nopython=True +) +def unpack_bits(value, bits_template, out): # pragma: no cover """Return a boolean array for a value's bit representation. This function also accepts arrays as input, the output shape will be @@ -110,6 +106,7 @@ def has_udp_trailer(value): class OnlineReader: """Reader for online ROOT files""" + def __init__(self, filename): self._fobj = uproot.open(filename) self._filename = filename @@ -137,20 +134,41 @@ class OnlineReader: tree = self._fobj["KM3NET_EVENT"] headers = tree["KM3NETDAQ::JDAQEventHeader"].array( - uproot.interpret(tree["KM3NETDAQ::JDAQEventHeader"], - cntvers=True)) + uproot.interpret(tree["KM3NETDAQ::JDAQEventHeader"], cntvers=True) + ) snapshot_hits = tree["snapshotHits"].array( - uproot.asjagged(uproot.astable( - uproot.asdtype([("dom_id", ">i4"), ("channel_id", "u1"), - ("time", "<u4"), ("tot", "u1")])), - skipbytes=10)) + uproot.asjagged( + uproot.astable( + uproot.asdtype( + [ + ("dom_id", ">i4"), + ("channel_id", "u1"), + ("time", "<u4"), + ("tot", "u1"), + ] + ) + ), + skipbytes=10, + ) + ) triggered_hits = tree["triggeredHits"].array( - uproot.asjagged(uproot.astable( - uproot.asdtype([("dom_id", ">i4"), ("channel_id", "u1"), - ("time", "<u4"), ("tot", "u1"), - (" cnt", "u4"), (" vers", "u2"), - ("trigger_mask", ">u8")])), - skipbytes=10)) + uproot.asjagged( + uproot.astable( + uproot.asdtype( + [ + ("dom_id", ">i4"), + ("channel_id", "u1"), + ("time", "<u4"), + ("tot", "u1"), + (" cnt", "u4"), + (" vers", "u2"), + ("trigger_mask", ">u8"), + ] + ) + ), + skipbytes=10, + ) + ) self._events = OnlineEvents(headers, snapshot_hits, triggered_hits) return self._events @@ -169,6 +187,7 @@ class OnlineReader: class SummarySlices: """A wrapper for summary slices""" + def __init__(self, fobj): self._fobj = fobj self._slices = None @@ -196,23 +215,35 @@ class SummarySlices: def _read_summaryslices(self): """Reads a lazyarray of summary slices""" - tree = self._fobj[b'KM3NET_SUMMARYSLICE'][b'KM3NET_SUMMARYSLICE'] - return tree[b'vector<KM3NETDAQ::JDAQSummaryFrame>'].lazyarray( - uproot.asjagged(uproot.astable( - uproot.asdtype([("dom_id", "i4"), ("dq_status", "u4"), - ("hrv", "u4"), ("fifo", "u4"), - ("status3", "u4"), ("status4", "u4")] + - [(c, "u1") for c in self._ch_selector])), - skipbytes=10), + tree = self._fobj[b"KM3NET_SUMMARYSLICE"][b"KM3NET_SUMMARYSLICE"] + return tree[b"vector<KM3NETDAQ::JDAQSummaryFrame>"].lazyarray( + uproot.asjagged( + uproot.astable( + uproot.asdtype( + [ + ("dom_id", "i4"), + ("dq_status", "u4"), + ("hrv", "u4"), + ("fifo", "u4"), + ("status3", "u4"), + ("status4", "u4"), + ] + + [(c, "u1") for c in self._ch_selector] + ) + ), + skipbytes=10, + ), basketcache=uproot.cache.ThreadSafeArrayCache( - SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE)) + SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE + ), + ) def _read_headers(self): """Reads a lazyarray of summary slice headers""" - tree = self._fobj[b'KM3NET_SUMMARYSLICE'][b'KM3NET_SUMMARYSLICE'] - return tree[b'KM3NETDAQ::JDAQSummarysliceHeader'].lazyarray( - uproot.interpret(tree[b'KM3NETDAQ::JDAQSummarysliceHeader'], - cntvers=True)) + tree = self._fobj[b"KM3NET_SUMMARYSLICE"][b"KM3NET_SUMMARYSLICE"] + return tree[b"KM3NETDAQ::JDAQSummarysliceHeader"].lazyarray( + uproot.interpret(tree[b"KM3NETDAQ::JDAQSummarysliceHeader"], cntvers=True) + ) def __str__(self): return "Number of summaryslices: {}".format(len(self.headers)) @@ -220,6 +251,7 @@ class SummarySlices: class Timeslices: """A simple wrapper for timeslices""" + def __init__(self, fobj): self._fobj = fobj self._timeslices = {} @@ -228,36 +260,50 @@ class Timeslices: def _read_streams(self): """Read the L0, L1, L2 and SN streams if available""" streams = set( - s.split(b"KM3NET_TIMESLICE_")[1].split(b';')[0] - for s in self._fobj.keys() if b"KM3NET_TIMESLICE_" in s) + s.split(b"KM3NET_TIMESLICE_")[1].split(b";")[0] + for s in self._fobj.keys() + if b"KM3NET_TIMESLICE_" in s + ) for stream in streams: - tree = self._fobj[b'KM3NET_TIMESLICE_' + - stream][b'KM3NETDAQ::JDAQTimeslice'] - headers = tree[b'KM3NETDAQ::JDAQTimesliceHeader'][ - b'KM3NETDAQ::JDAQHeader'][b'KM3NETDAQ::JDAQChronometer'] + tree = self._fobj[b"KM3NET_TIMESLICE_" + stream][ + b"KM3NETDAQ::JDAQTimeslice" + ] + headers = tree[b"KM3NETDAQ::JDAQTimesliceHeader"][b"KM3NETDAQ::JDAQHeader"][ + b"KM3NETDAQ::JDAQChronometer" + ] if len(headers) == 0: continue - superframes = tree[b'vector<KM3NETDAQ::JDAQSuperFrame>'] - hits_dtype = np.dtype([("pmt", "u1"), ("tdc", "<u4"), - ("tot", "u1")]) + superframes = tree[b"vector<KM3NETDAQ::JDAQSuperFrame>"] + hits_dtype = np.dtype([("pmt", "u1"), ("tdc", "<u4"), ("tot", "u1")]) hits_buffer = superframes[ - b'vector<KM3NETDAQ::JDAQSuperFrame>.buffer'].lazyarray( - uproot.asjagged(uproot.astable(uproot.asdtype(hits_dtype)), - skipbytes=6), - basketcache=uproot.cache.ThreadSafeArrayCache( - TIMESLICE_FRAME_BASKET_CACHE_SIZE)) - self._timeslices[stream.decode("ascii")] = (headers, superframes, - hits_buffer) - setattr(self, stream.decode("ascii"), - TimesliceStream(headers, superframes, hits_buffer)) + b"vector<KM3NETDAQ::JDAQSuperFrame>.buffer" + ].lazyarray( + uproot.asjagged( + uproot.astable(uproot.asdtype(hits_dtype)), skipbytes=6 + ), + basketcache=uproot.cache.ThreadSafeArrayCache( + TIMESLICE_FRAME_BASKET_CACHE_SIZE + ), + ) + self._timeslices[stream.decode("ascii")] = ( + headers, + superframes, + hits_buffer, + ) + setattr( + self, + stream.decode("ascii"), + TimesliceStream(headers, superframes, hits_buffer), + ) def stream(self, stream, idx): ts = self._timeslices[stream] return Timeslice(*ts, idx, stream) def __str__(self): - return "Available timeslice streams: {}".format(', '.join( - s for s in self._timeslices.keys())) + return "Available timeslice streams: {}".format( + ", ".join(s for s in self._timeslices.keys()) + ) def __repr__(self): return str(self) @@ -291,6 +337,7 @@ class TimesliceStream: class Timeslice: """A wrapper for a timeslice""" + def __init__(self, header, superframe, hits_buffer, idx, stream): self.header = header self._frames = {} @@ -310,49 +357,59 @@ class Timeslice: """Populate a dictionary of frames with the module ID as key""" hits_buffer = self._hits_buffer[self._idx] n_hits = self._superframe[ - b'vector<KM3NETDAQ::JDAQSuperFrame>.numberOfHits'].lazyarray( - basketcache=BASKET_CACHE)[self._idx] + b"vector<KM3NETDAQ::JDAQSuperFrame>.numberOfHits" + ].lazyarray(basketcache=BASKET_CACHE)[self._idx] try: module_ids = self._superframe[ - b'vector<KM3NETDAQ::JDAQSuperFrame>.id'].lazyarray( - basketcache=BASKET_CACHE)[self._idx] + b"vector<KM3NETDAQ::JDAQSuperFrame>.id" + ].lazyarray(basketcache=BASKET_CACHE)[self._idx] except KeyError: - module_ids = self._superframe[ - b'vector<KM3NETDAQ::JDAQSuperFrame>.KM3NETDAQ::JDAQModuleIdentifier'].lazyarray( + module_ids = ( + self._superframe[ + b"vector<KM3NETDAQ::JDAQSuperFrame>.KM3NETDAQ::JDAQModuleIdentifier" + ] + .lazyarray( uproot.asjagged( - uproot.astable(uproot.asdtype([("dom_id", ">i4")]))), - basketcache=BASKET_CACHE)[self._idx].dom_id + uproot.astable(uproot.asdtype([("dom_id", ">i4")])) + ), + basketcache=BASKET_CACHE, + )[self._idx] + .dom_id + ) idx = 0 for module_id, n_hits in zip(module_ids, n_hits): - self._frames[module_id] = hits_buffer[idx:idx + n_hits] + self._frames[module_id] = hits_buffer[idx : idx + n_hits] idx += n_hits def __len__(self): if self._n_frames is None: self._n_frames = len( - self._superframe[b'vector<KM3NETDAQ::JDAQSuperFrame>.id']. - lazyarray(basketcache=BASKET_CACHE)[self._idx]) + self._superframe[b"vector<KM3NETDAQ::JDAQSuperFrame>.id"].lazyarray( + basketcache=BASKET_CACHE + )[self._idx] + ) return self._n_frames def __str__(self): return "{} timeslice with {} frames.".format(self._stream, len(self)) def __repr__(self): - return "<{}: {} entries>".format(self.__class__.__name__, - len(self.header)) + return "<{}: {} entries>".format(self.__class__.__name__, len(self.header)) class OnlineEvents: """A simple wrapper for online events""" + def __init__(self, headers, snapshot_hits, triggered_hits): self.headers = headers self.snapshot_hits = snapshot_hits self.triggered_hits = triggered_hits def __getitem__(self, item): - return OnlineEvent(self.headers[item], self.snapshot_hits[item], - self.triggered_hits[item]) + return OnlineEvent( + self.headers[item], self.snapshot_hits[item], self.triggered_hits[item] + ) def __len__(self): return len(self.headers) @@ -366,6 +423,7 @@ class OnlineEvents: class OnlineEvent: """A wrapper for a online event""" + def __init__(self, header, snapshot_hits, triggered_hits): self.header = header self.snapshot_hits = snapshot_hits @@ -373,7 +431,8 @@ class OnlineEvent: def __str__(self): return "Online event with {} snapshot hits and {} triggered hits".format( - len(self.snapshot_hits), len(self.triggered_hits)) + len(self.snapshot_hits), len(self.triggered_hits) + ) def __repr__(self): return str(self) diff --git a/km3io/patches.py b/km3io/patches.py index 3ee5b34..7df3124 100644 --- a/km3io/patches.py +++ b/km3io/patches.py @@ -7,11 +7,11 @@ old_getitem = ak.ChunkedArray.__getitem__ def new_getitem(self, item): """Monkey patch the getitem in awkward.ChunkedArray to apply - awkward1.Array masks on awkward.ChunkedArray""" + awkward1.Array masks on awkward.ChunkedArray""" if isinstance(item, (ak1.Array, ak.ChunkedArray)): return ak1.Array(self)[item] else: return old_getitem(self, item) -ak.ChunkedArray.__getitem__ = new_getitem \ No newline at end of file +ak.ChunkedArray.__getitem__ = new_getitem diff --git a/km3io/rootio.py b/km3io/rootio.py index f12c39e..6e30552 100644 --- a/km3io/rootio.py +++ b/km3io/rootio.py @@ -6,7 +6,7 @@ import uproot from .tools import unfold_indices # 110 MB based on the size of the largest basket found so far in km3net -BASKET_CACHE_SIZE = 110 * 1024**2 +BASKET_CACHE_SIZE = 110 * 1024 ** 2 BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) @@ -38,16 +38,19 @@ class BranchMapper: List of keys to convert to awkward arrays (recommended for doubly ragged arrays) """ - def __init__(self, - name, - key, - extra=None, - exclude=None, - update=None, - attrparser=None, - flat=True, - interpretations=None, - toawkward=None): + + def __init__( + self, + name, + key, + extra=None, + exclude=None, + update=None, + attrparser=None, + flat=True, + interpretations=None, + toawkward=None, + ): self.name = name self.key = key @@ -62,13 +65,16 @@ class BranchMapper: class Branch: """Branch accessor class""" - def __init__(self, - tree, - mapper, - index_chain=None, - subbranchmaps=None, - keymap=None, - awkward_cache=None): + + def __init__( + self, + tree, + mapper, + index_chain=None, + subbranchmaps=None, + keymap=None, + awkward_cache=None, + ): self._tree = tree self._mapper = mapper self._index_chain = [] if index_chain is None else index_chain @@ -89,10 +95,12 @@ class Branch: if subbranchmaps is not None: for mapper in subbranchmaps: - subbranch = self.__class__(self._tree, - mapper=mapper, - index_chain=self._index_chain, - awkward_cache=self._awkward_cache) + subbranch = self.__class__( + self._tree, + mapper=mapper, + index_chain=self._index_chain, + awkward_cache=self._awkward_cache, + ) self._subbranches.append(subbranch) for subbranch in self._subbranches: setattr(self, subbranch._mapper.name, subbranch) @@ -100,12 +108,12 @@ class Branch: def _initialise_keys(self): """Create the keymap and instance attributes for branch keys""" # TODO: this could be a cached property - keys = set(k.decode('utf-8') - for k in self._branch.keys()) - set(self._mapper.exclude) + keys = set(k.decode("utf-8") for k in self._branch.keys()) - set( + self._mapper.exclude + ) self._keymap = { - **{self._mapper.attrparser(k): k - for k in keys}, - **self._mapper.extra + **{self._mapper.attrparser(k): k for k in keys}, + **self._mapper.extra, } self._keymap.update(self._mapper.update) for k in self._mapper.update.values(): @@ -129,23 +137,28 @@ class Branch: def __getkey__(self, key): interpretation = self._mapper.interpretations.get(key) - if key == 'usr_names': + if key == "usr_names": # TODO this will be fixed soon in uproot, # see https://github.com/scikit-hep/uproot/issues/465 interpretation = uproot.asgenobj( uproot.SimpleArray(uproot.STLVector(uproot.STLString())), - self._branch[self._keymap[key]]._context, 6) + self._branch[self._keymap[key]]._context, + 6, + ) - if key == 'usr': + if key == "usr": # triple jagged array is wrongly parsed in uproot interpretation = uproot.asgenobj( - uproot.SimpleArray(uproot.STLVector(uproot.asdtype('>f8'))), - self._branch[self._keymap[key]]._context, 6) + uproot.SimpleArray(uproot.STLVector(uproot.asdtype(">f8"))), + self._branch[self._keymap[key]]._context, + 6, + ) out = self._branch[self._keymap[key]].lazyarray( - interpretation=interpretation, basketcache=BASKET_CACHE) + interpretation=interpretation, basketcache=BASKET_CACHE + ) if self._index_chain is not None and key in self._mapper.toawkward: - cache_key = self._mapper.name + '/' + key + cache_key = self._mapper.name + "/" + key if cache_key not in self._awkward_cache: if len(out) > 20000: # It will take more than 10 seconds print("Creating cache for '{}'.".format(cache_key)) @@ -164,12 +177,14 @@ class Branch: # if item.__class__.__name__ == "ChunkedArray": # item = np.array(item) - return self.__class__(self._tree, - self._mapper, - index_chain=self._index_chain + [item], - keymap=self._keymap, - subbranchmaps=self._subbranchmaps, - awkward_cache=self._awkward_cache) + return self.__class__( + self._tree, + self._mapper, + index_chain=self._index_chain + [item], + keymap=self._keymap, + subbranchmaps=self._subbranchmaps, + awkward_cache=self._awkward_cache, + ) def __len__(self): if not self._index_chain: @@ -184,8 +199,12 @@ class Branch: else: return len( unfold_indices( - self._branch[self._keymap['id']].lazyarray( - basketcache=BASKET_CACHE), self._index_chain)) + self._branch[self._keymap["id"]].lazyarray( + basketcache=BASKET_CACHE + ), + self._index_chain, + ) + ) @property def is_single(self): @@ -208,12 +227,18 @@ class Branch: def __str__(self): length = len(self) - return "{} ({}) with {} element{}".format(self.__class__.__name__, - self._mapper.name, length, - 's' if length > 1 else '') + return "{} ({}) with {} element{}".format( + self.__class__.__name__, + self._mapper.name, + length, + "s" if length > 1 else "", + ) def __repr__(self): length = len(self) - return "<{}[{}]: {} element{}>".format(self.__class__.__name__, - self._mapper.name, length, - 's' if length > 1 else '') + return "<{}[{}]: {} element{}>".format( + self.__class__.__name__, + self._mapper.name, + length, + "s" if length > 1 else "", + ) diff --git a/km3io/tools.py b/km3io/tools.py index 7c50752..a422264 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -11,12 +11,13 @@ from km3io.definitions import w2list_genhen as kw2gen from km3io.definitions import w2list_gseagen as kw2gsg # 110 MB based on the size of the largest basket found so far in km3net -BASKET_CACHE_SIZE = 110 * 1024**2 +BASKET_CACHE_SIZE = 110 * 1024 ** 2 BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) class cached_property: """A simple cache decorator for properties.""" + def __init__(self, function): self.function = function @@ -36,8 +37,10 @@ def unfold_indices(obj, indices): except IndexError: raise IndexError( "IndexError while accessing an item from '{}' at depth {} ({}) " - "using the index chain {}".format(repr(original_obj), depth, - idx, indices)) + "using the index chain {}".format( + repr(original_obj), depth, idx, indices + ) + ) return obj @@ -76,7 +79,7 @@ def unique(array, dtype=np.int64): entry_idx += 1 out[entry_idx] = current last = current - return out[:entry_idx + 1] + return out[: entry_idx + 1] @nb.jit(nopython=True) @@ -103,7 +106,7 @@ def get_w2list_param(events, generator, param): events class in offline neutrino files. generator : str the name of the software generating neutrinos, it is either - 'genhen' or 'gseagen'. + 'genhen' or 'gseagen'. param : str the name of the parameters found in w2list as defined in the KM3NeT-Dataformat for both genhen and gseagen. @@ -185,7 +188,7 @@ def get_multiplicity(tracks, rec_stages): Parameters ---------- tracks : km3io.offline.OfflineBranch - tracks or a subste of tracks. + tracks or a subste of tracks. rec_stages : list Reconstruction stages (the ordering is respected) e.g. [1, 2, 3, 4, 5]. @@ -234,12 +237,10 @@ def best_track(tracks, startend=None, minmax=None, stages=None): inputs = (stages, startend, minmax) if all(v is None for v in inputs): - raise ValueError( - "either stages, startend or minmax must be specified.") + raise ValueError("either stages, startend or minmax must be specified.") if stages is not None and (startend is not None or minmax is not None): - raise ValueError( - "Please specify either a range or a set of rec stages.") + raise ValueError("Please specify either a range or a set of rec stages.") if stages is not None and startend is None and minmax is None: selected_tracks = tracks[mask(tracks, stages=stages)] @@ -264,8 +265,7 @@ def _longest_tracks(tracks): tracks_nesting_level = 1 len_stages = count_nested(tracks.rec_stages, axis=stages_nesting_level) - longest = tracks[len_stages == ak1.max(len_stages, - axis=tracks_nesting_level)] + longest = tracks[len_stages == ak1.max(len_stages, axis=tracks_nesting_level)] return longest @@ -311,12 +311,10 @@ def mask(tracks, stages=None, startend=None, minmax=None): inputs = (stages, startend, minmax) if all(v is None for v in inputs): - raise ValueError( - "either stages, startend or minmax must be specified.") + raise ValueError("either stages, startend or minmax must be specified.") if stages is not None and (startend is not None or minmax is not None): - raise ValueError( - "Please specify either a range or a set of rec stages.") + raise ValueError("Please specify either a range or a set of rec stages.") if stages is not None and startend is None and minmax is None: if isinstance(stages, list): @@ -335,7 +333,7 @@ def mask(tracks, stages=None, startend=None, minmax=None): def _mask_rec_stages_between_start_end(tracks, start, end): """Mask tracks.rec_stages that start exactly with start and end exactly - with end. ie [start, a, b ...,z , end] """ + with end. ie [start, a, b ...,z , end]""" builder = ak1.ArrayBuilder() if tracks.is_single: _find_between_single(tracks.rec_stages, start, end, builder) @@ -366,7 +364,7 @@ def _find_between(rec_stages, start, end, builder): @nb.jit(nopython=True) def _find_between_single(rec_stages, start, end, builder): """Find tracks.rec_stages where rec_stages[0] == start and - rec_stages[-1] == end in a single track. """ + rec_stages[-1] == end in a single track.""" builder.begin_list() for s in rec_stages: @@ -486,9 +484,9 @@ def best_jmuon(tracks): km3io.offline.OfflineBranch the longest + highest likelihood track reconstructed with JMUON. """ - mask = _mask_rec_stages_in_range_min_max(tracks, - min_stage=krec.JMUONBEGIN, - max_stage=krec.JMUONEND) + mask = _mask_rec_stages_in_range_min_max( + tracks, min_stage=krec.JMUONBEGIN, max_stage=krec.JMUONEND + ) return _max_lik_track(_longest_tracks(tracks[mask])) @@ -506,9 +504,9 @@ def best_jshower(tracks): km3io.offline.OfflineBranch the longest + highest likelihood track reconstructed with JSHOWER. """ - mask = _mask_rec_stages_in_range_min_max(tracks, - min_stage=krec.JSHOWERBEGIN, - max_stage=krec.JSHOWEREND) + mask = _mask_rec_stages_in_range_min_max( + tracks, min_stage=krec.JSHOWERBEGIN, max_stage=krec.JSHOWEREND + ) return _max_lik_track(_longest_tracks(tracks[mask])) @@ -526,9 +524,9 @@ def best_aashower(tracks): km3io.offline.OfflineBranch the longest + highest likelihood track reconstructed with AASHOWER. """ - mask = _mask_rec_stages_in_range_min_max(tracks, - min_stage=krec.AASHOWERBEGIN, - max_stage=krec.AASHOWEREND) + mask = _mask_rec_stages_in_range_min_max( + tracks, min_stage=krec.AASHOWERBEGIN, max_stage=krec.AASHOWEREND + ) return _max_lik_track(_longest_tracks(tracks[mask])) @@ -546,16 +544,14 @@ def best_dusjshower(tracks): km3io.offline.OfflineBranch the longest + highest likelihood track reconstructed with DUSJSHOWER. """ - mask = _mask_rec_stages_in_range_min_max(tracks, - min_stage=krec.DUSJSHOWERBEGIN, - max_stage=krec.DUSJSHOWEREND) + mask = _mask_rec_stages_in_range_min_max( + tracks, min_stage=krec.DUSJSHOWERBEGIN, max_stage=krec.DUSJSHOWEREND + ) return _max_lik_track(_longest_tracks(tracks[mask])) -def _mask_rec_stages_in_range_min_max(tracks, - min_stage=None, - max_stage=None): +def _mask_rec_stages_in_range_min_max(tracks, min_stage=None, max_stage=None): """Mask tracks where rec_stages are withing the range(min, max). Parameters @@ -578,8 +574,7 @@ def _mask_rec_stages_in_range_min_max(tracks, builder = ak1.ArrayBuilder() if tracks.is_single: - _find_in_range_single(tracks.rec_stages, min_stage, max_stage, - builder) + _find_in_range_single(tracks.rec_stages, min_stage, max_stage, builder) return (builder.snapshot() == 1)[0] else: _find_in_range(tracks.rec_stages, min_stage, max_stage, builder) diff --git a/km3io/utils/kprinttree.py b/km3io/utils/kprinttree.py index bed9200..7b5af9d 100644 --- a/km3io/utils/kprinttree.py +++ b/km3io/utils/kprinttree.py @@ -30,10 +30,11 @@ def print_tree(filename): def main(): from docopt import docopt + args = docopt(__doc__) - print_tree(args['-f']) + print_tree(args["-f"]) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/setup.py b/setup.py index 5b6c464..28592ad 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ The km3io setup script. from setuptools import setup import sys -with open('requirements.txt') as fobj: +with open("requirements.txt") as fobj: requirements = [l.strip() for l in fobj.readlines()] if sys.version_info[:2] == (3, 5): requirements.append("llvmlite==0.31.0") @@ -19,27 +19,25 @@ except UnicodeDecodeError: long_description = "km3io, a library to read KM3NeT files without ROOT" setup( - name='km3io', - url='http://git.km3net.de/km3py/km3io', - description='KM3NeT I/O without ROOT', + name="km3io", + url="http://git.km3net.de/km3py/km3io", + description="KM3NeT I/O without ROOT", long_description=long_description, - author='Zineb Aly, Tamas Gal, Johannes Schumann', - author_email='zaly@km3net.de, tgal@km3net.de, johannes.schumann@fau.de', - packages=['km3io'], + author="Zineb Aly, Tamas Gal, Johannes Schumann", + author_email="zaly@km3net.de, tgal@km3net.de, johannes.schumann@fau.de", + packages=["km3io"], include_package_data=True, - platforms='any', - setup_requires=['setuptools_scm'], + platforms="any", + setup_requires=["setuptools_scm"], use_scm_version=True, install_requires=requirements, - python_requires='>=3.5', - entry_points={ - 'console_scripts': ['KPrintTree=km3io.utils.kprinttree:main'] - }, + python_requires=">=3.5", + entry_points={"console_scripts": ["KPrintTree=km3io.utils.kprinttree:main"]}, classifiers=[ - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'Programming Language :: Python', + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Programming Language :: Python", ], ) -__author__ = 'Zineb Aly, Tamas Gal and Johannes Schumann' +__author__ = "Zineb Aly, Tamas Gal and Johannes Schumann" diff --git a/tests/test_gseagen.py b/tests/test_gseagen.py index 842d006..4b55e89 100644 --- a/tests/test_gseagen.py +++ b/tests/test_gseagen.py @@ -14,21 +14,21 @@ class TestGSGHeader(unittest.TestCase): self.header = GSG_READER.header def test_str_byte_type(self): - assert isinstance(self.header['gSeaGenVer'], str) - assert isinstance(self.header['GenieVer'], str) - assert isinstance(self.header['gSeaGenVer'], str) - assert isinstance(self.header['InpXSecFile'], str) - assert isinstance(self.header['Flux1'], str) - assert isinstance(self.header['Flux2'], str) + assert isinstance(self.header["gSeaGenVer"], str) + assert isinstance(self.header["GenieVer"], str) + assert isinstance(self.header["gSeaGenVer"], str) + assert isinstance(self.header["InpXSecFile"], str) + assert isinstance(self.header["Flux1"], str) + assert isinstance(self.header["Flux2"], str) def test_values(self): assert self.header["RunNu"] == 1 assert self.header["RanSeed"] == 3662074 - self.assertAlmostEqual(self.header["NTot"], 1000.) - self.assertAlmostEqual(self.header["EvMin"], 5.) - self.assertAlmostEqual(self.header["EvMax"], 50.) - self.assertAlmostEqual(self.header["CtMin"], -1.) - self.assertAlmostEqual(self.header["CtMax"], 1.) + self.assertAlmostEqual(self.header["NTot"], 1000.0) + self.assertAlmostEqual(self.header["EvMin"], 5.0) + self.assertAlmostEqual(self.header["EvMax"], 50.0) + self.assertAlmostEqual(self.header["CtMin"], -1.0) + self.assertAlmostEqual(self.header["CtMax"], 1.0) self.assertAlmostEqual(self.header["Alpha"], 1.4) assert self.header["NBin"] == 1 self.assertAlmostEqual(self.header["Can1"], 0.0) @@ -46,11 +46,11 @@ class TestGSGHeader(unittest.TestCase): self.assertAlmostEqual(self.header["SiteDepth"], 2425.0) self.assertAlmostEqual(self.header["SiteLatitude"], 0.747) self.assertAlmostEqual(self.header["SiteLongitude"], 0.10763) - self.assertAlmostEqual(self.header["SeaBottomRadius"], 6368000.) - assert round(self.header["GlobalGenWeight"] - 6.26910765e+08, 0) == 0 + self.assertAlmostEqual(self.header["SeaBottomRadius"], 6368000.0) + assert round(self.header["GlobalGenWeight"] - 6.26910765e08, 0) == 0 self.assertAlmostEqual(self.header["RhoSW"], 1.03975) self.assertAlmostEqual(self.header["RhoSR"], 2.65) - self.assertAlmostEqual(self.header["TGen"], 31556926.) + self.assertAlmostEqual(self.header["TGen"], 31556926.0) assert not self.header["PropMode"] assert self.header["NNu"] == 2 self.assertListEqual(self.header["NuList"].tolist(), [-14, 14]) @@ -109,46 +109,63 @@ class TestGSGEvents(unittest.TestCase): self.assertListEqual(event.Id_tr.tolist(), [4, 5, 10, 11, 12]) self.assertListEqual(event.Pdg_tr.tolist(), [22, -13, 2112, -211, 111]) [ - self.assertAlmostEqual(x, y) for x, y in zip( - event.E_tr, - [0.00618, 4.88912206, 2.33667201, 1.0022909, 1.17186997]) + self.assertAlmostEqual(x, y) + for x, y in zip( + event.E_tr, [0.00618, 4.88912206, 2.33667201, 1.0022909, 1.17186997] + ) ] [ self.assertAlmostEqual(x, y) - for x, y in zip(event.Vx_tr, [ - -337.67895799, -337.67895799, -337.67895799, -337.67895799, - -337.67895799 - ]) + for x, y in zip( + event.Vx_tr, + [ + -337.67895799, + -337.67895799, + -337.67895799, + -337.67895799, + -337.67895799, + ], + ) ] [ self.assertAlmostEqual(x, y) - for x, y in zip(event.Vy_tr, [ - -203.90999969, -203.90999969, -203.90999969, -203.90999969, - -203.90999969 - ]) + for x, y in zip( + event.Vy_tr, + [ + -203.90999969, + -203.90999969, + -203.90999969, + -203.90999969, + -203.90999969, + ], + ) ] [ self.assertAlmostEqual(x, y) - for x, y in zip(event.Vz_tr, [ - 416.08845294, 416.08845294, 416.08845294, 416.08845294, - 416.08845294 - ]) + for x, y in zip( + event.Vz_tr, + [416.08845294, 416.08845294, 416.08845294, 416.08845294, 416.08845294], + ) ] [ self.assertAlmostEqual(x, y) - for x, y in zip(event.Dx_tr, [ - 0.06766196, -0.63563065, -0.70627586, -0.76364544, -0.80562216 - ]) + for x, y in zip( + event.Dx_tr, + [0.06766196, -0.63563065, -0.70627586, -0.76364544, -0.80562216], + ) ] [ - self.assertAlmostEqual(x, y) for x, y in zip( + self.assertAlmostEqual(x, y) + for x, y in zip( event.Dy_tr, - [0.33938809, -0.4846643, 0.50569058, -0.04136113, 0.10913917]) + [0.33938809, -0.4846643, 0.50569058, -0.04136113, 0.10913917], + ) ] [ self.assertAlmostEqual(x, y) - for x, y in zip(event.Dz_tr, [ - -0.93820978, -0.6008945, -0.49543056, -0.64430963, -0.58228994 - ]) + for x, y in zip( + event.Dz_tr, + [-0.93820978, -0.6008945, -0.49543056, -0.64430963, -0.58228994], + ) ] - [self.assertAlmostEqual(x, y) for x, y in zip(event.T_tr, 5 * [0.])] + [self.assertAlmostEqual(x, y) for x, y in zip(event.T_tr, 5 * [0.0])] diff --git a/tests/test_offline.py b/tests/test_offline.py index ec152f5..b99cb8b 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -11,10 +11,10 @@ OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root")) OFFLINE_MC_TRACK_USR = OfflineReader( data_path( - 'offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root' - )) -OFFLINE_NUMUCC = OfflineReader( - data_path("offline/numucc.root")) # with mc data + "offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root" + ) +) +OFFLINE_NUMUCC = OfflineReader(data_path("offline/numucc.root")) # with mc data class TestOfflineReader(unittest.TestCase): @@ -45,7 +45,7 @@ class TestHeader(unittest.TestCase): OFFLINE_FILE.header def test_missing_key_definitions(self): - head = {'a': '1 2 3', 'b': '4', 'c': 'd'} + head = {"a": "1 2 3", "b": "4", "c": "d"} header = Header(head) @@ -53,10 +53,10 @@ class TestHeader(unittest.TestCase): assert 2 == header.a.field_1 assert 3 == header.a.field_2 assert 4 == header.b - assert 'd' == header.c + assert "d" == header.c def test_missing_values(self): - head = {'can': '1'} + head = {"can": "1"} header = Header(head) @@ -65,7 +65,7 @@ class TestHeader(unittest.TestCase): assert header.can.r is None def test_additional_values_compared_to_definition(self): - head = {'can': '1 2 3 4'} + head = {"can": "1 2 3 4"} header = Header(head) @@ -76,10 +76,10 @@ class TestHeader(unittest.TestCase): def test_header(self): head = { - 'DAQ': '394', - 'PDF': '4', - 'can': '0 1027 888.4', - 'undefined': '1 2 test 3.4' + "DAQ": "394", + "PDF": "4", + "can": "0 1027 888.4", + "undefined": "1 2 test 3.4", } header = Header(head) @@ -120,12 +120,28 @@ class TestOfflineEvents(unittest.TestCase): self.n_hits = [176, 125, 318, 157, 83, 60, 71, 84, 255, 105] self.n_tracks = [56, 55, 56, 56, 56, 56, 56, 56, 54, 56] self.t_sec = [ - 1567036818, 1567036818, 1567036820, 1567036816, 1567036816, - 1567036816, 1567036822, 1567036818, 1567036818, 1567036820 + 1567036818, + 1567036818, + 1567036820, + 1567036816, + 1567036816, + 1567036816, + 1567036822, + 1567036818, + 1567036818, + 1567036820, ] self.t_ns = [ - 200000000, 300000000, 200000000, 500000000, 500000000, 500000000, - 200000000, 500000000, 500000000, 400000000 + 200000000, + 300000000, + 200000000, + 500000000, + 500000000, + 500000000, + 200000000, + 500000000, + 500000000, + 400000000, ] def test_len(self): @@ -144,10 +160,10 @@ class TestOfflineEvents(unittest.TestCase): self.assertListEqual(self.t_ns, list(self.events.t_ns)) def test_keys(self): - assert np.allclose(self.n_hits, self.events['n_hits']) - assert np.allclose(self.n_tracks, self.events['n_tracks']) - assert np.allclose(self.t_sec, self.events['t_sec']) - assert np.allclose(self.t_ns, self.events['t_ns']) + assert np.allclose(self.n_hits, self.events["n_hits"]) + assert np.allclose(self.n_tracks, self.events["n_tracks"]) + assert np.allclose(self.t_sec, self.events["t_sec"]) + assert np.allclose(self.t_ns, self.events["t_ns"]) def test_slicing(self): s = slice(2, 8, 2) @@ -168,12 +184,13 @@ class TestOfflineEvents(unittest.TestCase): def test_index_chaining(self): assert np.allclose(self.events[3:5].n_hits, self.events.n_hits[3:5]) - assert np.allclose(self.events[3:5][0].n_hits, - self.events.n_hits[3:5][0]) - assert np.allclose(self.events[3:5].hits[1].dom_id[4], - self.events.hits[3:5][1][4].dom_id) - assert np.allclose(self.events.hits[3:5][1][4].dom_id, - self.events[3:5][1][4].hits.dom_id) + assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]) + assert np.allclose( + self.events[3:5].hits[1].dom_id[4], self.events.hits[3:5][1][4].dom_id + ) + assert np.allclose( + self.events.hits[3:5][1][4].dom_id, self.events[3:5][1][4].hits.dom_id + ) def test_fancy_indexing(self): mask = self.events.n_tracks > 55 @@ -206,23 +223,55 @@ class TestOfflineHits(unittest.TestCase): self.n_hits = 10 self.dom_id = { 0: [ - 806451572, 806451572, 806451572, 806451572, 806455814, - 806455814, 806455814, 806483369, 806483369, 806483369 + 806451572, + 806451572, + 806451572, + 806451572, + 806455814, + 806455814, + 806455814, + 806483369, + 806483369, + 806483369, ], 5: [ - 806455814, 806487219, 806487219, 806487219, 806487226, - 808432835, 808432835, 808432835, 808432835, 808432835 - ] + 806455814, + 806487219, + 806487219, + 806487219, + 806487226, + 808432835, + 808432835, + 808432835, + 808432835, + 808432835, + ], } self.t = { 0: [ - 70104010., 70104016., 70104192., 70104123., 70103096., - 70103797., 70103796., 70104191., 70104223., 70104181. + 70104010.0, + 70104016.0, + 70104192.0, + 70104123.0, + 70103096.0, + 70103797.0, + 70103796.0, + 70104191.0, + 70104223.0, + 70104181.0, ], 5: [ - 81861237., 81859608., 81860586., 81861062., 81860357., - 81860627., 81860628., 81860625., 81860627., 81860629. - ] + 81861237.0, + 81859608.0, + 81860586.0, + 81861062.0, + 81860357.0, + 81860627.0, + 81860628.0, + 81860625.0, + 81860627.0, + 81860629.0, + ], } def test_attributes_available(self): @@ -241,10 +290,9 @@ class TestOfflineHits(unittest.TestCase): def test_attributes(self): for idx, dom_id in self.dom_id.items(): - self.assertListEqual(dom_id, - list(self.hits.dom_id[idx][:len(dom_id)])) + self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)])) for idx, t in self.t.items(): - assert np.allclose(t, self.hits.t[idx][:len(t)]) + assert np.allclose(t, self.hits.t[idx][: len(t)]) def test_slicing(self): s = slice(2, 8, 2) @@ -258,23 +306,25 @@ class TestOfflineHits(unittest.TestCase): def test_slicing_consistency(self): for s in [slice(1, 3), slice(2, 7, 3)]: for idx in range(3): - assert np.allclose(self.hits.dom_id[idx][s], - self.hits[idx].dom_id[s]) - assert np.allclose(OFFLINE_FILE.events[idx].hits.dom_id[s], - self.hits.dom_id[idx][s]) + assert np.allclose(self.hits.dom_id[idx][s], self.hits[idx].dom_id[s]) + assert np.allclose( + OFFLINE_FILE.events[idx].hits.dom_id[s], self.hits.dom_id[idx][s] + ) def test_index_consistency(self): for idx, dom_ids in self.dom_id.items(): - assert np.allclose(self.hits[idx].dom_id[:self.n_hits], - dom_ids[:self.n_hits]) assert np.allclose( - OFFLINE_FILE.events[idx].hits.dom_id[:self.n_hits], - dom_ids[:self.n_hits]) + self.hits[idx].dom_id[: self.n_hits], dom_ids[: self.n_hits] + ) + assert np.allclose( + OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits], + dom_ids[: self.n_hits], + ) for idx, ts in self.t.items(): - assert np.allclose(self.hits[idx].t[:self.n_hits], - ts[:self.n_hits]) - assert np.allclose(OFFLINE_FILE.events[idx].hits.t[:self.n_hits], - ts[:self.n_hits]) + assert np.allclose(self.hits[idx].t[: self.n_hits], ts[: self.n_hits]) + assert np.allclose( + OFFLINE_FILE.events[idx].hits.t[: self.n_hits], ts[: self.n_hits] + ) def test_keys(self): assert "dom_id" in self.hits.keys() @@ -294,14 +344,14 @@ class TestOfflineTracks(unittest.TestCase): @unittest.skip def test_attributes(self): for idx, dom_id in self.dom_id.items(): - self.assertListEqual(dom_id, - list(self.hits.dom_id[idx][:len(dom_id)])) + self.assertListEqual(dom_id, list(self.hits.dom_id[idx][: len(dom_id)])) for idx, t in self.t.items(): - assert np.allclose(t, self.hits.t[idx][:len(t)]) + assert np.allclose(t, self.hits.t[idx][: len(t)]) def test_item_selection(self): - self.assertListEqual(list(self.tracks[0].dir_z[:2]), - [-0.872885221293917, -0.872885221293917]) + self.assertListEqual( + list(self.tracks[0].dir_z[:2]), [-0.872885221293917, -0.872885221293917] + ) def test_repr(self): assert " 10 " in repr(self.tracks) @@ -315,24 +365,33 @@ class TestOfflineTracks(unittest.TestCase): track_selection_2 = tracks[1:3] assert 2 == len(track_selection_2) for _slice in [ - slice(0, 0), - slice(0, 1), - slice(0, 2), - slice(1, 5), - slice(3, -2) + slice(0, 0), + slice(0, 1), + slice(0, 2), + slice(1, 5), + slice(3, -2), ]: - self.assertListEqual(list(tracks.E[:, 0][_slice]), - list(tracks[_slice].E[:, 0])) + self.assertListEqual( + list(tracks.E[:, 0][_slice]), list(tracks[_slice].E[:, 0]) + ) def test_nested_indexing(self): - self.assertAlmostEqual(self.f.events.tracks.fitinf[3:5][1][9][2], - self.f.events[3:5].tracks[1].fitinf[9][2]) - self.assertAlmostEqual(self.f.events.tracks.fitinf[3:5][1][9][2], - self.f.events[3:5][1][9][2].tracks.fitinf) - self.assertAlmostEqual(self.f.events.tracks.fitinf[3:5][1][9][2], - self.f.events[3:5][1].tracks[9][2].fitinf) - self.assertAlmostEqual(self.f.events.tracks.fitinf[3:5][1][9][2], - self.f.events[3:5][1].tracks[9].fitinf[2]) + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5].tracks[1].fitinf[9][2], + ) + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5][1][9][2].tracks.fitinf, + ) + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5][1].tracks[9][2].fitinf, + ) + self.assertAlmostEqual( + self.f.events.tracks.fitinf[3:5][1][9][2], + self.f.events[3:5][1].tracks[9].fitinf[2], + ) class TestBranchIndexingMagic(unittest.TestCase): @@ -341,10 +400,12 @@ class TestBranchIndexingMagic(unittest.TestCase): def test_foo(self): self.assertEqual(318, self.events[2:4].n_hits[0]) - assert np.allclose(self.events[3].tracks.dir_z[10], - self.events.tracks.dir_z[3, 10]) - assert np.allclose(self.events[3:6].tracks.pos_y[:, 0], - self.events.tracks.pos_y[3:6, 0]) + assert np.allclose( + self.events[3].tracks.dir_z[10], self.events.tracks.dir_z[3, 10] + ) + assert np.allclose( + self.events[3:6].tracks.pos_y[:, 0], self.events.tracks.pos_y[3:6, 0] + ) # test selecting with a list self.assertEqual(3, len(self.events[[0, 2, 3]])) @@ -358,29 +419,48 @@ class TestUsr(unittest.TestCase): print(self.f.events.usr) def test_keys_flat(self): - self.assertListEqual([ - 'RecoQuality', 'RecoNDF', 'CoC', 'ToT', 'ChargeAbove', - 'ChargeBelow', 'ChargeRatio', 'DeltaPosZ', 'FirstPartPosZ', - 'LastPartPosZ', 'NSnapHits', 'NTrigHits', 'NTrigDOMs', - 'NTrigLines', 'NSpeedVetoHits', 'NGeometryVetoHits', - 'ClassficationScore' - ], self.f.events.usr.keys()) + self.assertListEqual( + [ + "RecoQuality", + "RecoNDF", + "CoC", + "ToT", + "ChargeAbove", + "ChargeBelow", + "ChargeRatio", + "DeltaPosZ", + "FirstPartPosZ", + "LastPartPosZ", + "NSnapHits", + "NTrigHits", + "NTrigDOMs", + "NTrigLines", + "NSpeedVetoHits", + "NGeometryVetoHits", + "ClassficationScore", + ], + self.f.events.usr.keys(), + ) def test_getitem_flat(self): assert np.allclose( [118.6302815337638, 44.33580521344907, 99.93916717621543], - self.f.events.usr['CoC']) + self.f.events.usr["CoC"], + ) assert np.allclose( [37.51967774166617, -10.280346193553832, 13.67595659707355], - self.f.events.usr['DeltaPosZ']) + self.f.events.usr["DeltaPosZ"], + ) def test_attributes_flat(self): assert np.allclose( [118.6302815337638, 44.33580521344907, 99.93916717621543], - self.f.events.usr.CoC) + self.f.events.usr.CoC, + ) assert np.allclose( [37.51967774166617, -10.280346193553832, 13.67595659707355], - self.f.events.usr.DeltaPosZ) + self.f.events.usr.DeltaPosZ, + ) class TestMcTrackUsr(unittest.TestCase): @@ -391,21 +471,25 @@ class TestMcTrackUsr(unittest.TestCase): n_tracks = len(self.f.events) for i in range(3): self.assertListEqual( - [b'bx', b'by', b'ichan', b'cc'], - self.f.events.mc_tracks.usr_names[i][0].tolist()) + [b"bx", b"by", b"ichan", b"cc"], + self.f.events.mc_tracks.usr_names[i][0].tolist(), + ) self.assertListEqual( - [b'energy_lost_in_can'], - self.f.events.mc_tracks.usr_names[i][1].tolist()) + [b"energy_lost_in_can"], + self.f.events.mc_tracks.usr_names[i][1].tolist(), + ) def test_usr(self): - assert np.allclose([0.0487, 0.0588, 3, 2], - self.f.events.mc_tracks.usr[0][0].tolist(), - atol=0.0001) - assert np.allclose([0.147, 0.4, 3, 2], - self.f.events.mc_tracks.usr[1][0].tolist(), - atol=0.001) + assert np.allclose( + [0.0487, 0.0588, 3, 2], + self.f.events.mc_tracks.usr[0][0].tolist(), + atol=0.0001, + ) + assert np.allclose( + [0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001 + ) class TestNestedMapper(unittest.TestCase): def test_nested_mapper(self): - self.assertEqual('pos_x', _nested_mapper("trks.pos.x")) + self.assertEqual("pos_x", _nested_mapper("trks.pos.x")) diff --git a/tests/test_online.py b/tests/test_online.py index 24bb9f3..50e5689 100644 --- a/tests/test_online.py +++ b/tests/test_online.py @@ -6,7 +6,14 @@ import unittest from km3net_testdata import data_path -from km3io.online import OnlineReader, get_rate, has_udp_trailer, get_udp_max_sequence_number, get_channel_flags, get_number_udp_packets +from km3io.online import ( + OnlineReader, + get_rate, + has_udp_trailer, + get_udp_max_sequence_number, + get_channel_flags, + get_number_udp_packets, +) from km3io.tools import to_num ONLINE_FILE = data_path("online/km3net_online.root") @@ -42,12 +49,10 @@ class TestOnlineEvent(unittest.TestCase): self.event = OnlineReader(ONLINE_FILE).events[0] def test_str(self): - assert re.match(".*event.*96.*snapshot.*18.*triggered", - str(self.event)) + assert re.match(".*event.*96.*snapshot.*18.*triggered", str(self.event)) def test_repr(self): - assert re.match(".*event.*96.*snapshot.*18.*triggered", - self.event.__repr__()) + assert re.match(".*event.*96.*snapshot.*18.*triggered", self.event.__repr__()) class TestOnlineEventsSnapshotHits(unittest.TestCase): @@ -74,11 +79,11 @@ class TestOnlineEventsSnapshotHits(unittest.TestCase): def test_data_values(self): hits = self.events.snapshot_hits - self.assertListEqual([806451572, 806451572, 806455814], - list(hits.dom_id[0][:3])) + self.assertListEqual( + [806451572, 806451572, 806455814], list(hits.dom_id[0][:3]) + ) self.assertListEqual([10, 13, 0], list(hits.channel_id[0][:3])) - self.assertListEqual([30733918, 30733916, 30733256], - list(hits.time[0][:3])) + self.assertListEqual([30733918, 30733916, 30733256], list(hits.time[0][:3])) def test_channel_ids_have_valid_values(self): hits = self.events.snapshot_hits @@ -113,11 +118,11 @@ class TestOnlineEventsTriggeredHits(unittest.TestCase): def test_data_values(self): hits = self.events.triggered_hits - self.assertListEqual([806451572, 806451572, 808432835], - list(hits.dom_id[0][:3])) + self.assertListEqual( + [806451572, 806451572, 808432835], list(hits.dom_id[0][:3]) + ) self.assertListEqual([10, 13, 1], list(hits.channel_id[0][:3])) - self.assertListEqual([30733918, 30733916, 30733429], - list(hits.time[0][:3])) + self.assertListEqual([30733918, 30733916, 30733429], list(hits.time[0][:3])) self.assertListEqual([16, 16, 4], list(hits.trigger_mask[0][:3])) def test_channel_ids_have_valid_values(self): @@ -174,8 +179,7 @@ class TestSummaryslices(unittest.TestCase): assert 3 == len(self.ss.headers) self.assertListEqual([44, 44, 44], list(self.ss.headers.detector_id)) self.assertListEqual([6633, 6633, 6633], list(self.ss.headers.run)) - self.assertListEqual([126, 127, 128], - list(self.ss.headers.frame_index)) + self.assertListEqual([126, 127, 128], list(self.ss.headers.frame_index)) assert 806451572 == self.ss.slices[0].dom_id[0] def test_slices(self): @@ -190,7 +194,7 @@ class TestSummaryslices(unittest.TestCase): 808981510: True, 808981523: False, 808981672: False, - 808974773: False + 808974773: False, } for dom_id, fifo_status in dct_fifo_stat.items(): frame = s[s.dom_id == dom_id] @@ -209,7 +213,7 @@ class TestSummaryslices(unittest.TestCase): 808432835: True, 808435278: True, 808447180: True, - 808447186: True + 808447186: True, } for dom_id, udp_trailer in dct_udp_trailer.items(): frame = s[s.dom_id == dom_id] @@ -249,7 +253,7 @@ class TestSummaryslices(unittest.TestCase): 808981672: False, 808981812: True, 808981864: False, - 808982018: False + 808982018: False, } for dom_id, high_rate_veto in dct_high_rate_veto.items(): frame = s[s.dom_id == dom_id] @@ -285,12 +289,13 @@ class TestSummaryslices(unittest.TestCase): 809524432: 21, 809526097: 23, 809544058: 21, - 809544061: 23 + 809544061: 23, } for dom_id, max_sequence_number in dct_seq_numbers.items(): frame = s[s.dom_id == dom_id] - assert get_udp_max_sequence_number( - frame.dq_status[0]) == max_sequence_number + assert ( + get_udp_max_sequence_number(frame.dq_status[0]) == max_sequence_number + ) def test_number_udp_packets(self): s = self.ss.slices[0] @@ -315,7 +320,7 @@ class TestSummaryslices(unittest.TestCase): 808961655: 20, 808964815: 20, 808964852: 28, - 808969848: 21 + 808969848: 21, } for dom_id, n_udp_packets in dct_n_packets.items(): frame = s[s.dom_id == dom_id] @@ -325,83 +330,351 @@ class TestSummaryslices(unittest.TestCase): s = self.ss.slices[0] dct_hrv_flags = { 809524432: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ], 809526097: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - True, False, False, False, False, False, False, False, True, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + False, + False, + False, + False, + False, + False, + False, + True, + False, + False, + False, + False, ], 809544058: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ], 809544061: [ - False, True, False, False, False, True, False, False, False, - False, False, False, False, False, False, True, False, False, - False, False, False, True, False, False, False, False, False, - False, False, False, False - ] + False, + True, + False, + False, + False, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + False, + False, + False, + False, + False, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], } for dom_id, hrv_flags in dct_hrv_flags.items(): frame = s[s.dom_id == dom_id] - assert any([ - a == b - for a, b in zip(get_channel_flags(frame.hrv[0]), hrv_flags) - ]) + assert any( + [a == b for a, b in zip(get_channel_flags(frame.hrv[0]), hrv_flags)] + ) def test_fifo_flags(self): s = self.ss.slices[0] dct_fifo_flags = { 808982547: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ], 808984711: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ], 808996773: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ], 808997793: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ], 809006037: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, False, False, False, False, - False, False, False, False + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, ], 808981510: [ - False, False, False, False, False, False, False, False, False, - False, False, False, False, False, True, True, False, False, - False, True, False, True, True, True, True, True, True, False, - False, True, False - ] + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + False, + False, + False, + True, + False, + True, + True, + True, + True, + True, + True, + False, + False, + True, + False, + ], } for dom_id, fifo_flags in dct_fifo_flags.items(): frame = s[s.dom_id == dom_id] - assert any([ - a == b - for a, b in zip(get_channel_flags(frame.fifo[0]), fifo_flags) - ]) + assert any( + [a == b for a, b in zip(get_channel_flags(frame.fifo[0]), fifo_flags)] + ) def test_str(self): print(str(self.ss)) @@ -418,16 +691,13 @@ class TestGetChannelFlags_Issue59(unittest.TestCase): Entry = namedtuple("Entry", fieldnames) with open( - data_path( - "online/KM3NeT_00000049_00008456.summaryslice-167941.txt") + data_path("online/KM3NeT_00000049_00008456.summaryslice-167941.txt") ) as fobj: - ref_entries = [ - Entry(*list(l.strip().split())) for l in fobj.readlines() - ] + ref_entries = [Entry(*list(l.strip().split())) for l in fobj.readlines()] r = OnlineReader( - data_path( - "online/KM3NeT_00000049_00008456.summaryslice-167941.root")) + data_path("online/KM3NeT_00000049_00008456.summaryslice-167941.root") + ) summaryslice = r.summaryslices.slices[0] for ours, ref in zip(summaryslice, ref_entries): diff --git a/tests/test_tools.py b/tests/test_tools.py index b23c283..d5cfd00 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -10,11 +10,24 @@ from km3io.definitions import fitparameters as kfit from km3io import OfflineReader -from km3io.tools import (to_num, cached_property, unfold_indices, unique, - uniquecount, fitinf, count_nested, _find, mask, - best_track, get_w2list_param, get_multiplicity, - best_jmuon, best_jshower, best_aashower, - best_dusjshower) +from km3io.tools import ( + to_num, + cached_property, + unfold_indices, + unique, + uniquecount, + fitinf, + count_nested, + _find, + mask, + best_track, + get_w2list_param, + get_multiplicity, + best_jmuon, + best_jshower, + best_aashower, + best_dusjshower, +) OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) @@ -45,7 +58,8 @@ class TestBestTrackSelection(unittest.TestCase): self.one_event = OFFLINE_FILE.events[0] def test_best_track_selection_from_multiple_events_with_explicit_stages_in_list( - self): + self, + ): best = best_track(self.events.tracks, stages=[1, 3, 5, 4]) assert len(best) == 10 @@ -76,7 +90,8 @@ class TestBestTrackSelection(unittest.TestCase): assert best3.rec_stages[3] is None def test_best_track_selection_from_multiple_events_with_explicit_stages_in_set( - self): + self, + ): best = best_track(self.events.tracks, stages={1, 3, 4, 5}) assert len(best) == 10 @@ -169,8 +184,7 @@ class TestBestTrackSelection(unittest.TestCase): assert best3.rec_stages[0] == [1, 3, 5, 4] def test_best_track_on_slices_one_event(self): - tracks_slice = self.one_event.tracks[self.one_event.tracks.rec_type == - 4000] + tracks_slice = self.one_event.tracks[self.one_event.tracks.rec_type == 4000] # test stages with list best = best_track(tracks_slice, stages=[1, 3, 5, 4]) @@ -238,9 +252,7 @@ class TestBestTrackSelection(unittest.TestCase): def test_best_track_raises_when_too_many_inputs(self): with self.assertRaises(ValueError): - best_track(self.events.tracks, - startend=(1, 4), - stages=[1, 3, 5, 4]) + best_track(self.events.tracks, startend=(1, 4), stages=[1, 3, 5, 4]) class TestBestJmuon(unittest.TestCase): @@ -312,8 +324,7 @@ class TestBestDusjshower(unittest.TestCase): class TestGetMultiplicity(unittest.TestCase): def test_get_multiplicity(self): - multiplicity = get_multiplicity(OFFLINE_FILE.events.tracks, - [1, 3, 5, 4]) + multiplicity = get_multiplicity(OFFLINE_FILE.events.tracks, [1, 3, 5, 4]) assert len(multiplicity) == 10 assert multiplicity[0] == 1 @@ -322,8 +333,7 @@ class TestGetMultiplicity(unittest.TestCase): assert multiplicity[3] == 1 # test with no nexisting rec_stages - multiplicity2 = get_multiplicity(OFFLINE_FILE.events.tracks, - [1, 2, 3, 4, 5]) + multiplicity2 = get_multiplicity(OFFLINE_FILE.events.tracks, [1, 2, 3, 4, 5]) assert len(multiplicity2) == 10 assert multiplicity2[0] == 0 @@ -343,8 +353,13 @@ class TestCountNested(unittest.TestCase): class TestRecStagesMasks(unittest.TestCase): def setUp(self): - self.nested = ak.Array([[[1, 2, 3], [1, 2, 3], [1]], [[0], [1, 2, 3]], - [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]]]) + self.nested = ak.Array( + [ + [[1, 2, 3], [1, 2, 3], [1]], + [[0], [1, 2, 3]], + [[0], [0, 1, 3], [0], [1, 2, 3], [1, 2, 3]], + ] + ) self.tracks = OFFLINE_FILE.events.tracks @@ -418,8 +433,7 @@ class TestUnique(unittest.TestCase): max_range = 100 for i in range(23): low = np.random.randint(0, max_range) - high = np.random.randint(low + 1, - low + 2 + np.random.randint(max_range)) + high = np.random.randint(low + 1, low + 2 + np.random.randint(max_range)) n = np.random.randint(max_range) arr = np.random.randint(low, high, n).astype(dtype) np_reference = np.sort(np.unique(arr)) @@ -496,8 +510,7 @@ class TestUnfoldIndices(unittest.TestCase): assert data[indices[0]][indices[1]] == unfold_indices(data, indices) indices = [slice(1, 9, 2), slice(1, 4), 2] - assert data[indices[0]][indices[1]][indices[2]] == unfold_indices( - data, indices) + assert data[indices[0]][indices[1]][indices[2]] == unfold_indices(data, indices) def test_unfold_indices_raises_index_error(self): data = range(10) -- GitLab