From b87b001f47b5fc23f74dc7edf73a88000b8b258f Mon Sep 17 00:00:00 2001
From: Stefan Reck <stefan.reck@fau.de>
Date: Wed, 9 Dec 2020 11:52:29 +0100
Subject: [PATCH] minor

---
 orcasong/tools/shuffle2.py | 83 +++++++++++++++++++++++++++-----------
 1 file changed, 60 insertions(+), 23 deletions(-)

diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py
index 80b02b5..1068d8b 100644
--- a/orcasong/tools/shuffle2.py
+++ b/orcasong/tools/shuffle2.py
@@ -5,6 +5,7 @@ import argparse
 import numpy as np
 import psutil
 import h5py
+from km3pipe.sys import peak_memory_usage
 
 from orcasong.tools.postproc import get_filepath_output, copy_used_files
 from orcasong.tools.concatenate import copy_attrs
@@ -18,8 +19,9 @@ def shuffle_v2(
         datasets=("x", "y"),
         output_file=None,
         max_ram=None,
-        max_ram_fraction=0.45,
+        max_ram_fraction=0.25,
         chunks=False,
+        delete=False,
         seed=42):
     """
     Shuffle datasets in a h5file that have the same length.
@@ -37,13 +39,15 @@ def shuffle_v2(
         Available ram in bytes. Default: Use fraction of
         maximum available (see max_ram_fraction).
     max_ram_fraction : float
-        in [0, 1]. Fraction of ram to use for reading data when max_ram
-        is None. Note: when using chunks, this should be <=0.45, since
-        lots of ram is needed for in-memory shuffling.
+        in [0, 1]. Fraction of ram to use for reading one batch of data
+        when max_ram is None. Note: when using chunks, this should
+        be <=~0.25, since lots of ram is needed for in-memory shuffling.
     chunks : bool
-        Use chunk-wise readout. Up to x8 speed boost, but will
+        Use chunk-wise readout. Large speed boost, but will
         only quasi-randomize order! Needs lots of ram
         to be accurate! (use a node with at least 32gb, the more the better)
+    delete : bool
+        Delete the original file afterwards?
     seed : int
         Sets a fixed random seed for the shuffling.
 
@@ -62,6 +66,7 @@ def shuffle_v2(
         max_ram = max_ram_fraction * psutil.virtual_memory().available
         print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
 
+    temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
     with h5py.File(input_file, "r") as f_in:
         dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
         np.random.seed(seed)
@@ -71,7 +76,7 @@ def shuffle_v2(
             indices = np.arange(n_lines)
             np.random.shuffle(indices)
 
-            with h5py.File(output_file, "x") as f_out:
+            with h5py.File(temp_output_file, "x") as f_out:
                 for dset_info in dset_infos:
                     print("Creating dataset", dset_info["name"])
                     make_dset(f_out, dset_info, indices)
@@ -79,15 +84,17 @@ def shuffle_v2(
         else:
             indices_chunked = get_indices_largest(dset_infos)
 
-            with h5py.File(output_file, "x") as f_out:
+            with h5py.File(temp_output_file, "x") as f_out:
                 for dset_info in dset_infos:
                     print("Creating dataset", dset_info["name"])
                     make_dset_chunked(f_out, dset_info, indices_chunked)
                     print("Done!")
 
-    copy_used_files(input_file, output_file)
-    copy_attrs(input_file, output_file)
-
+    copy_used_files(input_file, temp_output_file)
+    copy_attrs(input_file, temp_output_file)
+    os.rename(temp_output_file, output_file)
+    if delete:
+        os.remove(input_file)
     print(f"Elapsed time: "
           f"{datetime.timedelta(seconds=int(time.time() - start_time))}")
     return output_file
@@ -97,10 +104,13 @@ def get_indices_largest(dset_infos):
     largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
     dset_info = dset_infos[largest_dset]
 
-    print(f"Lines per batch: {dset_info['lines_per_batch']}")
-    if dset_info['lines_per_batch'] <= 50000:
+    print(f"Total chunks: {dset_info['n_chunks']}")
+    ratio = dset_info['chunks_per_batch']/dset_info['n_chunks']
+    print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
+
+    if ratio <= 0.1:
         print("Warning: Should have more than "
-              "50 000 lines per batch to ensure proper shuffling!")
+              "10% of chunks per batch to ensure proper shuffling!")
     return get_indices_chunked(
         dset_info["n_batches_chunkwise"],
         dset_info["n_chunks"],
@@ -229,6 +239,8 @@ def make_dset_chunked(f_out, dset_info, indices_chunked):
             f_out[dset_info["name"]][start_idx:end_idx] = data
             start_idx = end_idx
 
+        print("Memory peak: {0:.3f} MB".format(peak_memory_usage()))
+
     if start_idx != len(dset_info["dset"]):
         print(f"Warning: last index was {start_idx} not {len(dset_info['dset'])}")
 
@@ -241,18 +253,43 @@ def slicify(fancy_indices):
     return [slice(slice_starts[i], slice_ends[i]) for i in range(len(slice_starts))]
 
 
-def h5shuffle():
-    parser = argparse.ArgumentParser(description='Shuffle an h5 file using h5py.')
+def h5shuffle2():
+    parser = argparse.ArgumentParser(
+        description='Shuffle datasets in a h5file that have the same length.'
+                    'Uses chunkwise readout for a pseudo-shuffle, so shuffling'
+                    'multiple times is recommended for larger files.')
     parser.add_argument('input_file', type=str,
-                        help='File to shuffle.')
+                        help='Path of the file that will be shuffled.')
     parser.add_argument('--output_file', type=str,
-                        help='Name of output file. Default: Auto generate name.')
-    parser.add_argument('--chunks', action='store_true',
-                        help="Use chunk-wise readout. Up to 8x speed boost, but will "
-                             "only quasi-randomize order! Needs lots of ram "
-                             "to be accurate!")
-    shuffle_v2(**vars(parser.parse_args()))
+                        help='If given, this will be the name of the output file. '
+                             'Otherwise, a name is auto generated.')
+    parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"),
+                        help='Which datasets to include in output. Default: x, y')
+    parser.add_argument('--max_ram_fraction', type=float, default=0.25,
+                        help="in [0, 1]. Fraction of ram to use for reading one batch of data"
+                             "when max_ram is None. Note: this should"
+                             "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling.")
+    parser.add_argument('--iterations', type=int, default=2,
+                        help="Shuffle the file this many times. Default: 2")
+    kwargs = vars(parser.parse_args())
+
+    outfile = kwargs.pop("output_file")
+    iterations = kwargs.pop("iterations")
+    for i in range(iterations):
+        print(f"Iteration {i}")
+        # temp filenames for anything but last iteration
+        if i+1 == iterations:
+            outf = outfile
+        else:
+            outf = f"{outfile}_temp_{i}"
+        # delete temp files
+        if i == 0:
+            delete = False
+        else:
+            delete = True
+
+        outfile = shuffle_v2(**kwargs, output_file=outf, chunks=True, delete=delete)
 
 
 if __name__ == '__main__':
-    h5shuffle()
+    h5shuffle2()
-- 
GitLab