Skip to content
Snippets Groups Projects
shuffle2.py 9.07 KiB
Newer Older
Stefan Reck's avatar
Stefan Reck committed
import os
import time
import datetime
Stefan Reck's avatar
Stefan Reck committed
import numpy as np
Stefan Reck's avatar
Stefan Reck committed
import psutil
Stefan Reck's avatar
Stefan Reck committed
import h5py
Stefan Reck's avatar
Stefan Reck committed
from km3pipe.sys import peak_memory_usage
Stefan Reck's avatar
Stefan Reck committed

from orcasong.tools.postproc import get_filepath_output, copy_used_files
from orcasong.tools.concatenate import copy_attrs


Stefan Reck's avatar
Stefan Reck committed
__author__ = "Stefan Reck"
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
def h5shuffle2(
    input_file,
    output_file=None,
    iterations=None,
    datasets=("x", "y"),
    max_ram_fraction=0.25,
    max_ram=None,
Stefan Reck's avatar
Stefan Reck committed
    seed=42,
Stefan Reck's avatar
Stefan Reck committed
):
Stefan Reck's avatar
Stefan Reck committed
    """
    Shuffle datasets in a h5file that have the same length.

    Parameters
    ----------
    input_file : str
        Path of the file that will be shuffle.
    output_file : str, optional
        If given, this will be the name of the output file.
        Otherwise, a name is auto generated.
Stefan Reck's avatar
Stefan Reck committed
    iterations : int, optional
        Shuffle the file this many times. For each additional iteration,
        a temporary file will be created and then deleted afterwards.
        Default: Auto choose best number based on available RAM.
    datasets : tuple
        Which datasets to include in output.
Stefan Reck's avatar
Stefan Reck committed
    max_ram : int, optional
Stefan Reck's avatar
Stefan Reck committed
        Available ram in bytes. Default: Use fraction of
        maximum available (see max_ram_fraction).
    max_ram_fraction : float
Stefan Reck's avatar
Stefan Reck committed
        in [0, 1]. Fraction of RAM to use for reading one batch of data
Stefan Reck's avatar
Stefan Reck committed
        when max_ram is None. Note: when using chunks, this should
        be <=~0.25, since lots of ram is needed for in-memory shuffling.
Stefan Reck's avatar
Stefan Reck committed
    seed : int or None
        Seed for randomness.
Stefan Reck's avatar
Stefan Reck committed

    Returns
    -------
    output_file : str
        Path to the output file.

    """
    if output_file is None:
        output_file = get_filepath_output(input_file, shuffle=True)
Stefan Reck's avatar
Stefan Reck committed
    if iterations is None:
        iterations = get_n_iterations(
            input_file,
            datasets=datasets,
            max_ram_fraction=max_ram_fraction,
            max_ram=max_ram,
        )
    # filenames of all iterations, in the right order
    filenames = (
        input_file,
        *_get_temp_filenames(output_file, number=iterations - 1),
        output_file,
    )
    if seed:
        np.random.seed(seed)
    for i in range(iterations):
        print(f"\nIteration {i+1}/{iterations}")
        _shuffle_file(
            input_file=filenames[i],
            output_file=filenames[i + 1],
            delete=i > 0,
            datasets=datasets,
            max_ram=max_ram,
            max_ram_fraction=max_ram_fraction,
        )
    return output_file


def _shuffle_file(
    input_file,
    output_file,
    datasets=("x", "y"),
    max_ram=None,
    max_ram_fraction=0.25,
    delete=False,
):
    start_time = time.time()
Stefan Reck's avatar
Stefan Reck committed
    if os.path.exists(output_file):
        raise FileExistsError(output_file)
Stefan Reck's avatar
Stefan Reck committed
    if max_ram is None:
        max_ram = get_max_ram(max_ram_fraction)
Stefan Reck's avatar
Stefan Reck committed
    # create file with temp name first, then rename afterwards
Stefan Reck's avatar
Stefan Reck committed
    temp_output_file = (
        output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
    )
Stefan Reck's avatar
Stefan Reck committed
    with h5py.File(input_file, "r") as f_in:
