diff --git a/docs/orcasong.rst b/docs/orcasong.rst index e52cbd43ebee5a9b7cd96db32cb2a0f495aa015c..7ad90195e6c844adbfb0fbbf8ffd1c7684761a2f 100644 --- a/docs/orcasong.rst +++ b/docs/orcasong.rst @@ -130,7 +130,7 @@ calibrate the data on the fly: Adding mc_info ^^^^^^^^^^^^^^ -Define a function ``my_mcinfo_extractor``, which takes as an input a km3pipe blob, +Define a function ``my_extractor``, which takes as an input a km3pipe blob, and outputs a dict mapping str to float. It should contain everything you need later down the pipeline, e.g. labels, event identifiers, ... @@ -140,5 +140,5 @@ the str being the dtype names. Set up like follows: .. code-block:: python - fb = FileBinner(bin_edges_list, mc_info_extr=my_mcinfo_extractor) + fb = FileBinner(bin_edges_list, extractor=my_extractor) diff --git a/orcasong/core.py b/orcasong/core.py index cf0c6895b6241d36a89a395f3d244e0785fc3213..73c2780b83c8f6c2a72d39de83316788121e6ac9 100644 --- a/orcasong/core.py +++ b/orcasong/core.py @@ -21,12 +21,12 @@ class BaseProcessor: Parameters ---------- - mc_info_extr : function, optional - Function that extracts desired mc_info from a blob, which is then + extractor : function, optional + Function that extracts desired info from a blob, which is then stored as the "y" datafield in the .h5 file. The function takes the km3pipe blob as an input, and returns a dict mapping str to floats. - Some examples can be found in orcasong.mc_info_extr. + Examples can be found in orcasong.extractors. det_file : str, optional Path to a .detx detector geometry file, which can be used to calibrate the hits. @@ -82,7 +82,7 @@ class BaseProcessor: each pipeline. """ - def __init__(self, mc_info_extr=None, + def __init__(self, extractor=None, det_file=None, center_time=True, add_t0=False, @@ -93,7 +93,7 @@ class BaseProcessor: keep_mc_tracks=False, overwrite=True, mc_info_to_float64=True): - self.mc_info_extr = mc_info_extr + self.extractor = extractor self.det_file = det_file self.center_time = center_time self.add_t0 = add_t0 @@ -198,9 +198,9 @@ class BaseProcessor: def get_cmpts_post(self, outfile): """ Modules that postproc and save the events. """ cmpts = [] - if self.mc_info_extr is not None: + if self.extractor is not None: cmpts.append((modules.McInfoMaker, { - "mc_info_extr": self.mc_info_extr, + "extractor": self.extractor, "to_float64": self.mc_info_to_float64, "store_as": "mc_info"})) diff --git a/orcasong/mc_info_extr.py b/orcasong/extractors.py similarity index 100% rename from orcasong/mc_info_extr.py rename to orcasong/extractors.py diff --git a/orcasong/modules.py b/orcasong/modules.py index d08696b226d2361ec89e35a3d98afecd1a9e5729..88e5a2348c4a7b729c3294109eef1bd49122dc4c 100644 --- a/orcasong/modules.py +++ b/orcasong/modules.py @@ -12,11 +12,11 @@ __author__ = 'Stefan Reck' class McInfoMaker(kp.Module): """ - Store mc info as float64 in the blob. + Stores info as float64 in the blob. Attributes ---------- - mc_info_extr : function + extractor : function Function to extract the info. Takes the blob as input, outputs a dict with the desired mc_infos. store_as : str @@ -25,12 +25,12 @@ class McInfoMaker(kp.Module): """ def configure(self): - self.mc_info_extr = self.require('mc_info_extr') + self.extractor = self.require('extractor') self.store_as = self.require('store_as') self.to_float64 = self.get("to_float64", default=True) def process(self, blob): - track = self.mc_info_extr(blob) + track = self.extractor(blob) if self.to_float64: dtypes = [] for key, v in track.items(): diff --git a/tests/test_core.py b/tests/test_core.py index 978bd7674d030a3b3b23010cfa2e6eef01a1b9ef..e21d48d5ace8d37c4bbda7eb8ad4ec4ea3c716e4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,7 +4,7 @@ import tempfile import numpy as np import h5py import orcasong.core -import orcasong.mc_info_extr +import orcasong.extractors as extractors from orcasong.plotting.plot_binstats import read_hists_from_h5file @@ -15,6 +15,7 @@ test_dir = os.path.dirname(os.path.realpath(__file__)) MUPAGE_FILE = os.path.join(test_dir, "data", "mupage.root.h5") DET_FILE = os.path.join(test_dir, "data", "KM3NeT_-00000001_20171212.detx") + class TestFileBinner(TestCase): """ Assert that the filebinner still produces the same output. """ @classmethod @@ -25,7 +26,7 @@ class TestFileBinner(TestCase): ["time", np.linspace(0, 600, 3)], ["channel_id", np.linspace(-0.5, 30.5, 3)], ], - mc_info_extr=orcasong.mc_info_extr.get_real_data_info_extr(MUPAGE_FILE), + extractor=extractors.get_real_data_info_extr(MUPAGE_FILE), det_file=DET_FILE, add_t0=True, ) @@ -85,7 +86,7 @@ class TestFileGraph(TestCase): max_n_hits=3, time_window=[0, 50], hit_infos=["pos_z", "time", "channel_id"], - mc_info_extr=orcasong.mc_info_extr.get_real_data_info_extr(MUPAGE_FILE), + extractor=extractors.get_real_data_info_extr(MUPAGE_FILE), det_file=DET_FILE, add_t0=True, ) @@ -139,20 +140,3 @@ class TestFileGraph(TestCase): } for k, v in target.items(): np.testing.assert_equal(y[k], v) - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/tests/test_extractor.py b/tests/test_extractor.py index d48d01019ef82888d7dd4c5bc0e031bbc18d3c81..0cab2f5f259979efa0d80e6b62706edfd19ed0ba 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -4,11 +4,10 @@ import tempfile import numpy as np import h5py import orcasong.core -import orcasong.mc_info_extr -from orcasong.plotting.plot_binstats import read_hists_from_h5file +import orcasong.extractors as extractors -__author__ = 'Daniel Guderian' +__author__ = "Daniel Guderian" test_dir = os.path.dirname(os.path.realpath(__file__)) @@ -18,13 +17,14 @@ DET_FILE_NEUTRINO = os.path.join(test_dir, "data", "neutrino_detector_file.detx" class TestStdRecoExtractor(TestCase): """ Assert that the neutrino info is extracted correctly File has 18 events. """ + @classmethod def setUpClass(cls): cls.proc = orcasong.core.FileGraph( max_n_hits=3, time_window=[0, 50], hit_infos=["pos_z", "time", "channel_id"], - mc_info_extr=orcasong.mc_info_extr.get_neutrino_mc_info_extr(NEUTRINO_FILE), + mc_info_extr=extractors.get_neutrino_mc_info_extr(NEUTRINO_FILE), det_file=DET_FILE_NEUTRINO, add_t0=True, ) @@ -39,147 +39,155 @@ class TestStdRecoExtractor(TestCase): cls.tmpdir.cleanup() def test_keys(self): - self.assertSetEqual(set(self.f.keys()), { - '_i_event_info', '_i_group_info', '_i_y', - 'event_info', 'group_info', 'x', 'x_indices', 'y'}) + self.assertSetEqual( + set(self.f.keys()), + { + "_i_event_info", + "_i_group_info", + "_i_y", + "event_info", + "group_info", + "x", + "x_indices", + "y", + }, + ) def test_y(self): y = self.f["y"][()] target = { - 'weight_w2': np.array([29650.0, - 297100.0, - 41450.0, - 371400.0, - 1101000000.0, - 2757000.0, - 15280000.0, - 262800000.0, - 22590.0, - 24240.0, - 80030.0, - 3018000.0, - 120600.0, - 872200.0, - 50440000.0, - 21540.0, - 42170.0, - 25230.0]), - - 'n_gen': np.array([60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0, - 60000.0]), - - 'dir_z': np.array([-0.896549, - -0.835252, - 0.300461, - 0.108997, - 0.128445, - -0.543621, - -0.23205, - -0.297228, - 0.694932, - 0.73835, - -0.007682, - 0.437847, - -0.126804, - 0.153432, - -0.263229, - 0.820217, - 0.452473, - 0.294217]), - - 'is_cc': np.array([2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0]), - - 'std_dir_z': np.array([-0.923199825369434, - -0.6422689266782661, - 0.38853917922036363, - -0.16690804339142448, - -0.01584853496341109, - -0.10151549881670698, - -0.0409694104272829, - -0.32964369874021787, - -0.3294926806601529, - 0.6524241250799204, - -0.3899574246450216, - 0.27872277417339086, - 0.0019490791409933206, - 0.20341370281708737, - -0.15739475718286297, - 0.8040250543935723, - 0.08772622550043882, - -0.7766722433951796]), - - 'std_energy': np.array([4.7187625606210775, - 4.169818842606011, - 1.0056373761749966, - 5.908597073055873, - 12.409377607517195, - 7.566695371401163, - 1.3546775620239864, - 2.659528737837978, - 1.0056373761749966, - 2.1968321463948755, - 1.4821714294894754, - 10.135831333340658, - 2.6003934443336765, - 1.4492149732348223, - 71.69167874147956, - 8.094744120333358, - 3.148088080484504, - 1.0056373761749966]), - + "weight_w2": np.array( + [ + 29650.0, + 297100.0, + 41450.0, + 371400.0, + 1101000000.0, + 2757000.0, + 15280000.0, + 262800000.0, + 22590.0, + 24240.0, + 80030.0, + 3018000.0, + 120600.0, + 872200.0, + 50440000.0, + 21540.0, + 42170.0, + 25230.0, + ] + ), + "n_gen": np.array( + [ + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + 60000.0, + ] + ), + "dir_z": np.array( + [ + -0.896549, + -0.835252, + 0.300461, + 0.108997, + 0.128445, + -0.543621, + -0.23205, + -0.297228, + 0.694932, + 0.73835, + -0.007682, + 0.437847, + -0.126804, + 0.153432, + -0.263229, + 0.820217, + 0.452473, + 0.294217, + ] + ), + "is_cc": np.array( + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ] + ), + "std_dir_z": np.array( + [ + -0.923199825369434, + -0.6422689266782661, + 0.38853917922036363, + -0.16690804339142448, + -0.01584853496341109, + -0.10151549881670698, + -0.0409694104272829, + -0.32964369874021787, + -0.3294926806601529, + 0.6524241250799204, + -0.3899574246450216, + 0.27872277417339086, + 0.0019490791409933206, + 0.20341370281708737, + -0.15739475718286297, + 0.8040250543935723, + 0.08772622550043882, + -0.7766722433951796, + ] + ), + "std_energy": np.array( + [ + 4.7187625606210775, + 4.169818842606011, + 1.0056373761749966, + 5.908597073055873, + 12.409377607517195, + 7.566695371401163, + 1.3546775620239864, + 2.659528737837978, + 1.0056373761749966, + 2.1968321463948755, + 1.4821714294894754, + 10.135831333340658, + 2.6003934443336765, + 1.4492149732348223, + 71.69167874147956, + 8.094744120333358, + 3.148088080484504, + 1.0056373761749966, + ] + ), } for k, v in target.items(): np.testing.assert_equal(y[k], v) - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/tests/test_modules.py b/tests/test_modules.py index 242f26ef86e5688a68bdb8a65bf19f3feb273e27..5c4623166fc0763682c29fab70c6b07714186e89 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -10,7 +10,7 @@ __author__ = 'Stefan Reck' class TestModules(TestCase): def test_mc_info_maker(self): """ Test the mcinfo maker on some dummy data. """ - def mc_info_extr(blob): + def extractor(blob): hits = blob["Hits"] return {"dom_id_0": hits.dom_id[0], "time_2": hits.time[2]} @@ -23,7 +23,7 @@ class TestModules(TestCase): }) } module = modules.McInfoMaker( - mc_info_extr=mc_info_extr, store_as="test") + extractor=extractor, store_as="test") out_blob = module.process(in_blob) self.assertSequenceEqual(list(out_blob.keys()), ["Hits", "test"]) @@ -36,7 +36,7 @@ class TestModules(TestCase): def test_mc_info_maker_dtype(self): """ Test the mcinfo maker on some dummy data. """ - def mc_info_extr(blob): + def extractor(blob): hits = blob["Hits"] return {"dom_id_0": hits.dom_id[0], "time_2": hits.time[2]} @@ -48,7 +48,7 @@ class TestModules(TestCase): }) } module = modules.McInfoMaker( - mc_info_extr=mc_info_extr, store_as="test", to_float64=False) + extractor=extractor, store_as="test", to_float64=False) out_blob = module.process(in_blob) np.testing.assert_array_equal(