From 8335f02764ff72a141271383b0f078fe6f6dd2f7 Mon Sep 17 00:00:00 2001
From: Johannes Schumann <johannes.schumann@fau.de>
Date: Tue, 23 Nov 2021 18:20:37 +0100
Subject: [PATCH] Update propagation tests to predefined values

---
 km3buu/tests/test_propagation.py | 46 ++++++++++++++++++++------------
 1 file changed, 29 insertions(+), 17 deletions(-)

diff --git a/km3buu/tests/test_propagation.py b/km3buu/tests/test_propagation.py
index 6b013f0..7bceaa3 100644
--- a/km3buu/tests/test_propagation.py
+++ b/km3buu/tests/test_propagation.py
@@ -16,40 +16,49 @@ import numpy as np
 import uproot
 from os.path import abspath, join, dirname
 from thepipe.logger import get_logger
-from km3net_testdata import data_path
 
 import proposal as pp
 
-from km3buu.output import GiBUUOutput
+import awkward as ak
 from km3buu.propagation import propagate_lepton
 
-TESTDATA_DIR = data_path("gibuu")
-
 pp.RandomGenerator.get().set_seed(1234)
 
 
 class TestTauPropagation(unittest.TestCase):
     def setUp(self):
-        log = get_logger("ctrl.py")
-        log.setLevel("INFO")
-        self.gibuu_output = GiBUUOutput(TESTDATA_DIR)
-        fname = join(TESTDATA_DIR, self.gibuu_output.root_pert_files[0])
-        fobj = uproot.open(fname)
-        data = fobj["RootTuple"].arrays()
+        data = ak.Array({
+            "lepOut_E": [
+                17.45830624434573, 3.1180989952362594, 21.270059768902005,
+                5.262659790136034, 23.52185741888274
+            ],
+            "lepOut_Px": [
+                -0.42224402086330426, -1.0232258668453014, -0.5801431899058521,
+                -0.9038349288874724, 0.9022573877437422
+            ],
+            "lepOut_Py": [
+                0.3644190693190108, -0.24542303987320932, 0.24499631087268617,
+                -1.1060562370375715, -3.982173292871768
+            ],
+            "lepOut_Pz": [
+                17.35867612031871, 2.336148261778657, 21.186342871282157,
+                4.743161507744377, 23.096499191566885
+            ]
+        })
         self.sec = propagate_lepton(data, 15)
 
     def test_secondary_momenta(self):
         np.testing.assert_array_almost_equal(np.array(self.sec[0].E),
-                                             [0.535, 1.316, 0.331],
+                                             [2.182, 13.348, 1.928],
                                              decimal=3)
         np.testing.assert_array_almost_equal(np.array(self.sec[0].Px),
-                                             [-0.467, 0.321, -0.246],
+                                             [0.295, -0.48, -0.237],
                                              decimal=3)
         np.testing.assert_array_almost_equal(np.array(self.sec[0].Py),
-                                             [0.127, -0.822, 0.218],
+                                             [-0.375, 0.784, -0.044],
                                              decimal=3)
         np.testing.assert_array_almost_equal(np.array(self.sec[0].Pz),
-                                             [0.179, 0.967, -0.041],
+                                             [2.129, 13.316, 1.913],
                                              decimal=3)
 
     def test_secondary_types(self):
@@ -57,9 +66,12 @@ class TestTauPropagation(unittest.TestCase):
                                       [13, 16, -14])
 
     def test_secondary_positions(self):
-        np.testing.assert_array_almost_equal(np.array(self.sec[0].x), [0, 0],
+        np.testing.assert_array_almost_equal(np.array(self.sec[0].x),
+                                             [-1.4e-05, -1.4e-05, -1.4e-05],
                                              decimal=1)
-        np.testing.assert_array_almost_equal(np.array(self.sec[0].y), [0, 0],
+        np.testing.assert_array_almost_equal(np.array(self.sec[0].y),
+                                             [1.2e-05, 1.2e-05, 1.2e-05],
                                              decimal=1)
-        np.testing.assert_array_almost_equal(np.array(self.sec[0].z), [0, 0],
+        np.testing.assert_array_almost_equal(np.array(self.sec[0].z),
+                                             [0., 0., 0.],
                                              decimal=1)
-- 
GitLab