Skip to content
Snippets Groups Projects
Commit 27e9d8cd authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Fix fitinf

parent 080d9d54
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16340 failed
......@@ -134,34 +134,19 @@ def fitinf(fitparam, tracks):
----------
fitparam : int
the fit parameter key according to fitparameters defined in
KM3NeT-Dataformat.
tracks : km3io.offline.OfflineBranch
the tracks class. both full tracks branch or a slice of the
tracks branch (example tracks[:, 0]) work.
KM3NeT-Dataformat (see km3io.definitions.fitparameters).
tracks : ak.Array or km3io.rootio.Branch
reconstructed tracks with .fitinf attribute
Returns
-------
awkward1.Array
awkward array of the values of the fit parameter requested.
awkward array of the values of the fit parameter requested. Missing
values are set to NaN.
"""
fit = tracks.fitinf
index = fitparam
if tracks.is_single and len(tracks) != 1:
params = fit[count_nested(fit, axis=1) > index]
out = params[:, index]
if tracks.is_single and len(tracks) == 1:
out = fit[index]
else:
if len(tracks[0]) == 1: # case of tracks slice with 1 track per event.
params = fit[count_nested(fit, axis=1) > index]
out = params[:, index]
else:
params = fit[count_nested(fit, axis=2) > index]
out = ak.Array([i[:, index] for i in params])
return out
nonempty = ak.num(fit, axis=-1) > 0
return ak.fill_none(fit.mask[nonempty][..., 0], np.nan)
def count_nested(arr, axis=0):
......@@ -487,7 +472,7 @@ def is_cc(fobj):
cc_flag = w2list[:, kw2gen.W2LIST_GENHEN_CC]
out = cc_flag > 0
else:
raise ValueError(f"simulation program {program} is not implemented.")
raise NotImplementedError(f"don't know how to determine the CCness of {program} files.")
return out
......
......@@ -50,7 +50,6 @@ class TestFitinf(unittest.TestCase):
self.best = self.tracks[:, 0]
self.best_fit = self.best.fitinf
@unittest.skip
def test_fitinf_from_all_events(self):
beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks)
......@@ -58,7 +57,6 @@ class TestFitinf(unittest.TestCase):
assert beta[0][1] == self.fit[0][1][0]
assert beta[0][2] == self.fit[0][2][0]
@unittest.skip
def test_fitinf_from_one_event(self):
beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best)
......@@ -66,12 +64,6 @@ class TestFitinf(unittest.TestCase):
assert beta[1] == self.best_fit[1][0]
assert beta[2] == self.best_fit[2][0]
@unittest.skip
def test_fitinf_from_one_event_and_one_track(self):
beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks[0][0])
assert beta == self.tracks[0][0].fitinf[0]
class TestBestTrackSelection(unittest.TestCase):
def setUp(self):
......@@ -85,7 +77,6 @@ class TestBestTrackSelection(unittest.TestCase):
assert len(best) == 10
# TODO: nested items, no idea how to solve this...
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
......@@ -118,7 +109,6 @@ class TestBestTrackSelection(unittest.TestCase):
assert len(best) == 10
# TODO: nested items, no idea how to solve this...
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
......@@ -138,7 +128,6 @@ class TestBestTrackSelection(unittest.TestCase):
assert len(best) == 10
# TODO: nested items, no idea how to solve this...
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
......@@ -208,52 +197,81 @@ class TestBestTrackSelection(unittest.TestCase):
assert best2.lik == ak.max(tracks_slice.lik)
assert best2.rec_stages.tolist() == [1, 3, 5, 4]
@unittest.skip
def test_best_track_on_slices_with_start_end_one_event(self):
tracks_slice = self.one_event.tracks[0:5]
best = best_track(tracks_slice, startend=(1, 4))
assert len(best.lik) == 1
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0][0] == 1
assert best.rec_stages[0][-1] == 4
assert best.rec_stages[0] == 1
assert best.rec_stages[-1] == 4
@unittest.skip
def test_best_track_on_slices_with_explicit_rec_stages_one_event(self):
tracks_slice = self.one_event.tracks[0:5]
best = best_track(tracks_slice, stages=[1, 3, 5, 4])
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0][0] == 1
assert best.rec_stages[0][-1] == 4
assert best.rec_stages[0] == 1
assert best.rec_stages[-1] == 4
@unittest.skip
def test_best_track_on_slices_multiple_events(self):
tracks_slice = self.events.tracks[0:5]
tracks_slice = self.events[0:5].tracks
# stages in list
best = best_track(tracks_slice, stages=[1, 3, 5, 4])
assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert np.allclose(
best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
# stages in set
best = best_track(tracks_slice, stages={1, 3, 4, 5})
best = best_track(tracks_slice, stages={3, 4, 5})
assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert np.allclose(
best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
# using start and end
best = best_track(tracks_slice, startend=(1, 4))
start, end = (1, 4)
best = best_track(tracks_slice, startend=(start, end))
assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert np.allclose(
best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
rs = best.rec_stages[i].tolist()
assert rs[0] == start
assert rs[-1] == end
def test_best_track_raises_when_unknown_stages(self):
with self.assertRaises(ValueError):
......@@ -265,7 +283,6 @@ class TestBestTrackSelection(unittest.TestCase):
class TestBestJmuon(unittest.TestCase):
@unittest.skip
def test_best_jmuon(self):
best = best_jmuon(OFFLINE_FILE.events.tracks)
......@@ -537,7 +554,6 @@ class TestUnfoldIndices(unittest.TestCase):
class TestIsCC(unittest.TestCase):
@unittest.skip
def test_is_cc(self):
NC_file = is_cc(GENHEN_OFFLINE_FILE)
CC_file = is_cc(GSEAGEN_OFFLINE_FILE)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment