From 437c0bea579d2d0681b4ce8c73adf2f0668baba0 Mon Sep 17 00:00:00 2001 From: Tamas Gal <tgal@km3net.de> Date: Wed, 9 Dec 2020 17:32:54 +0100 Subject: [PATCH] Fix usr --- examples/plot_offline_usr.py | 15 ++----- km3io/tools.py | 17 ++++++++ tests/test_tools.py | 76 ++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 11 deletions(-) diff --git a/examples/plot_offline_usr.py b/examples/plot_offline_usr.py index 9d7959b..8bf3b2d 100644 --- a/examples/plot_offline_usr.py +++ b/examples/plot_offline_usr.py @@ -18,19 +18,12 @@ r = ki.OfflineReader(data_path("offline/usr-sample.root")) ##################################################### -# Accessing the usr data: +# Accessing the usr fields: -usr = r.events.usr -print(usr) +print(r.events.usr_names.tolist()) ##################################################### -# to access data of a specific key, you can either do: +# to access data of a specific key: -print(usr.DeltaPosZ) - - -##################################################### -# or - -print(usr["RecoQuality"]) +print(ki.tools.usr(r.events, "DeltaPosZ")) diff --git a/km3io/tools.py b/km3io/tools.py index 2bc0e00..8db9a5b 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -494,3 +494,20 @@ def is_cc(fobj): raise ValueError(f"simulation program {program} is not implemented.") return out + + +def usr(objects, field): + """Return the usr-data for a given field. + + Parameters + ---------- + objects : awkward.Array + Events, tracks, hits or whatever objects which have usr and usr_names + fields (e.g. OfflineReader().events). + """ + if len(unique(ak.num(objects.usr_names))) > 1: + # let's do it the hard way + return ak.flatten(objects.usr[objects.usr_names == field]) + available_fields = objects.usr_names[0].tolist() + idx = available_fields.index(field) + return objects.usr[:, idx] diff --git a/tests/test_tools.py b/tests/test_tools.py index 3505b6a..69bcdca 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -27,9 +27,16 @@ from km3io.tools import ( best_aashower, best_dusjshower, is_cc, + usr, ) OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root")) +OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root")) +OFFLINE_MC_TRACK_USR = OfflineReader( + data_path( + "offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root" + ) +) GENHEN_OFFLINE_FILE = OfflineReader( data_path("offline/mcv5.1.genhen_anumuNC.sirene.jte.jchain.aashower.sample.root") ) @@ -540,3 +547,72 @@ class TestIsCC(unittest.TestCase): all(NC_file) == True ) # this test fails because the CC flags are not reliable in old files self.assertTrue(all(CC_file) == True) + + +class TestUsr(unittest.TestCase): + def setUp(self): + self.f = OFFLINE_USR + + def test_str_flat(self): + print(self.f.events.usr) + + def test_keys_flat(self): + self.assertListEqual( + [ + "RecoQuality", + "RecoNDF", + "CoC", + "ToT", + "ChargeAbove", + "ChargeBelow", + "ChargeRatio", + "DeltaPosZ", + "FirstPartPosZ", + "LastPartPosZ", + "NSnapHits", + "NTrigHits", + "NTrigDOMs", + "NTrigLines", + "NSpeedVetoHits", + "NGeometryVetoHits", + "ClassficationScore", + ], + self.f.events.usr_names[0].tolist(), + ) + + def test_getitem_flat(self): + assert np.allclose( + [118.6302815337638, 44.33580521344907, 99.93916717621543], + usr(self.f.events, "CoC").tolist(), + ) + assert np.allclose( + [37.51967774166617, -10.280346193553832, 13.67595659707355], + usr(self.f.events, "DeltaPosZ").tolist(), + ) + + +class TestMcTrackUsr(unittest.TestCase): + def setUp(self): + self.f = OFFLINE_MC_TRACK_USR + + def test_usr_names(self): + n_tracks = len(self.f.events) + for i in range(3): + self.assertListEqual( + ["bx", "by", "ichan", "cc"], + self.f.events.mc_tracks.usr_names[i][0].tolist(), + ) + self.assertListEqual( + ["energy_lost_in_can"], + self.f.events.mc_tracks.usr_names[i][1].tolist(), + ) + + def test_usr(self): + assert np.allclose( + [0.0487, 0.0588, 3, 2], + self.f.events.mc_tracks.usr[0][0].tolist(), + atol=0.0001, + ) + assert np.allclose( + [0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001 + ) -- GitLab