diff --git a/orcasong/from_toml.py b/orcasong/from_toml.py
index e3715c161cc7a4dbdca0f96fb76d513f903e7774..65029414b12dd7aae755893f68edf0f7ea0ccc34 100644
--- a/orcasong/from_toml.py
+++ b/orcasong/from_toml.py
@@ -11,8 +11,8 @@ EXTRACTORS = {
     "nu_chain_muon": extractors.get_muon_mc_info_extr,
     "nu_chain_noise": extractors.get_random_noise_mc_info_extr,
     "nu_chain_data": extractors.get_real_data_info_extr,
-    "bundle_mc": extractors.bundles.BundleMCExtractor,
-    "bundle_data": extractors.bundles.BundleDataExtractor,
+    "bundle_mc": extractors.BundleMCExtractor,
+    "bundle_data": extractors.BundleDataExtractor,
 }
 
 MODES = {
diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py
index 0fcdaa79a7eb7532b8e6e8c3cd344940c0708811..639d62a96b2e39d8c1730e7a971f3e2a919e4395 100644
--- a/orcasong/tools/shuffle2.py
+++ b/orcasong/tools/shuffle2.py
@@ -1,8 +1,6 @@
 import os
 import time
 import datetime
-import argparse
-import warnings
 
 import numpy as np
 import psutil
@@ -23,64 +21,7 @@ def h5shuffle2(
     datasets=("x", "y"),
     max_ram_fraction=0.25,
     max_ram=None,
-):
-    if output_file is None:
-        output_file = get_filepath_output(input_file, shuffle=True)
-    if iterations is None:
-        iterations = get_n_iterations(
-            input_file,
-            datasets=datasets,
-            max_ram_fraction=max_ram_fraction,
-            max_ram=max_ram,
-        )
-    np.random.seed(42)
-    for i in range(iterations):
-        print(f"\nIteration {i+1}/{iterations}")
-        if iterations == 1:
-            # special case if theres only one iteration
-            stgs = {
-                "input_file": input_file,
-                "output_file": output_file,
-                "delete": False,
-            }
-        elif i == 0:
-            # first iteration
-            stgs = {
-                "input_file": input_file,
-                "output_file": f"{output_file}_temp_{i}",
-                "delete": False,
-            }
-        elif i == iterations - 1:
-            # last iteration
-            stgs = {
-                "input_file": f"{output_file}_temp_{i-1}",
-                "output_file": output_file,
-                "delete": True,
-            }
-        else:
-            # intermediate iterations
-            stgs = {
-                "input_file": f"{output_file}_temp_{i-1}",
-                "output_file": f"{output_file}_temp_{i}",
-                "delete": True,
-            }
-        shuffle_file(
-            datasets=datasets,
-            max_ram=max_ram,
-            max_ram_fraction=max_ram_fraction,
-            chunks=True,
-            **stgs,
-        )
-
-
-def shuffle_file(
-    input_file,
-    datasets=("x", "y"),
-    output_file=None,
-    max_ram=None,
-    max_ram_fraction=0.25,
-    chunks=False,
-    delete=False,
+    seed=42,
 ):
     """
     Shuffle datasets in a h5file that have the same length.
@@ -89,24 +30,24 @@ def shuffle_file(
     ----------
     input_file : str
         Path of the file that will be shuffle.
-    datasets : tuple
-        Which datasets to include in output.
     output_file : str, optional
         If given, this will be the name of the output file.
         Otherwise, a name is auto generated.
+    iterations : int, optional
+        Shuffle the file this many times. For each additional iteration,
+        a temporary file will be created and then deleted afterwards.
+        Default: Auto choose best number based on available RAM.
+    datasets : tuple
+        Which datasets to include in output.
     max_ram : int, optional
         Available ram in bytes. Default: Use fraction of
         maximum available (see max_ram_fraction).
     max_ram_fraction : float
-        in [0, 1]. Fraction of ram to use for reading one batch of data
+        in [0, 1]. Fraction of RAM to use for reading one batch of data
         when max_ram is None. Note: when using chunks, this should
         be <=~0.25, since lots of ram is needed for in-memory shuffling.
-    chunks : bool
-        Use chunk-wise readout. Large speed boost, but will
-        only quasi-randomize order! Needs lots of ram
-        to be accurate! (use a node with at least 32gb, the more the better)
-    delete : bool
-        Delete the original file afterwards?
+    seed : int or None
+        Seed for randomness.
 
     Returns
     -------
@@ -114,38 +55,68 @@ def shuffle_file(
         Path to the output file.
 
     """
-    start_time = time.time()
     if output_file is None:
         output_file = get_filepath_output(input_file, shuffle=True)
+    if iterations is None:
+        iterations = get_n_iterations(
+            input_file,
+            datasets=datasets,
+            max_ram_fraction=max_ram_fraction,
+            max_ram=max_ram,
+        )
+    # filenames of all iterations, in the right order
+    filenames = (
+        input_file,
+        *_get_temp_filenames(output_file, number=iterations - 1),
+        output_file,
+    )
+    if seed:
+        np.random.seed(seed)
+    for i in range(iterations):
+        print(f"\nIteration {i+1}/{iterations}")
+        _shuffle_file(
+            input_file=filenames[i],
+            output_file=filenames[i + 1],
+            delete=i > 0,
+            datasets=datasets,
+            max_ram=max_ram,
+            max_ram_fraction=max_ram_fraction,
+        )
+    return output_file
+
+
+def _shuffle_file(
+    input_file,
+    output_file,
+    datasets=("x", "y"),
+    max_ram=None,
+    max_ram_fraction=0.25,
+    delete=False,
+):
+    start_time = time.time()
     if os.path.exists(output_file):
         raise FileExistsError(output_file)
     if max_ram is None:
         max_ram = get_max_ram(max_ram_fraction)
-
+    # create file with temp name first, then rename afterwards
     temp_output_file = (
         output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
     )
     with h5py.File(input_file, "r") as f_in:
-        dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
-        print(f"Shuffling datasets {datasets} with {n_lines} lines each")
-
-        if not chunks:
-            indices = np.arange(n_lines)
-            np.random.shuffle(indices)
-
-            with h5py.File(temp_output_file, "x") as f_out:
-                for dset_info in dset_infos:
-                    print("Creating dataset", dset_info["name"])
-                    make_dset(f_out, dset_info, indices)
-                    print("Done!")
-        else:
-            indices_chunked = get_indices_largest(dset_infos)
+        _check_dsets(f_in, datasets)
+        dset_info = _get_largest_dset(f_in, datasets, max_ram)
+        print(f"Shuffling datasets {datasets}")
+        indices_per_batch = _get_indices_per_batch(
+            dset_info["n_batches"],
+            dset_info["n_chunks"],
+            dset_info["chunksize"],
+        )
 
-            with h5py.File(temp_output_file, "x") as f_out:
-                for dset_info in dset_infos:
-                    print("Creating dataset", dset_info["name"])
-                    make_dset_chunked(f_out, dset_info, indices_chunked)
-                    print("Done!")
+        with h5py.File(temp_output_file, "x") as f_out:
+            for dset_name in datasets:
+                print("Creating dataset", dset_name)
+                _shuffle_dset(f_out, f_in, dset_name, indices_per_batch)
+                print("Done!")
 
     copy_used_files(input_file, temp_output_file)
     copy_attrs(input_file, temp_output_file)
@@ -164,21 +135,6 @@ def get_max_ram(max_ram_fraction):
     return max_ram
 
 
-def get_indices_largest(dset_infos):
-    largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
-    dset_info = dset_infos[largest_dset]
-
-    print(f"Total chunks: {dset_info['n_chunks']}")
-    ratio = dset_info["chunks_per_batch"] / dset_info["n_chunks"]
-    print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
-
-    return get_indices_chunked(
-        dset_info["n_batches_chunkwise"],
-        dset_info["n_chunks"],
-        dset_info["chunksize"],
-    )
-
-
 def get_n_iterations(
     input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
 ):
@@ -186,9 +142,7 @@ def get_n_iterations(
     if max_ram is None:
         max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
     with h5py.File(input_file, "r") as f_in:
-        dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
-    largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
-    dset_info = dset_infos[largest_dset]
+        dset_info = _get_largest_dset(f_in, datasets, max_ram)
     n_iterations = int(
         np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
     )
@@ -198,137 +152,117 @@ def get_n_iterations(
     return n_iterations
 
 
-def get_indices_chunked(n_batches, n_chunks, chunksize):
-    """ Return a list with the chunkwise shuffled indices of each batch. """
+def _get_indices_per_batch(n_batches, n_chunks, chunksize):
+    """
+    Return a list with the shuffled indices for each batch.
+
+    Returns
+    -------
+    indices_per_batch : List
+        Length n_batches, each element is a np.array[int].
+        Element i of the list are the indices of each sample in batch number i.
+
+    """
     chunk_indices = np.arange(n_chunks)
     np.random.shuffle(chunk_indices)
     chunk_batches = np.array_split(chunk_indices, n_batches)
 
-    index_batches = []
+    indices_per_batch = []
     for bat in chunk_batches:
         idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
         np.random.shuffle(idx)
-        index_batches.append(idx)
+        indices_per_batch.append(idx)
+
+    return indices_per_batch
+
 
-    return index_batches
+def _get_largest_dset(f, datasets, max_ram):
+    """
+    Get infos about the dset that needs the most batches.
+    This is the dset that determines how many samples are shuffled at a time.
+    """
+    dset_infos = _get_dset_infos(f, datasets, max_ram)
+    return dset_infos[np.argmax([v["n_batches"] for v in dset_infos])]
 
 
-def get_dset_infos(f, datasets, max_ram):
-    """ Check datasets and retrieve relevant infos for each. """
+def _check_dsets(f, datasets):
+    # check if all datasets have the same number of lines
+    n_lines_list = [len(f[dset_name]) for dset_name in datasets]
+    if not all([n == n_lines_list[0] for n in n_lines_list]):
+        raise ValueError(
+            f"Datasets have different lengths! " f"{n_lines_list}"
+        )
+
+
+def _get_dset_infos(f, datasets, max_ram):
+    """ Retrieve infos for each dataset. """
     dset_infos = []
-    n_lines = None
     for i, name in enumerate(datasets):
         dset = f[name]
-        if i == 0:
-            n_lines = len(dset)
-        else:
-            if len(dset) != n_lines:
-                raise ValueError(
-                    f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}"
-                )
+        n_lines = len(dset)
         chunksize = dset.chunks[0]
         n_chunks = int(np.ceil(n_lines / chunksize))
-        # TODO in h5py 3.X, use .nbytes to get uncompressed size
         bytes_per_line = np.asarray(dset[0]).nbytes
         bytes_per_chunk = bytes_per_line * chunksize
-
-        lines_per_batch = int(np.floor(max_ram / bytes_per_line))
         chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
 
-        dset_infos.append(
-            {
-                "name": name,
-                "dset": dset,
-                "chunksize": chunksize,
-                "n_lines": n_lines,
-                "n_chunks": n_chunks,
-                "bytes_per_line": bytes_per_line,
-                "bytes_per_chunk": bytes_per_chunk,
-                "lines_per_batch": lines_per_batch,
-                "chunks_per_batch": chunks_per_batch,
-                "n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
-                "n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
-            }
-        )
-    return dset_infos, n_lines
-
+        dset_infos.append({
+            "name": name,
+            "dset": dset,
+            "n_chunks": n_chunks,
+            "chunks_per_batch": chunks_per_batch,
+            "n_batches": int(np.ceil(n_chunks / chunks_per_batch)),
+            "chunksize": chunksize,
+        })
 
-def make_dset(f_out, dset_info, indices):
-    """ Create a shuffled dataset in the output file. """
-    for batch_index in range(dset_info["n_batches_linewise"]):
-        print(f"Processing batch {batch_index+1}/{dset_info['n_batches_linewise']}")
+    return dset_infos
 
-        slc = slice(
-            batch_index * dset_info["lines_per_batch"],
-            (batch_index + 1) * dset_info["lines_per_batch"],
-        )
-        to_read = indices[slc]
-        # reading has to be done with linearly increasing index,
-        #  so sort -> read -> undo sorting
-        sort_ix = np.argsort(to_read)
-        unsort_ix = np.argsort(sort_ix)
-        data = dset_info["dset"][to_read[sort_ix]][unsort_ix]
-
-        if batch_index == 0:
-            in_dset = dset_info["dset"]
-            out_dset = f_out.create_dataset(
-                dset_info["name"],
-                data=data,
-                maxshape=in_dset.shape,
-                chunks=in_dset.chunks,
-                compression=in_dset.compression,
-                compression_opts=in_dset.compression_opts,
-                shuffle=in_dset.shuffle,
-            )
-            out_dset.resize(len(in_dset), axis=0)
-        else:
-            f_out[dset_info["name"]][slc] = data
 
+def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch):
+    """
+    Create a batchwise-shuffled dataset in the output file using given indices.
 
-def make_dset_chunked(f_out, dset_info, indices_chunked):
-    """ Create a shuffled dataset in the output file. """
+    """
+    dset_in = f_in[dset_name]
     start_idx = 0
-    for batch_index, to_read in enumerate(indices_chunked):
-        print(f"Processing batch {batch_index+1}/{len(indices_chunked)}")
-
+    for batch_number, indices in enumerate(indices_per_batch):
+        print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}")
         # remove indices outside of dset
