Skip to content
Snippets Groups Projects
Commit 27e9d8cd authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Fix fitinf

parent 080d9d54
No related branches found
No related tags found
1 merge request!47Resolve "uproot4 integration"
Pipeline #16340 failed
...@@ -134,34 +134,19 @@ def fitinf(fitparam, tracks): ...@@ -134,34 +134,19 @@ def fitinf(fitparam, tracks):
---------- ----------
fitparam : int fitparam : int
the fit parameter key according to fitparameters defined in the fit parameter key according to fitparameters defined in
KM3NeT-Dataformat. KM3NeT-Dataformat (see km3io.definitions.fitparameters).
tracks : km3io.offline.OfflineBranch tracks : ak.Array or km3io.rootio.Branch
the tracks class. both full tracks branch or a slice of the reconstructed tracks with .fitinf attribute
tracks branch (example tracks[:, 0]) work.
Returns Returns
------- -------
awkward1.Array awkward1.Array
awkward array of the values of the fit parameter requested. awkward array of the values of the fit parameter requested. Missing
values are set to NaN.
""" """
fit = tracks.fitinf fit = tracks.fitinf
index = fitparam nonempty = ak.num(fit, axis=-1) > 0
if tracks.is_single and len(tracks) != 1: return ak.fill_none(fit.mask[nonempty][..., 0], np.nan)
params = fit[count_nested(fit, axis=1) > index]
out = params[:, index]
if tracks.is_single and len(tracks) == 1:
out = fit[index]
else:
if len(tracks[0]) == 1: # case of tracks slice with 1 track per event.
params = fit[count_nested(fit, axis=1) > index]
out = params[:, index]
else:
params = fit[count_nested(fit, axis=2) > index]
out = ak.Array([i[:, index] for i in params])
return out
def count_nested(arr, axis=0): def count_nested(arr, axis=0):
...@@ -487,7 +472,7 @@ def is_cc(fobj): ...@@ -487,7 +472,7 @@ def is_cc(fobj):
cc_flag = w2list[:, kw2gen.W2LIST_GENHEN_CC] cc_flag = w2list[:, kw2gen.W2LIST_GENHEN_CC]
out = cc_flag > 0 out = cc_flag > 0
else: else:
raise ValueError(f"simulation program {program} is not implemented.") raise NotImplementedError(f"don't know how to determine the CCness of {program} files.")
return out return out
......
...@@ -50,7 +50,6 @@ class TestFitinf(unittest.TestCase): ...@@ -50,7 +50,6 @@ class TestFitinf(unittest.TestCase):
self.best = self.tracks[:, 0] self.best = self.tracks[:, 0]
self.best_fit = self.best.fitinf self.best_fit = self.best.fitinf
@unittest.skip
def test_fitinf_from_all_events(self): def test_fitinf_from_all_events(self):
beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks) beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks)
...@@ -58,7 +57,6 @@ class TestFitinf(unittest.TestCase): ...@@ -58,7 +57,6 @@ class TestFitinf(unittest.TestCase):
assert beta[0][1] == self.fit[0][1][0] assert beta[0][1] == self.fit[0][1][0]
assert beta[0][2] == self.fit[0][2][0] assert beta[0][2] == self.fit[0][2][0]
@unittest.skip
def test_fitinf_from_one_event(self): def test_fitinf_from_one_event(self):
beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best) beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.best)
...@@ -66,12 +64,6 @@ class TestFitinf(unittest.TestCase): ...@@ -66,12 +64,6 @@ class TestFitinf(unittest.TestCase):
assert beta[1] == self.best_fit[1][0] assert beta[1] == self.best_fit[1][0]
assert beta[2] == self.best_fit[2][0] assert beta[2] == self.best_fit[2][0]
@unittest.skip
def test_fitinf_from_one_event_and_one_track(self):
beta = fitinf(kfit.JGANDALF_BETA0_RAD, self.tracks[0][0])
assert beta == self.tracks[0][0].fitinf[0]
class TestBestTrackSelection(unittest.TestCase): class TestBestTrackSelection(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -85,7 +77,6 @@ class TestBestTrackSelection(unittest.TestCase): ...@@ -85,7 +77,6 @@ class TestBestTrackSelection(unittest.TestCase):
assert len(best) == 10 assert len(best) == 10
# TODO: nested items, no idea how to solve this...
assert best.rec_stages[0].tolist() == [1, 3, 5, 4] assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4] assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4] assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
...@@ -118,7 +109,6 @@ class TestBestTrackSelection(unittest.TestCase): ...@@ -118,7 +109,6 @@ class TestBestTrackSelection(unittest.TestCase):
assert len(best) == 10 assert len(best) == 10
# TODO: nested items, no idea how to solve this...
assert best.rec_stages[0].tolist() == [1, 3, 5, 4] assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4] assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4] assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
...@@ -138,7 +128,6 @@ class TestBestTrackSelection(unittest.TestCase): ...@@ -138,7 +128,6 @@ class TestBestTrackSelection(unittest.TestCase):
assert len(best) == 10 assert len(best) == 10
# TODO: nested items, no idea how to solve this...
assert best.rec_stages[0].tolist() == [1, 3, 5, 4] assert best.rec_stages[0].tolist() == [1, 3, 5, 4]
assert best.rec_stages[1].tolist() == [1, 3, 5, 4] assert best.rec_stages[1].tolist() == [1, 3, 5, 4]
assert best.rec_stages[2].tolist() == [1, 3, 5, 4] assert best.rec_stages[2].tolist() == [1, 3, 5, 4]
...@@ -208,52 +197,81 @@ class TestBestTrackSelection(unittest.TestCase): ...@@ -208,52 +197,81 @@ class TestBestTrackSelection(unittest.TestCase):
assert best2.lik == ak.max(tracks_slice.lik) assert best2.lik == ak.max(tracks_slice.lik)
assert best2.rec_stages.tolist() == [1, 3, 5, 4] assert best2.rec_stages.tolist() == [1, 3, 5, 4]
@unittest.skip
def test_best_track_on_slices_with_start_end_one_event(self): def test_best_track_on_slices_with_start_end_one_event(self):
tracks_slice = self.one_event.tracks[0:5] tracks_slice = self.one_event.tracks[0:5]
best = best_track(tracks_slice, startend=(1, 4)) best = best_track(tracks_slice, startend=(1, 4))
assert len(best.lik) == 1
assert best.lik == ak.max(tracks_slice.lik) assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0][0] == 1 assert best.rec_stages[0] == 1
assert best.rec_stages[0][-1] == 4 assert best.rec_stages[-1] == 4
@unittest.skip
def test_best_track_on_slices_with_explicit_rec_stages_one_event(self): def test_best_track_on_slices_with_explicit_rec_stages_one_event(self):
tracks_slice = self.one_event.tracks[0:5] tracks_slice = self.one_event.tracks[0:5]
best = best_track(tracks_slice, stages=[1, 3, 5, 4]) best = best_track(tracks_slice, stages=[1, 3, 5, 4])
assert best.lik == ak.max(tracks_slice.lik) assert best.lik == ak.max(tracks_slice.lik)
assert best.rec_stages[0][0] == 1 assert best.rec_stages[0] == 1
assert best.rec_stages[0][-1] == 4 assert best.rec_stages[-1] == 4
@unittest.skip
def test_best_track_on_slices_multiple_events(self): def test_best_track_on_slices_multiple_events(self):
tracks_slice = self.events.tracks[0:5] tracks_slice = self.events[0:5].tracks
# stages in list # stages in list
best = best_track(tracks_slice, stages=[1, 3, 5, 4]) best = best_track(tracks_slice, stages=[1, 3, 5, 4])
assert len(best) == 5 assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik) assert np.allclose(
assert best.rec_stages[0].tolist() == [1, 3, 5, 4] best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
# stages in set # stages in set
best = best_track(tracks_slice, stages={1, 3, 4, 5}) best = best_track(tracks_slice, stages={3, 4, 5})
assert len(best) == 5 assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik) assert np.allclose(
assert best.rec_stages[0].tolist() == [1, 3, 5, 4] best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
assert best.rec_stages[i].tolist() == [1, 3, 5, 4]
# using start and end # using start and end
best = best_track(tracks_slice, startend=(1, 4)) start, end = (1, 4)
best = best_track(tracks_slice, startend=(start, end))
assert len(best) == 5 assert len(best) == 5
assert best.lik == ak.max(tracks_slice.lik) assert np.allclose(
assert best.rec_stages[0].tolist() == [1, 3, 5, 4] best.lik.tolist(),
[
294.6407542676734,
96.75133289411137,
560.2775306614813,
278.2872951665753,
99.59098153341449,
],
)
for i in range(len(best)):
rs = best.rec_stages[i].tolist()
assert rs[0] == start
assert rs[-1] == end
def test_best_track_raises_when_unknown_stages(self): def test_best_track_raises_when_unknown_stages(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -265,7 +283,6 @@ class TestBestTrackSelection(unittest.TestCase): ...@@ -265,7 +283,6 @@ class TestBestTrackSelection(unittest.TestCase):
class TestBestJmuon(unittest.TestCase): class TestBestJmuon(unittest.TestCase):
@unittest.skip
def test_best_jmuon(self): def test_best_jmuon(self):
best = best_jmuon(OFFLINE_FILE.events.tracks) best = best_jmuon(OFFLINE_FILE.events.tracks)
...@@ -537,7 +554,6 @@ class TestUnfoldIndices(unittest.TestCase): ...@@ -537,7 +554,6 @@ class TestUnfoldIndices(unittest.TestCase):
class TestIsCC(unittest.TestCase): class TestIsCC(unittest.TestCase):
@unittest.skip
def test_is_cc(self): def test_is_cc(self):
NC_file = is_cc(GENHEN_OFFLINE_FILE) NC_file = is_cc(GENHEN_OFFLINE_FILE)
CC_file = is_cc(GSEAGEN_OFFLINE_FILE) CC_file = is_cc(GSEAGEN_OFFLINE_FILE)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment