From 55b3f66a19cf3188f0692c95133dfe492030d5e4 Mon Sep 17 00:00:00 2001 From: Stefan Reck <stefan.reck@fau.de> Date: Thu, 11 Apr 2019 13:57:01 +0200 Subject: [PATCH] Added event skipper to orcasong plag. --- .../data_tools/concatenate/concatenate_h5.py | 12 +++++----- orcasong_plag/core.py | 15 ++++++++----- orcasong_plag/modules.py | 22 +++++++++++++++++++ 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/orcasong_contrib/data_tools/concatenate/concatenate_h5.py b/orcasong_contrib/data_tools/concatenate/concatenate_h5.py index b9ed5ef..7580f22 100644 --- a/orcasong_contrib/data_tools/concatenate/concatenate_h5.py +++ b/orcasong_contrib/data_tools/concatenate/concatenate_h5.py @@ -106,7 +106,8 @@ def parse_input(): def get_cum_number_of_rows(file_list): """ Returns the cumulative number of rows (axis_0) in a list based on the - specified input .h5 files. + specified input .h5 files (i.e. [0,100,200,300,...] if each file has + 100 rows). Parameters ---------- @@ -200,7 +201,8 @@ def get_f_compression_and_chunking(filepath): return compression, compression_opts, chunksize -def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, complib, complevel): +def concatenate_h5_files(output_filepath, file_list, + chunksize=None, complib=None, complevel=None): """ Function that concatenates hdf5 files based on an output_filepath and a file_list of input files. @@ -213,8 +215,6 @@ def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, c String that specifies the filepath (path+name) of the output .h5 file. file_list : list List that contains all filepaths of the input files. - cum_rows_list : list - List that contains the cumulative number of rows (i.e. [0,100,200,300,...] if each file has 100 rows). chunksize : None/int Specifies the chunksize for axis_0 in the concatenated output files. If None, the chunksize is read from the first input file. @@ -231,6 +231,7 @@ def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, c Else, a custom compression level will be used. """ + cum_rows_list = get_cum_number_of_rows(file_list) complib_f, complevel_f, chunksize_f = get_f_compression_and_chunking(file_list[0]) chunksize = chunksize_f if chunksize is None else chunksize @@ -305,8 +306,7 @@ def main(): In deep learning applications for example, the chunksize should be equal to the batch size that is used later on for reading the data. """ file_list, output_filepath, chunksize, complib, complevel = parse_input() - cum_rows_list = get_cum_number_of_rows(file_list) - concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, complib, complevel) + concatenate_h5_files(output_filepath, file_list, chunksize, complib, complevel) if __name__ == '__main__': diff --git a/orcasong_plag/core.py b/orcasong_plag/core.py index cfefa2f..12f6d04 100644 --- a/orcasong_plag/core.py +++ b/orcasong_plag/core.py @@ -2,7 +2,7 @@ import km3pipe as kp import km3modules as km import os -from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker, BinningStatsMaker +from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker, BinningStatsMaker, EventSkipper from orcasong_plag.mc_info_types import get_mc_info_extr from orcasong_plag.util.bin_stats_plot import plot_hists, add_hists_to_h5file, plot_hist_of_files @@ -29,6 +29,9 @@ class FileBinner: Function that extracts desired mc_info from a blob, which is then stored as the "y" datafield in the .h5 file. Can also give a str identifier for an existing extractor. + event_skipper : func, optional + Function that takes the blob as an input, and returns a bool. + If the bool is true, the blob will be skipped. bin_plot_freq : int or None If int is given, defines after how many blobs data for an overview histogram is extracted. @@ -59,9 +62,11 @@ class FileBinner: but it increases the RAM usage as well. """ - def __init__(self, bin_edges_list, mc_info_extr=None, add_bin_stats=True): + def __init__(self, bin_edges_list, mc_info_extr=None, + event_skipper=None, add_bin_stats=True): self.bin_edges_list = bin_edges_list self.mc_info_extr = mc_info_extr + self.event_skipper = event_skipper if add_bin_stats: self.bin_plot_freq = 1 @@ -71,7 +76,6 @@ class FileBinner: self.n_statusbar = 1000 self.n_memory_observer = 1000 self.do_time_preproc = True - # self.data_cuts = None self.chunksize = 32 self.complib = 'zlib' @@ -160,9 +164,8 @@ class FileBinner: if self.do_time_preproc: pipe.attach(TimePreproc) - # if self.data_cuts is not None: - # from orcasong.utils import EventSkipper - # pipe.attach(EventSkipper, data_cuts=self.data_cuts) + if self.event_skipper is not None: + pipe.attach(EventSkipper, event_skipper=self.event_skipper) if self.bin_plot_freq is not None: pipe.attach(BinningStatsMaker, diff --git a/orcasong_plag/modules.py b/orcasong_plag/modules.py index 72c615a..30e8860 100644 --- a/orcasong_plag/modules.py +++ b/orcasong_plag/modules.py @@ -214,3 +214,25 @@ class BinningStatsMaker(kp.Module): """ return self.hists + + +class EventSkipper(kp.Module): + """ + Skip events based on some user function. + + Attributes + ---------- + event_skipper : func + Function that takes the blob as an input, and returns a bool. + If the bool is true, the blob will be skipped. + + """ + def configure(self): + self.event_skipper = self.require('event_skipper') + + def process(self, blob): + skip_event = self.event_skipper(blob) + if skip_event: + return + else: + return blob -- GitLab