Skip to content
Snippets Groups Projects
Verified Commit a5c5443f authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Remove uproot3 related stuff

parent 9ede3782
No related branches found
No related tags found
1 merge request!86Remove online support, uproot3 and update to awkward2 and uproot5
......@@ -73,10 +73,6 @@ dev =
sphinxcontrib-versioning
wheel
[options.entry_points]
console_scripts =
KPrintTree = km3io.utils.kprinttree:main
[options.package_data]
* = *.mplstyle, *.py.typed
......
......@@ -16,11 +16,6 @@ import os
# This needs to be done before import numpy
os.environ["KMP_WARNINGS"] = "off"
with warnings.catch_warnings():
for warning_category in (FutureWarning, DeprecationWarning):
warnings.simplefilter("ignore", category=warning_category)
import uproot3
from .offline import OfflineReader
from .online import OnlineReader
from .acoustics import RawAcousticReader
import binascii
from collections import namedtuple
import os
import uproot
import uproot3
import numpy as np
import numba as nb
TIMESLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte]
SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte]
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
# Parameters for PMT rate conversions, since the rates in summary slices are
# stored as a single byte to save space. The values from 0-255 can be decoded
# using the `get_rate(value)` function, which will yield the actual rate
......@@ -220,12 +213,12 @@ class OnlineReader:
"""Reader for online ROOT files"""
def __init__(self, filename):
self._fobj = uproot3.open(filename)
self._fobj = uproot.open(filename)
self._filename = filename
self._events = None
self._timeslices = None
self._summaryslices = None
self._uuid = binascii.hexlify(self._fobj._context.uuid).decode("ascii")
self._uuid = self._fobj.parent.uuid.hex
@property
def uuid(self):
......@@ -284,12 +277,6 @@ class OnlineReader:
self._events = OnlineEvents(headers, snapshot_hits, triggered_hits)
return self._events
@property
def timeslices(self):
if self._timeslices is None:
self._timeslices = Timeslices(self._fobj)
return self._timeslices
@property
def summaryslices(self):
if self._summaryslices is None:
......@@ -299,155 +286,6 @@ class OnlineReader:
return self._summaryslices
class Timeslices:
"""A simple wrapper for timeslices"""
def __init__(self, fobj):
self._fobj = fobj
self._timeslices = {}
self._read_streams()
def _read_streams(self):
"""Read the L0, L1, L2 and SN streams if available"""
streams = set(
s.split(b"KM3NET_TIMESLICE_")[1].split(b";")[0]
for s in self._fobj.keys()
if b"KM3NET_TIMESLICE_" in s
)
for stream in streams:
tree = self._fobj[b"KM3NET_TIMESLICE_" + stream][
b"KM3NETDAQ::JDAQTimeslice"
]
headers = tree[b"KM3NETDAQ::JDAQTimesliceHeader"][b"KM3NETDAQ::JDAQHeader"][
b"KM3NETDAQ::JDAQChronometer"
]
if len(headers) == 0:
continue
superframes = tree[b"vector<KM3NETDAQ::JDAQSuperFrame>"]
hits_dtype = np.dtype([("pmt", "u1"), ("tdc", "<u4"), ("tot", "u1")])
hits_buffer = superframes[
b"vector<KM3NETDAQ::JDAQSuperFrame>.buffer"
].lazyarray(
uproot3.asjagged(
uproot3.astable(uproot3.asdtype(hits_dtype)), skipbytes=6
),
basketcache=uproot3.cache.ThreadSafeArrayCache(
TIMESLICE_FRAME_BASKET_CACHE_SIZE
),
)
self._timeslices[stream.decode("ascii")] = (
headers,
superframes,
hits_buffer,
)
setattr(
self,
stream.decode("ascii"),
TimesliceStream(headers, superframes, hits_buffer),
)
def stream(self, stream, idx):
ts = self._timeslices[stream]
return Timeslice(*ts, idx, stream)
def __str__(self):
return "Available timeslice streams: {}".format(
", ".join(s for s in self._timeslices.keys())
)
def __repr__(self):
return str(self)
class TimesliceStream:
def __init__(self, headers, superframes, hits_buffer):
# self.headers = headers.lazyarray(
# uproot3.asjagged(uproot3.astable(
# uproot3.asdtype(
# np.dtype([('a', 'i4'), ('b', 'i4'), ('c', 'i4'),
# ('d', 'i4'), ('e', 'i4')]))),
# skipbytes=6),
# basketcache=uproot3.cache.ThreadSafeArrayCache(
# TIMESLICE_FRAME_BASKET_CACHE_SIZE))
self.headers = headers
self.superframes = superframes
self._hits_buffer = hits_buffer
# def frames(self):
# n_hits = self._superframe[
# b'vector<KM3NETDAQ::JDAQSuperFrame>.numberOfHits'].lazyarray(
# basketcache=BASKET_CACHE)[self._idx]
# module_ids = self._superframe[
# b'vector<KM3NETDAQ::JDAQSuperFrame>.id'].lazyarray(basketcache=BASKET_CACHE)[self._idx]
# idx = 0
# for module_id, n_hits in zip(module_ids, n_hits):
# self._frames[module_id] = hits_buffer[idx:idx + n_hits]
# idx += n_hits
class Timeslice:
"""A wrapper for a timeslice"""
def __init__(self, header, superframe, hits_buffer, idx, stream):
self.header = header
self._frames = {}
self._superframe = superframe
self._hits_buffer = hits_buffer
self._idx = idx
self._stream = stream
self._n_frames = None
@property
def frames(self):
if not self._frames:
self._read_frames()
return self._frames
def _read_frames(self):
"""Populate a dictionary of frames with the module ID as key"""
hits_buffer = self._hits_buffer[self._idx]
n_hits = self._superframe[
b"vector<KM3NETDAQ::JDAQSuperFrame>.numberOfHits"
].lazyarray(basketcache=BASKET_CACHE)[self._idx]
try:
module_ids = self._superframe[
b"vector<KM3NETDAQ::JDAQSuperFrame>.id"
].lazyarray(basketcache=BASKET_CACHE)[self._idx]
except KeyError:
module_ids = (
self._superframe[
b"vector<KM3NETDAQ::JDAQSuperFrame>.KM3NETDAQ::JDAQModuleIdentifier"
]
.lazyarray(
uproot3.asjagged(
uproot3.astable(uproot3.asdtype([("dom_id", ">i4")]))
),
basketcache=BASKET_CACHE,
)[self._idx]
.dom_id
)
idx = 0
for module_id, n_hits in zip(module_ids, n_hits):
self._frames[module_id] = hits_buffer[idx : idx + n_hits]
idx += n_hits
def __len__(self):
if self._n_frames is None:
self._n_frames = len(
self._superframe[b"vector<KM3NETDAQ::JDAQSuperFrame>.id"].lazyarray(
basketcache=BASKET_CACHE
)[self._idx]
)
return self._n_frames
def __str__(self):
return "{} timeslice with {} frames.".format(self._stream, len(self))
def __repr__(self):
return "<{}: {} entries>".format(self.__class__.__name__, len(self.header))
class OnlineEvents:
"""A simple wrapper for online events"""
......
......@@ -60,7 +60,7 @@ class EventReader:
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._uuid = self._fobj.parent.uuid
self._iterator_index = 0
self._keys = keys
self._event_ctor = event_ctor
......
......@@ -3,7 +3,6 @@ from collections import namedtuple
import numba as nb
import numpy as np
import awkward as ak
import uproot3
import km3io.definitions
from km3io.definitions import reconstruction as krec
......@@ -12,10 +11,6 @@ from km3io.definitions import fitparameters as kfit
from km3io.definitions import w2list_genhen as kw2gen
from km3io.definitions import w2list_gseagen as kw2gsg
# 110 MB based on the size of the largest basket found so far in km3net
BASKET_CACHE_SIZE = 110 * 1024**2
BASKET_CACHE = uproot3.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE)
class cached_property:
"""A simple cache decorator for properties."""
......
#!/usr/bin/env python
# coding=utf-8
# Filename: kprinttree.py
# Author: Tamas Gal <tgal@km3net.de>
"""
Print the available ROOT trees.
Usage:
KPrintTree -f FILENAME
KPrintTree (-h | --help)
Options:
-f FILENAME The file to print (;
-h --help Show this screen.
"""
import warnings
with warnings.catch_warnings():
for warning_category in (FutureWarning, DeprecationWarning):
warnings.simplefilter("ignore", category=warning_category)
import uproot3
def print_tree(filename):
f = uproot3.open(filename)
for key in f.keys():
try:
print("{:<30} : {:>9} items".format(key.decode(), len(f[key])))
except (TypeError, KeyError):
print("{}".format(key.decode()))
except NotImplementedError:
print("{} (TStreamerSTL)".format(key.decode()))
def main():
from docopt import docopt
args = docopt(__doc__)
print_tree(args["-f"])
if __name__ == "__main__":
main()
......@@ -21,156 +21,15 @@ from km3io.tools import to_num
ONLINE_FILE = data_path("online/km3net_online.root")
class TestOnlineReaderContextManager(unittest.TestCase):
def test_context_manager(self):
with OnlineReader(ONLINE_FILE) as r:
assert r._filename == ONLINE_FILE
# class TestOnlineReaderContextManager(unittest.TestCase):
# def test_context_manager(self):
# with OnlineReader(ONLINE_FILE) as r:
# assert r._filename == ONLINE_FILE
class TestUUID(unittest.TestCase):
def test_uuid(self):
assert OnlineReader(ONLINE_FILE).uuid == "00010c85603008c611ea971772f09e86beef"
class TestOnlineEvents(unittest.TestCase):
def setUp(self):
self.events = OnlineReader(ONLINE_FILE).events
def test_index_lookup(self):
assert 3 == len(self.events)
def test_str(self):
assert re.match(".*events.*3", str(self.events))
def test_repr(self):
assert re.match(".*events.*3", self.events.__repr__())
class TestOnlineEvent(unittest.TestCase):
def setUp(self):
self.event = OnlineReader(ONLINE_FILE).events[0]
def test_str(self):
assert re.match(".*event.*96.*snapshot.*18.*triggered", str(self.event))
def test_repr(self):
assert re.match(".*event.*96.*snapshot.*18.*triggered", self.event.__repr__())
class TestOnlineEventsSnapshotHits(unittest.TestCase):
def setUp(self):
self.events = OnlineReader(ONLINE_FILE).events
self.lengths = {0: 96, 1: 124, -1: 78}
self.total_item_count = 298
def test_reading_snapshot_hits(self):
hits = self.events.snapshot_hits
for event_id, length in self.lengths.items():
assert length == len(hits[event_id].dom_id)
assert length == len(hits[event_id].channel_id)
assert length == len(hits[event_id].time)
def test_total_item_counts(self):
hits = self.events.snapshot_hits
assert self.total_item_count == sum(hits.dom_id.count())
assert self.total_item_count == sum(hits.channel_id.count())
assert self.total_item_count == sum(hits.time.count())
def test_data_values(self):
hits = self.events.snapshot_hits
self.assertListEqual(
[806451572, 806451572, 806455814], list(hits.dom_id[0][:3])
)
self.assertListEqual([10, 13, 0], list(hits.channel_id[0][:3]))
self.assertListEqual([30733918, 30733916, 30733256], list(hits.time[0][:3]))
def test_channel_ids_have_valid_values(self):
hits = self.events.snapshot_hits
# channel IDs are always between [0, 30]
assert all(c >= 0 for c in hits.channel_id.min())
assert all(c < 31 for c in hits.channel_id.max())
class TestOnlineEventsTriggeredHits(unittest.TestCase):
def setUp(self):
self.events = OnlineReader(ONLINE_FILE).events
self.lengths = {0: 18, 1: 53, -1: 9}
self.total_item_count = 80
def test_data_lengths(self):
hits = self.events.triggered_hits
for event_id, length in self.lengths.items():
assert length == len(hits[event_id].dom_id)
assert length == len(hits[event_id].channel_id)
assert length == len(hits[event_id].time)
assert length == len(hits[event_id].trigger_mask)
def test_total_item_counts(self):
hits = self.events.triggered_hits
assert self.total_item_count == sum(hits.dom_id.count())
assert self.total_item_count == sum(hits.channel_id.count())
assert self.total_item_count == sum(hits.time.count())
def test_data_values(self):
hits = self.events.triggered_hits
self.assertListEqual(
[806451572, 806451572, 808432835], list(hits.dom_id[0][:3])
)
self.assertListEqual([10, 13, 1], list(hits.channel_id[0][:3]))
self.assertListEqual([30733918, 30733916, 30733429], list(hits.time[0][:3]))
self.assertListEqual([16, 16, 4], list(hits.trigger_mask[0][:3]))
def test_channel_ids_have_valid_values(self):
hits = self.events.triggered_hits
# channel IDs are always between [0, 30]
assert all(c >= 0 for c in hits.channel_id.min())
assert all(c < 31 for c in hits.channel_id.max())
class TestTimeslices(unittest.TestCase):
def setUp(self):
self.ts = OnlineReader(ONLINE_FILE).timeslices
def test_data_lengths(self):
assert 3 == len(self.ts._timeslices["L1"][0])
assert 3 == len(self.ts._timeslices["SN"][0])
with self.assertRaises(KeyError):
assert 0 == len(self.ts._timeslices["L2"][0])
with self.assertRaises(KeyError):
assert 0 == len(self.ts._timeslices["L0"][0])
def test_streams(self):
self.ts.stream("L1", 0)
self.ts.stream("SN", 0)
def test_reading_frames(self):
assert 8 == len(self.ts.stream("SN", 1).frames[808447186])
def test_str(self):
s = str(self.ts)
assert "L1" in s
assert "SN" in s
class TestTimeslice(unittest.TestCase):
def setUp(self):
self.ts = OnlineReader(ONLINE_FILE).timeslices
self.n_frames = {"L1": [69, 69, 69], "SN": [64, 66, 68]}
def test_str(self):
for stream, n_frames in self.n_frames.items():
print(stream, n_frames)
for i in range(len(n_frames)):
s = str(self.ts.stream(stream, i))
assert re.match("{}.*{}".format(stream, n_frames[i]), s)
assert OnlineReader(ONLINE_FILE).uuid == "0c85603008c611ea971772f09e86beef"
class TestSummaryslices(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment