Skip to content
Snippets Groups Projects
Commit 5586acd2 authored by Stefan Reck's avatar Stefan Reck
Browse files

expand and adjust tests

parent 04928c6c
No related branches found
No related tags found
1 merge request!23Jagged graph
......@@ -15,6 +15,9 @@ clean:
test:
py.test --junitxml=./reports/junit.xml -o junit_suite_name=$(PKGNAME) tests
retest:
py.test --junitxml=./reports/junit.xml -o junit_suite_name=$(PKGNAME) tests --last-failed
test-cov:
py.test tests --cov $(ALLNAMES) --cov-report term-missing --cov-report xml:reports/coverage.xml --cov-report html:reports/coverage tests
......@@ -38,4 +41,4 @@ yapf:
yapf -i -r $(PKGNAME)
yapf -i setup.py
.PHONY: all clean build install install-dev test test-nocov flake8 pep8 dependencies docstyle
.PHONY: all clean build install install-dev test retest test-nocov flake8 pep8 dependencies docstyle
......@@ -83,7 +83,9 @@ class TestFileGraph(TestCase):
""" Assert that the FileGraph still produces the same output. """
@classmethod
def setUpClass(cls):
cls.proc = orcasong.core.FileGraph(
# produce test file, once for fixed_length (old format), and once
# for the new format
cls.proc_fixed_length, cls.proc = [orcasong.core.FileGraph(
max_n_hits=3,
time_window=[0, 50],
hit_infos=["pos_z", "time", "channel_id"],
......@@ -92,8 +94,14 @@ class TestFileGraph(TestCase):
add_t0=False,
keep_event_info=True,
correct_timeslew=False,
)
fixed_length=fixed_length,
) for fixed_length in (True, False)]
cls.tmpdir = tempfile.TemporaryDirectory()
cls.outfile_fixed_length = os.path.join(cls.tmpdir.name, "binned_fixed_length.h5")
cls.proc_fixed_length.run(infile=MUPAGE_FILE, outfile=cls.outfile_fixed_length)
cls.f_fixed_length = h5py.File(cls.outfile_fixed_length, "r")
cls.outfile = os.path.join(cls.tmpdir.name, "binned.h5")
cls.proc.run(infile=MUPAGE_FILE, outfile=cls.outfile)
cls.f = h5py.File(cls.outfile, "r")
......@@ -101,25 +109,44 @@ class TestFileGraph(TestCase):
@classmethod
def tearDownClass(cls):
cls.f.close()
cls.f_fixed_length.close()
cls.tmpdir.cleanup()
def test_keys_fixed_length(self):
self.assertSetEqual(set(self.f_fixed_length.keys()), {
'_i_event_info', '_i_group_info', '_i_y',
'event_info', 'group_info', 'x', 'x_indices', 'y'})
def test_keys(self):
self.assertSetEqual(set(self.f.keys()), {
self.assertSetEqual(set(self.f_fixed_length.keys()), {
'_i_event_info', '_i_group_info', '_i_y',
'event_info', 'group_info', 'x', 'x_indices', 'y'})
def test_x_attrs_fixed_length(self):
to_check = {
"hit_info_0": "pos_z",
"hit_info_1": "time",
"hit_info_2": "channel_id",
"hit_info_3": "is_valid",
"indexed": False,
}
attrs = dict(self.f_fixed_length["x"].attrs)
for k, v in to_check.items():
self.assertTrue(attrs[k] == v)
def test_x_attrs(self):
to_check = {
"hit_info_0": "pos_z",
"hit_info_1": "time",
"hit_info_2": "channel_id",
"hit_info_3": "is_valid",
"indexed": True,
}
attrs = dict(self.f["x"].attrs)
for k, v in to_check.items():
self.assertTrue(attrs[k] == v)
def test_x(self):
def test_x_fixed_length(self):
target = np.array([
[[676.941, 13., 30., 1.],
[461.111, 32., 9., 1.],
......@@ -131,8 +158,33 @@ class TestFileGraph(TestCase):
[605.111, 9., 4., 1.],
[424.889, 46., 29., 1.]]
], dtype=np.float32)
np.testing.assert_equal(target, self.f_fixed_length["x"])
def test_x(self):
target = np.array([
[676.941, 13., 30.],
[461.111, 32., 9.],
[424.941, 1., 30.],
[172.83, 32., 25.],
[316.83, 2., 14.],
[461.059, 1., 3.],
[496.83, 34., 25.],
[605.111, 9., 4.],
[424.889, 46., 29.],
], dtype=np.float32)
np.testing.assert_equal(target, self.f["x"])
def test_y_fixed_length(self):
y = self.f_fixed_length["y"][()]
target = {
'event_id': np.array([0., 1., 2.]),
'run_id': np.array([1., 1., 1.]),
'trigger_mask': np.array([18., 18., 16.]),
'group_id': np.array([0, 1, 2]),
}
for k, v in target.items():
np.testing.assert_equal(y[k], v)
def test_y(self):
y = self.f["y"][()]
target = {
......@@ -142,4 +194,4 @@ class TestFileGraph(TestCase):
'group_id': np.array([0, 1, 2]),
}
for k, v in target.items():
np.testing.assert_equal(y[k], v)
np.testing.assert_equal(y[k], v)
\ No newline at end of file
......@@ -17,6 +17,7 @@ DET_FILE_NEUTRINO = os.path.join(test_dir, "data", "KM3NeT_00000049_20200707.det
NO_COMPLE_RECO_FILE = os.path.join(test_dir, "data", "arca_test_without_some_jmuon_recos.h5")
ARCA_DETX = os.path.join(test_dir, "data", "KM3NeT_-00000001_20171212.detx")
class TestStdRecoExtractor(TestCase):
""" Assert that the neutrino info is extracted correctly. File has 18 events. """
......@@ -31,6 +32,7 @@ class TestStdRecoExtractor(TestCase):
det_file=DET_FILE_NEUTRINO,
add_t0=True,
keep_event_info=True,
fixed_length=True,
)
cls.tmpdir = tempfile.TemporaryDirectory()
cls.outfile = os.path.join(cls.tmpdir.name, "binned.h5")
......@@ -50,6 +52,7 @@ class TestStdRecoExtractor(TestCase):
det_file=ARCA_DETX,
add_t0=True,
keep_event_info=True,
fixed_length=True,
)
cls.outfile_arca = os.path.join(cls.tmpdir.name, "binned_arca.h5")
cls.proc.run(infile=NO_COMPLE_RECO_FILE, outfile=cls.outfile_arca)
......
......@@ -191,6 +191,19 @@ class TestPointMaker(TestCase):
pm = modules.PointMaker(
max_n_hits=4)
result = pm.process(self.input_blob_1)["samples"]
self.assertTupleEqual(
pm.finish()["hit_infos"], ("t0", "time", "x"))
target = np.array(
[[0.1, 1, 4],
[0.2, 2, 5],
[0.3, 3, 6]],
dtype="float32")
np.testing.assert_array_equal(result, target)
def test_default_settings_fixed_length(self):
pm = modules.PointMaker(
max_n_hits=4, fixed_length=True)
result = pm.process(self.input_blob_1)["samples"]
self.assertTupleEqual(
pm.finish()["hit_infos"], ("t0", "time", "x", "is_valid"))
target = np.array(
......@@ -200,12 +213,13 @@ class TestPointMaker(TestCase):
[0, 0, 0, 0]]], dtype="float32")
np.testing.assert_array_equal(result, target)
def test_input_blob_1(self):
def test_input_blob_1_fixed_length(self):
pm = modules.PointMaker(
max_n_hits=4,
hit_infos=("x", "time"),
time_window=None,
dset_n_hits=None,
fixed_length=True,
)
result = pm.process(self.input_blob_1)["samples"]
self.assertTupleEqual(
......@@ -217,7 +231,7 @@ class TestPointMaker(TestCase):
[0, 0, 0]]], dtype="float32")
np.testing.assert_array_equal(result, target)
def test_input_blob_1_max_n_hits(self):
def test_input_blob_1_max_n_hits_fixed_length(self):
input_blob_long = {
"Hits": kp.Table({
"x": np.random.rand(1000).astype("float32"),
......@@ -227,18 +241,20 @@ class TestPointMaker(TestCase):
hit_infos=("x",),
time_window=None,
dset_n_hits=None,
fixed_length=True,
).process(input_blob_long)["samples"]
self.assertSequenceEqual(result.shape, (1, 10, 2))
self.assertTrue(all(
np.isin(result[0, :, 0], input_blob_long["Hits"]["x"])))
def test_input_blob_time_window(self):
def test_input_blob_time_window_fixed_length(self):
result = modules.PointMaker(
max_n_hits=4,
hit_infos=("x", "time"),
time_window=[1, 2],
dset_n_hits=None,
fixed_length=True,
).process(self.input_blob_1)["samples"]
target = np.array(
[[[4, 1, 1],
......@@ -247,12 +263,13 @@ class TestPointMaker(TestCase):
[0, 0, 0]]], dtype="float32")
np.testing.assert_array_equal(result, target)
def test_input_blob_time_window_nhits(self):
def test_input_blob_time_window_nhits_fixed_length(self):
result = modules.PointMaker(
max_n_hits=4,
hit_infos=("x", "time"),
time_window=[1, 2],
dset_n_hits="EventInfo",
fixed_length=True,
).process(self.input_blob_1)["EventInfo"]
print(result)
self.assertEqual(result["n_hits_intime"], 2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment