Skip to content
Snippets Groups Projects
Commit 368b5a0d authored by Stefan Reck's avatar Stefan Reck
Browse files

auto choose best n_iterations

parent aa51e454
No related branches found
No related tags found
No related merge requests found
......@@ -21,8 +21,7 @@ def shuffle_v2(
max_ram=None,
max_ram_fraction=0.25,
chunks=False,
delete=False,
seed=42):
delete=False):
"""
Shuffle datasets in a h5file that have the same length.
......@@ -48,8 +47,6 @@ def shuffle_v2(
to be accurate! (use a node with at least 32gb, the more the better)
delete : bool
Delete the original file afterwards?
seed : int
Sets a fixed random seed for the shuffling.
Returns
-------
......@@ -63,13 +60,11 @@ def shuffle_v2(
if os.path.exists(output_file):
raise FileExistsError(output_file)
if max_ram is None:
max_ram = max_ram_fraction * psutil.virtual_memory().available
print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
max_ram = get_max_ram(max_ram_fraction)
temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
np.random.seed(seed)
print(f"Shuffling datasets {datasets} with {n_lines} lines each")
if not chunks:
......@@ -100,6 +95,12 @@ def shuffle_v2(
return output_file
def get_max_ram(max_ram_fraction):
max_ram = max_ram_fraction * psutil.virtual_memory().available
print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
return max_ram
def get_indices_largest(dset_infos):
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
......@@ -108,9 +109,6 @@ def get_indices_largest(dset_infos):
ratio = dset_info['chunks_per_batch']/dset_info['n_chunks']
print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
if ratio <= 0.1:
print("Warning: Should have more than "
"10% of chunks per batch to ensure proper shuffling!")
return get_indices_chunked(
dset_info["n_batches_chunkwise"],
dset_info["n_chunks"],
......@@ -118,6 +116,22 @@ def get_indices_largest(dset_infos):
)
def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25):
""" Get how often you have to shuffle with given ram to get proper randomness. """
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
n_iterations = int(np.ceil(
np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch'])))
print(f"Total chunks: {dset_info['n_chunks']}")
print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
print(f"--> min iterations for full shuffle: {n_iterations}")
return n_iterations
def get_indices_chunked(n_batches, n_chunks, chunksize):
""" Return a list with the chunkwise shuffled indices of each batch. """
chunk_indices = np.arange(n_chunks)
......@@ -255,22 +269,22 @@ def slicify(fancy_indices):
def h5shuffle2():
parser = argparse.ArgumentParser(
description='Shuffle datasets in a h5file that have the same length.'
'Uses chunkwise readout for a pseudo-shuffle, so shuffling '
'multiple times is recommended for larger files.')
description='Shuffle datasets in a h5file that have the same length. '
'Uses chunkwise readout for speed-up.')
parser.add_argument('input_file', type=str,
help='Path of the file that will be shuffled.')
parser.add_argument('--output_file', type=str,
parser.add_argument('--output_file', type=str, default=None,
help='If given, this will be the name of the output file. '
'Otherwise, a name is auto generated.')
'Default: input_file + suffix.')
parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"),
help='Which datasets to include in output. Default: x, y')
parser.add_argument('--max_ram_fraction', type=float, default=0.25,
help="in [0, 1]. Fraction of ram to use for reading one batch of data "
"when max_ram is None. Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling.")
parser.add_argument('--iterations', type=int, default=2,
help="Shuffle the file this many times. Default: 2")
help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
"Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
"Default: 0.25")
parser.add_argument('--iterations', type=int, default=None,
help="Shuffle the file this many times. Default: Auto choose best number.")
kwargs = vars(parser.parse_args())
input_file = kwargs.pop("input_file")
......@@ -278,9 +292,15 @@ def h5shuffle2():
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
iterations = kwargs.pop("iterations")
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=kwargs["datasets"],
max_ram_fraction=kwargs["max_ram_fraction"],
)
np.random.seed(42)
for i in range(iterations):
print(f"Iteration {i}")
print(f"\nIteration {i+1}/{iterations}")
if i == 0:
# first iteration
stgs = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment