From 368b5a0db8fcb58d94e840428c043711647ddf83 Mon Sep 17 00:00:00 2001 From: Stefan Reck <stefan.reck@fau.de> Date: Tue, 15 Dec 2020 15:06:41 +0100 Subject: [PATCH] auto choose best n_iterations --- orcasong/tools/shuffle2.py | 64 +++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py index 1d85fa3..6516037 100644 --- a/orcasong/tools/shuffle2.py +++ b/orcasong/tools/shuffle2.py @@ -21,8 +21,7 @@ def shuffle_v2( max_ram=None, max_ram_fraction=0.25, chunks=False, - delete=False, - seed=42): + delete=False): """ Shuffle datasets in a h5file that have the same length. @@ -48,8 +47,6 @@ def shuffle_v2( to be accurate! (use a node with at least 32gb, the more the better) delete : bool Delete the original file afterwards? - seed : int - Sets a fixed random seed for the shuffling. Returns ------- @@ -63,13 +60,11 @@ def shuffle_v2( if os.path.exists(output_file): raise FileExistsError(output_file) if max_ram is None: - max_ram = max_ram_fraction * psutil.virtual_memory().available - print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes") + max_ram = get_max_ram(max_ram_fraction) temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) with h5py.File(input_file, "r") as f_in: dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) - np.random.seed(seed) print(f"Shuffling datasets {datasets} with {n_lines} lines each") if not chunks: @@ -100,6 +95,12 @@ def shuffle_v2( return output_file +def get_max_ram(max_ram_fraction): + max_ram = max_ram_fraction * psutil.virtual_memory().available + print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes") + return max_ram + + def get_indices_largest(dset_infos): largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos]) dset_info = dset_infos[largest_dset] @@ -108,9 +109,6 @@ def get_indices_largest(dset_infos): ratio = dset_info['chunks_per_batch']/dset_info['n_chunks'] print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})") - if ratio <= 0.1: - print("Warning: Should have more than " - "10% of chunks per batch to ensure proper shuffling!") return get_indices_chunked( dset_info["n_batches_chunkwise"], dset_info["n_chunks"], @@ -118,6 +116,22 @@ def get_indices_largest(dset_infos): ) +def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25): + """ Get how often you have to shuffle with given ram to get proper randomness. """ + if max_ram is None: + max_ram = get_max_ram(max_ram_fraction=max_ram_fraction) + with h5py.File(input_file, "r") as f_in: + dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) + largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos]) + dset_info = dset_infos[largest_dset] + n_iterations = int(np.ceil( + np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch']))) + print(f"Total chunks: {dset_info['n_chunks']}") + print(f"Chunks per batch: {dset_info['chunks_per_batch']}") + print(f"--> min iterations for full shuffle: {n_iterations}") + return n_iterations + + def get_indices_chunked(n_batches, n_chunks, chunksize): """ Return a list with the chunkwise shuffled indices of each batch. """ chunk_indices = np.arange(n_chunks) @@ -255,22 +269,22 @@ def slicify(fancy_indices): def h5shuffle2(): parser = argparse.ArgumentParser( - description='Shuffle datasets in a h5file that have the same length.' - 'Uses chunkwise readout for a pseudo-shuffle, so shuffling ' - 'multiple times is recommended for larger files.') + description='Shuffle datasets in a h5file that have the same length. ' + 'Uses chunkwise readout for speed-up.') parser.add_argument('input_file', type=str, help='Path of the file that will be shuffled.') - parser.add_argument('--output_file', type=str, + parser.add_argument('--output_file', type=str, default=None, help='If given, this will be the name of the output file. ' - 'Otherwise, a name is auto generated.') + 'Default: input_file + suffix.') parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"), help='Which datasets to include in output. Default: x, y') parser.add_argument('--max_ram_fraction', type=float, default=0.25, - help="in [0, 1]. Fraction of ram to use for reading one batch of data " - "when max_ram is None. Note: this should " - "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling.") - parser.add_argument('--iterations', type=int, default=2, - help="Shuffle the file this many times. Default: 2") + help="in [0, 1]. Fraction of all available ram to use for reading one batch of data " + "Note: this should " + "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. " + "Default: 0.25") + parser.add_argument('--iterations', type=int, default=None, + help="Shuffle the file this many times. Default: Auto choose best number.") kwargs = vars(parser.parse_args()) input_file = kwargs.pop("input_file") @@ -278,9 +292,15 @@ def h5shuffle2(): if output_file is None: output_file = get_filepath_output(input_file, shuffle=True) iterations = kwargs.pop("iterations") - + if iterations is None: + iterations = get_n_iterations( + input_file, + datasets=kwargs["datasets"], + max_ram_fraction=kwargs["max_ram_fraction"], + ) + np.random.seed(42) for i in range(iterations): - print(f"Iteration {i}") + print(f"\nIteration {i+1}/{iterations}") if i == 0: # first iteration stgs = { -- GitLab