From 24cf773fcf0a642236f2fb85c8845d73d071f6a1 Mon Sep 17 00:00:00 2001 From: Johannes Schumann <johannes.schumann@fau.de> Date: Mon, 7 Dec 2020 17:04:17 +0100 Subject: [PATCH] Add tests --- km3buu/output.py | 2 +- km3buu/tests/test_output.py | 22 ++++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/km3buu/output.py b/km3buu/output.py index 8eb318f..9c6835e 100644 --- a/km3buu/output.py +++ b/km3buu/output.py @@ -297,7 +297,7 @@ def write_detector_file(gibuu_output, ofile="gibuu.aanet.root", can=(0, 476.5, 403.4), livetime=3.156e7, - propagate_tau=True): + propagate_tau=True): # pragma: no cover """ Convert the GiBUU output to a KM3NeT MC (AANET) file diff --git a/km3buu/tests/test_output.py b/km3buu/tests/test_output.py index 69d8f30..dc90b17 100644 --- a/km3buu/tests/test_output.py +++ b/km3buu/tests/test_output.py @@ -11,6 +11,7 @@ __email__ = "jschumann@km3net.de" __status__ = "Development" import unittest +from unittest.mock import patch import numpy as np import pytest import km3io @@ -18,7 +19,7 @@ from km3buu.output import * from os import listdir from os.path import abspath, join, dirname from km3net_testdata import data_path -from tempfile import NamedTemporaryFile +from tempfile import NamedTemporaryFile, TemporaryDirectory TESTDATA_DIR = data_path("gibuu") @@ -40,9 +41,17 @@ class TestXSection(unittest.TestCase): class TestGiBUUOutput(unittest.TestCase): - def setUp(self): + def setup_class(self): self.output = GiBUUOutput(TESTDATA_DIR) + def test_tmp_dir_init(self): + with patch('tempfile.TemporaryDirectory', + spec=TemporaryDirectory) as mock: + instance = mock.return_value + instance.name = abspath(TESTDATA_DIR) + output = GiBUUOutput(instance) + assert output.data_path == abspath(TESTDATA_DIR) + def test_attr(self): assert hasattr(self.output, "df") @@ -54,6 +63,15 @@ class TestGiBUUOutput(unittest.TestCase): n_evts = self.output.flux_interpolation.integral(0.7, 1.0) / 0.02 self.assertAlmostEqual(xsec / n_evts, 0.8, places=2) + def test_nucleus_properties(self): + assert self.output.Z == 8 + assert self.output.A == 16 + + def test_w2weights(self): + w2 = self.output.w2weights(123.0, 2.6e28, 4 * np.pi) + np.testing.assert_array_almost_equal( + w2[:3], [7.63360911e+01, 3.60997502e-01, 1.13273189e+03]) + @pytest.mark.skipif(not AANET_AVAILABLE, reason="aanet required") class TestAANET(unittest.TestCase): -- GitLab