diff --git a/tests/test_offline.py b/tests/test_offline.py
index 965ead5edee2a1c4a92d74c8e126b8734807d8a9..a89a5840fc069dd17fd3b189bc6f3de4344d64b8 100644
--- a/tests/test_offline.py
+++ b/tests/test_offline.py
@@ -4,12 +4,15 @@ import awkward1 as ak1
 from pathlib import Path
 
 from km3io import OfflineReader
-from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask
+from km3io.offline import _nested_mapper, Header, fitinf, fitparams, count_nested, _find, mask, best_track
 
 SAMPLES_DIR = Path(__file__).parent / 'samples'
 OFFLINE_FILE = OfflineReader(SAMPLES_DIR / 'aanet_v2.0.0.root')
 OFFLINE_USR = OfflineReader(SAMPLES_DIR / 'usr-sample.root')
-OFFLINE_MC_TRACK_USR = OfflineReader(SAMPLES_DIR / 'mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root')
+OFFLINE_MC_TRACK_USR = OfflineReader(
+    SAMPLES_DIR /
+    'mcv5.11r2.gsg_muonCChigherE-CC_50-5000GeV.km3_AAv1.jterbr00004695.jchain.aanet.498.root'
+)
 OFFLINE_NUMUCC = OfflineReader(SAMPLES_DIR / "numucc.root")  # with mc data
 
 
@@ -38,6 +41,21 @@ class TestFitinf(unittest.TestCase):
         assert "JGANDALF_BETA0_RAD" in keys
 
 
+class TestBestTrack(unittest.TestCase):
+    def setUp(self):
+        self.tracks = OFFLINE_FILE.events.tracks
+
+    def test_best_tracks(self):
+        first_tracks = best_track(self.tracks, strategy="first")
+        rec_stages_tracks = best_track(self.tracks,
+                                       strategy="rec_stages",
+                                       rec_stages=[1, 3, 5, 4])
+
+        assert first_tracks.dir_z[0] == self.tracks.dir_z[0][0]
+        assert first_tracks.dir_x[1] == self.tracks.dir_x[1][0]
+        assert rec_stages_tracks.rec_stages[0] == [1, 3, 5, 4]
+
+
 class TestCountNested(unittest.TestCase):
     def test_count_nested(self):
         fit = OFFLINE_FILE.events.tracks.fitinf
@@ -442,10 +460,12 @@ class TestMcTrackUsr(unittest.TestCase):
     def test_usr_names(self):
         n_tracks = len(self.f.events)
         for i in range(3):
-            self.assertListEqual([b'bx', b'by', b'ichan', b'cc'],
-                                self.f.events.mc_tracks.usr_names[i][0].tolist())
-            self.assertListEqual([b'energy_lost_in_can'],
-                                self.f.events.mc_tracks.usr_names[i][1].tolist())
+            self.assertListEqual(
+                [b'bx', b'by', b'ichan', b'cc'],
+                self.f.events.mc_tracks.usr_names[i][0].tolist())
+            self.assertListEqual(
+                [b'energy_lost_in_can'],
+                self.f.events.mc_tracks.usr_names[i][1].tolist())
 
     def test_usr(self):
         assert np.allclose([0.0487, 0.0588, 3, 2],