Stefan Reck's avatar
Stefan Reck committed
        _check_dsets(f_in, datasets)
        dset_info = _get_largest_dset(f_in, datasets, max_ram)
        print(f"Shuffling datasets {datasets}")
        indices_per_batch = _get_indices_per_batch(
            dset_info["n_batches"],
            dset_info["n_chunks"],
            dset_info["chunksize"],
        )
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
        with h5py.File(temp_output_file, "x") as f_out:
            for dset_name in datasets:
                print("Creating dataset", dset_name)
                _shuffle_dset(f_out, f_in, dset_name, indices_per_batch)
                print("Done!")
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
    copy_used_files(input_file, temp_output_file)
    copy_attrs(input_file, temp_output_file)
    os.rename(temp_output_file, output_file)
    if delete:
        os.remove(input_file)
Stefan Reck's avatar
Stefan Reck committed
    print(
        f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
    )
Stefan Reck's avatar
Stefan Reck committed
    return output_file


def get_max_ram(max_ram_fraction):
    max_ram = max_ram_fraction * psutil.virtual_memory().available
    print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
    return max_ram


Stefan Reck's avatar
Stefan Reck committed
def get_n_iterations(
    input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
):
    """ Get how often you have to shuffle with given ram to get proper randomness. """
    if max_ram is None:
        max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
    with h5py.File(input_file, "r") as f_in:
Stefan Reck's avatar
Stefan Reck committed
        dset_info = _get_largest_dset(f_in, datasets, max_ram)
Stefan Reck's avatar
Stefan Reck committed
    n_iterations = int(
        np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
    )
Stefan Reck's avatar
Stefan Reck committed
    print(f"Largest dataset: {dset_info['name']}")
    print(f"Total chunks: {dset_info['n_chunks']}")
Stefan Reck's avatar
Stefan Reck committed
    print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}")
    print(f"--> min iterations for full shuffle: {n_iterations}")
    return n_iterations


Stefan Reck's avatar
Stefan Reck committed
def _get_indices_per_batch(n_batches, n_chunks, chunksize):
    """
    Return a list with the shuffled indices for each batch.

    Returns
    -------
    indices_per_batch : List
        Length n_batches, each element is a np.array[int].
        Element i of the list are the indices of each sample in batch number i.

    """
Stefan Reck's avatar
Stefan Reck committed
    chunk_indices = np.arange(n_chunks)
    np.random.shuffle(chunk_indices)
    chunk_batches = np.array_split(chunk_indices, n_batches)

Stefan Reck's avatar
Stefan Reck committed
    indices_per_batch = []
Stefan Reck's avatar
Stefan Reck committed
    for bat in chunk_batches:
Stefan Reck's avatar
Stefan Reck committed
        idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
Stefan Reck's avatar
Stefan Reck committed
        np.random.shuffle(idx)
Stefan Reck's avatar
Stefan Reck committed
        indices_per_batch.append(idx)

    return indices_per_batch

Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
def _get_largest_dset(f, datasets, max_ram):
    """
    Get infos about the dset that needs the most batches.
    This is the dset that determines how many samples are shuffled at a time.
    """
    dset_infos = _get_dset_infos(f, datasets, max_ram)
    return dset_infos[np.argmax([v["n_batches"] for v in dset_infos])]
Stefan Reck's avatar
Stefan Reck committed
def _check_dsets(f, datasets):
    # check if all datasets have the same number of lines
    n_lines_list = [len(f[dset_name]) for dset_name in datasets]
    if not all([n == n_lines_list[0] for n in n_lines_list]):
        raise ValueError(
            f"Datasets have different lengths! " f"{n_lines_list}"
        )


def _get_dset_infos(f, datasets, max_ram):
    """ Retrieve infos for each dataset. """
Stefan Reck's avatar
Stefan Reck committed
    dset_infos = []
    for i, name in enumerate(datasets):
        dset = f[name]
Stefan Reck's avatar
Stefan Reck committed
        n_lines = len(dset)
Stefan Reck's avatar
Stefan Reck committed
        chunksize = dset.chunks[0]
        n_chunks = int(np.ceil(n_lines / chunksize))
