From 7daf7437361f366418a35f516f34345e57cc8240 Mon Sep 17 00:00:00 2001 From: Stefan Reck <stefan.reck@fau.de> Date: Mon, 7 Dec 2020 15:50:37 +0100 Subject: [PATCH] init --- orcasong/tools/shuffle2.py | 143 +++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 orcasong/tools/shuffle2.py diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py new file mode 100644 index 0000000..266cf15 --- /dev/null +++ b/orcasong/tools/shuffle2.py @@ -0,0 +1,143 @@ +import os +import time +import datetime +import argparse +import numpy as np +import h5py + +from orcasong.tools.postproc import get_filepath_output, copy_used_files +from orcasong.tools.concatenate import copy_attrs + +# neu max_ram = 1e9: 2:41 (161s) +# neu max_ram = 6e9: 0:25 (25s) +# alt: 3:38 (218s) + + +def shuffle_v2( + input_file, + datasets=("x", "y"), + output_file=None, + max_ram=1e9, + seed=42): + """ + Shuffle datasets in a h5file that have the same length. + + Parameters + ---------- + input_file : str + Path of the file that will be shuffle. + datasets : tuple + Which datasets to include in output. + output_file : str, optional + If given, this will be the name of the output file. + Otherwise, a name is auto generated. + max_ram : int + Available ram. + seed : int + Sets a fixed random seed for the shuffling. + + Returns + ------- + output_file : str + Path to the output file. + + """ + start_time = time.time() + if output_file is None: + output_file = get_filepath_output(input_file, shuffle=True) + if os.path.exists(output_file): + raise FileExistsError(output_file) + + with h5py.File(input_file, "r") as f_in: + dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) + print(f"Shuffling datasets {datasets} with {n_lines} lines each") + np.random.seed(seed) + indices = np.arange(n_lines) + np.random.shuffle(indices) + + with h5py.File(output_file, "x") as f_out: + for dset_info in dset_infos: + print("Creating dataset", dset_info["name"]) + make_dset(f_out, dset_info, indices) + print("Done!") + + copy_used_files(input_file, output_file) + copy_attrs(input_file, output_file) + + print(f"Elapsed time: " + f"{datetime.timedelta(seconds=int(time.time() - start_time))}") + return output_file + + +def get_dset_infos(f, datasets, max_ram): + """ Check datasets and retrieve relevant infos for each. """ + dset_infos = [] + n_lines = None + for i, name in enumerate(datasets): + dset = f[name] + if i == 0: + n_lines = len(dset) + else: + if len(dset) != n_lines: + raise ValueError(f"dataset {name} has different length! " + f"{len(dset)} vs {n_lines}") + # TODO in h5py 3.X, use .nbytes to get uncompressed size + bytes_per_line = np.asarray(dset[0]).nbytes + lines_per_batch = int(max_ram / bytes_per_line) + n_batches = int(np.ceil(n_lines / lines_per_batch)) + dset_infos.append({ + "name": name, + "dset": dset, + "bytes_per_line": bytes_per_line, + "lines_per_batch": lines_per_batch, + "n_batches": n_batches, + }) + return dset_infos, n_lines + + +def get_indices(n_lines, chunksize, chunks_per_batch): + indices = np.arange(n_lines) + chunk_starts = indices[::chunksize] + + np.random.shuffle(chunk_starts) + for batch_no in range(chunks_per_batch) + + +def make_dset(f_out, dset_info, indices): + """ Create a shuffled dataset in the output file. """ + for batch_index in range(dset_info["n_batches"]): + print(f"Processing batch {batch_index+1}/{dset_info['n_batches']}") + + slc = slice( + batch_index * dset_info["lines_per_batch"], + (batch_index+1) * dset_info["lines_per_batch"], + ) + to_read = indices[slc] + # reading has to be done with linearly increasing index, + # so sort -> read -> undo sorting + sort_ix = np.argsort(to_read) + unsort_ix = np.argsort(sort_ix) + data = dset_info["dset"][to_read[sort_ix]][unsort_ix] + + if batch_index == 0: + in_dset = dset_info["dset"] + out_dset = f_out.create_dataset( + dset_info["name"], + data=data, + maxshape=in_dset.shape, + chunks=in_dset.chunks, + compression=in_dset.compression, + compression_opts=in_dset.compression_opts, + shuffle=in_dset.shuffle, + ) + out_dset.resize(len(in_dset), axis=0) + else: + f_out[dset_info["name"]][slc] = data + + +def h5shuffle(): + parser = argparse.ArgumentParser(description='Shuffle an h5 file using h5py.') + parser.add_argument('input_file', type=str, help='File to shuffle.') + parser.add_argument('--output_file', type=str, + help='Name of output file. Default: Auto generate name.') + shuffle_v2(**vars(parser.parse_args())) -- GitLab