diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py index 651603740b80e83ac5aa301d25826a7e2108fafa..61ae1c8af9a474db5aaabe8dcb330d897b084b8c 100644 --- a/orcasong/tools/shuffle2.py +++ b/orcasong/tools/shuffle2.py @@ -14,7 +14,61 @@ from orcasong.tools.concatenate import copy_attrs __author__ = "Stefan Reck" -def shuffle_v2( +def h5shuffle2(input_file, + output_file=None, + iterations=None, + datasets=("x", "y"), + max_ram_fraction=0.25, + **kwargs): + if output_file is None: + output_file = get_filepath_output(input_file, shuffle=True) + if iterations is None: + iterations = get_n_iterations( + input_file, + datasets=datasets, + max_ram_fraction=max_ram_fraction, + ) + np.random.seed(42) + for i in range(iterations): + print(f"\nIteration {i+1}/{iterations}") + if iterations == 1: + # special case if theres only one iteration + stgs = { + "input_file": input_file, + "output_file": output_file, + "delete": False, + } + elif i == 0: + # first iteration + stgs = { + "input_file": input_file, + "output_file": f"{output_file}_temp_{i}", + "delete": False, + } + elif i == iterations-1: + # last iteration + stgs = { + "input_file": f"{output_file}_temp_{i-1}", + "output_file": output_file, + "delete": True, + } + else: + # intermediate iterations + stgs = { + "input_file": f"{output_file}_temp_{i-1}", + "output_file": f"{output_file}_temp_{i}", + "delete": True, + } + shuffle_file( + datasets=datasets, + max_ram_fraction=max_ram_fraction, + chunks=True, + **stgs, + **kwargs, + ) + + +def shuffle_file( input_file, datasets=("x", "y"), output_file=None, @@ -267,7 +321,7 @@ def slicify(fancy_indices): return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))] -def h5shuffle2(): +def run_parser(): parser = argparse.ArgumentParser( description='Shuffle datasets in a h5file that have the same length. ' 'Uses chunkwise readout for speed-up.') @@ -285,41 +339,4 @@ def h5shuffle2(): "Default: 0.25") parser.add_argument('--iterations', type=int, default=None, help="Shuffle the file this many times. Default: Auto choose best number.") - kwargs = vars(parser.parse_args()) - - input_file = kwargs.pop("input_file") - output_file = kwargs.pop("output_file") - if output_file is None: - output_file = get_filepath_output(input_file, shuffle=True) - iterations = kwargs.pop("iterations") - if iterations is None: - iterations = get_n_iterations( - input_file, - datasets=kwargs["datasets"], - max_ram_fraction=kwargs["max_ram_fraction"], - ) - np.random.seed(42) - for i in range(iterations): - print(f"\nIteration {i+1}/{iterations}") - if i == 0: - # first iteration - stgs = { - "input_file": input_file, - "output_file": f"{output_file}_temp_{i}", - "delete": False - } - elif i == iterations-1: - # last iteration - stgs = { - "input_file": f"{output_file}_temp_{i-1}", - "output_file": output_file, - "delete": True - } - else: - # intermediate iterations - stgs = { - "input_file": f"{output_file}_temp_{i-1}", - "output_file": f"{output_file}_temp_{i}", - "delete": True - } - shuffle_v2(**kwargs, **stgs, chunks=True) + h5shuffle2(**vars(parser.parse_args())) diff --git a/tests/test_postproc.py b/tests/test_postproc.py index abb9e8d55f14a2d12fe571e1bde8cfa8002191e2..85a461e3fe2b5c007e60ac0aa0c830a1fcdb98dd 100644 --- a/tests/test_postproc.py +++ b/tests/test_postproc.py @@ -3,7 +3,7 @@ import os import h5py import numpy as np import orcasong.tools.postproc as postproc -from orcasong.tools.shuffle2 import shuffle_v2 +import orcasong.tools.shuffle2 as shuffle2 __author__ = 'Stefan Reck' @@ -41,11 +41,10 @@ class TestShuffleV2(TestCase): self.x, self.y = _make_shuffle_dummy_file(self.temp_input) np.random.seed(42) - shuffle_v2( + shuffle2.h5shuffle2( input_file=self.temp_input, output_file=self.temp_output, datasets=("x", "y"), - chunks=True, max_ram=400, # -> 2 batches ) @@ -76,6 +75,20 @@ class TestShuffleV2(TestCase): f["x"][:, 0], target_order ) + def test_run_3_iterations(self): + # just check if it goes through without errors + fname = "temp_output_triple.h5" + try: + shuffle2.h5shuffle2( + input_file=self.temp_input, + output_file=fname, + datasets=("x", "y"), + iterations=3, + ) + finally: + if os.path.exists(fname): + os.remove(fname) + def _make_shuffle_dummy_file(filepath): x = np.random.rand(22, 2)