From 368b5a0db8fcb58d94e840428c043711647ddf83 Mon Sep 17 00:00:00 2001
From: Stefan Reck <stefan.reck@fau.de>
Date: Tue, 15 Dec 2020 15:06:41 +0100
Subject: [PATCH] auto choose best n_iterations

---
 orcasong/tools/shuffle2.py | 64 +++++++++++++++++++++++++-------------
 1 file changed, 42 insertions(+), 22 deletions(-)

diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py
index 1d85fa3..6516037 100644
--- a/orcasong/tools/shuffle2.py
+++ b/orcasong/tools/shuffle2.py
@@ -21,8 +21,7 @@ def shuffle_v2(
         max_ram=None,
         max_ram_fraction=0.25,
         chunks=False,
-        delete=False,
-        seed=42):
+        delete=False):
     """
     Shuffle datasets in a h5file that have the same length.
 
@@ -48,8 +47,6 @@ def shuffle_v2(
         to be accurate! (use a node with at least 32gb, the more the better)
     delete : bool
         Delete the original file afterwards?
-    seed : int
-        Sets a fixed random seed for the shuffling.
 
     Returns
     -------
@@ -63,13 +60,11 @@ def shuffle_v2(
     if os.path.exists(output_file):
         raise FileExistsError(output_file)
     if max_ram is None:
-        max_ram = max_ram_fraction * psutil.virtual_memory().available
-        print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
+        max_ram = get_max_ram(max_ram_fraction)
 
     temp_output_file = output_file + "_temp_" + time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
     with h5py.File(input_file, "r") as f_in:
         dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
-        np.random.seed(seed)
         print(f"Shuffling datasets {datasets} with {n_lines} lines each")
 
         if not chunks:
@@ -100,6 +95,12 @@ def shuffle_v2(
     return output_file
 
 
+def get_max_ram(max_ram_fraction):
+    max_ram = max_ram_fraction * psutil.virtual_memory().available
+    print(f"Using {max_ram_fraction:.2%} of available ram = {max_ram} bytes")
+    return max_ram
+
+
 def get_indices_largest(dset_infos):
     largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
     dset_info = dset_infos[largest_dset]
@@ -108,9 +109,6 @@ def get_indices_largest(dset_infos):
     ratio = dset_info['chunks_per_batch']/dset_info['n_chunks']
     print(f"Chunks per batch: {dset_info['chunks_per_batch']} ({ratio:.2%})")
 
-    if ratio <= 0.1:
-        print("Warning: Should have more than "
-              "10% of chunks per batch to ensure proper shuffling!")
     return get_indices_chunked(
         dset_info["n_batches_chunkwise"],
         dset_info["n_chunks"],
@@ -118,6 +116,22 @@ def get_indices_largest(dset_infos):
     )
 
 
+def get_n_iterations(input_file, datasets=("x", "y"), max_ram=None, max_ram_fraction=0.25):
+    """ Get how often you have to shuffle with given ram to get proper randomness. """
+    if max_ram is None:
+        max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
+    with h5py.File(input_file, "r") as f_in:
+        dset_infos, n_lines = get_dset_infos(f_in, datasets, max_ram)
+    largest_dset = np.argmax([v["n_batches_chunkwise"] for v in dset_infos])
+    dset_info = dset_infos[largest_dset]
+    n_iterations = int(np.ceil(
+        np.log(dset_info['n_chunks'])/np.log(dset_info['chunks_per_batch'])))
+    print(f"Total chunks: {dset_info['n_chunks']}")
+    print(f"Chunks per batch: {dset_info['chunks_per_batch']}")
+    print(f"--> min iterations for full shuffle: {n_iterations}")
+    return n_iterations
+
+
 def get_indices_chunked(n_batches, n_chunks, chunksize):
     """ Return a list with the chunkwise shuffled indices of each batch. """
     chunk_indices = np.arange(n_chunks)
@@ -255,22 +269,22 @@ def slicify(fancy_indices):
 
 def h5shuffle2():
     parser = argparse.ArgumentParser(
-        description='Shuffle datasets in a h5file that have the same length.'
-                    'Uses chunkwise readout for a pseudo-shuffle, so shuffling '
-                    'multiple times is recommended for larger files.')
+        description='Shuffle datasets in a h5file that have the same length. '
+                    'Uses chunkwise readout for speed-up.')
     parser.add_argument('input_file', type=str,
                         help='Path of the file that will be shuffled.')
-    parser.add_argument('--output_file', type=str,
+    parser.add_argument('--output_file', type=str, default=None,
                         help='If given, this will be the name of the output file. '
-                             'Otherwise, a name is auto generated.')
+                             'Default: input_file + suffix.')
     parser.add_argument('--datasets', type=str, nargs="*", default=("x", "y"),
                         help='Which datasets to include in output. Default: x, y')
     parser.add_argument('--max_ram_fraction', type=float, default=0.25,
-                        help="in [0, 1]. Fraction of ram to use for reading one batch of data "
-                             "when max_ram is None. Note: this should "
-                             "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling.")
-    parser.add_argument('--iterations', type=int, default=2,
-                        help="Shuffle the file this many times. Default: 2")
+                        help="in [0, 1]. Fraction of all available ram to use for reading one batch of data "
+                             "Note: this should "
+                             "be <=~0.25 or so, since lots of ram is needed for in-memory shuffling. "
+                             "Default: 0.25")
+    parser.add_argument('--iterations', type=int, default=None,
+                        help="Shuffle the file this many times. Default: Auto choose best number.")
     kwargs = vars(parser.parse_args())
 
     input_file = kwargs.pop("input_file")
@@ -278,9 +292,15 @@ def h5shuffle2():
     if output_file is None:
         output_file = get_filepath_output(input_file, shuffle=True)
     iterations = kwargs.pop("iterations")
-
+    if iterations is None:
+        iterations = get_n_iterations(
+            input_file,
+            datasets=kwargs["datasets"],
+            max_ram_fraction=kwargs["max_ram_fraction"],
+        )
+    np.random.seed(42)
     for i in range(iterations):
-        print(f"Iteration {i}")
+        print(f"\nIteration {i+1}/{iterations}")
         if i == 0:
             # first iteration
             stgs = {
-- 
GitLab