From ab6b9683d6a155b06e83171f67622a8390a6d347 Mon Sep 17 00:00:00 2001
From: Stefan Reck <stefan.reck@fau.de>
Date: Fri, 12 Apr 2019 12:49:21 +0200
Subject: [PATCH] Added a event skipper functionality to shuffle_h5.

---
 .../data_tools/shuffle/shuffle_h5.py          | 61 ++++++++-----------
 orcasong_plag/core.py                         | 13 +++-
 2 files changed, 36 insertions(+), 38 deletions(-)

diff --git a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
index 7b57dc0..0abd22e 100644
--- a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
+++ b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
@@ -21,6 +21,8 @@ import numpy as np
 import h5py
 import km3pipe as kp
 import km3modules as km
+from orcasong_contrib.data_tools.concatenate.concatenate_h5 import get_f_compression_and_chunking
+from orcasong_plag.modules import EventSkipper
 
 # from memory_profiler import profile # for memory profiling, call with @profile; myfunc()
 
@@ -122,43 +124,15 @@ def parse_input():
     return input_files_list, delete, chunksize, complib, complevel, legacy_mode
 
 
-def get_f_compression_and_chunking(filepath):
-    """
-    Function that gets the used compression library, the compression level (if applicable)
-    and the chunksize of axis_0 of the first dataset of the file.
-
-    Parameters
-    ----------
-    filepath : str
-        Filepath of a .hdf5 file.
-
-    Returns
-    -------
-    compression : str
-        The compression library that has been identified in the input file. E.g. 'gzip', or 'lzf'.
-    complevel : int
-        The compression level that has been identified in the input file.
-    chunksize : None/int
-        The chunksize of axis_0 that has been indentified in the input file.
-
-    """
-    f = h5py.File(filepath, 'r')
-
-    # remove any keys to pytables folders that may be in the file
-    f_keys_stripped = [x for x in list(f.keys()) if '_i_' not in x]
-
-    compression = f[f_keys_stripped[0]].compression  # compression filter
-    compression_opts = f[f_keys_stripped[0]].compression_opts  # filter strength
-    chunksize = f[f_keys_stripped[0]].chunks[0]  # chunksize along axis_0 of the dataset
-
-    return compression, compression_opts, chunksize
-
-
-def shuffle_h5(filepath_input, tool=False, seed=42, delete=True, chunksize=None, complib=None, complevel=None, legacy_mode=False):
+def shuffle_h5(filepath_input, tool=False, seed=42, delete=True, chunksize=None,
+               complib=None, complevel=None, legacy_mode=False, shuffle=True,
+               event_skipper=None):
     """
     Shuffles a .h5 file where each dataset needs to have the same number of rows (axis_0).
     The shuffled data is saved to a new .h5 file with the suffix < _shuffled.h5 >.
 
+    Can also skip certain events if a event_skipper is given.
+
     Parameters
     ----------
     filepath_input : str
@@ -187,6 +161,11 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=True, chunksize=None,
     legacy_mode : bool
         Boolean flag that specifies, if the legacy shuffle mode should be used instead of the standard one.
         A more detailed description of this mode can be found in the summary at the top of this python file.
+    shuffle : bool
+        If false, events will not be shuffled.
+    event_skipper : func, optional
+        Function that takes the blob as an input, and returns a bool.
+        If the bool is true, the blob will be skipped.
 
     Returns
     -------
@@ -194,6 +173,9 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=True, chunksize=None,
         H5py file instance of the shuffled output file.
 
     """
+    if event_skipper is None and not shuffle:
+        raise ValueError("Either event_skipper or shuffle has to be set")
+
     complib_f, complevel_f, chunksize_f = get_f_compression_and_chunking(filepath_input)
 
     chunksize = chunksize_f if chunksize is None else chunksize
@@ -204,7 +186,12 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=True, chunksize=None,
         complevel = None
 
     filepath_input_without_ext = os.path.splitext(filepath_input)[0]
-    filepath_output = filepath_input_without_ext + '_shuffled.h5'
+    fname_adtn = ''
+    if shuffle:
+        fname_adtn += '_shuffled'
+    if event_skipper is not None:
+        fname_adtn += '_reb'
+    filepath_output = filepath_input_without_ext + fname_adtn + ".h5"
 
     if not legacy_mode:
         # set random km3pipe (=numpy) seed
@@ -218,7 +205,11 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=True, chunksize=None,
         pipe = kp.Pipeline(timeit=True)  # add timeit=True argument for profiling
         pipe.attach(km.common.StatusBar, every=200)
         pipe.attach(km.common.MemoryObserver, every=200)
-        pipe.attach(kp.io.hdf5.HDF5Pump, filename=filepath_input, shuffle=True, reset_index=True)
+        pipe.attach(kp.io.hdf5.HDF5Pump, filename=filepath_input, shuffle=shuffle, reset_index=True)
+
+        if event_skipper is not None:
+            pipe.attach(EventSkipper, event_skipper=event_skipper)
+
         pipe.attach(kp.io.hdf5.HDF5Sink, filename=filepath_output, complib=complib, complevel=complevel, chunksize=chunksize, flush_frequency=1000)
         pipe.drain()
         if delete:
diff --git a/orcasong_plag/core.py b/orcasong_plag/core.py
index 12f6d04..181aa7f 100644
--- a/orcasong_plag/core.py
+++ b/orcasong_plag/core.py
@@ -2,9 +2,15 @@ import km3pipe as kp
 import km3modules as km
 import os
 
-from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker, BinningStatsMaker, EventSkipper
+from orcasong_plag.modules import (TimePreproc,
+                                   ImageMaker,
+                                   McInfoMaker,
+                                   BinningStatsMaker,
+                                   EventSkipper)
 from orcasong_plag.mc_info_types import get_mc_info_extr
-from orcasong_plag.util.bin_stats_plot import plot_hists, add_hists_to_h5file, plot_hist_of_files
+from orcasong_plag.util.bin_stats_plot import (plot_hists,
+                                               add_hists_to_h5file,
+                                               plot_hist_of_files)
 
 
 class FileBinner:
@@ -198,7 +204,8 @@ class FileBinner:
 
     def get_names_and_shape(self):
         """
-        Get names and shape of the resulting x data, e.g. (pos_z, time), (18, 50).
+        Get names and shape of the resulting x data,
+        e.g. (pos_z, time), (18, 50).
         """
         names, shape = [], []
         for bin_name, bin_edges in self.bin_edges_list:
-- 
GitLab