Skip to content
Snippets Groups Projects
Commit b6d0ec52 authored by Stefan Reck's avatar Stefan Reck
Browse files

cleanup

parent e02df80b
No related branches found
No related tags found
1 merge request!22Cleanup
......@@ -11,8 +11,8 @@ EXTRACTORS = {
"nu_chain_muon": extractors.get_muon_mc_info_extr,
"nu_chain_noise": extractors.get_random_noise_mc_info_extr,
"nu_chain_data": extractors.get_real_data_info_extr,
"bundle_mc": extractors.bundles.BundleMCExtractor,
"bundle_data": extractors.bundles.BundleDataExtractor,
"bundle_mc": extractors.BundleMCExtractor,
"bundle_data": extractors.BundleDataExtractor,
}
MODES = {
......
import os
import time
import datetime
import argparse
import warnings
import numpy as np
import psutil
......@@ -23,64 +21,7 @@ def h5shuffle2(
datasets=("x", "y"),
max_ram_fraction=0.25,
max_ram=None,
):
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=datasets,
max_ram_fraction=max_ram_fraction,
max_ram=max_ram,
)
np.random.seed(42)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
if iterations == 1:
# special case if theres only one iteration
stgs = {
"input_file": input_file,
"output_file": output_file,
"delete": False,
}
elif i == 0:
# first iteration
stgs = {
"input_file": input_file,
"output_file": f"{output_file}_temp_{i}",
"delete": False,
}
elif i == iterations - 1:
# last iteration
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": output_file,
"delete": True,
}
else:
# intermediate iterations
stgs = {
"input_file": f"{output_file}_temp_{i-1}",
"output_file": f"{output_file}_temp_{i}",
"delete": True,
}
shuffle_file(
datasets=datasets,
max_ram=max_ram,
max_ram_fraction=max_ram_fraction,
chunks=True,
**stgs,
)
def shuffle_file(
input_file,
datasets=("x", "y"),
output_file=None,
max_ram=None,
max_ram_fraction=0.25,
chunks=False,
delete=False,
seed=42,
):
"""
Shuffle datasets in a h5file that have the same length.
......@@ -89,24 +30,24 @@ def shuffle_file(
----------
input_file : str
Path of the file that will be shuffle.
datasets : tuple
Which datasets to include in output.
output_file : str, optional
If given, this will be the name of the output file.
Otherwise, a name is auto generated.
iterations : int, optional
Shuffle the file this many times. For each additional iteration,
a temporary file will be created and then deleted afterwards.
Default: Auto choose best number based on available RAM.
datasets : tuple
Which datasets to include in output.
max_ram : int, optional
Available ram in bytes. Default: Use fraction of
maximum available (see max_ram_fraction).
max_ram_fraction : float
in [0, 1]. Fraction of ram to use for reading one batch of data
in [0, 1]. Fraction of RAM to use for reading one batch of data
when max_ram is None. Note: when using chunks, this should
be <=~0.25, since lots of ram is needed for in-memory shuffling.
chunks : bool
Use chunk-wise readout. Large speed boost, but will
only quasi-randomize order! Needs lots of ram
to be accurate! (use a node with at least 32gb, the more the better)
delete : bool
Delete the original file afterwards?
seed : int or None
Seed for randomness.
Returns
-------
......@@ -114,38 +55,68 @@ def shuffle_file(
Path to the output file.
"""
start_time = time.time()
if output_file is None:
output_file = get_filepath_output(input_file, shuffle=True)
if iterations is None:
iterations = get_n_iterations(
input_file,
datasets=datasets,
max_ram_fraction=max_ram_fraction,
max_ram=max_ram,
)
# filenames of all iterations, in the right order
filenames = (
input_file,
*_get_temp_filenames(output_file, number=iterations - 1),
output_file,
)
if seed:
np.random.seed(seed)
for i in range(iterations):
print(f"\nIteration {i+1}/{iterations}")
_shuffle_file(
input_file=filenames[i],
output_file=filenames[i + 1],
delete=i > 0,
datasets=datasets,
max_ram=max_ram,
max_ram_fraction=max_ram_fraction,
)
return output_file
def _shuffle_file(
input_file,
output_file,
datasets=("x", "y"),
max_ram=None,
max_ram_fraction=0.25,
delete=False,
):
start_time = time.time()
if os.path.exists(output_file):
raise FileExistsError(output_file)
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction)
# create file with temp name first, then rename afterwards
temp_output_file = (
output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
print(f"Shuffling datasets {datasets} with {n_lines} lines each")
if not chunks:
indices = np.arange(n_lines)
np.random.shuffle(indices)
with h5py.File(temp_output_file, "x") as f_out:
for dset_info in dset_infos:
print("Creating dataset", dset_info["name"])
make_dset(f_out, dset_info, indices)
print("Done!")
else:
indices_chunked = get_indices_largest(dset_infos)
_check_dsets(f_in, datasets)
dset_info = _get_largest_dset(f_in, datasets, max_ram)
print(f"Shuffling datasets {datasets}")
indices_per_batch = _get_indices_per_batch(
dset_info["n_batches"],
dset_info["n_chunks"],
dset_info["chunksize"],
)
with h5py.File(temp_output_file, "x") as f_out:
for dset_info in dset_infos:
print("Creating dataset", dset_info["name"])
make_dset_chunked(f_out, dset_info, indices_chunked)
print("Done!")
with h5py.File(temp_output_file, "x") as f_out:
for dset_name in datasets:
print("Creating dataset", dset_name)
_shuffle_dset(f_out, f_in, dset_name, indices_per_batch)
print("Done!")
copy_used_files(input_file, temp_output_file)
copy_attrs(input_file, temp_output_file)
......@@ -164,21 +135,6 @@ def get_max_ram(max_ram_fraction):
return max_ram
def get_indices_largest(dset_infos):
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
print(f"Total chunks: {dset_info['n_chunks']}")
ratio = dset_info["chunks_per_batch"] / dset_info["n_chunks"]
print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
return get_indices_chunked(
dset_info["n_batches_chunkwise"],
dset_info["n_chunks"],
dset_info["chunksize"],
)
def get_n_iterations(
input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25
):
......@@ -186,9 +142,7 @@ def get_n_iterations(
if max_ram is None:
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
with h5py.File(input_file, "r") as f_in:
dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
dset_info = dset_infos[largest_dset]
dset_info = _get_largest_dset(f_in, datasets, max_ram)
n_iterations = int(
np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
)
......@@ -198,137 +152,117 @@ def get_n_iterations(
return n_iterations
def get_indices_chunked(n_batches, n_chunks, chunksize):
""" Return a list with the chunkwise shuffled indices of each batch. """
def _get_indices_per_batch(n_batches, n_chunks, chunksize):
"""
Return a list with the shuffled indices for each batch.
Returns
-------
indices_per_batch : List
Length n_batches, each element is a np.array[int].
Element i of the list are the indices of each sample in batch number i.
"""
chunk_indices = np.arange(n_chunks)
np.random.shuffle(chunk_indices)
chunk_batches = np.array_split(chunk_indices, n_batches)
index_batches = []
indices_per_batch = []
for bat in chunk_batches:
idx = (bat[:, None] * chunksize + np.arange(chunksize)[None, :]).flatten()
np.random.shuffle(idx)
index_batches.append(idx)
indices_per_batch.append(idx)
return indices_per_batch
return index_batches
def _get_largest_dset(f, datasets, max_ram):
"""
Get infos about the dset that needs the most batches.
This is the dset that determines how many samples are shuffled at a time.
"""
dset_infos = _get_dset_infos(f, datasets, max_ram)
return dset_infos[np.argmax([v["n_batches"] for v in dset_infos])]
def get_dset_infos(f, datasets, max_ram):
""" Check datasets and retrieve relevant infos for each. """
def _check_dsets(f, datasets):
# check if all datasets have the same number of lines
n_lines_list = [len(f[dset_name]) for dset_name in datasets]
if not all([n == n_lines_list[0] for n in n_lines_list]):
raise ValueError(
f"Datasets have different lengths! " f"{n_lines_list}"
)
def _get_dset_infos(f, datasets, max_ram):
""" Retrieve infos for each dataset. """
dset_infos = []
n_lines = None
for i, name in enumerate(datasets):
dset = f[name]
if i == 0:
n_lines = len(dset)
else:
if len(dset) != n_lines:
raise ValueError(
f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}"
)
n_lines = len(dset)
chunksize = dset.chunks[0]
n_chunks = int(np.ceil(n_lines / chunksize))
# TODO in h5py 3.X, use .nbytes to get uncompressed size
bytes_per_line = np.asarray(dset[0]).nbytes
bytes_per_chunk = bytes_per_line * chunksize
lines_per_batch = int(np.floor(max_ram / bytes_per_line))
chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk))
dset_infos.append(
{
"name": name,
"dset": dset,
"chunksize": chunksize,
"n_lines": n_lines,
"n_chunks": n_chunks,
"bytes_per_line": bytes_per_line,
"bytes_per_chunk": bytes_per_chunk,
"lines_per_batch": lines_per_batch,
"chunks_per_batch": chunks_per_batch,
"n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)),
"n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)),
}
)
return dset_infos, n_lines
dset_infos.append({
"name": name,
"dset": dset,
"n_chunks": n_chunks,
"chunks_per_batch": chunks_per_batch,
"n_batches": int(np.ceil(n_chunks / chunks_per_batch)),
"chunksize": chunksize,
})
def make_dset(f_out, dset_info, indices):
""" Create a shuffled dataset in the output file. """
for batch_index in range(dset_info["n_batches_linewise"]):
print(f"Processing batch {batch_index+1}/{dset_info['n_batches_linewise']}")
return dset_infos
slc = slice(
batch_index * dset_info["lines_per_batch"],
(batch_index + 1) * dset_info["lines_per_batch"],
)
to_read = indices[slc]
# reading has to be done with linearly increasing index,
# so sort -> read -> undo sorting
sort_ix = np.argsort(to_read)
unsort_ix = np.argsort(sort_ix)
data = dset_info["dset"][to_read[sort_ix]][unsort_ix]
if batch_index == 0:
in_dset = dset_info["dset"]
out_dset = f_out.create_dataset(
dset_info["name"],
data=data,
maxshape=in_dset.shape,
chunks=in_dset.chunks,
compression=in_dset.compression,
compression_opts=in_dset.compression_opts,
shuffle=in_dset.shuffle,
)
out_dset.resize(len(in_dset), axis=0)
else:
f_out[dset_info["name"]][slc] = data
def _shuffle_dset(f_out, f_in, dset_name, indices_per_batch):
"""
Create a batchwise-shuffled dataset in the output file using given indices.
def make_dset_chunked(f_out, dset_info, indices_chunked):
""" Create a shuffled dataset in the output file. """
"""
dset_in = f_in[dset_name]
start_idx = 0
for batch_index, to_read in enumerate(indices_chunked):
print(f"Processing batch {batch_index+1}/{len(indices_chunked)}")
for batch_number, indices in enumerate(indices_per_batch):
print(f"Processing batch {batch_number+1}/{len(indices_per_batch)}")
# remove indices outside of dset
to_read = to_read[to_read < len(dset_info["dset"])]
indices = indices[indices < len(dset_in)]
# reading has to be done with linearly increasing index
# fancy indexing is super slow
# so sort -> turn to slices -> read -> conc -> undo sorting
sort_ix = np.argsort(to_read)
sort_ix = np.argsort(indices)
unsort_ix = np.argsort(sort_ix)
fancy_indices = to_read[sort_ix]
slices = slicify(fancy_indices)
data = np.concatenate([dset_info["dset"][slc] for slc in slices])
fancy_indices = indices[sort_ix]
slices = _slicify(fancy_indices)
data = np.concatenate([dset_in[slc] for slc in slices])
data = data[unsort_ix]
if batch_index == 0:
in_dset = dset_info["dset"]
if batch_number == 0:
out_dset = f_out.create_dataset(
dset_info["name"],
dset_name,
data=data,
maxshape=in_dset.shape,
chunks=in_dset.chunks,
compression=in_dset.compression,
compression_opts=in_dset.compression_opts,
shuffle=in_dset.shuffle,
maxshape=dset_in.shape,
chunks=dset_in.chunks,
compression=dset_in.compression,
compression_opts=dset_in.compression_opts,
shuffle=dset_in.shuffle,
)
out_dset.resize(len(in_dset), axis=0)
out_dset.resize(len(dset_in), axis=0)
start_idx = len(data)
else:
end_idx = start_idx + len(data)
f_out[dset_info["name"]][start_idx:end_idx] = data
f_out[dset_name][start_idx:end_idx] = data
start_idx = end_idx
print("Memory peak: {0:.3f} MB".format(peak_memory_usage()))
if start_idx != len(dset_info["dset"]):
print(f"Warning: last index was {start_idx} not {len(dset_info['dset'])}")
if start_idx != len(dset_in):
print(f"Warning: last index was {start_idx} not {len(dset_in)}")
def slicify(fancy_indices):
def _slicify(fancy_indices):
""" [0,1,2, 6,7,8] --> [0:3, 6:9] """
steps = np.diff(fancy_indices) != 1
slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]])
......@@ -336,50 +270,6 @@ def slicify(fancy_indices):
return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
def run_parser():
# TODO deprecated
warnings.warn("h5shuffle2 is deprecated and has been renamed to orcasong h5shuffle2")
parser = argparse.ArgumentParser(
description="Shuffle datasets in a h5file that have the same length. "
"Uses chunkwise readout for speed-up."
)
parser.add_argument(
"input_file", type=str, help="Path of the file that will be shuffled."
)
parser.add_argument(
"--output_file",
type=str,
default=None,
help="If given, this will be the name of the output file. "
"Default: input_file + suffix.",
)
parser.add_argument(
"--datasets",
type=str,
nargs="*",
default=("x", "y"),
help="Which datasets to include in output. Default: x, y",
)
parser.add_argument(
"--max_ram_fraction",
type=float,
default=0.25,
help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
"Note: this should "
"be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
"Default: 0.25",
)
parser.add_argument(
"--iterations",
type=int,
default=None,
help="Shuffle the file this many times. Default: Auto choose best number.",
)
parser.add_argument(
"--max_ram",
type=int,
default=None,
help="Available ram in bytes. Default: Use fraction of maximum "
"available instead (see max_ram_fraction).",
)
h5shuffle2(**vars(parser.parse_args()))
def _get_temp_filenames(output_file, number):
path, file = os.path.split(output_file)
return [os.path.join(path, f"temp_iteration_{i}_{file}") for i in range(number)]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment