diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py
index 651603740b80e83ac5aa301d25826a7e2108fafa..9fd6970a1c70ef2e93d6f7ccd821c52fe27ea86d 100644
--- a/orcasong/tools/shuffle2.py
+++ b/orcasong/tools/shuffle2.py
@@ -14,14 +14,71 @@ from orcasong.tools.concatenate import copy_attrs
 __author__ = "Stefan Reck"
 
 
-def shuffle_v2(
-        input_file,
-        datasets=("x", "y"),
-        output_file=None,
-        max_ram=None,
-        max_ram_fraction=0.25,
-        chunks=False,
-        delete=False):
+def h5shuffle2(
+    input_file,
+    output_file=None,
+    iterations=None,
+    datasets=("x", "y"),
+    max_ram_fraction=0.25,
+    **kwargs,
+):
+    if output_file is None:
+        output_file = get_filepath_output(input_file, shuffle=True)
+    if iterations is None:
+        iterations = get_n_iterations(
+            input_file,
+            datasets=datasets,
+            max_ram_fraction=max_ram_fraction,
+        )
+    np.random.seed(42)
+    for i in range(iterations):
+        print(f"\nIteration {i+1}/{iterations}")
+        if iterations == 1:
+            # special case if theres only one iteration
+            stgs = {
+                "input_file": input_file,
+                "output_file": output_file,
+                "delete": False,
+            }
+        elif i == 0:
+            # first iteration
+            stgs = {
+                "input_file": input_file,
+                "output_file": f"{output_file}_temp_{i}",
+                "delete": False,
+            }
+        elif i == iterations - 1:
+            # last iteration
+            stgs = {
+                "input_file": f"{output_file}_temp_{i-1}",
+                "output_file": output_file,
+                "delete": True,
+            }
+        else:
+            # intermediate iterations
+            stgs = {
+                "input_file": f"{output_file}_temp_{i-1}",
+                "output_file": f"{output_file}_temp_{i}",
+                "delete": True,
+            }
+        shuffle_file(
+            datasets=datasets,
+            max_ram_fraction=max_ram_fraction,
+            chunks=True,
+            **stgs,
+            **kwargs,
+        )
+
+
+def shuffle_file(
+    input_file,
+    datasets=("x", "y"),
+    output_file=None,
+    max_ram=None,
+    max_ram_fraction=0.25,
+    chunks=False,
+    delete=False,
+):
     """
     Shuffle datasets in a h5file that have the same length.
 
@@ -62,7 +119,9 @@ def shuffle_v2(
     if max_ram is None:
         max_ram = get_max_ram(max_ram_fraction)
 
-    temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
+    temp_output_file = (
+        output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
+    )
     with h5py.File(input_file, "r") as f_in:
         dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
         print(f"Shuffling datasets {datasets} with {n_lines} lines each")
@@ -90,8 +149,9 @@ def shuffle_v2(
     os.rename(temp_output_file, output_file)
     if delete:
         os.remove(input_file)
-    print(f"Elapsed time: "
-          f"{datetime.timedelta(seconds=int(time.time() - start_time))}")
+    print(
+        f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
+    )
     return output_file
 
 
@@ -106,7 +166,7 @@ def get_indices_largest(dset_infos):
     dset_info = dset_infos[largest_dset]
 
     print(f"Total chunks: {dset_info['n_chunks']}")
-    ratio = dset_info['chunks_per_batch']/dset_info['n_chunks']
+    ratio = dset_info["chunks_per_batch"] / dset_info["n_chunks"]
     print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
 
     return get_indices_chunked(
@@ -116,7 +176,9 @@ def get_indices_largest(dset_infos):
     )
 
 
-def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25):
+def get_n_iterations(
+    input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
+):
     """ Get how often you have to shuffle with given ram to get proper randomness. """
     if max_ram is None:
         max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
@@ -124,8 +186,9 @@ def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_frac
         dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
     largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
     dset_info = dset_infos[largest_dset]
-    n_iterations = int(np.ceil(
-        np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch'])))
+    n_iterations = int(
+        np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
+    )
     print(f"Total chunks: {dset_info['n_chunks']}")
     print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
     print(f"--> min iterations for full shuffle: {n_iterations}")
@@ -140,7 +203,7 @@ def get_indices_chunked(n_batches, n_chunks, chunksize):
 
     index_batches = []
     for bat in chunk_batches:
-        idx = (bat[:, None]*chunksize + np.arange(chunksize)[None, :]).flatten()
+        idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
         np.random.shuffle(idx)
         index_batches.append(idx)
 
@@ -157,8 +220,9 @@ def get_dset_infos(f, datasets, max_ram):
             n_lines = len(dset)
         else:
             if len(dset) != n_lines:
-                raise ValueError(f"dataset {name} has different length! "
-                                 f"{len(dset)} vs {n_lines}")
+                raise ValueError(
+                    f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}"
+                )
         chunksize = dset.chunks[0]
         n_chunks = int(np.ceil(n_lines / chunksize))
         # TODO in h5py 3.X, use .nbytes to get uncompressed size
@@ -168,19 +232,21 @@ def get_dset_infos(f, datasets, max_ram):
         lines_per_batch = int(np.floor(max_ram / bytes_per_line))
         chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
 
-        dset_infos.append({
-            "name": name,
-            "dset": dset,
-            "chunksize": chunksize,
-            "n_lines": n_lines,
-            "n_chunks": n_chunks,
-            "bytes_per_line": bytes_per_line,
-            "bytes_per_chunk": bytes_per_chunk,
-            "lines_per_batch": lines_per_batch,
-            "chunks_per_batch": chunks_per_batch,
-            "n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
-            "n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
-        })
+        dset_infos.append(
+            {
+                "name": name,
+                "dset": dset,
+                "chunksize": chunksize,
+                "n_lines": n_lines,
+                "n_chunks": n_chunks,
+                "bytes_per_line": bytes_per_line,
+                "bytes_per_chunk": bytes_per_chunk,
+                "lines_per_batch": lines_per_batch,
+                "chunks_per_batch": chunks_per_batch,
+                "n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
+                "n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
+            }
+        )
     return dset_infos, n_lines
 
 
@@ -191,7 +257,7 @@ def make_dset(f_out, dset_info, indices):
 
         slc = slice(
             batch_index * dset_info["lines_per_batch"],
-            (batch_index+1) * dset_info["lines_per_batch"],
+            (batch_index + 1) * dset_info["lines_per_batch"],
         )
         to_read = indices[slc]
         # reading has to be done with linearly increasing index,
@@ -267,59 +333,41 @@ def slicify(fancy_indices):
     return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
 
 
-def h5shuffle2():
+def run_parser():
     parser = argparse.ArgumentParser(
-        description='Shuffle datasets in a h5file that have the same length. '
-                    'Uses chunkwise readout for speed-up.')
-    parser.add_argument('input_file', type=str,
-                        help='Path of the file that will be shuffled.')
-    parser.add_argument('--output_file', type=str, default=None,
-                        help='If given, this will be the name of the output file. '
-                             'Default: input_file + suffix.')
-    parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"),
-                        help='Which datasets to include in output. Default: x, y')
-    parser.add_argument('--max_ram_fraction', type=float, default=0.25,
-                        help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
-                             "Note: this should "
-                             "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
-                             "Default: 0.25")
-    parser.add_argument('--iterations', type=int, default=None,
-                        help="Shuffle the file this many times. Default: Auto choose best number.")
-    kwargs = vars(parser.parse_args())
-
-    input_file = kwargs.pop("input_file")
-    output_file = kwargs.pop("output_file")
-    if output_file is None:
-        output_file = get_filepath_output(input_file, shuffle=True)
-    iterations = kwargs.pop("iterations")
-    if iterations is None:
-        iterations = get_n_iterations(
-            input_file,
-            datasets=kwargs["datasets"],
-            max_ram_fraction=kwargs["max_ram_fraction"],
-        )
-    np.random.seed(42)
-    for i in range(iterations):
-        print(f"\nIteration {i+1}/{iterations}")
-        if i == 0:
-            # first iteration
-            stgs = {
-                "input_file": input_file,
-                "output_file": f"{output_file}_temp_{i}",
-                "delete": False
-            }
-        elif i == iterations-1:
-            # last iteration
-            stgs = {
-                "input_file": f"{output_file}_temp_{i-1}",
-                "output_file": output_file,
-                "delete": True
-            }
-        else:
-            # intermediate iterations
-            stgs = {
-                "input_file": f"{output_file}_temp_{i-1}",
-                "output_file": f"{output_file}_temp_{i}",
-                "delete": True
-            }
-        shuffle_v2(**kwargs, **stgs, chunks=True)
+        description="Shuffle datasets in a h5file that have the same length. "
+        "Uses chunkwise readout for speed-up."
+    )
+    parser.add_argument(
+        "input_file", type=str, help="Path of the file that will be shuffled."
+    )
+    parser.add_argument(
+        "--output_file",
+        type=str,
+        default=None,
+        help="If given, this will be the name of the output file. "
+        "Default: input_file + suffix.",
+    )
+    parser.add_argument(
+        "--datasets",
+        type=str,
+        nargs="*",
+        default=("x", "y"),
+        help="Which datasets to include in output. Default: x, y",
+    )
+    parser.add_argument(
+        "--max_ram_fraction",
+        type=float,
+        default=0.25,
+        help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
+        "Note: this should "
+        "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
+        "Default: 0.25",
+    )
+    parser.add_argument(
+        "--iterations",
+        type=int,
+        default=None,
+        help="Shuffle the file this many times. Default: Auto choose best number.",
+    )
+    h5shuffle2(**vars(parser.parse_args()))