Skip to content
Snippets Groups Projects
Commit 7a23b53d authored by Stefan Reck's avatar Stefan Reck
Browse files

Merge branch 'fix_shuffle2' into 'master'

Fix shuffle2

See merge request !15
parents 87abba00 0195e356
No related branches found
No related tags found
1 merge request!15Fix shuffle2
import os
import time
import h5py
import numpy as np
......@@ -177,6 +178,7 @@ class FileConcatenator:
errors, rows_per_file, valid_input_files = [], [0], []
for i, file_name in enumerate(input_files, start=1):
file_name = os.path.abspath(file_name)
try:
rows_this_file = _get_rows(file_name, keys_stripped)
except Exception as e:
......
......@@ -14,14 +14,71 @@ from orcasong.tools.concatenate import copy_attrs
__author__ = "Stefan Reck"
def shuffle_v2(
input_file,
datasets=("x", "y"),
output_file=None,
max_ram=None,
max_ram_fraction=0.25,
chunks=False,
delete=False):
def h5shuffle2(
input_file,
output_file=None,
iterations=None,
datasets=("x", "y"),
max_ram_fraction=0.25,
**kwargs,
):
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=datasets,
max_ram_fraction=max_ram_fraction,
)
np.random.seed(42)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
if iterations == 1:
# special case if theres only one iteration
stgs = {
"input_file": input_file,
"output_file": output_file,
"delete": False,
}
elif i == 0:
# first iteration
stgs = {
"input_file": input_file,
"output_file": f"{output_file}_temp_{i}",
"delete": False,
}
elif i == iterations - 1:
# last iteration
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": output_file,
"delete": True,
}
else:
# intermediate iterations
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": f"{output_file}_temp_{i}",
"delete": True,
}
shuffle_file(
datasets=datasets,
max_ram_fraction=max_ram_fraction,
chunks=True,
**stgs,
**kwargs,
)
def shuffle_file(
input_file,
datasets=("x", "y"),
output_file=None,
max_ram=None,
max_ram_fraction=0.25,
chunks=False,
delete=False,
):
"""
Shuffle datasets in a h5file that have the same length.
......@@ -62,7 +119,9 @@ def shuffle_v2(
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction)
temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
temp_output_file = (
output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
print(f"Shuffling datasets {datasets} with {n_lines} lines each")
......@@ -90,8 +149,9 @@ def shuffle_v2(
os.rename(temp_output_file, output_file)
if delete:
os.remove(input_file)
print(f"Elapsed time: "
f"{datetime.timedelta(seconds=int(time.time() - start_time))}")
print(
f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
)
return output_file
......@@ -106,7 +166,7 @@ def get_indices_largest(dset_infos):
dset_info = dset_infos[largest_dset]
print(f"Total chunks: {dset_info['n_chunks']}")
ratio = dset_info['chunks_per_batch']/dset_info['n_chunks']
ratio = dset_info["chunks_per_batch"] / dset_info["n_chunks"]
print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
return get_indices_chunked(
......@@ -116,7 +176,9 @@ def get_indices_largest(dset_infos):
)
def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25):
def get_n_iterations(
input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
):
""" Get how often you have to shuffle with given ram to get proper randomness. """
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
......@@ -124,8 +186,9 @@ def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_frac
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
n_iterations = int(np.ceil(
np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch'])))
n_iterations = int(
np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
)
print(f"Total chunks: {dset_info['n_chunks']}")
print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
print(f"--> min iterations for full shuffle: {n_iterations}")
......@@ -140,7 +203,7 @@ def get_indices_chunked(n_batches, n_chunks, chunksize):
index_batches = []
for bat in chunk_batches:
idx = (bat[:, None]*chunksize + np.arange(chunksize)[None, :]).flatten()
idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
np.random.shuffle(idx)
index_batches.append(idx)
......@@ -157,8 +220,9 @@ def get_dset_infos(f, datasets, max_ram):
n_lines = len(dset)
else:
if len(dset) != n_lines:
raise ValueError(f"dataset {name} has different length! "
f"{len(dset)} vs {n_lines}")
raise ValueError(
f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}"
)
chunksize = dset.chunks[0]
n_chunks = int(np.ceil(n_lines / chunksize))
# TODO in h5py 3.X, use .nbytes to get uncompressed size
......@@ -168,19 +232,21 @@ def get_dset_infos(f, datasets, max_ram):
lines_per_batch = int(np.floor(max_ram / bytes_per_line))
chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
dset_infos.append({
"name": name,
"dset": dset,
"chunksize": chunksize,
"n_lines": n_lines,
"n_chunks": n_chunks,
"bytes_per_line": bytes_per_line,
"bytes_per_chunk": bytes_per_chunk,
"lines_per_batch": lines_per_batch,
"chunks_per_batch": chunks_per_batch,
"n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
"n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
})
dset_infos.append(
{
"name": name,
"dset": dset,
"chunksize": chunksize,
"n_lines": n_lines,
"n_chunks": n_chunks,
"bytes_per_line": bytes_per_line,
"bytes_per_chunk": bytes_per_chunk,
"lines_per_batch": lines_per_batch,
"chunks_per_batch": chunks_per_batch,
"n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
"n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
}
)
return dset_infos, n_lines
......@@ -191,7 +257,7 @@ def make_dset(f_out, dset_info, indices):
slc = slice(
batch_index * dset_info["lines_per_batch"],
(batch_index+1) * dset_info["lines_per_batch"],
(batch_index + 1) * dset_info["lines_per_batch"],
)
to_read = indices[slc]
# reading has to be done with linearly increasing index,
......@@ -267,59 +333,41 @@ def slicify(fancy_indices):
return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
def h5shuffle2():
def run_parser():
parser = argparse.ArgumentParser(
description='Shuffle datasets in a h5file that have the same length. '
'Uses chunkwise readout for speed-up.')
parser.add_argument('input_file', type=str,
help='Path of the file that will be shuffled.')
parser.add_argument('--output_file', type=str, default=None,
help='If given, this will be the name of the output file. '
'Default: input_file + suffix.')
parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"),
help='Which datasets to include in output. Default: x, y')
parser.add_argument('--max_ram_fraction', type=float, default=0.25,
help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
"Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
"Default: 0.25")
parser.add_argument('--iterations', type=int, default=None,
help="Shuffle the file this many times. Default: Auto choose best number.")
kwargs = vars(parser.parse_args())
input_file = kwargs.pop("input_file")
output_file = kwargs.pop("output_file")
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
iterations = kwargs.pop("iterations")
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=kwargs["datasets"],
max_ram_fraction=kwargs["max_ram_fraction"],
)
np.random.seed(42)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
if i == 0:
# first iteration
stgs = {
"input_file": input_file,
"output_file": f"{output_file}_temp_{i}",
"delete": False
}
elif i == iterations-1:
# last iteration
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": output_file,
"delete": True
}
else:
# intermediate iterations
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": f"{output_file}_temp_{i}",
"delete": True
}
shuffle_v2(**kwargs, **stgs, chunks=True)
description="Shuffle datasets in a h5file that have the same length. "
"Uses chunkwise readout for speed-up."
)
parser.add_argument(
"input_file", type=str, help="Path of the file that will be shuffled."
)
parser.add_argument(
"--output_file",
type=str,
default=None,
help="If given, this will be the name of the output file. "
"Default: input_file + suffix.",
)
parser.add_argument(
"--datasets",
type=str,
nargs="*",
default=("x", "y"),
help="Which datasets to include in output. Default: x, y",
)
parser.add_argument(
"--max_ram_fraction",
type=float,
default=0.25,
help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
"Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
"Default: 0.25",
)
parser.add_argument(
"--iterations",
type=int,
default=None,
help="Shuffle the file this many times. Default: Auto choose best number.",
)
h5shuffle2(**vars(parser.parse_args()))
......@@ -29,7 +29,7 @@ setup(
entry_points={'console_scripts': [
'concatenate=orcasong.tools.concatenate:main',
'h5shuffle=orcasong.tools.postproc:h5shuffle',
'h5shuffle2=orcasong.tools.shuffle2:h5shuffle2',
'h5shuffle2=orcasong.tools.shuffle2:run_parser',
'plot_binstats=orcasong.plotting.plot_binstats:main',
'make_nn_images=legacy.make_nn_images:main',
'make_dsplit=orcasong_contrib.data_tools.make_data_split.make_data_split:main']}
......
......@@ -3,7 +3,7 @@ import os
import h5py
import numpy as np
import orcasong.tools.postproc as postproc
from orcasong.tools.shuffle2 import shuffle_v2
import orcasong.tools.shuffle2 as shuffle2
__author__ = 'Stefan Reck'
......@@ -41,11 +41,10 @@ class TestShuffleV2(TestCase):
self.x, self.y = _make_shuffle_dummy_file(self.temp_input)
np.random.seed(42)
shuffle_v2(
shuffle2.h5shuffle2(
input_file=self.temp_input,
output_file=self.temp_output,
datasets=("x", "y"),
chunks=True,
max_ram=400, # -> 2 batches
)
......@@ -76,6 +75,20 @@ class TestShuffleV2(TestCase):
f["x"][:, 0], target_order
)
def test_run_3_iterations(self):
# just check if it goes through without errors
fname = "temp_output_triple.h5"
try:
shuffle2.h5shuffle2(
input_file=self.temp_input,
output_file=fname,
datasets=("x", "y"),
iterations=3,
)
finally:
if os.path.exists(fname):
os.remove(fname)
def _make_shuffle_dummy_file(filepath):
x = np.random.rand(22, 2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment