Newer
Older
import os
import time
import datetime
import argparse
import numpy as np
from orcasong.tools.postproc import get_filepath_output, copy_used_files
from orcasong.tools.concatenate import copy_attrs
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def h5shuffle2(input_file,
output_file=None,
iterations=None,
datasets=("x", "y"),
max_ram_fraction=0.25,
**kwargs):
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=datasets,
max_ram_fraction=max_ram_fraction,
)
np.random.seed(42)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
if iterations == 1:
# special case if theres only one iteration
stgs = {
"input_file": input_file,
"output_file": output_file,
"delete": False,
}
elif i == 0:
# first iteration
stgs = {
"input_file": input_file,
"output_file": f"{output_file}_temp_{i}",
"delete": False,
}
elif i == iterations-1:
# last iteration
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": output_file,
"delete": True,
}
else:
# intermediate iterations
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": f"{output_file}_temp_{i}",
"delete": True,
}
shuffle_file(
datasets=datasets,
max_ram_fraction=max_ram_fraction,
chunks=True,
**stgs,
**kwargs,
)
def shuffle_file(
"""
Shuffle datasets in a h5file that have the same length.
Parameters
----------
input_file : str
Path of the file that will be shuffle.
datasets : tuple
Which datasets to include in output.
output_file : str, optional
If given, this will be the name of the output file.
Otherwise, a name is auto generated.
Available ram in bytes. Default: Use fraction of
maximum available (see max_ram_fraction).
max_ram_fraction : float
in [0, 1]. Fraction of ram to use for reading one batch of data
when max_ram is None. Note: when using chunks, this should
be <=~0.25, since lots of ram is needed for in-memory shuffling.
to be accurate! (use a node with at least 32gb, the more the better)
Returns
-------
output_file : str
Path to the output file.
"""
start_time = time.time()
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if os.path.exists(output_file):
raise FileExistsError(output_file)
temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
print(f"Shuffling datasets {datasets} with {n_lines} lines each")
if not chunks:
indices = np.arange(n_lines)
np.random.shuffle(indices)
for dset_info in dset_infos:
print("Creating dataset", dset_info["name"])
make_dset(f_out, dset_info, indices)
print("Done!")
else:
indices_chunked = get_indices_largest(dset_infos)
for dset_info in dset_infos:
print("Creating dataset", dset_info["name"])
make_dset_chunked(f_out, dset_info, indices_chunked)
print("Done!")
copy_used_files(input_file, temp_output_file)
copy_attrs(input_file, temp_output_file)
os.rename(temp_output_file, output_file)
if delete:
os.remove(input_file)
print(f"Elapsed time: "
f"{datetime.timedelta(seconds=int(time.time() - start_time))}")
return output_file
def get_max_ram(max_ram_fraction):
max_ram = max_ram_fraction * psutil.virtual_memory().available
print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
return max_ram
def get_indices_largest(dset_infos):
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
print(f"Total chunks: {dset_info['n_chunks']}")
ratio = dset_info['chunks_per_batch']/dset_info['n_chunks']
print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
return get_indices_chunked(
dset_info["n_batches_chunkwise"],
dset_info["n_chunks"],
dset_info["chunksize"],
)
def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25):
""" Get how often you have to shuffle with given ram to get proper randomness. """
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
n_iterations = int(np.ceil(
np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch'])))
print(f"Total chunks: {dset_info['n_chunks']}")
print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
print(f"--> min iterations for full shuffle: {n_iterations}")
return n_iterations
def get_indices_chunked(n_batches, n_chunks, chunksize):
""" Return a list with the chunkwise shuffled indices of each batch. """
chunk_indices = np.arange(n_chunks)
np.random.shuffle(chunk_indices)
chunk_batches = np.array_split(chunk_indices, n_batches)
index_batches = []
for bat in chunk_batches:
idx = (bat[:, None]*chunksize + np.arange(chunksize)[None, :]).flatten()
np.random.shuffle(idx)
index_batches.append(idx)
return index_batches
def get_dset_infos(f, datasets, max_ram):
""" Check datasets and retrieve relevant infos for each. """
dset_infos = []
n_lines = None
for i, name in enumerate(datasets):
dset = f[name]
if i == 0:
n_lines = len(dset)
else:
if len(dset) != n_lines:
raise ValueError(f"dataset {name} has different length! "
f"{len(dset)} vs {n_lines}")
chunksize = dset.chunks[0]
n_chunks = int(np.ceil(n_lines / chunksize))
# TODO in h5py 3.X, use .nbytes to get uncompressed size
bytes_per_line = np.asarray(dset[0]).nbytes
bytes_per_chunk = bytes_per_line * chunksize
lines_per_batch = int(np.floor(max_ram / bytes_per_line))
chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
"chunksize": chunksize,
"n_lines": n_lines,
"n_chunks": n_chunks,
"chunks_per_batch": chunks_per_batch,
"n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
"n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
})
return dset_infos, n_lines
def make_dset(f_out, dset_info, indices):
""" Create a shuffled dataset in the output file. """
for batch_index in range(dset_info["n_batches_linewise"]):
print(f"Processing batch {batch_index+1}/{dset_info['n_batches_linewise']}")
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
slc = slice(
batch_index * dset_info["lines_per_batch"],
(batch_index+1) * dset_info["lines_per_batch"],
)
to_read = indices[slc]
# reading has to be done with linearly increasing index,
# so sort -> read -> undo sorting
sort_ix = np.argsort(to_read)
unsort_ix = np.argsort(sort_ix)
data = dset_info["dset"][to_read[sort_ix]][unsort_ix]
if batch_index == 0:
in_dset = dset_info["dset"]
out_dset = f_out.create_dataset(
dset_info["name"],
data=data,
maxshape=in_dset.shape,
chunks=in_dset.chunks,
compression=in_dset.compression,
compression_opts=in_dset.compression_opts,
shuffle=in_dset.shuffle,
)
out_dset.resize(len(in_dset), axis=0)
else:
f_out[dset_info["name"]][slc] = data
def make_dset_chunked(f_out, dset_info, indices_chunked):
""" Create a shuffled dataset in the output file. """
start_idx = 0
for batch_index, to_read in enumerate(indices_chunked):
print(f"Processing batch {batch_index+1}/{len(indices_chunked)}")
# remove indices outside of dset
to_read = to_read[to_read < len(dset_info["dset"])]
# reading has to be done with linearly increasing index
# fancy indexing is super slow
# so sort -> turn to slices -> read -> conc -> undo sorting
sort_ix = np.argsort(to_read)
unsort_ix = np.argsort(sort_ix)
fancy_indices = to_read[sort_ix]
slices = slicify(fancy_indices)
data = np.concatenate([dset_info["dset"][slc] for slc in slices])
data = data[unsort_ix]
if batch_index == 0:
in_dset = dset_info["dset"]
out_dset = f_out.create_dataset(
dset_info["name"],
data=data,
maxshape=in_dset.shape,
chunks=in_dset.chunks,
compression=in_dset.compression,
compression_opts=in_dset.compression_opts,
shuffle=in_dset.shuffle,
)
out_dset.resize(len(in_dset), axis=0)
start_idx = len(data)
else:
end_idx = start_idx + len(data)
f_out[dset_info["name"]][start_idx:end_idx] = data
start_idx = end_idx
if start_idx != len(dset_info["dset"]):
print(f"Warning: last index was {start_idx} not {len(dset_info['dset'])}")
def slicify(fancy_indices):
""" [0,1,2, 6,7,8] --> [0:3, 6:9] """
steps = np.diff(fancy_indices) != 1
slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1
return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
def run_parser():
description='Shuffle datasets in a h5file that have the same length. '
'Uses chunkwise readout for speed-up.')
parser.add_argument('--output_file', type=str, default=None,
parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"),
help='Which datasets to include in output. Default: x, y')
parser.add_argument('--max_ram_fraction', type=float, default=0.25,
help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
"Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
"Default: 0.25")
parser.add_argument('--iterations', type=int, default=None,
help="Shuffle the file this many times. Default: Auto choose best number.")
h5shuffle2(**vars(parser.parse_args()))