Skip to content
Snippets Groups Projects
Commit 55b3f66a authored by Stefan Reck's avatar Stefan Reck
Browse files

Added event skipper to orcasong plag.

parent f9acd679
No related branches found
No related tags found
No related merge requests found
......@@ -106,7 +106,8 @@ def parse_input():
def get_cum_number_of_rows(file_list):
"""
Returns the cumulative number of rows (axis_0) in a list based on the
specified input .h5 files.
specified input .h5 files (i.e. [0,100,200,300,...] if each file has
100 rows).
Parameters
----------
......@@ -200,7 +201,8 @@ def get_f_compression_and_chunking(filepath):
return compression, compression_opts, chunksize
def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, complib, complevel):
def concatenate_h5_files(output_filepath, file_list,
chunksize=None, complib=None, complevel=None):
"""
Function that concatenates hdf5 files based on an output_filepath and a file_list of input files.
......@@ -213,8 +215,6 @@ def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, c
String that specifies the filepath (path+name) of the output .h5 file.
file_list : list
List that contains all filepaths of the input files.
cum_rows_list : list
List that contains the cumulative number of rows (i.e. [0,100,200,300,...] if each file has 100 rows).
chunksize : None/int
Specifies the chunksize for axis_0 in the concatenated output files.
If None, the chunksize is read from the first input file.
......@@ -231,6 +231,7 @@ def concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, c
Else, a custom compression level will be used.
"""
cum_rows_list = get_cum_number_of_rows(file_list)
complib_f, complevel_f, chunksize_f = get_f_compression_and_chunking(file_list[0])
chunksize = chunksize_f if chunksize is None else chunksize
......@@ -305,8 +306,7 @@ def main():
In deep learning applications for example, the chunksize should be equal to the batch size that is used later on for reading the data.
"""
file_list, output_filepath, chunksize, complib, complevel = parse_input()
cum_rows_list = get_cum_number_of_rows(file_list)
concatenate_h5_files(output_filepath, file_list, cum_rows_list, chunksize, complib, complevel)
concatenate_h5_files(output_filepath, file_list, chunksize, complib, complevel)
if __name__ == '__main__':
......
......@@ -2,7 +2,7 @@ import km3pipe as kp
import km3modules as km
import os
from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker, BinningStatsMaker
from orcasong_plag.modules import TimePreproc, ImageMaker, McInfoMaker, BinningStatsMaker, EventSkipper
from orcasong_plag.mc_info_types import get_mc_info_extr
from orcasong_plag.util.bin_stats_plot import plot_hists, add_hists_to_h5file, plot_hist_of_files
......@@ -29,6 +29,9 @@ class FileBinner:
Function that extracts desired mc_info from a blob, which is then
stored as the "y" datafield in the .h5 file.
Can also give a str identifier for an existing extractor.
event_skipper : func, optional
Function that takes the blob as an input, and returns a bool.
If the bool is true, the blob will be skipped.
bin_plot_freq : int or None
If int is given, defines after how many blobs data for an overview
histogram is extracted.
......@@ -59,9 +62,11 @@ class FileBinner:
but it increases the RAM usage as well.
"""
def __init__(self, bin_edges_list, mc_info_extr=None, add_bin_stats=True):
def __init__(self, bin_edges_list, mc_info_extr=None,
event_skipper=None, add_bin_stats=True):
self.bin_edges_list = bin_edges_list
self.mc_info_extr = mc_info_extr
self.event_skipper = event_skipper
if add_bin_stats:
self.bin_plot_freq = 1
......@@ -71,7 +76,6 @@ class FileBinner:
self.n_statusbar = 1000
self.n_memory_observer = 1000
self.do_time_preproc = True
# self.data_cuts = None
self.chunksize = 32
self.complib = 'zlib'
......@@ -160,9 +164,8 @@ class FileBinner:
if self.do_time_preproc:
pipe.attach(TimePreproc)
# if self.data_cuts is not None:
# from orcasong.utils import EventSkipper
# pipe.attach(EventSkipper, data_cuts=self.data_cuts)
if self.event_skipper is not None:
pipe.attach(EventSkipper, event_skipper=self.event_skipper)
if self.bin_plot_freq is not None:
pipe.attach(BinningStatsMaker,
......
......@@ -214,3 +214,25 @@ class BinningStatsMaker(kp.Module):
"""
return self.hists
class EventSkipper(kp.Module):
"""
Skip events based on some user function.
Attributes
----------
event_skipper : func
Function that takes the blob as an input, and returns a bool.
If the bool is true, the blob will be skipped.
"""
def configure(self):
self.event_skipper = self.require('event_skipper')
def process(self, blob):
skip_event = self.event_skipper(blob)
if skip_event:
return
else:
return blob
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment