Skip to content
Snippets Groups Projects
shuffle2.py 12.5 KiB
Newer Older
Stefan Reck's avatar
Stefan Reck committed
import os
import time
import datetime
import argparse
import numpy as np
Stefan Reck's avatar
Stefan Reck committed
import psutil
Stefan Reck's avatar
Stefan Reck committed
import h5py
Stefan Reck's avatar
Stefan Reck committed
from km3pipe.sys import peak_memory_usage
Stefan Reck's avatar
Stefan Reck committed

from orcasong.tools.postproc import get_filepath_output, copy_used_files
from orcasong.tools.concatenate import copy_attrs


Stefan Reck's avatar
Stefan Reck committed
__author__ = "Stefan Reck"
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
def h5shuffle2(
    input_file,
    output_file=None,
    iterations=None,
    datasets=("x", "y"),
    max_ram_fraction=0.25,
    **kwargs,
):
    if output_file is None:
        output_file = get_filepath_output(input_file, shuffle=True)
    if iterations is None:
        iterations = get_n_iterations(
            input_file,
            datasets=datasets,
            max_ram_fraction=max_ram_fraction,
        )
    np.random.seed(42)
    for i in range(iterations):
        print(f"\nIteration {i+1}/{iterations}")
        if iterations == 1:
            # special case if theres only one iteration
            stgs = {
                "input_file": input_file,
                "output_file": output_file,
                "delete": False,
            }
        elif i == 0:
            # first iteration
            stgs = {
                "input_file": input_file,
                "output_file": f"{output_file}_temp_{i}",
                "delete": False,
            }
Stefan Reck's avatar
Stefan Reck committed
        elif i == iterations - 1:
            # last iteration
            stgs = {
                "input_file": f"{output_file}_temp_{i-1}",
                "output_file": output_file,
                "delete": True,
            }
        else:
            # intermediate iterations
            stgs = {
                "input_file": f"{output_file}_temp_{i-1}",
                "output_file": f"{output_file}_temp_{i}",
                "delete": True,
            }
        shuffle_file(
            datasets=datasets,
            max_ram_fraction=max_ram_fraction,
            chunks=True,
            **stgs,
            **kwargs,
        )


def shuffle_file(
Stefan Reck's avatar
Stefan Reck committed
    input_file,
    datasets=("x", "y"),
    output_file=None,
    max_ram=None,
    max_ram_fraction=0.25,
    chunks=False,
    delete=False,
):
Stefan Reck's avatar
Stefan Reck committed
    """
    Shuffle datasets in a h5file that have the same length.

    Parameters
    ----------
    input_file : str
        Path of the file that will be shuffle.
    datasets : tuple
        Which datasets to include in output.
    output_file : str, optional
        If given, this will be the name of the output file.
        Otherwise, a name is auto generated.
Stefan Reck's avatar
Stefan Reck committed
    max_ram : int, optional
Stefan Reck's avatar
Stefan Reck committed
        Available ram in bytes. Default: Use fraction of
        maximum available (see max_ram_fraction).
    max_ram_fraction : float
Stefan Reck's avatar
Stefan Reck committed
        in [0, 1]. Fraction of ram to use for reading one batch of data
        when max_ram is None. Note: when using chunks, this should
        be <=~0.25, since lots of ram is needed for in-memory shuffling.
Stefan Reck's avatar
Stefan Reck committed
    chunks : bool
Stefan Reck's avatar
Stefan Reck committed
        Use chunk-wise readout. Large speed boost, but will
Stefan Reck's avatar
Stefan Reck committed
        only quasi-randomize order! Needs lots of ram
Stefan Reck's avatar
Stefan Reck committed
        to be accurate! (use a node with at least 32gb, the more the better)
Stefan Reck's avatar
Stefan Reck committed
    delete : bool
        Delete the original file afterwards?
Stefan Reck's avatar
Stefan Reck committed

    Returns
    -------
    output_file : str
        Path to the output file.

    """
    start_time = time.time()
    if output_file is None:
        output_file = get_filepath_output(input_file, shuffle=True)
    if os.path.exists(output_file):
        raise FileExistsError(output_file)
Stefan Reck's avatar
Stefan Reck committed
    if max_ram is None:
        max_ram = get_max_ram(max_ram_fraction)
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
    temp_output_file = (
        output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
    )
Stefan Reck's avatar
Stefan Reck committed
    with h5py.File(input_file, "r") as f_in:
        dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
Stefan Reck's avatar
Stefan Reck committed
        print(f"Shuffling datasets {datasets} with {n_lines} lines each")

        if not chunks:
            indices = np.arange(n_lines)
            np.random.shuffle(indices)

