From 378a84c962f19f3b04113df387fda0f37027ff51 Mon Sep 17 00:00:00 2001 From: Stefan Reck <stefan.reck@fau.de> Date: Mon, 6 May 2019 13:32:30 +0200 Subject: [PATCH] Minor. --- .../data_tools/shuffle/shuffle_h5.py | 24 +++++++++++++------ orcasong_plag/util/split_conc.py | 16 ++++++++----- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py index 56cdbc3..601e3d3 100644 --- a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py +++ b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py @@ -124,6 +124,21 @@ def parse_input(): return input_files_list, delete, chunksize, complib, complevel, legacy_mode +def get_filepath_output(filepath_input, shuffle, event_skipper): + """ + Get the filename of the shuffled / rebalanced output file as a str. + + """ + filepath_input_without_ext = os.path.splitext(filepath_input)[0] + fname_adtn = '' + if shuffle: + fname_adtn += '_shuffled' + if event_skipper is not None: + fname_adtn += '_reb' + filepath_output = filepath_input_without_ext + fname_adtn + ".h5" + return filepath_output + + def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None, complib=None, complevel=None, legacy_mode=False, shuffle=True, event_skipper=None, filepath_output=None): @@ -189,13 +204,8 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None complevel = None if filepath_output is None: - filepath_input_without_ext = os.path.splitext(filepath_input)[0] - fname_adtn = '' - if shuffle: - fname_adtn += '_shuffled' - if event_skipper is not None: - fname_adtn += '_reb' - filepath_output = filepath_input_without_ext + fname_adtn + ".h5" + filepath_output = get_filepath_output(filepath_input, shuffle, + event_skipper) if not legacy_mode: # set random km3pipe (=numpy) seed diff --git a/orcasong_plag/util/split_conc.py b/orcasong_plag/util/split_conc.py index 7b311b7..3b8a285 100644 --- a/orcasong_plag/util/split_conc.py +++ b/orcasong_plag/util/split_conc.py @@ -13,12 +13,14 @@ Example: [ { - 'file_list': array(['file_2.h5', file_1.h5]), dtype='<U69'), - 'output_filepath': 'test_train_0.h5' + 'file_list': array(['file_2.h5', file_1.h5]), + 'output_filepath': 'test_train_0.h5', + 'is_train': True, }, { - 'file_list': array(['file_4.h5', 'file_0.h5']), dtype='<U69'), - 'output_filepath': 'test_val_0.h5' + 'file_list': array(['file_4.h5', 'file_0.h5']), + 'output_filepath': 'test_val_0.h5', + 'is_train': False, }, ] @@ -136,7 +138,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1, output_filepath = "{}_train_{}.h5".format(outfile_basestr, i) job_dict = { "file_list": job_files, - "output_filepath": output_filepath + "output_filepath": output_filepath, + "is_train": True, } jobs.append(job_dict) @@ -144,7 +147,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1, output_filepath = "{}_val_{}.h5".format(outfile_basestr, i) job_dict = { "file_list": job_files, - "output_filepath": output_filepath + "output_filepath": output_filepath, + "is_train": False, } jobs.append(job_dict) -- GitLab