diff --git a/README.rst b/README.rst index 990f6043bb171893044f839e1cc9fdf6ece2c0bc..f64e88fa6696932573258a56df31aced38b79729 100644 --- a/README.rst +++ b/README.rst @@ -928,6 +928,34 @@ to get the reconstruction data of interest, for example ['JENERGY_ENERGY']: 208.6103912 , 1336.52338666, 998.87632267, 1206.54345674, 16.28973662]) +to get a dictionary of the corresponding hits data (for example dom ids and hits ids) + +.. code-block:: python3 + + >>> r.get_reco_hits([1,2,3,4,5], ["dom_id", "id"])) + {'dom_id': <ChunkedArray [[102 102 102 ... 11517 11518 11518] [101 101 101 ... 11517 11518 11518] [101 101 102 ... 11518 11518 11518] [101 102 102 ... 11516 11517 11518] [101 101 102 ... 11517 11518 11518] [101 101 102 ... 11517 11517 11518] [101 101 102 ... 11516 11516 11517] ...] at 0x7f553ab7f3d0>, + 'id': <ChunkedArray [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ...] at 0x7f553ab7f890>} + +to get a dictionary of the corresponding tracks data (for example position x and y) + +.. code-block:: python3 + + >>> r.get_reco_tracks([1, 2, 3, 4, 5], ["pos_x", "pos_y"]) + + {'pos_x': array([-647.39638136, 448.98490051, 451.12336854, 174.23666051,207.24223984, -460.75770881, -522.58197621, 324.16230509, + -436.2319534 ]), + 'pos_y': array([-138.62068609, 77.58887593, 251.08805881, -114.60614519, 143.61947974, 86.85012087, -263.14983599, -203.14263572, + 467.75113594])} + +to get a dictionary of the corresponding events data (for example det_id and run_id) + +.. code-block:: python3 + + >>> r.get_reco_events([1, 2, 3, 4, 5], ["run_id", "det_id"]) + + {'run_id': <ChunkedArray [1 1 1 1 1 1 1 ...] at 0x7f553b5b2710>, + 'det_id': <ChunkedArray [20 20 20 20 20 20 20 ...] at 0x7f5558030750>} + **Note**: When the reconstruction stages of interest are not found in all your data file, an error is raised. diff --git a/km3io/offline.py b/km3io/offline.py index a9326d901a538b83b7d8322a7fc7e63e251d48ec..4677316320d9bdcdf130cc32ec921438204af46f 100644 --- a/km3io/offline.py +++ b/km3io/offline.py @@ -313,7 +313,8 @@ class Reader: do not contain data. Therefore, the keys corresponding to these fake branches are not read. """ - if key not in self.keys.valid_keys and not isinstance(key, int): + keys = self.keys.valid_keys + if key not in keys and not isinstance(key, int): raise KeyError( "'{}' is not a valid key or is a fake branch.".format(key)) return self._data[key] @@ -577,6 +578,127 @@ class OfflineReader: names=keys) return rec_array + def get_reco_hits(self, stages, keys): + """construct a dictionary of hits class data based on the reconstruction + stages of interest. For example, if the reconstruction stages of interest + are [1, 2, 3, 4, 5], then get_reco_hits method will select the hits data + from the events that were reconstructed following these stages (i.e + [1, 2, 3, 4, 5]). + + Parameters + ---------- + stages : list + list of reconstruction stages of interest. for example + [1, 2, 3, 4, 5]. + keys : list of str + list of the hits class attributes. + + Returns + ------- + dict + dictionary of lazyarrays containing data for each hits attribute requested. + + Raises + ------ + ValueError + ValueError raised when the reconstruction stages of interest + are not found in the file. + """ + lazy_d = {} + rec_stages = np.array( + [match for match in self._find_rec_stages(stages)]) + mask = rec_stages[:, 1] != None + if np.all(rec_stages[:, 1] == None): + raise ValueError( + "The stages {} are not found in your file.".format( + str(stages))) + else: + for key in keys: + lazy_d[key] = getattr(self.hits, key)[mask] + return lazy_d + + def get_reco_events(self, stages, keys): + """construct a dictionary of events class data based on the reconstruction + stages of interest. For example, if the reconstruction stages of interest + are [1, 2, 3, 4, 5], then get_reco_events method will select the events data + that were reconstructed following these stages (i.e [1, 2, 3, 4, 5]). + + Parameters + ---------- + stages : list + list of reconstruction stages of interest. for example + [1, 2, 3, 4, 5]. + keys : list of str + list of the events class attributes. + + Returns + ------- + dict + dictionary of lazyarrays containing data for each events attribute requested. + + Raises + ------ + ValueError + ValueError raised when the reconstruction stages of interest + are not found in the file. + """ + lazy_d = {} + rec_stages = np.array( + [match for match in self._find_rec_stages(stages)]) + mask = rec_stages[:, 1] != None + if np.all(rec_stages[:, 1] == None): + raise ValueError( + "The stages {} are not found in your file.".format( + str(stages))) + else: + for key in keys: + lazy_d[key] = getattr(self.events, key)[mask] + return lazy_d + + def get_reco_tracks(self, stages, keys): + """construct a dictionary of tracks class data based on the reconstruction + stages of interest. For example, if the reconstruction stages of interest + are [1, 2, 3, 4, 5], then get_reco_tracks method will select tracks data + from the events that were reconstructed following these stages (i.e + [1, 2, 3, 4, 5]). + + Parameters + ---------- + stages : list + list of reconstruction stages of interest. for example + [1, 2, 3, 4, 5]. + keys : list of str + list of the tracks class attributes. + + Returns + ------- + dict + dictionary of lazyarrays containing data for each tracks attribute requested. + + Raises + ------ + ValueError + ValueError raised when the reconstruction stages of interest + are not found in the file. + """ + lazy_d = {} + rec_stages = np.array( + [match for match in self._find_rec_stages(stages)]) + mask = rec_stages[:, 1] != None + if np.all(rec_stages[:, 1] == None): + raise ValueError( + "The stages {} are not found in your file.".format( + str(stages))) + else: + for key in keys: + lazy_d[key] = np.array([ + i[k] for i, k in zip( + getattr(self.tracks, key)[mask], rec_stages[:, + 1][mask]) + ]) + + return lazy_d + def _find_rec_stages(self, stages): """find the index of reconstruction stages of interest in a list of multiple reconstruction stages. diff --git a/tests/test_offline.py b/tests/test_offline.py index 5257b4d99b834f5a33f7624041ce145351e7d674..ca5802e2e2ce905dfce2cec4105d4067e1e0797c 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -199,6 +199,35 @@ class TestOfflineReader(unittest.TestCase): with self.assertRaises(ValueError): self.nu.get_reco_fit([1000, 4512, 5625]) + def test_get_reco_hits(self): + + doms = self.nu.get_reco_hits([1, 2, 3, 4, 5], ["dom_id"])["dom_id"] + + self.assertEqual(doms.size, 9) + self.assertListEqual(doms[0][0:4].tolist(), + self.nu.hits[0].dom_id[0:4].tolist()) + with self.assertRaises(ValueError): + self.nu.get_reco_hits([1000, 4512, 5625], ["dom_id"]) + + def test_get_reco_tracks(self): + + pos = self.nu.get_reco_tracks([1, 2, 3, 4, 5], ["pos_x"])["pos_x"] + + self.assertEqual(pos.size, 9) + self.assertEqual(pos[0], self.nu.tracks[0].pos_x[0]) + with self.assertRaises(ValueError): + self.nu.get_reco_tracks([1000, 4512, 5625], ["pos_x"]) + + def test_get_reco_events(self): + + hits = self.nu.get_reco_events([1, 2, 3, 4, 5], ["hits"])["hits"] + + self.assertEqual(hits.size, 9) + self.assertListEqual(hits[0:4].tolist(), + self.nu.events.hits[0:4].tolist()) + with self.assertRaises(ValueError): + self.nu.get_reco_events([1000, 4512, 5625], ["hits"]) + def test_get_max_reco_stages(self): rec_stages = self.nu.tracks.rec_stages max_reco = self.nu._get_max_reco_stages(rec_stages)