Stefan Reck's avatar
Stefan Reck committed
        bytes_per_line = np.asarray(dset[0]).nbytes
Stefan Reck's avatar
Stefan Reck committed
        bytes_per_chunk = bytes_per_line * chunksize
        chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))

Stefan Reck's avatar
Stefan Reck committed
        dset_infos.append({
            "name": name,
            "dset": dset,
            "n_chunks": n_chunks,
            "chunks_per_batch": chunks_per_batch,
            "n_batches": int(np.ceil(n_chunks / chunks_per_batch)),
            "chunksize": chunksize,
        })
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
    return dset_infos
Stefan Reck's avatar
Stefan Reck committed


Stefan Reck's avatar
Stefan Reck committed
def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch):
    """
    Create a batchwise-shuffled dataset in the output file using given indices.
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
    """
    dset_in = f_in[dset_name]
Stefan Reck's avatar
Stefan Reck committed
    start_idx = 0
Stefan Reck's avatar
Stefan Reck committed
    for batch_number, indices in enumerate(indices_per_batch):
        print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}")
Stefan Reck's avatar
Stefan Reck committed
        # remove indices outside of dset
Stefan Reck's avatar
Stefan Reck committed
        indices = indices[indices < len(dset_in)]
Stefan Reck's avatar
Stefan Reck committed
        # reading has to be done with linearly increasing index
        #  fancy indexing is super slow
        #  so sort -> turn to slices -> read -> conc -> undo sorting
Stefan Reck's avatar
Stefan Reck committed
        sort_ix = np.argsort(indices)
Stefan Reck's avatar
Stefan Reck committed
        unsort_ix = np.argsort(sort_ix)
Stefan Reck's avatar
Stefan Reck committed
        fancy_indices = indices[sort_ix]
        slices = _slicify(fancy_indices)
        data = np.concatenate([dset_in[slc] for slc in slices])
Stefan Reck's avatar
Stefan Reck committed
        data = data[unsort_ix]
Stefan Reck's avatar
Stefan Reck committed

Stefan Reck's avatar
Stefan Reck committed
        if batch_number == 0:
Stefan Reck's avatar
Stefan Reck committed
            out_dset = f_out.create_dataset(
Stefan Reck's avatar
Stefan Reck committed
                dset_name,
Stefan Reck's avatar
Stefan Reck committed
                data=data,
Stefan Reck's avatar
Stefan Reck committed
                maxshape=dset_in.shape,
                chunks=dset_in.chunks,
                compression=dset_in.compression,
                compression_opts=dset_in.compression_opts,
                shuffle=dset_in.shuffle,
Stefan Reck's avatar
Stefan Reck committed
            )
Stefan Reck's avatar
Stefan Reck committed
            out_dset.resize(len(dset_in), axis=0)
Stefan Reck's avatar
Stefan Reck committed
            start_idx = len(data)
        else:
            end_idx = start_idx + len(data)
Stefan Reck's avatar
Stefan Reck committed
            f_out[dset_name][start_idx:end_idx] = data
Stefan Reck's avatar
Stefan Reck committed
            start_idx = end_idx

Stefan Reck's avatar
Stefan Reck committed
        print("Memory peak: {0:.3f} MB".format(peak_memory_usage()))

Stefan Reck's avatar
Stefan Reck committed
    if start_idx != len(dset_in):
        print(f"Warning: last index was {start_idx} not {len(dset_in)}")
Stefan Reck's avatar
Stefan Reck committed
def _slicify(fancy_indices):
Stefan Reck's avatar
Stefan Reck committed
    """ [0,1,2, 6,7,8] --> [0:3, 6:9] """
    steps = np.diff(fancy_indices) != 1
    slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
    slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1
    return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]


Stefan Reck's avatar
Stefan Reck committed
def _get_temp_filenames(output_file, number):
    path, file = os.path.split(output_file)
    return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)]


def run_parser():
    # TODO deprecated
    raise NotImplementedError(
        "h5shuffle2 has been renamed to orcasong h5shuffle2")