-        to_read = to_read[to_read < len(dset_info["dset"])]
-
+        indices = indices[indices < len(dset_in)]
         # reading has to be done with linearly increasing index
         #  fancy indexing is super slow
         #  so sort -> turn to slices -> read -> conc -> undo sorting
-        sort_ix = np.argsort(to_read)
+        sort_ix = np.argsort(indices)
         unsort_ix = np.argsort(sort_ix)
-        fancy_indices = to_read[sort_ix]
-        slices = slicify(fancy_indices)
-        data = np.concatenate([dset_info["dset"][slc] for slc in slices])
+        fancy_indices = indices[sort_ix]
+        slices = _slicify(fancy_indices)
+        data = np.concatenate([dset_in[slc] for slc in slices])
         data = data[unsort_ix]
 
-        if batch_index == 0:
-            in_dset = dset_info["dset"]
+        if batch_number == 0:
             out_dset = f_out.create_dataset(
-                dset_info["name"],
+                dset_name,
                 data=data,
-                maxshape=in_dset.shape,
-                chunks=in_dset.chunks,
-                compression=in_dset.compression,
-                compression_opts=in_dset.compression_opts,
-                shuffle=in_dset.shuffle,
+                maxshape=dset_in.shape,
+                chunks=dset_in.chunks,
+                compression=dset_in.compression,
+                compression_opts=dset_in.compression_opts,
+                shuffle=dset_in.shuffle,
             )
