diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py index e0c3f2d6e3882f1fb3be1403da320de3acdd3ef9..dcd9ed06d1d0ec66a950c6c45b2c480536ecc15d 100644 --- a/orcasong/tools/shuffle2.py +++ b/orcasong/tools/shuffle2.py @@ -145,9 +145,9 @@ def get_n_iterations( max_ram = get_max_ram(max_ram_fraction=max_ram_fraction) with h5py.File(input_file, "r") as f_in: dset_info = _get_largest_dset(f_in, datasets, max_ram) - n_iterations = int( + n_iterations = np.amax((1, int( np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"])) - ) + ))) print(f"Largest dataset: {dset_info['name']}") print(f"Total chunks: {dset_info['n_chunks']}") print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}") diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 302878f8bdb6bc83b35fb9aae4c65f38276f70e3..3a417a284c6eb463095e457e922f0b588f3d0df8 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1,5 +1,5 @@ import tempfile -from unittest import TestCase +import unittest import numpy as np import h5py import orcasong.tools.concatenate as conc @@ -8,7 +8,7 @@ import os __author__ = 'Stefan Reck' -class TestFileConcatenator(TestCase): +class TestFileConcatenator(unittest.TestCase): """ Test concatenation on pre-generated h5 files. They are in tests/data. @@ -117,20 +117,27 @@ class TestFileConcatenator(TestCase): ) -class TestConcatenateIndexed(TestCase): - @classmethod - def setUpClass(cls) -> None: - cls.infile = tempfile.NamedTemporaryFile() - with h5py.File(cls.infile, "w") as f: - cls.x = np.arange(20) - dset_x = f.create_dataset("x", data=cls.x, chunks=True) - dset_x.attrs.create("indexed", True) - cls.indices = np.array( - [(0, 5), (5, 12), (17, 3)], - dtype=[('index', '<i8'), ('n_items', '<i8')] - ) - f.create_dataset("x_indices", data=cls.indices, chunks=True) +class BaseTestClass: + class BaseIndexedFile(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.infile = tempfile.NamedTemporaryFile() + with h5py.File(cls.infile, "w") as f: + cls.x = np.arange(20) + dset_x = f.create_dataset("x", data=cls.x, chunks=True) + dset_x.attrs.create("indexed", True) + cls.indices = np.array( + [(0, 5), (5, 12), (17, 3)], + dtype=[('index', '<i8'), ('n_items', '<i8')] + ) + f.create_dataset("x_indices", data=cls.indices, chunks=True) + + @classmethod + def tearDownClass(cls) -> None: + cls.infile.close() + +class TestConcatenateIndexed(BaseTestClass.BaseIndexedFile): def setUp(self) -> None: self.outfile = "temp_out.h5" conc.concatenate([self.infile.name] * 2, outfile=self.outfile) @@ -139,10 +146,6 @@ class TestConcatenateIndexed(TestCase): if os.path.exists(self.outfile): os.remove(self.outfile) - @classmethod - def tearDownClass(cls) -> None: - cls.infile.close() - def test_check_x(self): with h5py.File(self.outfile) as f_out: np.testing.assert_array_equal( diff --git a/tests/test_postproc.py b/tests/test_postproc.py index dea566b37903351b103ec81a5ee234b7ad545176..0efac09ccbd45afb06747f2004037b60345e6d87 100644 --- a/tests/test_postproc.py +++ b/tests/test_postproc.py @@ -4,6 +4,7 @@ import h5py import numpy as np import orcasong.tools.postproc as postproc import orcasong.tools.shuffle2 as shuffle2 +from .test_concatenate import BaseTestClass __author__ = 'Stefan Reck' @@ -90,6 +91,40 @@ class TestShuffleV2(TestCase): os.remove(fname) +class TestShuffleIndexed(BaseTestClass.BaseIndexedFile): + def setUp(self) -> None: + self.outfile = "temp_out.h5" + shuffle2.h5shuffle2( + self.infile.name, + output_file=self.outfile, + datasets=("x",), + seed=2, + ) + + def tearDown(self) -> None: + if os.path.exists(self.outfile): + os.remove(self.outfile) + + def test_check_x(self): + with h5py.File(self.outfile) as f_out: + np.testing.assert_array_equal( + f_out["x"], + np.concatenate([np.arange(17, 20), np.arange(5, 17), np.arange(0, 5)]) + ) + + def test_check_x_indices_n_items(self): + with h5py.File(self.outfile) as f_out: + target_n_items = np.array([3, 12, 5]) + np.testing.assert_array_equal( + f_out["x_indices"]["n_items"], target_n_items) + + def test_check_x_indices_index(self): + with h5py.File(self.outfile) as f_out: + target_index = np.array([0, 3, 15]) + np.testing.assert_array_equal( + f_out["x_indices"]["index"], target_index) + + def _make_shuffle_dummy_file(filepath): x = np.random.rand(22, 2) x[:, 0] = np.arange(22)