diff --git a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py index 56cdbc35ae50f2541fa26ecf79d7d66e6cc2b2d6..601e3d33211f01f439a88099e5e341fadf2caa33 100644 --- a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py +++ b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py @@ -124,6 +124,21 @@ def parse_input(): return input_files_list, delete, chunksize, complib, complevel, legacy_mode +def get_filepath_output(filepath_input, shuffle, event_skipper): + """ + Get the filename of the shuffled / rebalanced output file as a str. + + """ + filepath_input_without_ext = os.path.splitext(filepath_input)[0] + fname_adtn = '' + if shuffle: + fname_adtn += '_shuffled' + if event_skipper is not None: + fname_adtn += '_reb' + filepath_output = filepath_input_without_ext + fname_adtn + ".h5" + return filepath_output + + def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None, complib=None, complevel=None, legacy_mode=False, shuffle=True, event_skipper=None, filepath_output=None): @@ -189,13 +204,8 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None complevel = None if filepath_output is None: - filepath_input_without_ext = os.path.splitext(filepath_input)[0] - fname_adtn = '' - if shuffle: - fname_adtn += '_shuffled' - if event_skipper is not None: - fname_adtn += '_reb' - filepath_output = filepath_input_without_ext + fname_adtn + ".h5" + filepath_output = get_filepath_output(filepath_input, shuffle, + event_skipper) if not legacy_mode: # set random km3pipe (=numpy) seed diff --git a/orcasong_plag/util/split_conc.py b/orcasong_plag/util/split_conc.py index 7b311b797edc35943f4e29ac6cf8fa948542bfd7..3b8a285907cc9c104de47397b90be30cdc10f4c8 100644 --- a/orcasong_plag/util/split_conc.py +++ b/orcasong_plag/util/split_conc.py @@ -13,12 +13,14 @@ Example: [ { - 'file_list': array(['file_2.h5', file_1.h5]), dtype='<U69'), - 'output_filepath': 'test_train_0.h5' + 'file_list': array(['file_2.h5', file_1.h5]), + 'output_filepath': 'test_train_0.h5', + 'is_train': True, }, { - 'file_list': array(['file_4.h5', 'file_0.h5']), dtype='<U69'), - 'output_filepath': 'test_val_0.h5' + 'file_list': array(['file_4.h5', 'file_0.h5']), + 'output_filepath': 'test_val_0.h5', + 'is_train': False, }, ] @@ -136,7 +138,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1, output_filepath = "{}_train_{}.h5".format(outfile_basestr, i) job_dict = { "file_list": job_files, - "output_filepath": output_filepath + "output_filepath": output_filepath, + "is_train": True, } jobs.append(job_dict) @@ -144,7 +147,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1, output_filepath = "{}_val_{}.h5".format(outfile_basestr, i) job_dict = { "file_list": job_files, - "output_filepath": output_filepath + "output_filepath": output_filepath, + "is_train": False, } jobs.append(job_dict)