Skip to content
Snippets Groups Projects
Commit bd53990f authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Introduce the Branch class to wrap awkward.Arrays

parent a56ed3af
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16318 failed
...@@ -143,10 +143,8 @@ class EventReader: ...@@ -143,10 +143,8 @@ class EventReader:
keys=self.keys(), keys=self.keys(),
event_ctor=self._event_ctor, event_ctor=self._event_ctor,
) )
# group counts, for e.g. n_events, n_hits etc.
if isinstance(key, str) and key.startswith( if isinstance(key, str) and key.startswith("n_"):
"n_"
): # group counts, for e.g. n_events, n_hits etc.
key = self._keyfor(key.split("n_")[1]) key = self._keyfor(key.split("n_")[1])
arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4")) arr = self._fobj[self.event_path][key].array(uproot.AsDtype(">i4"))
return unfold_indices(arr, self._index_chain) return unfold_indices(arr, self._index_chain)
...@@ -163,11 +161,10 @@ class EventReader: ...@@ -163,11 +161,10 @@ class EventReader:
if from_field in branch[key].keys(): if from_field in branch[key].keys():
fields.append(to_field) fields.append(to_field)
log.debug(fields) log.debug(fields)
out = branch[key].arrays(fields, aliases=self.special_branches[key]) # out = branch[key].arrays(fields, aliases=self.special_branches[key])
return Branch(branch[key], fields, self.special_branches[key], self._index_chain)
else: else:
out = branch[self.aliases.get(key, key)].array() return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain)
return unfold_indices(out, self._index_chain)
def __iter__(self): def __iter__(self):
self._events = self._event_generator() self._events = self._event_generator()
...@@ -269,3 +266,52 @@ class EventReader: ...@@ -269,3 +266,52 @@ class EventReader:
def __exit__(self, *args): def __exit__(self, *args):
self.close() self.close()
class Branch:
"""Helper class for nested branches likes tracks/hits"""
def __init__(self, branch, fields, aliases, index_chain):
self._branch = branch
self.fields = fields
self._aliases = aliases
self._index_chain = index_chain
def __getattr__(self, attr):
if attr not in self._aliases:
raise AttributeError(f"No field named {attr}. Available fields: {self.fields}")
return unfold_indices(self._branch[self._aliases[attr]].array(), self._index_chain)
def __getitem__(self, key):
return self.__class__(self._branch, self.fields, self._aliases, self._index_chain + [key])
def __len__(self):
if not self._index_chain:
return self._branch.num_entries
elif isinstance(self._index_chain[-1], (int, np.int32, np.int64)):
if len(self._index_chain) == 1:
return 1
# try:
# return len(self[:])
# except IndexError:
# return 1
return 1
else:
# ignore the usual index magic and access `id` directly
return len(self.id)
def __actual_len__(self):
"""The raw number of events without any indexing/slicing magic"""
return len(self._branch[self._aliases["id"]].array())
def __repr__(self):
length = len(self)
actual_length = self.__actual_len__()
return f"{self.__class__.__name__} ({length}{'/' + str(actual_length) if length < actual_length else ''} {self._branch.name})"
@property
def ndim(self):
if not self._index_chain:
return 2
elif any(isinstance(i, (int, np.int32, np.int64)) for i in self._index_chain):
return 1
return 2
...@@ -209,12 +209,7 @@ def get_multiplicity(tracks, rec_stages): ...@@ -209,12 +209,7 @@ def get_multiplicity(tracks, rec_stages):
""" """
masked_tracks = tracks[mask(tracks.rec_stages, sequence=rec_stages)] masked_tracks = tracks[mask(tracks.rec_stages, sequence=rec_stages)]
try: out = count_nested(masked_tracks.rec_stages, axis=tracks.ndim - 1)
axis = tracks.ndim
except AttributeError:
axis = 0
out = count_nested(masked_tracks.rec_stages, axis=axis)
return out return out
...@@ -269,10 +264,8 @@ def best_track(tracks, startend=None, minmax=None, stages=None): ...@@ -269,10 +264,8 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
if minmax is not None: if minmax is not None:
m1 = mask(tracks.rec_stages, minmax=minmax) m1 = mask(tracks.rec_stages, minmax=minmax)
try: original_ndim = tracks.ndim
axis = tracks.ndim axis = 1 if original_ndim == 2 else 0
except AttributeError:
axis = 0
tracks = tracks[m1] tracks = tracks[m1]
...@@ -284,10 +277,8 @@ def best_track(tracks, startend=None, minmax=None, stages=None): ...@@ -284,10 +277,8 @@ def best_track(tracks, startend=None, minmax=None, stages=None):
m3 = ak.argmax(tracks.lik, axis=axis, keepdims=True) m3 = ak.argmax(tracks.lik, axis=axis, keepdims=True)
out = tracks[m3] out = tracks[m3]
if isinstance(out, ak.highlevel.Record): if original_ndim == 1:
return namedtuple("BestTrack", out.fields)( return out[0]
*[getattr(out, a)[0] for a in out.fields]
)
return out[:, 0] return out[:, 0]
......
...@@ -381,7 +381,7 @@ class TestOfflineTracks(unittest.TestCase): ...@@ -381,7 +381,7 @@ class TestOfflineTracks(unittest.TestCase):
) )
def test_repr(self): def test_repr(self):
assert "10 * " in repr(self.tracks) assert "10" in repr(self.tracks)
def test_slicing(self): def test_slicing(self):
tracks = self.tracks tracks = self.tracks
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment