From 24cf773fcf0a642236f2fb85c8845d73d071f6a1 Mon Sep 17 00:00:00 2001
From: Johannes Schumann <johannes.schumann@fau.de>
Date: Mon, 7 Dec 2020 17:04:17 +0100
Subject: [PATCH] Add tests

---
 km3buu/output.py            |  2 +-
 km3buu/tests/test_output.py | 22 ++++++++++++++++++++--
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/km3buu/output.py b/km3buu/output.py
index 8eb318f..9c6835e 100644
--- a/km3buu/output.py
+++ b/km3buu/output.py
@@ -297,7 +297,7 @@ def write_detector_file(gibuu_output,
                         ofile="gibuu.aanet.root",
                         can=(0, 476.5, 403.4),
                         livetime=3.156e7,
-                        propagate_tau=True):
+                        propagate_tau=True):  # pragma: no cover
     """
     Convert the GiBUU output to a KM3NeT MC (AANET) file
 
diff --git a/km3buu/tests/test_output.py b/km3buu/tests/test_output.py
index 69d8f30..dc90b17 100644
--- a/km3buu/tests/test_output.py
+++ b/km3buu/tests/test_output.py
@@ -11,6 +11,7 @@ __email__ = "jschumann@km3net.de"
 __status__ = "Development"
 
 import unittest
+from unittest.mock import patch
 import numpy as np
 import pytest
 import km3io
@@ -18,7 +19,7 @@ from km3buu.output import *
 from os import listdir
 from os.path import abspath, join, dirname
 from km3net_testdata import data_path
-from tempfile import NamedTemporaryFile
+from tempfile import NamedTemporaryFile, TemporaryDirectory
 
 TESTDATA_DIR = data_path("gibuu")
 
@@ -40,9 +41,17 @@ class TestXSection(unittest.TestCase):
 
 
 class TestGiBUUOutput(unittest.TestCase):
-    def setUp(self):
+    def setup_class(self):
         self.output = GiBUUOutput(TESTDATA_DIR)
 
+    def test_tmp_dir_init(self):
+        with patch('tempfile.TemporaryDirectory',
+                   spec=TemporaryDirectory) as mock:
+            instance = mock.return_value
+            instance.name = abspath(TESTDATA_DIR)
+            output = GiBUUOutput(instance)
+            assert output.data_path == abspath(TESTDATA_DIR)
+
     def test_attr(self):
         assert hasattr(self.output, "df")
 
@@ -54,6 +63,15 @@ class TestGiBUUOutput(unittest.TestCase):
         n_evts = self.output.flux_interpolation.integral(0.7, 1.0) / 0.02
         self.assertAlmostEqual(xsec / n_evts, 0.8, places=2)
 
+    def test_nucleus_properties(self):
+        assert self.output.Z == 8
+        assert self.output.A == 16
+
+    def test_w2weights(self):
+        w2 = self.output.w2weights(123.0, 2.6e28, 4 * np.pi)
+        np.testing.assert_array_almost_equal(
+            w2[:3], [7.63360911e+01, 3.60997502e-01, 1.13273189e+03])
+
 
 @pytest.mark.skipif(not AANET_AVAILABLE, reason="aanet required")
 class TestAANET(unittest.TestCase):
-- 
GitLab