Stefan Reck's avatar
Stefan Reck committed
            with h5py.File(temp_output_file, "x") as f_out:
Stefan Reck's avatar
Stefan Reck committed
                for dset_info in dset_infos:
                    print("Creating dataset", dset_info["name"])
                    make_dset(f_out, dset_info, indices)
                    print("Done!")
        else:
            indices_chunked = get_indices_largest(dset_infos)
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
            with h5py.File(temp_output_file, "x") as f_out:
Stefan Reck's avatar
Stefan Reck committed
                for dset_info in dset_infos:
                    print("Creating dataset", dset_info["name"])
                    make_dset_chunked(f_out, dset_info, indices_chunked)
                    print("Done!")
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
    copy_used_files(input_file, temp_output_file)
    copy_attrs(input_file, temp_output_file)
    os.rename(temp_output_file, output_file)
    if delete:
        os.remove(input_file)
Stefan Reck's avatar
Stefan Reck committed
    print(
        f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
    )
Stefan Reck's avatar
Stefan Reck committed
    return output_file


def get_max_ram(max_ram_fraction):
    max_ram = max_ram_fraction * psutil.virtual_memory().available
    print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
    return max_ram


Stefan Reck's avatar
Stefan Reck committed
def get_indices_largest(dset_infos):
    largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
    dset_info = dset_infos[largest_dset]

Stefan Reck's avatar
Stefan Reck committed
    print(f"Total chunks: {dset_info['n_chunks']}")
Stefan Reck's avatar
Stefan Reck committed
    ratio = dset_info["chunks_per_batch"] / dset_info["n_chunks"]
Stefan Reck's avatar
Stefan Reck committed
    print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")

Stefan Reck's avatar
Stefan Reck committed
    return get_indices_chunked(
        dset_info["n_batches_chunkwise"],
        dset_info["n_chunks"],
        dset_info["chunksize"],
    )


Stefan Reck's avatar
Stefan Reck committed
def get_n_iterations(
    input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
):
    """ Get how often you have to shuffle with given ram to get proper randomness. """
    if max_ram is None:
        max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
    with h5py.File(input_file, "r") as f_in:
        dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
    largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
    dset_info = dset_infos[largest_dset]
Stefan Reck's avatar
Stefan Reck committed
    n_iterations = int(
        np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
    )
    print(f"Total chunks: {dset_info['n_chunks']}")
    print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
    print(f"--> min iterations for full shuffle: {n_iterations}")
    return n_iterations


Stefan Reck's avatar
Stefan Reck committed
def get_indices_chunked(n_batches, n_chunks, chunksize):
    """ Return a list with the chunkwise shuffled indices of each batch. """
    chunk_indices = np.arange(n_chunks)
    np.random.shuffle(chunk_indices)
    chunk_batches = np.array_split(chunk_indices, n_batches)

    index_batches = []
    for bat in chunk_batches:
Stefan Reck's avatar
Stefan Reck committed
        idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
Stefan Reck's avatar
Stefan Reck committed
        np.random.shuffle(idx)
        index_batches.append(idx)

    return index_batches


Stefan Reck's avatar
Stefan Reck committed
def get_dset_infos(f, datasets, max_ram):
    """ Check datasets and retrieve relevant infos for each. """
    dset_infos = []
    n_lines = None
    for i, name in enumerate(datasets):
        dset = f[name]
        if i == 0:
            n_lines = len(dset)
        else:
            if len(dset) != n_lines:
Stefan Reck's avatar
Stefan Reck committed
                raise ValueError(
                    f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}"
                )
Stefan Reck's avatar
Stefan Reck committed
        chunksize = dset.chunks[0]
        n_chunks = int(np.ceil(n_lines / chunksize))
Stefan Reck's avatar
Stefan Reck committed
        # TODO in h5py 3.X, use .nbytes to get uncompressed size
        bytes_per_line = np.asarray(dset[0]).nbytes
Stefan Reck's avatar
Stefan Reck committed
        bytes_per_chunk = bytes_per_line * chunksize

        lines_per_batch = int(np.floor(max_ram / bytes_per_line))
        chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))

Stefan Reck's avatar
Stefan Reck committed
        dset_infos.append(
            {
                "name": name,
                "dset": dset,
                "chunksize": chunksize,
                "n_lines": n_lines,
                "n_chunks": n_chunks,
                "bytes_per_line": bytes_per_line,
                "bytes_per_chunk": bytes_per_chunk,
                "lines_per_batch": lines_per_batch,
                "chunks_per_batch": chunks_per_batch,
                "n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
                "n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
            }
        )
Stefan Reck's avatar
Stefan Reck committed
    return dset_infos, n_lines


