From 55b3f66a19cf3188f0692c95133dfe492030d5e4 Mon Sep 17 00:00:00 2001
From: Stefan Reck <stefan.reck@fau.de>
Date: Thu, 11 Apr 2019 13:57:01 +0200
Subject: [PATCH] Added event skipper to orcasong plag.

---
 .../data_tools/concatenate/concatenate_h5.py  | 12 +++++-----
 orcasong_plag/core.py                         | 15 ++++++++-----
 orcasong_plag/modules.py                      | 22 +++++++++++++++++++
 3 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/orcasong_contrib/data_tools/concatenate/concatenate_h5.py b/orcasong_contrib/data_tools/concatenate/concatenate_h5.py
index b9ed5ef..7580f22 100644
--- a/orcasong_contrib/data_tools/concatenate/concatenate_h5.py
+++ b/orcasong_contrib/data_tools/concatenate/concatenate_h5.py
@@ -106,7 +106,8 @@ def parse_input():
 def get_cum_number_of_rows(file_list):
     """
     Returns the cumulative number of rows (axis_0) in a list based on the
-    specified input .h5 files.
+    specified input .h5 files (i.e. [0,100,200,300,...] if each file has
+    100 rows).
 
     Parameters
     ----------
@@ -200,7 +201,8 @@ def get_f_compression_and_chunking(filepath):
     return compression, compression_opts, chunksize
 
 
-def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, complib, complevel):
+def concatenate_h5_files(output_filepath, file_list,
+                         chunksize=None, complib=None, complevel=None):
     """
     Function that concatenates hdf5 files based on an output_filepath and a file_list of input files.
 
@@ -213,8 +215,6 @@ def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, c
         String that specifies the filepath (path+name) of the output .h5 file.
     file_list : list
         List that contains all filepaths of the input files.
-    cum_rows_list : list
-        List that contains the cumulative number of rows (i.e. [0,100,200,300,...] if each file has 100 rows).
     chunksize : None/int
         Specifies the chunksize for axis_0 in the concatenated output files.
         If None, the chunksize is read from the first input file.
@@ -231,6 +231,7 @@ def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, c
         Else, a custom compression level will be used.
 
     """
+    cum_rows_list = get_cum_number_of_rows(file_list)
     complib_f, complevel_f, chunksize_f = get_f_compression_and_chunking(file_list[0])
 
     chunksize = chunksize_f if chunksize is None else chunksize
@@ -305,8 +306,7 @@ def main():
     In deep learning applications for example, the chunksize should be equal to the batch size that is used later on for reading the data.
     """
     file_list, output_filepath, chunksize, complib, complevel = parse_input()
-    cum_rows_list = get_cum_number_of_rows(file_list)
-    concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, complib, complevel)
+    concatenate_h5_files(output_filepath, file_list, chunksize, complib, complevel)
 
 
 if __name__ == '__main__':
diff --git a/orcasong_plag/core.py b/orcasong_plag/core.py
index cfefa2f..12f6d04 100644
--- a/orcasong_plag/core.py
+++ b/orcasong_plag/core.py
@@ -2,7 +2,7 @@ import km3pipe as kp
 import km3modules as km
 import os
 
-from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker, BinningStatsMaker
+from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker, BinningStatsMaker, EventSkipper
 from orcasong_plag.mc_info_types import get_mc_info_extr
 from orcasong_plag.util.bin_stats_plot import plot_hists, add_hists_to_h5file, plot_hist_of_files
 
@@ -29,6 +29,9 @@ class FileBinner:
         Function that extracts desired mc_info from a blob, which is then
         stored as the "y" datafield in the .h5 file.
         Can also give a str identifier for an existing extractor.
+    event_skipper : func, optional
+        Function that takes the blob as an input, and returns a bool.
+        If the bool is true, the blob will be skipped.
     bin_plot_freq : int or None
         If int is given, defines after how many blobs data for an overview
         histogram is extracted.
@@ -59,9 +62,11 @@ class FileBinner:
         but it increases the RAM usage as well.
 
     """
-    def __init__(self, bin_edges_list, mc_info_extr=None, add_bin_stats=True):
+    def __init__(self, bin_edges_list, mc_info_extr=None,
+                 event_skipper=None, add_bin_stats=True):
         self.bin_edges_list = bin_edges_list
         self.mc_info_extr = mc_info_extr
+        self.event_skipper = event_skipper
 
         if add_bin_stats:
             self.bin_plot_freq = 1
@@ -71,7 +76,6 @@ class FileBinner:
         self.n_statusbar = 1000
         self.n_memory_observer = 1000
         self.do_time_preproc = True
-        # self.data_cuts = None
 
         self.chunksize = 32
         self.complib = 'zlib'
@@ -160,9 +164,8 @@ class FileBinner:
         if self.do_time_preproc:
             pipe.attach(TimePreproc)
 
-        # if self.data_cuts is not None:
-        #     from orcasong.utils import EventSkipper
-        #     pipe.attach(EventSkipper, data_cuts=self.data_cuts)
+        if self.event_skipper is not None:
+            pipe.attach(EventSkipper, event_skipper=self.event_skipper)
 
         if self.bin_plot_freq is not None:
             pipe.attach(BinningStatsMaker,
diff --git a/orcasong_plag/modules.py b/orcasong_plag/modules.py
index 72c615a..30e8860 100644
--- a/orcasong_plag/modules.py
+++ b/orcasong_plag/modules.py
@@ -214,3 +214,25 @@ class BinningStatsMaker(kp.Module):
 
         """
         return self.hists
+
+
+class EventSkipper(kp.Module):
+    """
+    Skip events based on some user function.
+
+    Attributes
+    ----------
+    event_skipper : func
+        Function that takes the blob as an input, and returns a bool.
+        If the bool is true, the blob will be skipped.
+
+    """
+    def configure(self):
+        self.event_skipper = self.require('event_skipper')
+
+    def process(self, blob):
+        skip_event = self.event_skipper(blob)
+        if skip_event:
+            return
+        else:
+            return blob
-- 
GitLab