Skip to content
Snippets Groups Projects
Commit 437c0bea authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Fix usr

parent 3ba9eccc
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16288 failed
This commit is part of merge request !47. Comments created here will be created in the context of that merge request.
......@@ -18,19 +18,12 @@ r = ki.OfflineReader(data_path("offline/usr-sample.root"))
#####################################################
# Accessing the usr data:
# Accessing the usr fields:
usr = r.events.usr
print(usr)
print(r.events.usr_names.tolist())
#####################################################
# to access data of a specific key, you can either do:
# to access data of a specific key:
print(usr.DeltaPosZ)
#####################################################
# or
print(usr["RecoQuality"])
print(ki.tools.usr(r.events, "DeltaPosZ"))
......@@ -494,3 +494,20 @@ def is_cc(fobj):
raise ValueError(f"simulation program {program} is not implemented.")
return out
def usr(objects, field):
"""Return the usr-data for a given field.
Parameters
----------
objects : awkward.Array
Events, tracks, hits or whatever objects which have usr and usr_names
fields (e.g. OfflineReader().events).
"""
if len(unique(ak.num(objects.usr_names))) > 1:
# let's do it the hard way
return ak.flatten(objects.usr[objects.usr_names == field])
available_fields = objects.usr_names[0].tolist()
idx = available_fields.index(field)
return objects.usr[:, idx]
......@@ -27,9 +27,16 @@ from km3io.tools import (
best_aashower,
best_dusjshower,
is_cc,
usr,
)
OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
OFFLINE_USR = OfflineReader(data_path("offline/usr-sample.root"))
OFFLINE_MC_TRACK_USR = OfflineReader(
data_path(
"offline/mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root"
)
)
GENHEN_OFFLINE_FILE = OfflineReader(
data_path("offline/mcv5.1.genhen_anumuNC.sirene.jte.jchain.aashower.sample.root")
)
......@@ -540,3 +547,72 @@ class TestIsCC(unittest.TestCase):
all(NC_file) == True
) # this test fails because the CC flags are not reliable in old files
self.assertTrue(all(CC_file) == True)
class TestUsr(unittest.TestCase):
def setUp(self):
self.f = OFFLINE_USR
def test_str_flat(self):
print(self.f.events.usr)
def test_keys_flat(self):
self.assertListEqual(
[
"RecoQuality",
"RecoNDF",
"CoC",
"ToT",
"ChargeAbove",
"ChargeBelow",
"ChargeRatio",
"DeltaPosZ",
"FirstPartPosZ",
"LastPartPosZ",
"NSnapHits",
"NTrigHits",
"NTrigDOMs",
"NTrigLines",
"NSpeedVetoHits",
"NGeometryVetoHits",
"ClassficationScore",
],
self.f.events.usr_names[0].tolist(),
)
def test_getitem_flat(self):
assert np.allclose(
[118.6302815337638, 44.33580521344907, 99.93916717621543],
usr(self.f.events, "CoC").tolist(),
)
assert np.allclose(
[37.51967774166617, -10.280346193553832, 13.67595659707355],
usr(self.f.events, "DeltaPosZ").tolist(),
)
class TestMcTrackUsr(unittest.TestCase):
def setUp(self):
self.f = OFFLINE_MC_TRACK_USR
def test_usr_names(self):
n_tracks = len(self.f.events)
for i in range(3):
self.assertListEqual(
["bx", "by", "ichan", "cc"],
self.f.events.mc_tracks.usr_names[i][0].tolist(),
)
self.assertListEqual(
["energy_lost_in_can"],
self.f.events.mc_tracks.usr_names[i][1].tolist(),
)
def test_usr(self):
assert np.allclose(
[0.0487, 0.0588, 3, 2],
self.f.events.mc_tracks.usr[0][0].tolist(),
atol=0.0001,
)
assert np.allclose(
[0.147, 0.4, 3, 2], self.f.events.mc_tracks.usr[1][0].tolist(), atol=0.001
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment