Newer
Older
import os
import time
import datetime
import argparse
import numpy as np
from orcasong.tools.postproc import get_filepath_output, copy_used_files
from orcasong.tools.concatenate import copy_attrs
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def h5shuffle2(
input_file,
output_file=None,
iterations=None,
datasets=("x", "y"),
max_ram_fraction=0.25,
**kwargs,
):
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=datasets,
max_ram_fraction=max_ram_fraction,
)
np.random.seed(42)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
if iterations == 1:
# special case if theres only one iteration
stgs = {
"input_file": input_file,
"output_file": output_file,
"delete": False,
}
elif i == 0:
# first iteration
stgs = {
"input_file": input_file,
"output_file": f"{output_file}_temp_{i}",
"delete": False,
}
elif i == iterations - 1:
# last iteration
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": output_file,
"delete": True,
}
else:
# intermediate iterations
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": f"{output_file}_temp_{i}",
"delete": True,
}
shuffle_file(
datasets=datasets,
max_ram_fraction=max_ram_fraction,
chunks=True,
**stgs,
**kwargs,
)
def shuffle_file(
input_file,
datasets=("x", "y"),
output_file=None,
max_ram=None,
max_ram_fraction=0.25,
chunks=False,
delete=False,
):
"""
Shuffle datasets in a h5file that have the same length.
Parameters
----------
input_file : str
Path of the file that will be shuffle.
datasets : tuple
Which datasets to include in output.
output_file : str, optional
If given, this will be the name of the output file.
Otherwise, a name is auto generated.
Available ram in bytes. Default: Use fraction of
maximum available (see max_ram_fraction).
max_ram_fraction : float
in [0, 1]. Fraction of ram to use for reading one batch of data
when max_ram is None. Note: when using chunks, this should
be <=~0.25, since lots of ram is needed for in-memory shuffling.
to be accurate! (use a node with at least 32gb, the more the better)
Returns
-------
output_file : str
Path to the output file.
"""
start_time = time.time()
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if os.path.exists(output_file):
raise FileExistsError(output_file)
temp_output_file = (
output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
print(f"Shuffling datasets {datasets} with {n_lines} lines each")
if not chunks:
indices = np.arange(n_lines)
np.random.shuffle(indices)
for dset_info in dset_infos:
print("Creating dataset", dset_info["name"])
make_dset(f_out, dset_info, indices)
print("Done!")
else:
indices_chunked = get_indices_largest(dset_infos)
for dset_info in dset_infos:
print("Creating dataset", dset_info["name"])
make_dset_chunked(f_out, dset_info, indices_chunked)
print("Done!")
copy_used_files(input_file, temp_output_file)
copy_attrs(input_file, temp_output_file)
os.rename(temp_output_file, output_file)
if delete:
os.remove(input_file)
print(
f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}"
)
def get_max_ram(max_ram_fraction):
max_ram = max_ram_fraction * psutil.virtual_memory().available
print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
return max_ram
def get_indices_largest(dset_infos):
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
ratio = dset_info["chunks_per_batch"] / dset_info["n_chunks"]
print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
return get_indices_chunked(
dset_info["n_batches_chunkwise"],
dset_info["n_chunks"],
dset_info["chunksize"],
)
def get_n_iterations(
input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
):
""" Get how often you have to shuffle with given ram to get proper randomness. """
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
n_iterations = int(
np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
)
print(f"Total chunks: {dset_info['n_chunks']}")
print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
print(f"--> min iterations for full shuffle: {n_iterations}")
return n_iterations
def get_indices_chunked(n_batches, n_chunks, chunksize):
""" Return a list with the chunkwise shuffled indices of each batch. """
chunk_indices = np.arange(n_chunks)
np.random.shuffle(chunk_indices)
chunk_batches = np.array_split(chunk_indices, n_batches)
index_batches = []
for bat in chunk_batches:
idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
np.random.shuffle(idx)
index_batches.append(idx)
return index_batches
def get_dset_infos(f, datasets, max_ram):
""" Check datasets and retrieve relevant infos for each. """
dset_infos = []
n_lines = None
for i, name in enumerate(datasets):
dset = f[name]
if i == 0:
n_lines = len(dset)
else:
if len(dset) != n_lines:
raise ValueError(
f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}"
)
chunksize = dset.chunks[0]
n_chunks = int(np.ceil(n_lines / chunksize))
# TODO in h5py 3.X, use .nbytes to get uncompressed size
bytes_per_line = np.asarray(dset[0]).nbytes
bytes_per_chunk = bytes_per_line * chunksize
lines_per_batch = int(np.floor(max_ram / bytes_per_line))
chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
dset_infos.append(
{
"name": name,
"dset": dset,
"chunksize": chunksize,
"n_lines": n_lines,
"n_chunks": n_chunks,
"bytes_per_line": bytes_per_line,
"bytes_per_chunk": bytes_per_chunk,
"lines_per_batch": lines_per_batch,
"chunks_per_batch": chunks_per_batch,
"n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
"n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
}
)
return dset_infos, n_lines
def make_dset(f_out, dset_info, indices):
""" Create a shuffled dataset in the output file. """
for batch_index in range(dset_info["n_batches_linewise"]):
print(f"Processing batch {batch_index+1}/{dset_info['n_batches_linewise']}")
(batch_index + 1) * dset_info["lines_per_batch"],
)
to_read = indices[slc]
# reading has to be done with linearly increasing index,
# so sort -> read -> undo sorting
sort_ix = np.argsort(to_read)
unsort_ix = np.argsort(sort_ix)
data = dset_info["dset"][to_read[sort_ix]][unsort_ix]
if batch_index == 0:
in_dset = dset_info["dset"]
out_dset = f_out.create_dataset(
dset_info["name"],
data=data,
maxshape=in_dset.shape,
chunks=in_dset.chunks,
compression=in_dset.compression,
compression_opts=in_dset.compression_opts,
shuffle=in_dset.shuffle,
)
out_dset.resize(len(in_dset), axis=0)
else:
f_out[dset_info["name"]][slc] = data
def make_dset_chunked(f_out, dset_info, indices_chunked):
""" Create a shuffled dataset in the output file. """
start_idx = 0
for batch_index, to_read in enumerate(indices_chunked):
print(f"Processing batch {batch_index+1}/{len(indices_chunked)}")
# remove indices outside of dset
to_read = to_read[to_read < len(dset_info["dset"])]
# reading has to be done with linearly increasing index
# fancy indexing is super slow
# so sort -> turn to slices -> read -> conc -> undo sorting
sort_ix = np.argsort(to_read)
unsort_ix = np.argsort(sort_ix)
fancy_indices = to_read[sort_ix]
slices = slicify(fancy_indices)
data = np.concatenate([dset_info["dset"][slc] for slc in slices])
data = data[unsort_ix]
if batch_index == 0:
in_dset = dset_info["dset"]
out_dset = f_out.create_dataset(
dset_info["name"],
data=data,
maxshape=in_dset.shape,
chunks=in_dset.chunks,
compression=in_dset.compression,
compression_opts=in_dset.compression_opts,
shuffle=in_dset.shuffle,
)
out_dset.resize(len(in_dset), axis=0)
start_idx = len(data)
else:
end_idx = start_idx + len(data)
f_out[dset_info["name"]][start_idx:end_idx] = data
start_idx = end_idx
if start_idx != len(dset_info["dset"]):
print(f"Warning: last index was {start_idx} not {len(dset_info['dset'])}")
def slicify(fancy_indices):
""" [0,1,2, 6,7,8] --> [0:3, 6:9] """
steps = np.diff(fancy_indices) != 1
slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1
return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
description="Shuffle datasets in a h5file that have the same length. "
"Uses chunkwise readout for speed-up."
)
parser.add_argument(
"input_file", type=str, help="Path of the file that will be shuffled."
)
parser.add_argument(
"--output_file",
type=str,
default=None,
help="If given, this will be the name of the output file. "
"Default: input_file + suffix.",
)
parser.add_argument(
"--datasets",
type=str,
nargs="*",
default=("x", "y"),
help="Which datasets to include in output. Default: x, y",
)
parser.add_argument(
"--max_ram_fraction",
type=float,
default=0.25,
help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
"Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
"Default: 0.25",
)
parser.add_argument(
"--iterations",
type=int,
default=None,
help="Shuffle the file this many times. Default: Auto choose best number.",
)
h5shuffle2(**vars(parser.parse_args()))