-            out_dset.resize(len(in_dset), axis=0)
+            out_dset.resize(len(dset_in), axis=0)
             start_idx = len(data)
         else:
             end_idx = start_idx + len(data)
-            f_out[dset_info["name"]][start_idx:end_idx] = data
+            f_out[dset_name][start_idx:end_idx] = data
             start_idx = end_idx
 
         print("Memory peak: {0:.3f} MB".format(peak_memory_usage()))
 
-    if start_idx != len(dset_info["dset"]):
-        print(f"Warning: last index was {start_idx} not {len(dset_info['dset'])}")
+    if start_idx != len(dset_in):
+        print(f"Warning: last index was {start_idx} not {len(dset_in)}")
 
 
-def slicify(fancy_indices):
+def _slicify(fancy_indices):
     """ [0,1,2, 6,7,8] --> [0:3, 6:9] """
     steps = np.diff(fancy_indices) != 1
     slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
@@ -336,50 +270,6 @@ def slicify(fancy_indices):
     return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
 
 
-def run_parser():
-    # TODO deprecated
-    warnings.warn("h5shuffle2 is deprecated and has been renamed to orcasong h5shuffle2")
-    parser = argparse.ArgumentParser(
-        description="Shuffle datasets in a h5file that have the same length. "
-        "Uses chunkwise readout for speed-up."
-    )
-    parser.add_argument(
-        "input_file", type=str, help="Path of the file that will be shuffled."
-    )
-    parser.add_argument(
-        "--output_file",
-        type=str,
-        default=None,
-        help="If given, this will be the name of the output file. "
-        "Default: input_file + suffix.",
-    )
-    parser.add_argument(
-        "--datasets",
-        type=str,
-        nargs="*",
-        default=("x", "y"),
-        help="Which datasets to include in output. Default: x, y",
-    )
-    parser.add_argument(
-        "--max_ram_fraction",
-        type=float,
-        default=0.25,
-        help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
-        "Note: this should "
-        "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
-        "Default: 0.25",
-    )
-    parser.add_argument(
-        "--iterations",
-        type=int,
-        default=None,
-        help="Shuffle the file this many times. Default: Auto choose best number.",
-    )
-    parser.add_argument(
-        "--max_ram",
-        type=int,
-        default=None,
-        help="Available ram in bytes. Default: Use fraction of maximum "
-             "available instead (see max_ram_fraction).",
-    )
-    h5shuffle2(**vars(parser.parse_args()))
+def _get_temp_filenames(output_file, number):
+    path, file = os.path.split(output_file)
+    return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)]