import os import time import datetime import argparse import numpy as np import psutil import h5py from km3pipe.sys import peak_memory_usage from orcasong.tools.postproc import get_filepath_output, copy_used_files from orcasong.tools.concatenate import copy_attrs __author__ = "Stefan Reck" def h5shuffle2(input_file, output_file=None, iterations=None, datasets=("x", "y"), max_ram_fraction=0.25, **kwargs): if output_file is None: output_file = get_filepath_output(input_file, shuffle=True) if iterations is None: iterations = get_n_iterations( input_file, datasets=datasets, max_ram_fraction=max_ram_fraction, ) np.random.seed(42) for i in range(iterations): print(f"\nIteration {i+1}/{iterations}") if iterations == 1: # special case if theres only one iteration stgs = { "input_file": input_file, "output_file": output_file, "delete": False, } elif i == 0: # first iteration stgs = { "input_file": input_file, "output_file": f"{output_file}_temp_{i}", "delete": False, } elif i == iterations-1: # last iteration stgs = { "input_file": f"{output_file}_temp_{i-1}", "output_file": output_file, "delete": True, } else: # intermediate iterations stgs = { "input_file": f"{output_file}_temp_{i-1}", "output_file": f"{output_file}_temp_{i}", "delete": True, } shuffle_file( datasets=datasets, max_ram_fraction=max_ram_fraction, chunks=True, **stgs, **kwargs, ) def shuffle_file( input_file, datasets=("x", "y"), output_file=None, max_ram=None, max_ram_fraction=0.25, chunks=False, delete=False): """ Shuffle datasets in a h5file that have the same length. Parameters ---------- input_file : str Path of the file that will be shuffle. datasets : tuple Which datasets to include in output. output_file : str, optional If given, this will be the name of the output file. Otherwise, a name is auto generated. max_ram : int, optional Available ram in bytes. Default: Use fraction of maximum available (see max_ram_fraction). max_ram_fraction : float in [0, 1]. Fraction of ram to use for reading one batch of data when max_ram is None. Note: when using chunks, this should be <=~0.25, since lots of ram is needed for in-memory shuffling. chunks : bool Use chunk-wise readout. Large speed boost, but will only quasi-randomize order! Needs lots of ram to be accurate! (use a node with at least 32gb, the more the better) delete : bool Delete the original file afterwards? Returns ------- output_file : str Path to the output file. """ start_time = time.time() if output_file is None: output_file = get_filepath_output(input_file, shuffle=True) if os.path.exists(output_file): raise FileExistsError(output_file) if max_ram is None: max_ram = get_max_ram(max_ram_fraction) temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime()) with h5py.File(input_file, "r") as f_in: dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) print(f"Shuffling datasets {datasets} with {n_lines} lines each") if not chunks: indices = np.arange(n_lines) np.random.shuffle(indices) with h5py.File(temp_output_file, "x") as f_out: for dset_info in dset_infos: print("Creating dataset", dset_info["name"]) make_dset(f_out, dset_info, indices) print("Done!") else: indices_chunked = get_indices_largest(dset_infos) with h5py.File(temp_output_file, "x") as f_out: for dset_info in dset_infos: print("Creating dataset", dset_info["name"]) make_dset_chunked(f_out, dset_info, indices_chunked) print("Done!") copy_used_files(input_file, temp_output_file) copy_attrs(input_file, temp_output_file) os.rename(temp_output_file, output_file) if delete: os.remove(input_file) print(f"Elapsed time: " f"{datetime.timedelta(seconds=int(time.time() - start_time))}") return output_file def get_max_ram(max_ram_fraction): max_ram = max_ram_fraction * psutil.virtual_memory().available print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes") return max_ram def get_indices_largest(dset_infos): largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos]) dset_info = dset_infos[largest_dset] print(f"Total chunks: {dset_info['n_chunks']}") ratio = dset_info['chunks_per_batch']/dset_info['n_chunks'] print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})") return get_indices_chunked( dset_info["n_batches_chunkwise"], dset_info["n_chunks"], dset_info["chunksize"], ) def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25): """ Get how often you have to shuffle with given ram to get proper randomness. """ if max_ram is None: max_ram = get_max_ram(max_ram_fraction=max_ram_fraction) with h5py.File(input_file, "r") as f_in: dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram) largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos]) dset_info = dset_infos[largest_dset] n_iterations = int(np.ceil( np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch']))) print(f"Total chunks: {dset_info['n_chunks']}") print(f"Chunks per batch: {dset_info['chunks_per_batch']}") print(f"--> min iterations for full shuffle: {n_iterations}") return n_iterations def get_indices_chunked(n_batches, n_chunks, chunksize): """ Return a list with the chunkwise shuffled indices of each batch. """ chunk_indices = np.arange(n_chunks) np.random.shuffle(chunk_indices) chunk_batches = np.array_split(chunk_indices, n_batches) index_batches = [] for bat in chunk_batches: idx = (bat[:, None]*chunksize + np.arange(chunksize)[None, :]).flatten() np.random.shuffle(idx) index_batches.append(idx) return index_batches def get_dset_infos(f, datasets, max_ram): """ Check datasets and retrieve relevant infos for each. """ dset_infos = [] n_lines = None for i, name in enumerate(datasets): dset = f[name] if i == 0: n_lines = len(dset) else: if len(dset) != n_lines: raise ValueError(f"dataset {name} has different length! " f"{len(dset)} vs {n_lines}") chunksize = dset.chunks[0] n_chunks = int(np.ceil(n_lines / chunksize)) # TODO in h5py 3.X, use .nbytes to get uncompressed size bytes_per_line = np.asarray(dset[0]).nbytes bytes_per_chunk = bytes_per_line * chunksize lines_per_batch = int(np.floor(max_ram / bytes_per_line)) chunks_per_batch = int(np.floor(max_ram / bytes_per_chunk)) dset_infos.append({ "name": name, "dset": dset, "chunksize": chunksize, "n_lines": n_lines, "n_chunks": n_chunks, "bytes_per_line": bytes_per_line, "bytes_per_chunk": bytes_per_chunk, "lines_per_batch": lines_per_batch, "chunks_per_batch": chunks_per_batch, "n_batches_linewise": int(np.ceil(n_lines / lines_per_batch)), "n_batches_chunkwise": int(np.ceil(n_chunks / chunks_per_batch)), }) return dset_infos, n_lines def make_dset(f_out, dset_info, indices): """ Create a shuffled dataset in the output file. """ for batch_index in range(dset_info["n_batches_linewise"]): print(f"Processing batch {batch_index+1}/{dset_info['n_batches_linewise']}") slc = slice( batch_index * dset_info["lines_per_batch"], (batch_index+1) * dset_info["lines_per_batch"], ) to_read = indices[slc] # reading has to be done with linearly increasing index, # so sort -> read -> undo sorting sort_ix = np.argsort(to_read) unsort_ix = np.argsort(sort_ix) data = dset_info["dset"][to_read[sort_ix]][unsort_ix] if batch_index == 0: in_dset = dset_info["dset"] out_dset = f_out.create_dataset( dset_info["name"], data=data, maxshape=in_dset.shape, chunks=in_dset.chunks, compression=in_dset.compression, compression_opts=in_dset.compression_opts, shuffle=in_dset.shuffle, ) out_dset.resize(len(in_dset), axis=0) else: f_out[dset_info["name"]][slc] = data def make_dset_chunked(f_out, dset_info, indices_chunked): """ Create a shuffled dataset in the output file. """ start_idx = 0 for batch_index, to_read in enumerate(indices_chunked): print(f"Processing batch {batch_index+1}/{len(indices_chunked)}") # remove indices outside of dset to_read = to_read[to_read < len(dset_info["dset"])] # reading has to be done with linearly increasing index # fancy indexing is super slow # so sort -> turn to slices -> read -> conc -> undo sorting sort_ix = np.argsort(to_read) unsort_ix = np.argsort(sort_ix) fancy_indices = to_read[sort_ix] slices = slicify(fancy_indices) data = np.concatenate([dset_info["dset"][slc] for slc in slices]) data = data[unsort_ix] if batch_index == 0: in_dset = dset_info["dset"] out_dset = f_out.create_dataset( dset_info["name"], data=data, maxshape=in_dset.shape, chunks=in_dset.chunks, compression=in_dset.compression, compression_opts=in_dset.compression_opts, shuffle=in_dset.shuffle, ) out_dset.resize(len(in_dset), axis=0) start_idx = len(data) else: end_idx = start_idx + len(data) f_out[dset_info["name"]][start_idx:end_idx] = data start_idx = end_idx print("Memory peak: {0:.3f} MB".format(peak_memory_usage())) if start_idx != len(dset_info["dset"]): print(f"Warning: last index was {start_idx} not {len(dset_info['dset'])}") def slicify(fancy_indices): """ [0,1,2, 6,7,8] --> [0:3, 6:9] """ steps = np.diff(fancy_indices) != 1 slice_starts = np.concatenate([fancy_indices[:1], fancy_indices[1:][steps]]) slice_ends = np.concatenate([fancy_indices[:-1][steps], fancy_indices[-1:]]) + 1 return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))] def run_parser(): parser = argparse.ArgumentParser( description='Shuffle datasets in a h5file that have the same length. ' 'Uses chunkwise readout for speed-up.') parser.add_argument('input_file', type=str, help='Path of the file that will be shuffled.') parser.add_argument('--output_file', type=str, default=None, help='If given, this will be the name of the output file. ' 'Default: input_file + suffix.') parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"), help='Which datasets to include in output. Default: x, y') parser.add_argument('--max_ram_fraction', type=float, default=0.25, help="in [0, 1]. Fraction of all available ram to use for reading one batch of data " "Note: this should " "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. " "Default: 0.25") parser.add_argument('--iterations', type=int, default=None, help="Shuffle the file this many times. Default: Auto choose best number.") h5shuffle2(**vars(parser.parse_args()))