Skip to content
Snippets Groups Projects
Commit 8809183e authored by Stefan Reck's avatar Stefan Reck
Browse files

fix incorrect filename when theres only one iteration & restructure

parent 87abba00
No related branches found
No related tags found
1 merge request!15Fix shuffle2
......@@ -14,7 +14,61 @@ from orcasong.tools.concatenate import copy_attrs
__author__ = "Stefan Reck"
def shuffle_v2(
def h5shuffle2(input_file,
output_file=None,
iterations=None,
datasets=("x", "y"),
max_ram_fraction=0.25,
**kwargs):
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=datasets,
max_ram_fraction=max_ram_fraction,
)
np.random.seed(42)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
if iterations == 1:
# special case if theres only one iteration
stgs = {
"input_file": input_file,
"output_file": output_file,
"delete": False,
}
elif i == 0:
# first iteration
stgs = {
"input_file": input_file,
"output_file": f"{output_file}_temp_{i}",
"delete": False,
}
elif i == iterations-1:
# last iteration
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": output_file,
"delete": True,
}
else:
# intermediate iterations
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": f"{output_file}_temp_{i}",
"delete": True,
}
shuffle_file(
datasets=datasets,
max_ram_fraction=max_ram_fraction,
chunks=True,
**stgs,
**kwargs,
)
def shuffle_file(
input_file,
datasets=("x", "y"),
output_file=None,
......@@ -267,7 +321,7 @@ def slicify(fancy_indices):
return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
def h5shuffle2():
def run_parser():
parser = argparse.ArgumentParser(
description='Shuffle datasets in a h5file that have the same length. '
'Uses chunkwise readout for speed-up.')
......@@ -285,41 +339,4 @@ def h5shuffle2():
"Default: 0.25")
parser.add_argument('--iterations', type=int, default=None,
help="Shuffle the file this many times. Default: Auto choose best number.")
kwargs = vars(parser.parse_args())
input_file = kwargs.pop("input_file")
output_file = kwargs.pop("output_file")
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
iterations = kwargs.pop("iterations")
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=kwargs["datasets"],
max_ram_fraction=kwargs["max_ram_fraction"],
)
np.random.seed(42)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
if i == 0:
# first iteration
stgs = {
"input_file": input_file,
"output_file": f"{output_file}_temp_{i}",
"delete": False
}
elif i == iterations-1:
# last iteration
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": output_file,
"delete": True
}
else:
# intermediate iterations
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": f"{output_file}_temp_{i}",
"delete": True
}
shuffle_v2(**kwargs, **stgs, chunks=True)
h5shuffle2(**vars(parser.parse_args()))
......@@ -3,7 +3,7 @@ import os
import h5py
import numpy as np
import orcasong.tools.postproc as postproc
from orcasong.tools.shuffle2 import shuffle_v2
import orcasong.tools.shuffle2 as shuffle2
__author__ = 'Stefan Reck'
......@@ -41,11 +41,10 @@ class TestShuffleV2(TestCase):
self.x, self.y = _make_shuffle_dummy_file(self.temp_input)
np.random.seed(42)
shuffle_v2(
shuffle2.h5shuffle2(
input_file=self.temp_input,
output_file=self.temp_output,
datasets=("x", "y"),
chunks=True,
max_ram=400, # -> 2 batches
)
......@@ -76,6 +75,20 @@ class TestShuffleV2(TestCase):
f["x"][:, 0], target_order
)
def test_run_3_iterations(self):
# just check if it goes through without errors
fname = "temp_output_triple.h5"
try:
shuffle2.h5shuffle2(
input_file=self.temp_input,
output_file=fname,
datasets=("x", "y"),
iterations=3,
)
finally:
if os.path.exists(fname):
os.remove(fname)
def _make_shuffle_dummy_file(filepath):
x = np.random.rand(22, 2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment