Skip to content
Snippets Groups Projects
Commit 20fd4765 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Getting ready

parent ebc24b5c
No related branches found
No related tags found
1 merge request!39WIP: Resolve "uproot4 integration"
Pipeline #16161 failed
from collections import namedtuple
import uproot4 as uproot
import warnings
import uproot4 as uproot
import numpy as np
import awkward1 as ak
from .definitions import mc_header
from .tools import cached_property, to_num
from .tools import cached_property, to_num, unfold_indices
class OfflineReader:
......@@ -70,46 +72,69 @@ class OfflineReader:
"mc_tracks": "mc_trks",
}
def __init__(self, file_path, step_size=2000):
def __init__(self, f, index_chain=None, step_size=2000, keys=None, aliases=None, event_ctor=None):
"""OfflineReader class is an offline ROOT file wrapper
Parameters
----------
file_path : path-like object
Path to the file of interest. It can be a str or any python
path-like object that points to the file.
f: str or uproot4.reading.ReadOnlyDirectory (from uproot4.open)
Path to the file of interest or uproot4 filedescriptor.
step_size: int, optional
Number of events to read into the cache when iterating.
Choosing higher numbers may improve the speed but also increases
the memory overhead.
index_chain: list, optional
Keeps track of index chaining.
keys: list or set, optional
Branch keys.
aliases: dict, optional
Branch key aliases.
event_ctor: class or namedtuple, optional
Event constructor.
"""
self._fobj = uproot.open(file_path)
self.step_size = step_size
self._filename = file_path
if isinstance(f, str):
self._fobj = uproot.open(f)
self._filepath = f
elif isinstance(f, uproot.reading.ReadOnlyDirectory):
self._fobj = f
self._filepath = f._file.file_path
else:
raise TypeError("Unsupported file descriptor.")
self._step_size = step_size
self._uuid = self._fobj._file.uuid
self._iterator_index = 0
self._keys = None
self._grouped_counts = {} # TODO: e.g. {"events": [3, 66, 34]}
if "E/Evt/AAObject/usr" in self._fobj:
if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
self.aliases.update({
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
})
self._initialise_keys()
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
self._keys = keys
self._event_ctor = event_ctor
self._index_chain = [] if index_chain is None else index_chain
if aliases is not None:
self.aliases = aliases
else:
# Check for usr-awesomeness backward compatibility crap
print("Found usr data")
if "E/Evt/AAObject/usr" in self._fobj:
if ak.count(f["E/Evt/AAObject/usr"].array()) > 0:
self.aliases.update(
{
"usr": "AAObject/usr",
"usr_names": "AAObject/usr_names",
}
)
if self._keys is None:
self._initialise_keys()
if self._event_ctor is None:
self._event_ctor = namedtuple(
self.item_name,
set(
list(self.keys())
+ list(self.aliases)
+ list(self.special_branches)
+ list(self.special_aliases)
),
)
def _initialise_keys(self):
skip_keys = set(self.skip_keys)
......@@ -144,9 +169,23 @@ class OfflineReader:
)
def __getitem__(self, key):
if key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc.
# indexing
if isinstance(key, (slice, int, np.int32, np.int64)):
if not isinstance(key, slice):
key = int(key)
return self.__class__(
self._fobj,
index_chain=self._index_chain + [key],
step_size=self._step_size,
aliases=self.aliases,
keys=self.keys(),
event_ctor=self._event_ctor
)
if isinstance(key, str) and key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1])
return self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain)
key = self._keyfor(key)
branch = self._fobj[self.event_path]
......@@ -154,10 +193,13 @@ class OfflineReader:
# We are explicitly grabbing just a predefined set of subbranches
# and also alias them to be backwards compatible (and attribute-accessible)
if key in self.special_branches:
return branch[key].arrays(
out = branch[key].arrays(
self.special_branches[key].keys(), aliases=self.special_branches[key]
)
return branch[self.aliases.get(key, key)].array()
else:
out = branch[self.aliases.get(key, key)].array()
return unfold_indices(out, self._index_chain)
def __iter__(self):
self._iterator_index = 0
......@@ -167,13 +209,18 @@ class OfflineReader:
def _event_generator(self):
events = self._fobj[self.event_path]
group_count_keys = set(k for k in self.keys() if k.startswith("n_"))
keys = set(list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
) + list(self.aliases.keys()))
events_it = events.iterate(keys, aliases=self.aliases, step_size=self.step_size)
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
......@@ -183,7 +230,7 @@ class OfflineReader:
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self.step_size,
step_size=self._step_size,
)
)
group_counts = {}
......@@ -206,7 +253,29 @@ class OfflineReader:
return next(self._events)
def __len__(self):
return self._fobj[self.event_path].num_entries
if not self._index_chain:
return self._fobj[self.event_path].num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(self._fobj[self.event_path]["id"].array(), self._index_chain)
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array())
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} events)"
@property
def uuid(self):
......
......@@ -149,12 +149,6 @@ class TestOfflineEvents(unittest.TestCase):
def test_len(self):
assert self.n_events == len(self.events)
@unittest.skip
def test_attributes_available(self):
for key in self.events._keymap.keys():
print(f"checking {key}")
getattr(self.events, key)
def test_attributes(self):
assert self.n_events == len(self.events.det_id)
self.assertListEqual(self.det_id, list(self.events.det_id))
......@@ -165,7 +159,6 @@ class TestOfflineEvents(unittest.TestCase):
self.assertListEqual(self.t_sec, list(self.events.t_sec))
self.assertListEqual(self.t_ns, list(self.events.t_ns))
@unittest.skip
def test_keys(self):
assert np.allclose(self.n_hits, self.events["n_hits"].tolist())
assert np.allclose(self.n_tracks, self.events["n_tracks"].tolist())
......@@ -182,38 +175,37 @@ class TestOfflineEvents(unittest.TestCase):
self.assertListEqual(self.t_sec[s], list(s_events.t_sec))
self.assertListEqual(self.t_ns[s], list(s_events.t_ns))
@unittest.skip
def test_slicing_consistency(self):
for s in [slice(1, 3), slice(2, 7, 3)]:
assert np.allclose(
self.events[s].n_hits.tolist(), self.events.n_hits[s].tolist()
)
@unittest.skip
def test_index_consistency(self):
for i in [0, 2, 5]:
assert np.allclose(
self.events[i].n_hits.tolist(), self.events.n_hits[i].tolist()
self.events[i].n_hits, self.events.n_hits[i]
)
@unittest.skip
def test_index_chaining(self):
assert np.allclose(
self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist()
)
assert np.allclose(
self.events[3:5][0].n_hits.tolist(), self.events.n_hits[3:5][0].tolist()
self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]
)
@unittest.skip
def test_index_chaining_on_nested_branches_aka_records(self):
assert np.allclose(
self.events[3:5].hits[1].dom_id[4].tolist(),
self.events.hits[3:5][1][4].dom_id.tolist(),
self.events[3:5].hits[1].dom_id[4],
self.events.hits[3:5][1][4].dom_id,
)
assert np.allclose(
self.events.hits[3:5][1][4].dom_id.tolist(),
self.events[3:5][1][4].hits.dom_id.tolist(),
)
@unittest.skip
def test_fancy_indexing(self):
mask = self.events.n_tracks > 55
tracks = self.events.tracks[mask]
......@@ -305,9 +297,6 @@ class TestOfflineHits(unittest.TestCase):
self.assertTrue(all(c >= 0 for c in ak.min(self.hits.channel_id, axis=1)))
self.assertTrue(all(c < 31 for c in ak.max(self.hits.channel_id, axis=1)))
def test_str(self):
assert str(self.n_hits) in str(self.hits)
def test_repr(self):
assert str(self.n_hits) in repr(self.hits)
......@@ -344,19 +333,24 @@ class TestOfflineHits(unittest.TestCase):
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.dom_id[: self.n_hits].tolist(),
dom_ids[: self.n_hits].tolist(),
dom_ids[: self.n_hits],
)
for idx, ts in self.t.items():
assert np.allclose(
self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits].tolist()
self.hits[idx].t[: self.n_hits].tolist(), ts[: self.n_hits]
)
assert np.allclose(
OFFLINE_FILE.events[idx].hits.t[: self.n_hits].tolist(),
ts[: self.n_hits].tolist(),
ts[: self.n_hits],
)
def test_keys(self):
assert "dom_id" in self.hits.keys()
def test_fields(self):
assert "dom_id" in self.hits.fields
assert "channel_id" in self.hits.fields
assert "t" in self.hits.fields
assert "tot" in self.hits.fields
assert "trig" in self.hits.fields
assert "id" in self.hits.fields
class TestOfflineTracks(unittest.TestCase):
......@@ -366,9 +360,9 @@ class TestOfflineTracks(unittest.TestCase):
self.tracks_numucc = OFFLINE_NUMUCC
self.n_events = 10
def test_attributes_available(self):
for key in self.tracks._keymap.keys():
getattr(self.tracks, key)
def test_fields(self):
for field in ['id', 'pos_x', 'pos_y', 'pos_z', 'dir_x', 'dir_y', 'dir_z', 't', 'E', 'len', 'lik', 'rec_type', 'rec_stages', 'fitinf']:
getattr(self.tracks, field)
@unittest.skip
def test_attributes(self):
......@@ -383,8 +377,9 @@ class TestOfflineTracks(unittest.TestCase):
)
def test_repr(self):
assert " 10 " in repr(self.tracks)
assert "10 * " in repr(self.tracks)
@unittest.skip
def test_slicing(self):
tracks = self.tracks
self.assertEqual(10, len(tracks)) # 10 events
......@@ -404,6 +399,7 @@ class TestOfflineTracks(unittest.TestCase):
list(tracks.E[:, 0][_slice]), list(tracks[_slice].E[:, 0])
)
@unittest.skip
def test_nested_indexing(self):
self.assertAlmostEqual(
self.f.events.tracks.fitinf[3:5][1][9][2],
......@@ -427,7 +423,7 @@ class TestBranchIndexingMagic(unittest.TestCase):
def setUp(self):
self.events = OFFLINE_FILE.events
def test_foo(self):
def test_slicing_magic(self):
self.assertEqual(318, self.events[2:4].n_hits[0])
assert np.allclose(
self.events[3].tracks.dir_z[10], self.events.tracks.dir_z[3, 10]
......@@ -437,6 +433,8 @@ class TestBranchIndexingMagic(unittest.TestCase):
self.events.tracks.pos_y[3:6, 0].tolist(),
)
@unittest.skip
def test_selecting_specific_items_via_a_list(self):
# test selecting with a list
self.assertEqual(3, len(self.events[[0, 2, 3]]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment