Skip to content
Snippets Groups Projects
Commit 080d9d54 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Black

parent 6bdee7cb
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16337 failed
...@@ -87,7 +87,7 @@ class EventReader: ...@@ -87,7 +87,7 @@ class EventReader:
def _initialise_keys(self): def _initialise_keys(self):
skip_keys = set(self.skip_keys) skip_keys = set(self.skip_keys)
all_keys = set(self._fobj[self.event_path].keys()) all_keys = set(self._fobj[self.event_path].keys())
toplevel_keys = set(k.split("/")[0] for k in all_keys) toplevel_keys = set(k.split("/")[0] for k in all_keys)
valid_aliases = {} valid_aliases = {}
for fromkey, tokey in self.aliases.items(): for fromkey, tokey in self.aliases.items():
...@@ -169,9 +169,13 @@ class EventReader: ...@@ -169,9 +169,13 @@ class EventReader:
if from_field in branch[key].keys(): if from_field in branch[key].keys():
fields.append(to_field) fields.append(to_field)
log.debug(fields) log.debug(fields)
return Branch(branch[key], fields, self.nested_branches[key], self._index_chain) return Branch(
branch[key], fields, self.nested_branches[key], self._index_chain
)
else: else:
return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain) return unfold_indices(
branch[self.aliases.get(key, key)].array(), self._index_chain
)
def __iter__(self, chunkwise=False): def __iter__(self, chunkwise=False):
self._events = self._event_generator(chunkwise=chunkwise) self._events = self._event_generator(chunkwise=chunkwise)
...@@ -180,7 +184,9 @@ class EventReader: ...@@ -180,7 +184,9 @@ class EventReader:
def _get_iterator_limits(self): def _get_iterator_limits(self):
"""Determines start and stop, used for event iteration""" """Determines start and stop, used for event iteration"""
if len(self._index_chain) > 1: if len(self._index_chain) > 1:
raise NotImplementedError("iteration is currently not supported with nested slices") raise NotImplementedError(
"iteration is currently not supported with nested slices"
)
if self._index_chain: if self._index_chain:
s = self._index_chain[0] s = self._index_chain[0]
if not isinstance(s, slice): if not isinstance(s, slice):
...@@ -189,7 +195,9 @@ class EventReader: ...@@ -189,7 +195,9 @@ class EventReader:
start = s.start start = s.start
stop = s.stop stop = s.stop
else: else:
raise NotImplementedError("iteration is only supported with single steps") raise NotImplementedError(
"iteration is only supported with single steps"
)
else: else:
start = None start = None
stop = None stop = None
...@@ -218,7 +226,11 @@ class EventReader: ...@@ -218,7 +226,11 @@ class EventReader:
log.debug("keys: %s", keys) log.debug("keys: %s", keys)
log.debug("aliases: %s", self.aliases) log.debug("aliases: %s", self.aliases)
events_it = events.iterate( events_it = events.iterate(
keys, aliases=self.aliases, step_size=self._step_size, entry_start=start, entry_stop=stop keys,
aliases=self.aliases,
step_size=self._step_size,
entry_start=start,
entry_stop=stop,
) )
nested = [] nested = []
nested_keys = ( nested_keys = (
...@@ -232,7 +244,7 @@ class EventReader: ...@@ -232,7 +244,7 @@ class EventReader:
aliases=self.nested_branches[key], aliases=self.nested_branches[key],
step_size=self._step_size, step_size=self._step_size,
entry_start=start, entry_start=start,
entry_stop=stop entry_stop=stop,
) )
) )
group_counts = {} group_counts = {}
...@@ -301,6 +313,7 @@ class EventReader: ...@@ -301,6 +313,7 @@ class EventReader:
class Branch: class Branch:
"""Helper class for nested branches likes tracks/hits""" """Helper class for nested branches likes tracks/hits"""
def __init__(self, branch, fields, aliases, index_chain): def __init__(self, branch, fields, aliases, index_chain):
self._branch = branch self._branch = branch
self.fields = fields self.fields = fields
...@@ -309,7 +322,9 @@ class Branch: ...@@ -309,7 +322,9 @@ class Branch:
def __getattr__(self, attr): def __getattr__(self, attr):
if attr not in self._aliases: if attr not in self._aliases:
raise AttributeError(f"No field named {attr}. Available fields: {self.fields}") raise AttributeError(
f"No field named {attr}. Available fields: {self.fields}"
)
key = self._aliases[attr] key = self._aliases[attr]
if self._index_chain: if self._index_chain:
...@@ -318,7 +333,9 @@ class Branch: ...@@ -318,7 +333,9 @@ class Branch:
# optimise single-element and slice lookups # optimise single-element and slice lookups
start = idx0 start = idx0
stop = idx0 + 1 stop = idx0 + 1
arr = ak.flatten(self._branch[key].array(entry_start=start, entry_stop=stop)) arr = ak.flatten(
self._branch[key].array(entry_start=start, entry_stop=stop)
)
return unfold_indices(arr, self._index_chain[1:]) return unfold_indices(arr, self._index_chain[1:])
if isinstance(idx0, slice): if isinstance(idx0, slice):
if idx0.step is None or idx0.step == 1: if idx0.step is None or idx0.step == 1:
...@@ -330,7 +347,9 @@ class Branch: ...@@ -330,7 +347,9 @@ class Branch:
return unfold_indices(self._branch[key].array(), self._index_chain) return unfold_indices(self._branch[key].array(), self._index_chain)
def __getitem__(self, key): def __getitem__(self, key):
return self.__class__(self._branch, self.fields, self._aliases, self._index_chain + [key]) return self.__class__(
self._branch, self.fields, self._aliases, self._index_chain + [key]
)
def __len__(self): def __len__(self):
if not self._index_chain: if not self._index_chain:
......
...@@ -116,12 +116,14 @@ class TestGSGEvents(unittest.TestCase): ...@@ -116,12 +116,14 @@ class TestGSGEvents(unittest.TestCase):
self.assertListEqual(event.Id_tr.tolist(), [4, 5, 10, 11, 12]) self.assertListEqual(event.Id_tr.tolist(), [4, 5, 10, 11, 12])
self.assertListEqual(event.Pdg_tr.tolist(), [22, -13, 2112, -211, 111]) self.assertListEqual(event.Pdg_tr.tolist(), [22, -13, 2112, -211, 111])
[ [
self.assertAlmostEqual(x, y) for x, y in zip( self.assertAlmostEqual(x, y)
event.E_tr, for x, y in zip(
[0.00618, 4.88912206, 2.33667201, 1.0022909, 1.17186997]) event.E_tr, [0.00618, 4.88912206, 2.33667201, 1.0022909, 1.17186997]
)
] ]
[ [
self.assertAlmostEqual(x, y) for x, y in zip( self.assertAlmostEqual(x, y)
for x, y in zip(
event.Vx_tr, event.Vx_tr,
[ [
-337.67895799, -337.67895799,
...@@ -133,7 +135,8 @@ class TestGSGEvents(unittest.TestCase): ...@@ -133,7 +135,8 @@ class TestGSGEvents(unittest.TestCase):
) )
] ]
[ [
self.assertAlmostEqual(x, y) for x, y in zip( self.assertAlmostEqual(x, y)
for x, y in zip(
event.Vy_tr, event.Vy_tr,
[ [
-203.90999969, -203.90999969,
...@@ -145,36 +148,31 @@ class TestGSGEvents(unittest.TestCase): ...@@ -145,36 +148,31 @@ class TestGSGEvents(unittest.TestCase):
) )
] ]
[ [
self.assertAlmostEqual(x, y) for x, y in zip( self.assertAlmostEqual(x, y)
for x, y in zip(
event.Vz_tr, event.Vz_tr,
[ [416.08845294, 416.08845294, 416.08845294, 416.08845294, 416.08845294],
416.08845294, 416.08845294, 416.08845294, 416.08845294,
416.08845294
],
) )
] ]
[ [
self.assertAlmostEqual(x, y) for x, y in zip( self.assertAlmostEqual(x, y)
for x, y in zip(
event.Dx_tr, event.Dx_tr,
[ [0.06766196, -0.63563065, -0.70627586, -0.76364544, -0.80562216],
0.06766196, -0.63563065, -0.70627586, -0.76364544,
-0.80562216
],
) )
] ]
[ [
self.assertAlmostEqual(x, y) for x, y in zip( self.assertAlmostEqual(x, y)
for x, y in zip(
event.Dy_tr, event.Dy_tr,
[0.33938809, -0.4846643, 0.50569058, -0.04136113, 0.10913917], [0.33938809, -0.4846643, 0.50569058, -0.04136113, 0.10913917],
) )
] ]
[ [
self.assertAlmostEqual(x, y) for x, y in zip( self.assertAlmostEqual(x, y)
for x, y in zip(
event.Dz_tr, event.Dz_tr,
[ [-0.93820978, -0.6008945, -0.49543056, -0.64430963, -0.58228994],
-0.93820978, -0.6008945, -0.49543056, -0.64430963,
-0.58228994
],
) )
] ]
[self.assertAlmostEqual(x, y) for x, y in zip(event.T_tr, 5 * [0.0])] [self.assertAlmostEqual(x, y) for x, y in zip(event.T_tr, 5 * [0.0])]
...@@ -549,7 +549,6 @@ class TestIsCC(unittest.TestCase): ...@@ -549,7 +549,6 @@ class TestIsCC(unittest.TestCase):
class TestUsr(unittest.TestCase): class TestUsr(unittest.TestCase):
def test_event_usr(self): def test_event_usr(self):
assert np.allclose( assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543], [118.6302815337638, 44.33580521344907, 99.93916717621543],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment