import os
import time
import datetime

import numpy as np
import psutil
import h5py
from km3pipe.sys import peak_memory_usage

from orcasong.tools.postproc import get_filepath_output, copy_used_files
from orcasong.tools.concatenate import copy_attrs


__author__ = "Stefan Reck"


def h5shuffle2(
    input_file,
    output_file=None,
    iterations=None,
    datasets=("x", "y"),
    max_ram_fraction=0.25,
    max_ram=None,
    seed=42,
):
    """
    Shuffle datasets in a h5file that have the same length.

    Parameters
    ----------
    input_file : str
        Path of the file that will be shuffle.
    output_file : str, optional
        If given, this will be the name of the output file.
        Otherwise, a name is auto generated.
    iterations : int, optional
        Shuffle the file this many times. For each additional iteration,
        a temporary file will be created and then deleted afterwards.
        Default: Auto choose best number based on available RAM.
    datasets : tuple
        Which datasets to include in output.
    max_ram : int, optional
        Available ram in bytes. Default: Use fraction of
        maximum available (see max_ram_fraction).
    max_ram_fraction : float
        in [0, 1]. Fraction of RAM to use for reading one batch of data
        when max_ram is None. Note: when using chunks, this should
        be <=~0.25, since lots of ram is needed for in-memory shuffling.
    seed : int or None
        Seed for randomness.

    Returns
    -------
    output_file : str
        Path to the output file.

    """
    if output_file is None:
        output_file = get_filepath_output(input_file, shuffle=True)
    if iterations is None:
        iterations = get_n_iterations(
            input_file,
            datasets=datasets,
            max_ram_fraction=max_ram_fraction,
            max_ram=max_ram,
        )
    # filenames of all iterations, in the right order
    filenames = (
        input_file,
        *_get_temp_filenames(output_file, number=iterations - 1),
        output_file,
    )
    if seed:
        np.random.seed(seed)
    for i in range(iterations):
        print(f"\nIteration {i+1}/{iterations}")
        _shuffle_file(
            input_file=filenames[i],
            output_file=filenames[i + 1],
            delete=i > 0,
            datasets=datasets,
            max_ram=max_ram,
            max_ram_fraction=max_ram_fraction,
        )
    return output_file


def _shuffle_file(
    input_file,
    output_file,
    datasets=("x", "y"),
    max_ram=None,
    max_ram_fraction=0.25,
    delete=False,
):
    start_time = time.time()
    if os.path.exists(output_file):
        raise FileExistsError(output_file)
    if max_ram is None:
        max_ram = get_max_ram(max_ram_fraction)
    # create file with temp name first, then rename afterwards
    temp_output_file = (
        output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
    )
    with h5py.File(input_file, "r") as f_in:
        _check_dsets(f_in, datasets)
        dset_info = _get_largest_dset(f_in, datasets, max_ram)
        print(f"Shuffling datasets {datasets}")
        indices_per_batch = _get_indices_per_batch(
            dset_info["n_batches"],
            dset_info["n_chunks"],
            dset_info["chunksize"],
        )

        with h5py.File(temp_output_file, "x") as f_out:
            for dset_name in datasets:
                print("Creating dataset", dset_name)
                _shuffle_dset(f_out, f_in, dset_name, indices_per_batch)
                print("Done!")

    copy_used_files(input_file, temp_output_file)
    copy_attrs(input_file, temp_output_file)
    os.rename(temp_output_file, output_file)
    if delete:
        os.remove(input_file)
    print(
        f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
    )
    return output_file


def get_max_ram(max_ram_fraction):
    max_ram = max_ram_fraction * psutil.virtual_memory().available
    print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
    return max_ram


def get_n_iterations(
    input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
):
    """ Get how often you have to shuffle with given ram to get proper randomness. """
    if max_ram is None:
        max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
    with h5py.File(input_file, "r") as f_in:
        dset_info = _get_largest_dset(f_in, datasets, max_ram)
    n_iterations = int(
        np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
    )
    print(f"Largest dataset: {dset_info['name']}")
    print(f"Total chunks: {dset_info['n_chunks']}")
    print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}")
    print(f"--> min iterations for full shuffle: {n_iterations}")
    return n_iterations


