Skip to content
Snippets Groups Projects
test_postproc.py 3.05 KiB
Newer Older
Stefan Reck's avatar
Stefan Reck committed
from unittest import TestCase
import os
import h5py
import numpy as np
import orcasong.tools.postproc as postproc
import orcasong.tools.shuffle2 as shuffle2
Stefan Reck's avatar
Stefan Reck committed

__author__ = 'Stefan Reck'

test_dir = os.path.dirname(os.path.realpath(__file__))
MUPAGE_FILE = os.path.join(test_dir, "data", "mupage.root.h5")


class TestPostproc(TestCase):
    def setUp(self):
        self.output_file = "temp_output.h5"

    def tearDown(self):
        if os.path.exists(self.output_file):
            os.remove(self.output_file)

    def test_shuffle(self):
        postproc.postproc_file(
            input_file=MUPAGE_FILE,
            output_file=self.output_file,
            shuffle=True,
            event_skipper=None,
            delete=False,
            seed=13,
        )

        with h5py.File(self.output_file, "r") as f:
            np.testing.assert_equal(f["event_info"]["event_id"], np.array([1, 0, 2]))
            self.assertTrue("origin" in f.attrs.keys())
Stefan Reck's avatar
Stefan Reck committed


class TestShuffleV2(TestCase):
    def setUp(self):
        self.temp_input = "temp_input.h5"
        self.temp_output = "temp_output.h5"

        self.x, self.y = _make_shuffle_dummy_file(self.temp_input)
Stefan Reck's avatar
Stefan Reck committed
        np.random.seed(42)
Stefan Reck's avatar
Stefan Reck committed
            input_file=self.temp_input,
            output_file=self.temp_output,
            datasets=("x", "y"),
            max_ram=400,  # -> 2 batches
        )

    def tearDown(self):
        for f in (self.temp_input, self.temp_output):
            if os.path.exists(f):
                os.remove(f)

    def test_shuffled_has_same_entries_as_input(self):
        with h5py.File(self.temp_output, "r") as f:
            x_s = f["x"][()]
            np.testing.assert_array_equal(
                self.x[:, 1:], x_s[:, 1:][np.argsort(x_s[:, 0])]
            )

    def test_all_shuffled_datasets_have_same_order(self):
        with h5py.File(self.temp_output, "r") as f:
            np.testing.assert_array_equal(
                f["x"][:, 0], f["y"][:, 0]
            )

    def test_seed_produces_this_shuffled_order(self):
        target_order = np.array(
            [5.,  6., 20.,  8., 13., 14., 10., 11.,  7., 21.,  9., 12., 19.,
             2.,  0., 16., 18., 15.,  3., 17.,  1.,  4.])
        with h5py.File(self.temp_output, "r") as f:
            np.testing.assert_array_equal(
                f["x"][:, 0], target_order
            )

    def test_run_3_iterations(self):
        # just check if it goes through without errors
        fname = "temp_output_triple.h5"
        try:
            shuffle2.h5shuffle2(
                input_file=self.temp_input,
                output_file=fname,
                datasets=("x", "y"),
                iterations=3,
            )
        finally:
            if os.path.exists(fname):
                os.remove(fname)

Stefan Reck's avatar
Stefan Reck committed

def _make_shuffle_dummy_file(filepath):
    x = np.random.rand(22, 2)
    x[:, 0] = np.arange(22)
    y = np.random.rand(22, 3)
    y[:, 0] = np.arange(22)
    with h5py.File(filepath, "w") as f:
        f.create_dataset("x", data=x, chunks=(5, 2))
        f.create_dataset("y", data=y, chunks=(5, 3))
    return x, y