Skip to content
Snippets Groups Projects
Commit 802be667 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Black and cleanup

parent 44bd42d3
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16255 failed
...@@ -79,7 +79,15 @@ class OfflineReader: ...@@ -79,7 +79,15 @@ class OfflineReader:
"mc_tracks": "mc_trks", "mc_tracks": "mc_trks",
} }
def __init__(self, f, index_chain=None, step_size=2000, keys=None, aliases=None, event_ctor=None): def __init__(
self,
f,
index_chain=None,
step_size=2000,
keys=None,
aliases=None,
event_ctor=None,
):
"""OfflineReader class is an offline ROOT file wrapper """OfflineReader class is an offline ROOT file wrapper
Parameters Parameters
...@@ -187,10 +195,12 @@ class OfflineReader: ...@@ -187,10 +195,12 @@ class OfflineReader:
step_size=self._step_size, step_size=self._step_size,
aliases=self.aliases, aliases=self.aliases,
keys=self.keys(), keys=self.keys(),
event_ctor=self._event_ctor event_ctor=self._event_ctor,
) )
if isinstance(key, str) and key.startswith("n_"): # group counts, for e.g. n_events, n_hits etc. if isinstance(key, str) and key.startswith(
"n_"
): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1]) key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4")) arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain) return unfold_indices(arr, self._index_chain)
...@@ -207,9 +217,7 @@ class OfflineReader: ...@@ -207,9 +217,7 @@ class OfflineReader:
if from_field in branch[key].keys(): if from_field in branch[key].keys():
fields.append(to_field) fields.append(to_field)
log.debug(fields) log.debug(fields)
out = branch[key].arrays( out = branch[key].arrays(fields, aliases=self.special_branches[key])
fields, aliases=self.special_branches[key]
)
else: else:
out = branch[self.aliases.get(key, key)].array() out = branch[self.aliases.get(key, key)].array()
...@@ -220,9 +228,57 @@ class OfflineReader: ...@@ -220,9 +228,57 @@ class OfflineReader:
return self return self
def _event_generator(self): def _event_generator(self):
for i in range(len(self)): events = self._fobj[self.event_path]
yield self[i] group_count_keys = set(
return k for k in self.keys() if k.startswith("n_")
) # special keys to make it easy to count subbranch lengths
log.debug("group_count_keys: %s", group_count_keys)
keys = set(
list(
set(self.keys())
- set(self.special_branches.keys())
- set(self.special_aliases)
- group_count_keys
)
+ list(self.aliases.keys())
) # all top-level keys for regular branches
log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases)
events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size
)
specials = []
special_keys = (
self.special_branches.keys()
) # dict-key ordering is an implementation detail
log.debug("special_keys: %s", special_keys)
for key in special_keys:
# print(f"adding {key} with keys {self.special_branches[key].keys()} and aliases {self.special_branches[key]}")
specials.append(
events[key].iterate(
self.special_branches[key].keys(),
aliases=self.special_branches[key],
step_size=self._step_size,
)
)
group_counts = {}
for key in group_count_keys:
group_counts[key] = iter(self[key])
log.debug("group_counts: %s", group_counts)
for event_set, *special_sets in zip(events_it, *specials):
for _event, *special_items in zip(event_set, *special_sets):
data = {}
for k in keys:
data[k] = _event[k]
for (k, i) in zip(special_keys, special_items):
data[k] = i
for tokey, fromkey in self.special_aliases.items():
data[tokey] = data[fromkey]
for key in group_counts:
data[key] = next(group_counts[key])
yield self._event_ctor(**data)
def __next__(self): def __next__(self):
return next(self._events) return next(self._events)
...@@ -246,7 +302,6 @@ class OfflineReader: ...@@ -246,7 +302,6 @@ class OfflineReader:
"""The raw number of events without any indexing/slicing magic""" """The raw number of events without any indexing/slicing magic"""
return len(self._fobj[self.event_path]["id"].array()) return len(self._fobj[self.event_path]["id"].array())
def __repr__(self): def __repr__(self):
length = len(self) length = len(self)
actual_length = self.__actual_len__() actual_length = self.__actual_len__()
......
...@@ -275,7 +275,7 @@ def best_track(tracks, startend=None, minmax=None, stages=None): ...@@ -275,7 +275,7 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
tracks = tracks[m1] tracks = tracks[m1]
rec_stage_lengths = ak.num(tracks.rec_stages, axis=axis+1) rec_stage_lengths = ak.num(tracks.rec_stages, axis=axis + 1)
max_rec_stage_length = ak.max(rec_stage_lengths, axis=axis) max_rec_stage_length = ak.max(rec_stage_lengths, axis=axis)
m2 = rec_stage_lengths == max_rec_stage_length m2 = rec_stage_lengths == max_rec_stage_length
tracks = tracks[m2] tracks = tracks[m2]
...@@ -284,7 +284,9 @@ def best_track(tracks, startend=None, minmax=None, stages=None): ...@@ -284,7 +284,9 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
out = tracks[m3] out = tracks[m3]
if isinstance(out, ak.highlevel.Record): if isinstance(out, ak.highlevel.Record):
return namedtuple("BestTrack", out.fields)(*[getattr(out, a)[0] for a in out.fields]) return namedtuple("BestTrack", out.fields)(
*[getattr(out, a)[0] for a in out.fields]
)
return out return out
...@@ -308,20 +310,22 @@ def mask(arr, sequence=None, startend=None, minmax=None, atleast=None): ...@@ -308,20 +310,22 @@ def mask(arr, sequence=None, startend=None, minmax=None, atleast=None):
inputs = (sequence, startend, minmax, atleast) inputs = (sequence, startend, minmax, atleast)
if all(v is None for v in inputs): if all(v is None for v in inputs):
raise ValueError("either sequence, startend, minmax or atleast must be specified.") raise ValueError(
"either sequence, startend, minmax or atleast must be specified."
)
builder = ak.ArrayBuilder() builder = ak.ArrayBuilder()
_mask(arr, builder, sequence, startend, minmax, atleast) _mask(arr, builder, sequence, startend, minmax, atleast)
return builder.snapshot() return builder.snapshot()
#nb.njit # TODO: not supported in awkward yet
# see https://github.com/scikit-hep/awkward-1.0/issues/572 # @nb.njit
def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None): def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None):
if arr.ndim == 2: # recursion stop if arr.ndim == 2: # recursion stop
if startend is not None: if startend is not None:
start, end = startend start, end = startend
for els in arr: for els in arr:
if ak.count(els) > 0 and els[0] == start and els[-1] == end: if len(els) > 0 and els[0] == start and els[-1] == end:
builder.boolean(True) builder.boolean(True)
else: else:
builder.boolean(False) builder.boolean(False)
...@@ -362,7 +366,6 @@ def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None) ...@@ -362,7 +366,6 @@ def _mask(arr, builder, sequence=None, startend=None, minmax=None, atleast=None)
builder.end_list() builder.end_list()
def best_jmuon(tracks): def best_jmuon(tracks):
"""Select the best JMUON track.""" """Select the best JMUON track."""
return best_track(tracks, minmax=(krec.JMUONBEGIN, krec.JMUONEND)) return best_track(tracks, minmax=(krec.JMUONBEGIN, krec.JMUONEND))
......
...@@ -183,17 +183,13 @@ class TestOfflineEvents(unittest.TestCase): ...@@ -183,17 +183,13 @@ class TestOfflineEvents(unittest.TestCase):
def test_index_consistency(self): def test_index_consistency(self):
for i in [0, 2, 5]: for i in [0, 2, 5]:
assert np.allclose( assert np.allclose(self.events[i].n_hits, self.events.n_hits[i])
self.events[i].n_hits, self.events.n_hits[i]
)
def test_index_chaining(self): def test_index_chaining(self):
assert np.allclose( assert np.allclose(
self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist() self.events[3:5].n_hits.tolist(), self.events.n_hits[3:5].tolist()
) )
assert np.allclose( assert np.allclose(self.events[3:5][0].n_hits, self.events.n_hits[3:5][0])
self.events[3:5][0].n_hits, self.events.n_hits[3:5][0]
)
@unittest.skip @unittest.skip
def test_index_chaining_on_nested_branches_aka_records(self): def test_index_chaining_on_nested_branches_aka_records(self):
...@@ -361,7 +357,22 @@ class TestOfflineTracks(unittest.TestCase): ...@@ -361,7 +357,22 @@ class TestOfflineTracks(unittest.TestCase):
self.n_events = 10 self.n_events = 10
def test_fields(self): def test_fields(self):
for field in ['id', 'pos_x', 'pos_y', 'pos_z', 'dir_x', 'dir_y', 'dir_z', 't', 'E', 'len', 'lik', 'rec_type', 'rec_stages', 'fitinf']: for field in [
"id",
"pos_x",
"pos_y",
"pos_z",
"dir_x",
"dir_y",
"dir_z",
"t",
"E",
"len",
"lik",
"rec_type",
"rec_stages",
"fitinf",
]:
getattr(self.tracks, field) getattr(self.tracks, field)
def test_item_selection(self): def test_item_selection(self):
......
...@@ -230,7 +230,6 @@ class TestBestTrackSelection(unittest.TestCase): ...@@ -230,7 +230,6 @@ class TestBestTrackSelection(unittest.TestCase):
assert len(best) == 5 assert len(best) == 5
import pdb; pdb.set_trace()
assert best.lik == ak.max(tracks_slice.lik) assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4] assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
...@@ -419,6 +418,27 @@ class TestRecStagesMasks(unittest.TestCase): ...@@ -419,6 +418,27 @@ class TestRecStagesMasks(unittest.TestCase):
mask(self.tracks) mask(self.tracks)
class TestMask(unittest.TestCase):
def test_minmax_2dim_mask(self):
arr = ak.Array([[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]])
m = mask(arr, minmax=(1, 4))
self.assertListEqual(m.tolist(), [True, False, False])
def test_minmax_3dim_mask(self):
arr = ak.Array([[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]])
m = mask(arr, minmax=(1, 4))
self.assertListEqual(m.tolist(), [[True, False, False], [True]])
def test_minmax_4dim_mask(self):
arr = ak.Array(
[[[[1, 2, 3, 4], [3, 4, 5], [1, 2, 5]], [[1, 2, 3]]], [[[1, 9], [3, 3]]]]
)
m = mask(arr, minmax=(1, 4))
self.assertListEqual(
m.tolist(), [[[True, False, False], [True]], [[False, True]]]
)
class TestUnique(unittest.TestCase): class TestUnique(unittest.TestCase):
def run_random_test_with_dtype(self, dtype): def run_random_test_with_dtype(self, dtype):
max_range = 100 max_range = 100
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment