Skip to content
Snippets Groups Projects
Commit 378a84c9 authored by Stefan Reck's avatar Stefan Reck
Browse files

Minor.

parent 9e69e74a
No related branches found
No related tags found
No related merge requests found
......@@ -124,6 +124,21 @@ def parse_input():
return input_files_list, delete, chunksize, complib, complevel, legacy_mode
def get_filepath_output(filepath_input, shuffle, event_skipper):
"""
Get the filename of the shuffled / rebalanced output file as a str.
"""
filepath_input_without_ext = os.path.splitext(filepath_input)[0]
fname_adtn = ''
if shuffle:
fname_adtn += '_shuffled'
if event_skipper is not None:
fname_adtn += '_reb'
filepath_output = filepath_input_without_ext + fname_adtn + ".h5"
return filepath_output
def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None,
complib=None, complevel=None, legacy_mode=False, shuffle=True,
event_skipper=None, filepath_output=None):
......@@ -189,13 +204,8 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None
complevel = None
if filepath_output is None:
filepath_input_without_ext = os.path.splitext(filepath_input)[0]
fname_adtn = ''
if shuffle:
fname_adtn += '_shuffled'
if event_skipper is not None:
fname_adtn += '_reb'
filepath_output = filepath_input_without_ext + fname_adtn + ".h5"
filepath_output = get_filepath_output(filepath_input, shuffle,
event_skipper)
if not legacy_mode:
# set random km3pipe (=numpy) seed
......
......@@ -13,12 +13,14 @@ Example:
[
{
'file_list': array(['file_2.h5', file_1.h5]), dtype='<U69'),
'output_filepath': 'test_train_0.h5'
'file_list': array(['file_2.h5', file_1.h5]),
'output_filepath': 'test_train_0.h5',
'is_train': True,
},
{
'file_list': array(['file_4.h5', 'file_0.h5']), dtype='<U69'),
'output_filepath': 'test_val_0.h5'
'file_list': array(['file_4.h5', 'file_0.h5']),
'output_filepath': 'test_val_0.h5',
'is_train': False,
},
]
......@@ -136,7 +138,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1,
output_filepath = "{}_train_{}.h5".format(outfile_basestr, i)
job_dict = {
"file_list": job_files,
"output_filepath": output_filepath
"output_filepath": output_filepath,
"is_train": True,
}
jobs.append(job_dict)
......@@ -144,7 +147,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1,
output_filepath = "{}_val_{}.h5".format(outfile_basestr, i)
job_dict = {
"file_list": job_files,
"output_filepath": output_filepath
"output_filepath": output_filepath,
"is_train": False,
}
jobs.append(job_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment