diff --git a/tests/test_tools.py b/tests/test_tools.py
index e719649c5319bd9718467ab6f94a1d05304737e1..dcc22aba9de766b357a16a5475a7d45493d94532 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -12,7 +12,8 @@ from km3io.tools import (to_num, cached_property, unfold_indices, unique,
                          uniquecount, fitinf, fitparams, count_nested, _find,
                          mask, best_track, rec_types, get_w2list_param,
                          get_multiplicity, best_jmuon, best_jshower,
-                         best_aashower, best_dusjshower)
+                         best_aashower, best_dusjshower, w2list_genhen_keys,
+                         w2list_gseagen_keys)
 
 OFFLINE_FILE = OfflineReader(data_path("offline/km3net_offline.root"))
 
@@ -353,7 +354,7 @@ class TestRecStagesMasks(unittest.TestCase):
         assert labels[0][2] == 0
         assert labels[1][0] == 0
 
-    def test_mask_with_explicit_rec_stages_with_multiple_events(self):
+    def test_mask_with_explicit_rec_stages_in_list_with_multiple_events(self):
         rec_stages = self.tracks.rec_stages
         stages = [1, 3, 5, 4]
         masks = mask(self.tracks, stages=stages)
@@ -362,6 +363,16 @@ class TestRecStagesMasks(unittest.TestCase):
         assert masks[1][0] == all(rec_stages[1][0] == ak.Array(stages))
         assert masks[0][1] == False
 
+    def test_mask_with_explicit_rec_stages_in_set_with_multiple_events(self):
+        stages = {1, 2, 3, 4, 5}
+        masks = mask(self.tracks, stages=stages)
+        tracks = self.tracks[masks]
+
+        assert 1 in tracks.rec_stages[0][0]
+        assert 3 in tracks.rec_stages[0][0]
+        assert 4 in tracks.rec_stages[0][0]
+        assert 5 in tracks.rec_stages[0][0]
+
     def test_mask_with_start_and_end_of_rec_stages_with_multiple_events(self):
         rec_stages = self.tracks.rec_stages
         stages = [1, 3, 5, 4]
@@ -389,6 +400,24 @@ class TestRecStagesMasks(unittest.TestCase):
         assert track[masks].rec_stages[0][0] == stages[0]
         assert track[masks].rec_stages[0][1] == stages[1]
 
+    def test_mask_raises_when_too_many_inputs(self):
+        with self.assertRaises(ValueError):
+            mask(self.tracks, start=1, end=4, stages=[1, 3, 5, 4])
+
+    def test_mask_raises_when_no_inputs(self):
+        with self.assertRaises(ValueError):
+            mask(self.tracks)
+
+
+class TestW2listGenhenKeys(unittest.TestCase):
+    def test_w2list_genhen_keys(self):
+        assert 'W2LIST_GENHEN_REFF' in w2list_genhen_keys()
+
+
+class TestW2listGseangenKeys(unittest.TestCase):
+    def test_w2list_gseagen_keys(self):
+        assert 'W2LIST_GSEAGEN_PS' in w2list_gseagen_keys()
+
 
 class TestUnique(unittest.TestCase):
     def run_random_test_with_dtype(self, dtype):