From 437c0bea579d2d0681b4ce8c73adf2f0668baba0 Mon Sep 17 00:00:00 2001
From: Tamas Gal <tgal@km3net.de>
Date: Wed, 9 Dec 2020 17:32:54 +0100
Subject: [PATCH] Fix usr

---
 examples/plot_offline_usr.py | 15 ++-----
 km3io/tools.py               | 17 ++++++++
 tests/test_tools.py          | 76 ++++++++++++++++++++++++++++++++++++
 3 files changed, 97 insertions(+), 11 deletions(-)

diff --git a/examples/plot_offline_usr.py b/examples/plot_offline_usr.py
index 9d7959b..8bf3b2d 100644
--- a/examples/plot_offline_usr.py
+++ b/examples/plot_offline_usr.py
@@ -18,19 +18,12 @@ r = ki.OfflineReader(data_path("offline/usr-sample.root"))
 
 
 #####################################################
-# Accessing the usr data:
+# Accessing the usr fields:
 
-usr = r.events.usr
-print(usr)
+print(r.events.usr_names.tolist())
 
 
 #####################################################
-# to access data of a specific key, you can either do:
+# to access data of a specific key:
 
-print(usr.DeltaPosZ)
-
-
-#####################################################
-# or
-
-print(usr["RecoQuality"])
+print(ki.tools.usr(r.events, "DeltaPosZ"))
diff --git a/km3io/tools.py b/km3io/tools.py
index 2bc0e00..8db9a5b 100644
--- a/km3io/tools.py
+++ b/km3io/tools.py
@@ -494,3 +494,20 @@ def is_cc(fobj):
             raise ValueError(f"simulation program {program} is not implemented.")
 
     return out
+
+
+def usr(objects, field):
+    """Return the usr-data for a given field.
+
+    Parameters
+    ----------
+    objects : awkward.Array
+      Events, tracks, hits or whatever objects which have usr and usr_names
+      fields (e.g. OfflineReader().events).
+    """
+    if len(unique(ak.num(objects.usr_names))) > 1:
+        # let's do it the hard way
+        return ak.flatten(objects.usr[objects.usr_names == field])
+    available_fields = objects.usr_names[0].tolist()
+    idx = available_fields.index(field)
+    return objects.usr[:, idx]
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 3505b6a..69bcdca 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -27,9 +27,16 @@ from km3io.tools import (
     best_aashower,
     best_dusjshower,
     is_cc,
+    usr,
 )
 
 OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
+OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root"))
+OFFLINE_MC_TRACK_USR = OfflineReader(
+    data_path(
+        "offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root"
+    )
+)
 GENHEN_OFFLINE_FILE = OfflineReader(
     data_path("offline/mcv5.1.genhen_anumuNC.sirene.jte.jchain.aashower.sample.root")
 )
@@ -540,3 +547,72 @@ class TestIsCC(unittest.TestCase):
             all(NC_file) == True
         )  # this test fails because the CC flags are not reliable in old files
         self.assertTrue(all(CC_file) == True)
+
+
+class TestUsr(unittest.TestCase):
+    def setUp(self):
+        self.f = OFFLINE_USR
+
+    def test_str_flat(self):
+        print(self.f.events.usr)
+
+    def test_keys_flat(self):
+        self.assertListEqual(
+            [
+                "RecoQuality",
+                "RecoNDF",
+                "CoC",
+                "ToT",
+                "ChargeAbove",
+                "ChargeBelow",
+                "ChargeRatio",
+                "DeltaPosZ",
+                "FirstPartPosZ",
+                "LastPartPosZ",
+                "NSnapHits",
+                "NTrigHits",
+                "NTrigDOMs",
+                "NTrigLines",
+                "NSpeedVetoHits",
+                "NGeometryVetoHits",
+                "ClassficationScore",
+            ],
+            self.f.events.usr_names[0].tolist(),
+        )
+
+    def test_getitem_flat(self):
+        assert np.allclose(
+            [118.6302815337638, 44.33580521344907, 99.93916717621543],
+            usr(self.f.events, "CoC").tolist(),
+        )
+        assert np.allclose(
+            [37.51967774166617, -10.280346193553832, 13.67595659707355],
+            usr(self.f.events, "DeltaPosZ").tolist(),
+        )
+
+
+class TestMcTrackUsr(unittest.TestCase):
+    def setUp(self):
+        self.f = OFFLINE_MC_TRACK_USR
+
+    def test_usr_names(self):
+        n_tracks = len(self.f.events)
+        for i in range(3):
+            self.assertListEqual(
+                ["bx", "by", "ichan", "cc"],
+                self.f.events.mc_tracks.usr_names[i][0].tolist(),
+            )
+            self.assertListEqual(
+                ["energy_lost_in_can"],
+                self.f.events.mc_tracks.usr_names[i][1].tolist(),
+            )
+
+    def test_usr(self):
+        assert np.allclose(
+            [0.0487, 0.0588, 3, 2],
+            self.f.events.mc_tracks.usr[0][0].tolist(),
+            atol=0.0001,
+        )
+        assert np.allclose(
+            [0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001
+        )
-- 
GitLab