Skip to content
Snippets Groups Projects
Commit 368b5a0d authored by Stefan Reck's avatar Stefan Reck
Browse files

auto choose best n_iterations

parent aa51e454
Branches
Tags
No related merge requests found
...@@ -21,8 +21,7 @@ def shuffle_v2( ...@@ -21,8 +21,7 @@ def shuffle_v2(
max_ram=None, max_ram=None,
max_ram_fraction=0.25, max_ram_fraction=0.25,
chunks=False, chunks=False,
delete=False, delete=False):
seed=42):
""" """
Shuffle datasets in a h5file that have the same length. Shuffle datasets in a h5file that have the same length.
...@@ -48,8 +47,6 @@ def shuffle_v2( ...@@ -48,8 +47,6 @@ def shuffle_v2(
to be accurate! (use a node with at least 32gb, the more the better) to be accurate! (use a node with at least 32gb, the more the better)
delete : bool delete : bool
Delete the original file afterwards? Delete the original file afterwards?
seed : int
Sets a fixed random seed for the shuffling.
Returns Returns
------- -------
...@@ -63,13 +60,11 @@ def shuffle_v2( ...@@ -63,13 +60,11 @@ def shuffle_v2(
if os.path.exists(output_file): if os.path.exists(output_file):
raise FileExistsError(output_file) raise FileExistsError(output_file)
if max_ram is None: if max_ram is None:
max_ram = max_ram_fraction * psutil.virtual_memory().available max_ram = get_max_ram(max_ram_fraction)
print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
with h5py.File(input_file, "r") as f_in: with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
np.random.seed(seed)
print(f"Shuffling datasets {datasets} with {n_lines} lines each") print(f"Shuffling datasets {datasets} with {n_lines} lines each")
if not chunks: if not chunks:
...@@ -100,6 +95,12 @@ def shuffle_v2( ...@@ -100,6 +95,12 @@ def shuffle_v2(
return output_file return output_file
def get_max_ram(max_ram_fraction):
max_ram = max_ram_fraction * psutil.virtual_memory().available
print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
return max_ram
def get_indices_largest(dset_infos): def get_indices_largest(dset_infos):
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos]) largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset] dset_info = dset_infos[largest_dset]
...@@ -108,9 +109,6 @@ def get_indices_largest(dset_infos): ...@@ -108,9 +109,6 @@ def get_indices_largest(dset_infos):
ratio = dset_info['chunks_per_batch']/dset_info['n_chunks'] ratio = dset_info['chunks_per_batch']/dset_info['n_chunks']
print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})") print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
if ratio <= 0.1:
print("Warning: Should have more than "
"10% of chunks per batch to ensure proper shuffling!")
return get_indices_chunked( return get_indices_chunked(
dset_info["n_batches_chunkwise"], dset_info["n_batches_chunkwise"],
dset_info["n_chunks"], dset_info["n_chunks"],
...@@ -118,6 +116,22 @@ def get_indices_largest(dset_infos): ...@@ -118,6 +116,22 @@ def get_indices_largest(dset_infos):
) )
def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25):
""" Get how often you have to shuffle with given ram to get proper randomness. """
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
n_iterations = int(np.ceil(
np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch'])))
print(f"Total chunks: {dset_info['n_chunks']}")
print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
print(f"--> min iterations for full shuffle: {n_iterations}")
return n_iterations
def get_indices_chunked(n_batches, n_chunks, chunksize): def get_indices_chunked(n_batches, n_chunks, chunksize):
""" Return a list with the chunkwise shuffled indices of each batch. """ """ Return a list with the chunkwise shuffled indices of each batch. """
chunk_indices = np.arange(n_chunks) chunk_indices = np.arange(n_chunks)
...@@ -255,22 +269,22 @@ def slicify(fancy_indices): ...@@ -255,22 +269,22 @@ def slicify(fancy_indices):
def h5shuffle2(): def h5shuffle2():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Shuffle datasets in a h5file that have the same length.' description='Shuffle datasets in a h5file that have the same length. '
'Uses chunkwise readout for a pseudo-shuffle, so shuffling ' 'Uses chunkwise readout for speed-up.')
'multiple times is recommended for larger files.')
parser.add_argument('input_file', type=str, parser.add_argument('input_file', type=str,
help='Path of the file that will be shuffled.') help='Path of the file that will be shuffled.')
parser.add_argument('--output_file', type=str, parser.add_argument('--output_file', type=str, default=None,
help='If given, this will be the name of the output file. ' help='If given, this will be the name of the output file. '
'Otherwise, a name is auto generated.') 'Default: input_file + suffix.')
parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"), parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"),
help='Which datasets to include in output. Default: x, y') help='Which datasets to include in output. Default: x, y')
parser.add_argument('--max_ram_fraction', type=float, default=0.25, parser.add_argument('--max_ram_fraction', type=float, default=0.25,
help="in [0, 1]. Fraction of ram to use for reading one batch of data " help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
"when max_ram is None. Note: this should " "Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling.") "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
parser.add_argument('--iterations', type=int, default=2, "Default: 0.25")
help="Shuffle the file this many times. Default: 2") parser.add_argument('--iterations', type=int, default=None,
help="Shuffle the file this many times. Default: Auto choose best number.")
kwargs = vars(parser.parse_args()) kwargs = vars(parser.parse_args())
input_file = kwargs.pop("input_file") input_file = kwargs.pop("input_file")
...@@ -278,9 +292,15 @@ def h5shuffle2(): ...@@ -278,9 +292,15 @@ def h5shuffle2():
if output_file is None: if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True) output_file = get_filepath_output(input_file, shuffle=True)
iterations = kwargs.pop("iterations") iterations = kwargs.pop("iterations")
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=kwargs["datasets"],
max_ram_fraction=kwargs["max_ram_fraction"],
)
np.random.seed(42)
for i in range(iterations): for i in range(iterations):
print(f"Iteration {i}") print(f"\nIteration {i+1}/{iterations}")
if i == 0: if i == 0:
# first iteration # first iteration
stgs = { stgs = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment