diff --git a/km3io/rootio.py b/km3io/rootio.py index 46c599f4ae86aa2277b8b567c39dba71b829d5d9..ccabe0d0b5f8f246e939b45484eea081c3db7023 100644 --- a/km3io/rootio.py +++ b/km3io/rootio.py @@ -87,7 +87,7 @@ class EventReader: def _initialise_keys(self): skip_keys = set(self.skip_keys) - all_keys = set(self._fobj[self.event_path].keys()) + all_keys = set(self._fobj[self.event_path].keys()) toplevel_keys = set(k.split("/")[0] for k in all_keys) valid_aliases = {} for fromkey, tokey in self.aliases.items(): @@ -169,9 +169,13 @@ class EventReader: if from_field in branch[key].keys(): fields.append(to_field) log.debug(fields) - return Branch(branch[key], fields, self.nested_branches[key], self._index_chain) + return Branch( + branch[key], fields, self.nested_branches[key], self._index_chain + ) else: - return unfold_indices(branch[self.aliases.get(key, key)].array(), self._index_chain) + return unfold_indices( + branch[self.aliases.get(key, key)].array(), self._index_chain + ) def __iter__(self, chunkwise=False): self._events = self._event_generator(chunkwise=chunkwise) @@ -180,7 +184,9 @@ class EventReader: def _get_iterator_limits(self): """Determines start and stop, used for event iteration""" if len(self._index_chain) > 1: - raise NotImplementedError("iteration is currently not supported with nested slices") + raise NotImplementedError( + "iteration is currently not supported with nested slices" + ) if self._index_chain: s = self._index_chain[0] if not isinstance(s, slice): @@ -189,7 +195,9 @@ class EventReader: start = s.start stop = s.stop else: - raise NotImplementedError("iteration is only supported with single steps") + raise NotImplementedError( + "iteration is only supported with single steps" + ) else: start = None stop = None @@ -218,7 +226,11 @@ class EventReader: log.debug("keys: %s", keys) log.debug("aliases: %s", self.aliases) events_it = events.iterate( - keys, aliases=self.aliases, step_size=self._step_size, entry_start=start, entry_stop=stop + keys, + aliases=self.aliases, + step_size=self._step_size, + entry_start=start, + entry_stop=stop, ) nested = [] nested_keys = ( @@ -232,7 +244,7 @@ class EventReader: aliases=self.nested_branches[key], step_size=self._step_size, entry_start=start, - entry_stop=stop + entry_stop=stop, ) ) group_counts = {} @@ -301,6 +313,7 @@ class EventReader: class Branch: """Helper class for nested branches likes tracks/hits""" + def __init__(self, branch, fields, aliases, index_chain): self._branch = branch self.fields = fields @@ -309,7 +322,9 @@ class Branch: def __getattr__(self, attr): if attr not in self._aliases: - raise AttributeError(f"No field named {attr}. Available fields: {self.fields}") + raise AttributeError( + f"No field named {attr}. Available fields: {self.fields}" + ) key = self._aliases[attr] if self._index_chain: @@ -318,7 +333,9 @@ class Branch: # optimise single-element and slice lookups start = idx0 stop = idx0 + 1 - arr = ak.flatten(self._branch[key].array(entry_start=start, entry_stop=stop)) + arr = ak.flatten( + self._branch[key].array(entry_start=start, entry_stop=stop) + ) return unfold_indices(arr, self._index_chain[1:]) if isinstance(idx0, slice): if idx0.step is None or idx0.step == 1: @@ -330,7 +347,9 @@ class Branch: return unfold_indices(self._branch[key].array(), self._index_chain) def __getitem__(self, key): - return self.__class__(self._branch, self.fields, self._aliases, self._index_chain + [key]) + return self.__class__( + self._branch, self.fields, self._aliases, self._index_chain + [key] + ) def __len__(self): if not self._index_chain: diff --git a/tests/test_gseagen.py b/tests/test_gseagen.py index f87fda3e512f818e288aa7194367de79b26eb72e..8cb0073ebc97bc001775a0a4a62cdb0b30a936e8 100644 --- a/tests/test_gseagen.py +++ b/tests/test_gseagen.py @@ -116,12 +116,14 @@ class TestGSGEvents(unittest.TestCase): self.assertListEqual(event.Id_tr.tolist(), [4, 5, 10, 11, 12]) self.assertListEqual(event.Pdg_tr.tolist(), [22, -13, 2112, -211, 111]) [ - self.assertAlmostEqual(x, y) for x, y in zip( - event.E_tr, - [0.00618, 4.88912206, 2.33667201, 1.0022909, 1.17186997]) + self.assertAlmostEqual(x, y) + for x, y in zip( + event.E_tr, [0.00618, 4.88912206, 2.33667201, 1.0022909, 1.17186997] + ) ] [ - self.assertAlmostEqual(x, y) for x, y in zip( + self.assertAlmostEqual(x, y) + for x, y in zip( event.Vx_tr, [ -337.67895799, @@ -133,7 +135,8 @@ class TestGSGEvents(unittest.TestCase): ) ] [ - self.assertAlmostEqual(x, y) for x, y in zip( + self.assertAlmostEqual(x, y) + for x, y in zip( event.Vy_tr, [ -203.90999969, @@ -145,36 +148,31 @@ class TestGSGEvents(unittest.TestCase): ) ] [ - self.assertAlmostEqual(x, y) for x, y in zip( + self.assertAlmostEqual(x, y) + for x, y in zip( event.Vz_tr, - [ - 416.08845294, 416.08845294, 416.08845294, 416.08845294, - 416.08845294 - ], + [416.08845294, 416.08845294, 416.08845294, 416.08845294, 416.08845294], ) ] [ - self.assertAlmostEqual(x, y) for x, y in zip( + self.assertAlmostEqual(x, y) + for x, y in zip( event.Dx_tr, - [ - 0.06766196, -0.63563065, -0.70627586, -0.76364544, - -0.80562216 - ], + [0.06766196, -0.63563065, -0.70627586, -0.76364544, -0.80562216], ) ] [ - self.assertAlmostEqual(x, y) for x, y in zip( + self.assertAlmostEqual(x, y) + for x, y in zip( event.Dy_tr, [0.33938809, -0.4846643, 0.50569058, -0.04136113, 0.10913917], ) ] [ - self.assertAlmostEqual(x, y) for x, y in zip( + self.assertAlmostEqual(x, y) + for x, y in zip( event.Dz_tr, - [ - -0.93820978, -0.6008945, -0.49543056, -0.64430963, - -0.58228994 - ], + [-0.93820978, -0.6008945, -0.49543056, -0.64430963, -0.58228994], ) ] [self.assertAlmostEqual(x, y) for x, y in zip(event.T_tr, 5 * [0.0])] diff --git a/tests/test_tools.py b/tests/test_tools.py index 43d12ef9b9939c9c94d33832835f637fd6168fe6..f7aa9de8cc5a2003138525d1b6da8f9a375a683c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -549,7 +549,6 @@ class TestIsCC(unittest.TestCase): class TestUsr(unittest.TestCase): - def test_event_usr(self): assert np.allclose( [118.6302815337638, 44.33580521344907, 99.93916717621543],