def _get_indices_per_batch(n_batches, n_chunks, chunksize):
    """
    Return a list with the shuffled indices for each batch.

    Returns
    -------
    indices_per_batch : List
        Length n_batches, each element is a np.array[int].
        Element i of the list are the indices of each sample in batch number i.

    """
    chunk_indices = np.arange(n_chunks)
    np.random.shuffle(chunk_indices)
    chunk_batches = np.array_split(chunk_indices, n_batches)

    indices_per_batch = []
    for bat in chunk_batches:
        idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
        np.random.shuffle(idx)
        indices_per_batch.append(idx)

    return indices_per_batch


def _get_largest_dset(f, datasets, max_ram):
    """
    Get infos about the dset that needs the most batches.
    This is the dset that determines how many samples are shuffled at a time.
    """
    dset_infos = _get_dset_infos(f, datasets, max_ram)
    return dset_infos[np.argmax([v["n_batches"] for v in dset_infos])]


def _check_dsets(f, datasets):
    # check if all datasets have the same number of lines
    n_lines_list = [len(f[dset_name]) for dset_name in datasets]
    if not all([n == n_lines_list[0] for n in n_lines_list]):
        raise ValueError(
            f"Datasets have different lengths! " f"{n_lines_list}"
        )


def _get_dset_infos(f, datasets, max_ram):
    """ Retrieve infos for each dataset. """
    dset_infos = []
    for i, name in enumerate(datasets):
        dset = f[name]
        n_lines = len(dset)
        chunksize = dset.chunks[0]
        n_chunks = int(np.ceil(n_lines / chunksize))
        bytes_per_line = np.asarray(dset[0]).nbytes
        bytes_per_chunk = bytes_per_line * chunksize
        chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))

        dset_infos.append({
            "name": name,
            "dset": dset,
            "n_chunks": n_chunks,
            "chunks_per_batch": chunks_per_batch,
            "n_batches": int(np.ceil(n_chunks / chunks_per_batch)),
            "chunksize": chunksize,
        })

    return dset_infos


def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch):
    """
    Create a batchwise-shuffled dataset in the output file using given indices.

    """
    dset_in = f_in[dset_name]
    start_idx = 0
    for batch_number, indices in enumerate(indices_per_batch):
        print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}")
        # remove indices outside of dset
        indices = indices[indices < len(dset_in)]
        # reading has to be done with linearly increasing index
        #  fancy indexing is super slow
        #  so sort -> turn to slices -> read -> conc -> undo sorting
        sort_ix = np.argsort(indices)
        unsort_ix = np.argsort(sort_ix)
        fancy_indices = indices[sort_ix]
        slices = _slicify(fancy_indices)
        data = np.concatenate([dset_in[slc] for slc in slices])
        data = data[unsort_ix]

        if batch_number == 0:
            out_dset = f_out.create_dataset(
                dset_name,
                data=data,
                maxshape=dset_in.shape,
                chunks=dset_in.chunks,
                compression=dset_in.compression,
                compression_opts=dset_in.compression_opts,
                shuffle=dset_in.shuffle,
            )
            out_dset.resize(len(dset_in), axis=0)
            start_idx = len(data)
        else:
            end_idx = start_idx + len(data)
            f_out[dset_name][start_idx:end_idx] = data
            start_idx = end_idx

        print("Memory peak: {0:.3f} MB".format(peak_memory_usage()))

    if start_idx != len(dset_in):
        print(f"Warning: last index was {start_idx} not {len(dset_in)}")


def _slicify(fancy_indices):
    """ [0,1,2, 6,7,8] --> [0:3, 6:9] """
    steps = np.diff(fancy_indices) != 1
    slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
    slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1
    return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]


def _get_temp_filenames(output_file, number):
    path, file = os.path.split(output_file)
    return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)]


def run_parser():
    # TODO deprecated
    raise NotImplementedError(
        "h5shuffle2 has been renamed to orcasong h5shuffle2")