def make_dset(f_out, dset_info, indices):
    """ Create a shuffled dataset in the output file. """
Stefan Reck's avatar
Stefan Reck committed
    for batch_index in range(dset_info["n_batches_linewise"]):
Stefan Reck's avatar
Stefan Reck committed
        print(f"Processing batch {batch_index+1}/{dset_info['n_batches_linewise']}")
Stefan Reck's avatar
Stefan Reck committed

        slc = slice(
            batch_index * dset_info["lines_per_batch"],
Stefan Reck's avatar
Stefan Reck committed
            (batch_index + 1) * dset_info["lines_per_batch"],
Stefan Reck's avatar
Stefan Reck committed
        )
        to_read = indices[slc]
        # reading has to be done with linearly increasing index,
        #  so sort -> read -> undo sorting
        sort_ix = np.argsort(to_read)
        unsort_ix = np.argsort(sort_ix)
        data = dset_info["dset"][to_read[sort_ix]][unsort_ix]

        if batch_index == 0:
            in_dset = dset_info["dset"]
            out_dset = f_out.create_dataset(
                dset_info["name"],
                data=data,
                maxshape=in_dset.shape,
                chunks=in_dset.chunks,
                compression=in_dset.compression,
                compression_opts=in_dset.compression_opts,
                shuffle=in_dset.shuffle,
            )
            out_dset.resize(len(in_dset), axis=0)
        else:
            f_out[dset_info["name"]][slc] = data


Stefan Reck's avatar
Stefan Reck committed
def make_dset_chunked(f_out, dset_info, indices_chunked):
    """ Create a shuffled dataset in the output file. """
    start_idx = 0
    for batch_index, to_read in enumerate(indices_chunked):
        print(f"Processing batch {batch_index+1}/{len(indices_chunked)}")

        # remove indices outside of dset
        to_read = to_read[to_read < len(dset_info["dset"])]

Stefan Reck's avatar
Stefan Reck committed
        # reading has to be done with linearly increasing index
        #  fancy indexing is super slow
        #  so sort -> turn to slices -> read -> conc -> undo sorting
Stefan Reck's avatar
Stefan Reck committed
        sort_ix = np.argsort(to_read)
        unsort_ix = np.argsort(sort_ix)
Stefan Reck's avatar
Stefan Reck committed
        fancy_indices = to_read[sort_ix]
        slices = slicify(fancy_indices)
        data = np.concatenate([dset_info["dset"][slc] for slc in slices])
        data = data[unsort_ix]
Stefan Reck's avatar
Stefan Reck committed

        if batch_index == 0:
            in_dset = dset_info["dset"]
            out_dset = f_out.create_dataset(
                dset_info["name"],
                data=data,
                maxshape=in_dset.shape,
                chunks=in_dset.chunks,
                compression=in_dset.compression,
                compression_opts=in_dset.compression_opts,
                shuffle=in_dset.shuffle,
            )
            out_dset.resize(len(in_dset), axis=0)
            start_idx = len(data)
        else:
            end_idx = start_idx + len(data)
            f_out[dset_info["name"]][start_idx:end_idx] = data
            start_idx = end_idx

Stefan Reck's avatar
Stefan Reck committed
        print("Memory peak: {0:.3f} MB".format(peak_memory_usage()))

Stefan Reck's avatar
Stefan Reck committed
    if start_idx != len(dset_info["dset"]):
        print(f"Warning: last index was {start_idx} not {len(dset_info['dset'])}")


Stefan Reck's avatar
Stefan Reck committed
def slicify(fancy_indices):
    """ [0,1,2, 6,7,8] --> [0:3, 6:9] """
    steps = np.diff(fancy_indices) != 1
    slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
    slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1
    return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]


Stefan Reck's avatar
Stefan Reck committed
    parser = argparse.ArgumentParser(
Stefan Reck's avatar
Stefan Reck committed
        description="Shuffle datasets in a h5file that have the same length. "
        "Uses chunkwise readout for speed-up."
    )
    parser.add_argument(
        "input_file", type=str, help="Path of the file that will be shuffled."
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default=None,
        help="If given, this will be the name of the output file. "
        "Default: input_file + suffix.",
    )
    parser.add_argument(
        "--datasets",
        type=str,
        nargs="*",
        default=("x", "y"),
        help="Which datasets to include in output. Default: x, y",
    )
    parser.add_argument(
        "--max_ram_fraction",
        type=float,
        default=0.25,
        help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
        "Note: this should "
        "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
        "Default: 0.25",
    )
    parser.add_argument(
        "--iterations",
        type=int,
        default=None,
        help="Shuffle the file this many times. Default: Auto choose best number.",
    )
    h5shuffle2(**vars(parser.parse_args()))