Newer
Older
from orcasong.tools.postproc import get_filepath_output, copy_used_files
from orcasong.tools.concatenate import copy_attrs
def h5shuffle2(
input_file,
output_file=None,
iterations=None,
datasets=("x", "y"),
max_ram_fraction=0.25,
"""
Shuffle datasets in a h5file that have the same length.
Parameters
----------
input_file : str
Path of the file that will be shuffle.
output_file : str, optional
If given, this will be the name of the output file.
Otherwise, a name is auto generated.
iterations : int, optional
Shuffle the file this many times. For each additional iteration,
a temporary file will be created and then deleted afterwards.
Default: Auto choose best number based on available RAM.
datasets : tuple
Which datasets to include in output.
Available ram in bytes. Default: Use fraction of
maximum available (see max_ram_fraction).
max_ram_fraction : float
when max_ram is None. Note: when using chunks, this should
be <=~0.25, since lots of ram is needed for in-memory shuffling.
Returns
-------
output_file : str
Path to the output file.
"""
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=datasets,
max_ram_fraction=max_ram_fraction,
max_ram=max_ram,
)
# filenames of all iterations, in the right order
filenames = (
input_file,
*_get_temp_filenames(output_file, number=iterations - 1),
output_file,
)
if seed:
np.random.seed(seed)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
_shuffle_file(
input_file=filenames[i],
output_file=filenames[i + 1],
delete=i > 0,
datasets=datasets,
max_ram=max_ram,
max_ram_fraction=max_ram_fraction,
)
return output_file
def _shuffle_file(
input_file,
output_file,
datasets=("x", "y"),
max_ram=None,
max_ram_fraction=0.25,
delete=False,
):
start_time = time.time()
temp_output_file = (
output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
)
_check_dsets(f_in, datasets)
dset_info = _get_largest_dset(f_in, datasets, max_ram)
print(f"Shuffling datasets {datasets}")
indices_per_batch = _get_indices_per_batch(
dset_info["n_batches"],
dset_info["n_chunks"],
dset_info["chunksize"],
)
with h5py.File(temp_output_file, "x") as f_out:
for dset_name in datasets:
print("Creating dataset", dset_name)
_shuffle_dset(f_out, f_in, dset_name, indices_per_batch)
print("Done!")
copy_used_files(input_file, temp_output_file)
copy_attrs(input_file, temp_output_file)
os.rename(temp_output_file, output_file)
if delete:
os.remove(input_file)
print(
f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
)
def get_max_ram(max_ram_fraction):
max_ram = max_ram_fraction * psutil.virtual_memory().available
print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
return max_ram
def get_n_iterations(
input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
):
""" Get how often you have to shuffle with given ram to get proper randomness. """
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
with h5py.File(input_file, "r") as f_in:
n_iterations = int(
np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
)
print(f"Total chunks: {dset_info['n_chunks']}")
print(f"--> min iterations for full shuffle: {n_iterations}")
return n_iterations
def _get_indices_per_batch(n_batches, n_chunks, chunksize):
"""
Return a list with the shuffled indices for each batch.
Returns
-------
indices_per_batch : List
Length n_batches, each element is a np.array[int].
Element i of the list are the indices of each sample in batch number i.
"""
chunk_indices = np.arange(n_chunks)
np.random.shuffle(chunk_indices)
chunk_batches = np.array_split(chunk_indices, n_batches)
idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
indices_per_batch.append(idx)
return indices_per_batch
def _get_largest_dset(f, datasets, max_ram):
"""
Get infos about the dset that needs the most batches.
This is the dset that determines how many samples are shuffled at a time.
"""
dset_infos = _get_dset_infos(f, datasets, max_ram)
return dset_infos[np.argmax([v["n_batches"] for v in dset_infos])]
def _check_dsets(f, datasets):
# check if all datasets have the same number of lines
n_lines_list = [len(f[dset_name]) for dset_name in datasets]
if not all([n == n_lines_list[0] for n in n_lines_list]):
raise ValueError(
f"Datasets have different lengths! " f"{n_lines_list}"
)
def _get_dset_infos(f, datasets, max_ram):
""" Retrieve infos for each dataset. """
dset_infos = []
for i, name in enumerate(datasets):
dset = f[name]
chunksize = dset.chunks[0]
n_chunks = int(np.ceil(n_lines / chunksize))
bytes_per_chunk = bytes_per_line * chunksize
chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
dset_infos.append({
"name": name,
"dset": dset,
"n_chunks": n_chunks,
"chunks_per_batch": chunks_per_batch,
"n_batches": int(np.ceil(n_chunks / chunks_per_batch)),
"chunksize": chunksize,
})
def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch):
"""
Create a batchwise-shuffled dataset in the output file using given indices.
for batch_number, indices in enumerate(indices_per_batch):
print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}")
# reading has to be done with linearly increasing index
# fancy indexing is super slow
# so sort -> turn to slices -> read -> conc -> undo sorting
fancy_indices = indices[sort_ix]
slices = _slicify(fancy_indices)
data = np.concatenate([dset_in[slc] for slc in slices])
maxshape=dset_in.shape,
chunks=dset_in.chunks,
compression=dset_in.compression,
compression_opts=dset_in.compression_opts,
shuffle=dset_in.shuffle,
start_idx = len(data)
else:
end_idx = start_idx + len(data)
if start_idx != len(dset_in):
print(f"Warning: last index was {start_idx} not {len(dset_in)}")
""" [0,1,2, 6,7,8] --> [0:3, 6:9] """
steps = np.diff(fancy_indices) != 1
slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1
return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
def _get_temp_filenames(output_file, number):
path, file = os.path.split(output_file)
return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)]
def run_parser():
# TODO deprecated
raise NotImplementedError(
"h5shuffle2 has been renamed to orcasong h5shuffle2")