from unittest import TestCase
import os
import h5py
import numpy as np
import orcasong.tools.postproc as postproc
import orcasong.tools.shuffle2 as shuffle2

__author__ = 'Stefan Reck'

test_dir = os.path.dirname(os.path.realpath(__file__))
MUPAGE_FILE = os.path.join(test_dir, "data", "mupage.root.h5")


class TestPostproc(TestCase):
    def setUp(self):
        self.output_file = "temp_output.h5"

    def tearDown(self):
        if os.path.exists(self.output_file):
            os.remove(self.output_file)

    def test_shuffle(self):
        postproc.postproc_file(
            input_file=MUPAGE_FILE,
            output_file=self.output_file,
            shuffle=True,
            event_skipper=None,
            delete=False,
            seed=13,
        )

        with h5py.File(self.output_file, "r") as f:
            np.testing.assert_equal(f["event_info"]["event_id"], np.array([1, 0, 2]))
            self.assertTrue("origin" in f.attrs.keys())


class TestShuffleV2(TestCase):
    def setUp(self):
        self.temp_input = "temp_input.h5"
        self.temp_output = "temp_output.h5"

        self.x, self.y = _make_shuffle_dummy_file(self.temp_input)
        np.random.seed(42)
        shuffle2.h5shuffle2(
            input_file=self.temp_input,
            output_file=self.temp_output,
            datasets=("x", "y"),
            max_ram=400,  # -> 2 batches
        )

    def tearDown(self):
        for f in (self.temp_input, self.temp_output):
            if os.path.exists(f):
                os.remove(f)

    def test_shuffled_has_same_entries_as_input(self):
        with h5py.File(self.temp_output, "r") as f:
            x_s = f["x"][()]
            np.testing.assert_array_equal(
                self.x[:, 1:], x_s[:, 1:][np.argsort(x_s[:, 0])]
            )

    def test_all_shuffled_datasets_have_same_order(self):
        with h5py.File(self.temp_output, "r") as f:
            np.testing.assert_array_equal(
                f["x"][:, 0], f["y"][:, 0]
            )

    def test_seed_produces_this_shuffled_order(self):
        target_order = np.array(
            [5.,  6., 20.,  8., 13., 14., 10., 11.,  7., 21.,  9., 12., 19.,
             2.,  0., 16., 18., 15.,  3., 17.,  1.,  4.])
        with h5py.File(self.temp_output, "r") as f:
            np.testing.assert_array_equal(
                f["x"][:, 0], target_order
            )

    def test_run_3_iterations(self):
        # just check if it goes through without errors
        fname = "temp_output_triple.h5"
        try:
            shuffle2.h5shuffle2(
                input_file=self.temp_input,
                output_file=fname,
                datasets=("x", "y"),
                iterations=3,
            )
        finally:
            if os.path.exists(fname):
                os.remove(fname)


def _make_shuffle_dummy_file(filepath):
    x = np.random.rand(22, 2)
    x[:, 0] = np.arange(22)
    y = np.random.rand(22, 3)
    y[:, 0] = np.arange(22)
    with h5py.File(filepath, "w") as f:
        f.create_dataset("x", data=x, chunks=(5, 2))
        f.create_dataset("y", data=y, chunks=(5, 3))
    return x, y