Skip to content
Snippets Groups Projects
Commit d2c13d91 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Merge branch 'refactor-offline-prepare-for-merging' into 'master'

Refactor offline I/O

See merge request !27
parents 52a777d5 fcca3704
No related branches found
No related tags found
1 merge request!27Refactor offline I/O
Pipeline #10195 passed with warnings
#!/usr/bin/env python3
from .mc_header import data as mc_header
from .trigger import data as trigger
from .fitparameters import data as fitparameters
from .reconstruction import data as reconstruction
#!/usr/bin/env python3
data = {
"DAQ": "livetime",
"seed": "program level iseed",
"PM1_type_area": "type area TTS",
"PDF": "i1 i2",
"model": "interaction muon scattering numberOfEnergyBins",
"can": "zmin zmax r",
"genvol": "zmin zmax r volume numberOfEvents",
"merge": "time gain",
"coord_origin": "x y z",
"translate": "x y z",
"genhencut": "gDir Emin",
"k40": "rate time",
"norma": "primaryFlux numberOfPrimaries",
"livetime": "numberOfSeconds errorOfSeconds",
"flux": "type key file_1 file_2",
"spectrum": "alpha",
"fixedcan": "xcenter ycenter zmin zmax radius",
"start_run": "run_id",
}
for key in "cut_primary cut_seamuon cut_in cut_nu".split():
data[key] = "Emin Emax cosTmin cosTmax"
for key in "generator physics simul".split():
data[key] = "program version date time"
for key in data.keys():
data[key] = data[key].split()
This diff is collapsed.
...@@ -19,6 +19,21 @@ class cached_property: ...@@ -19,6 +19,21 @@ class cached_property:
return prop return prop
def _unfold_indices(obj, indices):
"""Unfolds an index chain and returns the corresponding item"""
original_obj = obj
for depth, idx in enumerate(indices):
try:
obj = obj[idx]
except IndexError:
print(
"IndexError while accessing an item from '{}' at depth {} ({}) "
"using the index chain {}".format(repr(original_obj), depth,
idx, indices))
raise
return obj
BranchMapper = namedtuple( BranchMapper = namedtuple(
"BranchMapper", "BranchMapper",
['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat']) ['name', 'key', 'extra', 'exclude', 'update', 'attrparser', 'flat'])
...@@ -29,17 +44,19 @@ class Branch: ...@@ -29,17 +44,19 @@ class Branch:
def __init__(self, def __init__(self,
tree, tree,
mapper, mapper,
index=None, index_chain=None,
subbranchmaps=None, subbranchmaps=None,
keymap=None): keymap=None):
self._tree = tree self._tree = tree
self._mapper = mapper self._mapper = mapper
self._index = index self._index_chain = [] if index_chain is None else index_chain
self._keymap = None self._keymap = None
self._branch = tree[mapper.key] self._branch = tree[mapper.key]
self._subbranches = [] self._subbranches = []
self._subbranchmaps = subbranchmaps self._subbranchmaps = subbranchmaps
self._iterator_index = 0
if keymap is None: if keymap is None:
self._initialise_keys() # self._initialise_keys() #
else: else:
...@@ -49,7 +66,7 @@ class Branch: ...@@ -49,7 +66,7 @@ class Branch:
for mapper in subbranchmaps: for mapper in subbranchmaps:
subbranch = self.__class__(self._tree, subbranch = self.__class__(self._tree,
mapper=mapper, mapper=mapper,
index=self._index) index_chain=self._index_chain)
self._subbranches.append(subbranch) self._subbranches.append(subbranch)
for subbranch in self._subbranches: for subbranch in self._subbranches:
setattr(self, subbranch._mapper.name, subbranch) setattr(self, subbranch._mapper.name, subbranch)
...@@ -57,8 +74,8 @@ class Branch: ...@@ -57,8 +74,8 @@ class Branch:
def _initialise_keys(self): def _initialise_keys(self):
"""Create the keymap and instance attributes for branch keys""" """Create the keymap and instance attributes for branch keys"""
# TODO: this could be a cached property # TODO: this could be a cached property
keys = set(k.decode('utf-8') for k in self._branch.keys()) - set( keys = set(k.decode('utf-8')
self._mapper.exclude) for k in self._branch.keys()) - set(self._mapper.exclude)
self._keymap = { self._keymap = {
**{self._mapper.attrparser(k): k **{self._mapper.attrparser(k): k
for k in keys}, for k in keys},
...@@ -86,42 +103,46 @@ class Branch: ...@@ -86,42 +103,46 @@ class Branch:
def __getkey__(self, key): def __getkey__(self, key):
out = self._branch[self._keymap[key]].lazyarray( out = self._branch[self._keymap[key]].lazyarray(
basketcache=BASKET_CACHE) basketcache=BASKET_CACHE)
if self._index is not None: return _unfold_indices(out, self._index_chain)
out = out[self._index]
return out
def __getitem__(self, item): def __getitem__(self, item):
"""Slicing magic""" """Slicing magic"""
if isinstance(item, (int, slice)):
return self.__class__(self._tree,
self._mapper,
index=item,
keymap=self._keymap,
subbranchmaps=self._subbranchmaps)
if isinstance(item, tuple):
return self[item[0]][item[1]]
if isinstance(item, str): if isinstance(item, str):
return self.__getkey__(item) return self.__getkey__(item)
return self.__class__(self._tree, return self.__class__(self._tree,
self._mapper, self._mapper,
index=np.array(item), index_chain=self._index_chain + [item],
keymap=self._keymap, keymap=self._keymap,
subbranchmaps=self._subbranchmaps) subbranchmaps=self._subbranchmaps)
def __len__(self): def __len__(self):
if self._index is None: if not self._index_chain:
return len(self._branch) return len(self._branch)
elif isinstance(self._index, int): elif isinstance(self._index_chain[-1], int):
return 1 return 1
else: else:
return len(self._branch[self._keymap['id']].lazyarray( return len(
basketcache=BASKET_CACHE)[self._index]) _unfold_indices(
self._branch[self._keymap['id']].lazyarray(
basketcache=BASKET_CACHE), self._index_chain))
def __iter__(self):
self._iterator_index = 0
return self
def __next__(self):
idx = self._iterator_index
self._iterator_index += 1
if idx >= len(self):
raise StopIteration
return self[idx]
def __str__(self): def __str__(self):
return "Number of elements: {}".format(len(self._branch)) length = len(self)
return "{} ({}) with {} element{}".format(self.__class__.__name__,
self._mapper.name, length,
's' if length > 1 else '')
def __repr__(self): def __repr__(self):
length = len(self) length = len(self)
......
...@@ -5,12 +5,12 @@ import unittest ...@@ -5,12 +5,12 @@ import unittest
from km3io.daq import DAQReader, get_rate, has_udp_trailer, get_udp_max_sequence_number, get_channel_flags, get_number_udp_packets from km3io.daq import DAQReader, get_rate, has_udp_trailer, get_udp_max_sequence_number, get_channel_flags, get_number_udp_packets
SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples") SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "samples")
DAQ_FILE = DAQReader(os.path.join(SAMPLES_DIR, "daq_v1.0.0.root"))
class TestDAQEvents(unittest.TestCase): class TestDAQEvents(unittest.TestCase):
def setUp(self): def setUp(self):
self.events = DAQ_FILE.events self.events = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events
def test_index_lookup(self): def test_index_lookup(self):
assert 3 == len(self.events) assert 3 == len(self.events)
...@@ -24,7 +24,8 @@ class TestDAQEvents(unittest.TestCase): ...@@ -24,7 +24,8 @@ class TestDAQEvents(unittest.TestCase):
class TestDAQEvent(unittest.TestCase): class TestDAQEvent(unittest.TestCase):
def setUp(self): def setUp(self):
self.event = DAQ_FILE.events[0] self.event = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events[0]
def test_str(self): def test_str(self):
assert re.match(".*event.*96.*snapshot.*18.*triggered", assert re.match(".*event.*96.*snapshot.*18.*triggered",
...@@ -37,7 +38,8 @@ class TestDAQEvent(unittest.TestCase): ...@@ -37,7 +38,8 @@ class TestDAQEvent(unittest.TestCase):
class TestDAQEventsSnapshotHits(unittest.TestCase): class TestDAQEventsSnapshotHits(unittest.TestCase):
def setUp(self): def setUp(self):
self.events = DAQ_FILE.events self.events = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events
self.lengths = {0: 96, 1: 124, -1: 78} self.lengths = {0: 96, 1: 124, -1: 78}
self.total_item_count = 298 self.total_item_count = 298
...@@ -75,7 +77,8 @@ class TestDAQEventsSnapshotHits(unittest.TestCase): ...@@ -75,7 +77,8 @@ class TestDAQEventsSnapshotHits(unittest.TestCase):
class TestDAQEventsTriggeredHits(unittest.TestCase): class TestDAQEventsTriggeredHits(unittest.TestCase):
def setUp(self): def setUp(self):
self.events = DAQ_FILE.events self.events = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).events
self.lengths = {0: 18, 1: 53, -1: 9} self.lengths = {0: 18, 1: 53, -1: 9}
self.total_item_count = 80 self.total_item_count = 80
...@@ -115,7 +118,8 @@ class TestDAQEventsTriggeredHits(unittest.TestCase): ...@@ -115,7 +118,8 @@ class TestDAQEventsTriggeredHits(unittest.TestCase):
class TestDAQTimeslices(unittest.TestCase): class TestDAQTimeslices(unittest.TestCase):
def setUp(self): def setUp(self):
self.ts = DAQ_FILE.timeslices self.ts = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).timeslices
def test_data_lengths(self): def test_data_lengths(self):
assert 3 == len(self.ts._timeslices["L1"][0]) assert 3 == len(self.ts._timeslices["L1"][0])
...@@ -140,7 +144,8 @@ class TestDAQTimeslices(unittest.TestCase): ...@@ -140,7 +144,8 @@ class TestDAQTimeslices(unittest.TestCase):
class TestDAQTimeslice(unittest.TestCase): class TestDAQTimeslice(unittest.TestCase):
def setUp(self): def setUp(self):
self.ts = DAQ_FILE.timeslices self.ts = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).timeslices
self.n_frames = {"L1": [69, 69, 69], "SN": [64, 66, 68]} self.n_frames = {"L1": [69, 69, 69], "SN": [64, 66, 68]}
def test_str(self): def test_str(self):
...@@ -153,7 +158,8 @@ class TestDAQTimeslice(unittest.TestCase): ...@@ -153,7 +158,8 @@ class TestDAQTimeslice(unittest.TestCase):
class TestSummaryslices(unittest.TestCase): class TestSummaryslices(unittest.TestCase):
def setUp(self): def setUp(self):
self.ss = DAQ_FILE.summaryslices self.ss = DAQReader(os.path.join(SAMPLES_DIR,
"daq_v1.0.0.root")).summaryslices
def test_headers(self): def test_headers(self):
assert 3 == len(self.ss.headers) assert 3 == len(self.ss.headers)
......
This diff is collapsed.
#!/usr/bin/env python3 #!/usr/bin/env python3
import unittest import unittest
from km3io.tools import _to_num, cached_property from km3io.tools import _to_num, cached_property, _unfold_indices
class TestToNum(unittest.TestCase): class TestToNum(unittest.TestCase):
def test_to_num(self): def test_to_num(self):
...@@ -19,3 +20,21 @@ class TestCachedProperty(unittest.TestCase): ...@@ -19,3 +20,21 @@ class TestCachedProperty(unittest.TestCase):
pass pass
self.assertTrue(isinstance(Test.prop, cached_property)) self.assertTrue(isinstance(Test.prop, cached_property))
class TestUnfoldIndices(unittest.TestCase):
def test_unfold_indices(self):
data = range(10)
indices = [slice(2, 5), 0]
assert data[indices[0]][indices[1]] == _unfold_indices(data, indices)
indices = [slice(1, 9, 2), slice(1, 4), 2]
assert data[indices[0]][indices[1]][indices[2]] == _unfold_indices(
data, indices)
def test_unfold_indices_raises_index_error(self):
data = range(10)
indices = [slice(2, 5), 99]
with self.assertRaises(IndexError):
_unfold_indices(data, indices)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment