diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py index 651603740b80e83ac5aa301d25826a7e2108fafa..9fd6970a1c70ef2e93d6f7ccd821c52fe27ea86d 100644 --- a/orcasong/tools/shuffle2.py +++ b/orcasong/tools/shuffle2.py @@ -14,14 +14,71 @@ from orcasong.tools.concatenate import copy_attrs __author__ = "Stefan Reck" -def shuffle_v2( - input_file, - datasets=("x", "y"), - output_file=None, - max_ram=None, - max_ram_fraction=0.25, - chunks=False, - delete=False): +def h5shuffle2( + input_file, + output_file=None, + iterations=None, + datasets=("x", "y"), + max_ram_fraction=0.25, + **kwargs, +): + if output_file is None: + output_file = get_filepath_output(input_file, shuffle=True) + if iterations is None: + iterations = get_n_iterations( + input_file, + datasets=datasets, + max_ram_fraction=max_ram_fraction, + ) + np.random.seed(42) + for i in range(iterations): + print(f"\nIteration {i+1}/{iterations}") + if iterations == 1: + # special case if theres only one iteration + stgs = { + "input_file": input_file, + "output_file": output_file, + "delete": False, + } + elif i == 0: + # first iteration + stgs = { + "input_file": input_file, + "output_file": f"{output_file}_temp_{i}", + "delete": False, + } + elif i == iterations - 1: + # last iteration + stgs = { + "input_file": f"{output_file}_temp_{i-1}", + "output_file": output_file, + "delete": True, + } + else: + # intermediate iterations + stgs = { + "input_file": f"{output_file}_temp_{i-1}", + "output_file": f"{output_file}_temp_{i}", + "delete": True, + } + shuffle_file( + datasets=datasets, + max_ram_fraction=max_ram_fraction, + chunks=True, + **stgs, + **kwargs, + ) + + +def shuffle_file( + input_file, + datasets=("x", "y"), + output_file=None, + max_ram=None, + max_ram_fraction=0.25, + chunks=False, + delete=False, +): """ Shuffle datasets in a h5file that have the same length. @@ -62,7 +119,9 @@ def shuffle_v2( if max_ram is None: max_ram = get_max_ram(max_ram_fraction) - temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) + temp_output_file = ( + output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) + ) with h5py.File(input_file, "r") as f_in: dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) print(f"Shuffling datasets {datasets} with {n_lines} lines each") @@ -90,8 +149,9 @@ def shuffle_v2( os.rename(temp_output_file, output_file) if delete: os.remove(input_file) - print(f"Elapsed time: " - f"{datetime.timedelta(seconds=int(time.time() - start_time))}") + print( + f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}" + ) return output_file @@ -106,7 +166,7 @@ def get_indices_largest(dset_infos): dset_info = dset_infos[largest_dset] print(f"Total chunks: {dset_info['n_chunks']}") - ratio = dset_info['chunks_per_batch']/dset_info['n_chunks'] + ratio = dset_info["chunks_per_batch"] / dset_info["n_chunks"] print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})") return get_indices_chunked( @@ -116,7 +176,9 @@ def get_indices_largest(dset_infos): ) -def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25): +def get_n_iterations( + input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25 +): """ Get how often you have to shuffle with given ram to get proper randomness. """ if max_ram is None: max_ram = get_max_ram(max_ram_fraction=max_ram_fraction) @@ -124,8 +186,9 @@ def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_frac dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos]) dset_info = dset_infos[largest_dset] - n_iterations = int(np.ceil( - np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch']))) + n_iterations = int( + np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"])) + ) print(f"Total chunks: {dset_info['n_chunks']}") print(f"Chunks per batch: {dset_info['chunks_per_batch']}") print(f"--> min iterations for full shuffle: {n_iterations}") @@ -140,7 +203,7 @@ def get_indices_chunked(n_batches, n_chunks, chunksize): index_batches = [] for bat in chunk_batches: - idx = (bat[:, None]*chunksize + np.arange(chunksize)[None, :]).flatten() + idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten() np.random.shuffle(idx) index_batches.append(idx) @@ -157,8 +220,9 @@ def get_dset_infos(f, datasets, max_ram): n_lines = len(dset) else: if len(dset) != n_lines: - raise ValueError(f"dataset {name} has different length! " - f"{len(dset)} vs {n_lines}") + raise ValueError( + f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}" + ) chunksize = dset.chunks[0] n_chunks = int(np.ceil(n_lines / chunksize)) # TODO in h5py 3.X, use .nbytes to get uncompressed size @@ -168,19 +232,21 @@ def get_dset_infos(f, datasets, max_ram): lines_per_batch = int(np.floor(max_ram / bytes_per_line)) chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk)) - dset_infos.append({ - "name": name, - "dset": dset, - "chunksize": chunksize, - "n_lines": n_lines, - "n_chunks": n_chunks, - "bytes_per_line": bytes_per_line, - "bytes_per_chunk": bytes_per_chunk, - "lines_per_batch": lines_per_batch, - "chunks_per_batch": chunks_per_batch, - "n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)), - "n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)), - }) + dset_infos.append( + { + "name": name, + "dset": dset, + "chunksize": chunksize, + "n_lines": n_lines, + "n_chunks": n_chunks, + "bytes_per_line": bytes_per_line, + "bytes_per_chunk": bytes_per_chunk, + "lines_per_batch": lines_per_batch, + "chunks_per_batch": chunks_per_batch, + "n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)), + "n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)), + } + ) return dset_infos, n_lines @@ -191,7 +257,7 @@ def make_dset(f_out, dset_info, indices): slc = slice( batch_index * dset_info["lines_per_batch"], - (batch_index+1) * dset_info["lines_per_batch"], + (batch_index + 1) * dset_info["lines_per_batch"], ) to_read = indices[slc] # reading has to be done with linearly increasing index, @@ -267,59 +333,41 @@ def slicify(fancy_indices): return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))] -def h5shuffle2(): +def run_parser(): parser = argparse.ArgumentParser( - description='Shuffle datasets in a h5file that have the same length. ' - 'Uses chunkwise readout for speed-up.') - parser.add_argument('input_file', type=str, - help='Path of the file that will be shuffled.') - parser.add_argument('--output_file', type=str, default=None, - help='If given, this will be the name of the output file. ' - 'Default: input_file + suffix.') - parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"), - help='Which datasets to include in output. Default: x, y') - parser.add_argument('--max_ram_fraction', type=float, default=0.25, - help="in [0, 1]. Fraction of all available ram to use for reading one batch of data " - "Note: this should " - "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. " - "Default: 0.25") - parser.add_argument('--iterations', type=int, default=None, - help="Shuffle the file this many times. Default: Auto choose best number.") - kwargs = vars(parser.parse_args()) - - input_file = kwargs.pop("input_file") - output_file = kwargs.pop("output_file") - if output_file is None: - output_file = get_filepath_output(input_file, shuffle=True) - iterations = kwargs.pop("iterations") - if iterations is None: - iterations = get_n_iterations( - input_file, - datasets=kwargs["datasets"], - max_ram_fraction=kwargs["max_ram_fraction"], - ) - np.random.seed(42) - for i in range(iterations): - print(f"\nIteration {i+1}/{iterations}") - if i == 0: - # first iteration - stgs = { - "input_file": input_file, - "output_file": f"{output_file}_temp_{i}", - "delete": False - } - elif i == iterations-1: - # last iteration - stgs = { - "input_file": f"{output_file}_temp_{i-1}", - "output_file": output_file, - "delete": True - } - else: - # intermediate iterations - stgs = { - "input_file": f"{output_file}_temp_{i-1}", - "output_file": f"{output_file}_temp_{i}", - "delete": True - } - shuffle_v2(**kwargs, **stgs, chunks=True) + description="Shuffle datasets in a h5file that have the same length. " + "Uses chunkwise readout for speed-up." + ) + parser.add_argument( + "input_file", type=str, help="Path of the file that will be shuffled." + ) + parser.add_argument( + "--output_file", + type=str, + default=None, + help="If given, this will be the name of the output file. " + "Default: input_file + suffix.", + ) + parser.add_argument( + "--datasets", + type=str, + nargs="*", + default=("x", "y"), + help="Which datasets to include in output. Default: x, y", + ) + parser.add_argument( + "--max_ram_fraction", + type=float, + default=0.25, + help="in [0, 1]. Fraction of all available ram to use for reading one batch of data " + "Note: this should " + "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. " + "Default: 0.25", + ) + parser.add_argument( + "--iterations", + type=int, + default=None, + help="Shuffle the file this many times. Default: Auto choose best number.", + ) + h5shuffle2(**vars(parser.parse_args()))