From 378a84c962f19f3b04113df387fda0f37027ff51 Mon Sep 17 00:00:00 2001
From: Stefan Reck <stefan.reck@fau.de>
Date: Mon, 6 May 2019 13:32:30 +0200
Subject: [PATCH] Minor.

---
 .../data_tools/shuffle/shuffle_h5.py          | 24 +++++++++++++------
 orcasong_plag/util/split_conc.py              | 16 ++++++++-----
 2 files changed, 27 insertions(+), 13 deletions(-)

diff --git a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
index 56cdbc3..601e3d3 100644
--- a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
+++ b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
@@ -124,6 +124,21 @@ def parse_input():
     return input_files_list, delete, chunksize, complib, complevel, legacy_mode
 
 
+def get_filepath_output(filepath_input, shuffle, event_skipper):
+    """
+    Get the filename of the shuffled / rebalanced output file as a str.
+
+    """
+    filepath_input_without_ext = os.path.splitext(filepath_input)[0]
+    fname_adtn = ''
+    if shuffle:
+        fname_adtn += '_shuffled'
+    if event_skipper is not None:
+        fname_adtn += '_reb'
+    filepath_output = filepath_input_without_ext + fname_adtn + ".h5"
+    return filepath_output
+
+
 def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None,
                complib=None, complevel=None, legacy_mode=False, shuffle=True,
                event_skipper=None, filepath_output=None):
@@ -189,13 +204,8 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None
         complevel = None
 
     if filepath_output is None:
-        filepath_input_without_ext = os.path.splitext(filepath_input)[0]
-        fname_adtn = ''
-        if shuffle:
-            fname_adtn += '_shuffled'
-        if event_skipper is not None:
-            fname_adtn += '_reb'
-        filepath_output = filepath_input_without_ext + fname_adtn + ".h5"
+        filepath_output = get_filepath_output(filepath_input, shuffle,
+                                              event_skipper)
 
     if not legacy_mode:
         # set random km3pipe (=numpy) seed
diff --git a/orcasong_plag/util/split_conc.py b/orcasong_plag/util/split_conc.py
index 7b311b7..3b8a285 100644
--- a/orcasong_plag/util/split_conc.py
+++ b/orcasong_plag/util/split_conc.py
@@ -13,12 +13,14 @@ Example:
 
     [
      {
-      'file_list': array(['file_2.h5', file_1.h5]), dtype='<U69'),
-      'output_filepath': 'test_train_0.h5'
+      'file_list': array(['file_2.h5', file_1.h5]),
+      'output_filepath': 'test_train_0.h5',
+      'is_train': True,
      },
      {
-      'file_list': array(['file_4.h5', 'file_0.h5']), dtype='<U69'),
-      'output_filepath': 'test_val_0.h5'
+      'file_list': array(['file_4.h5', 'file_0.h5']),
+      'output_filepath': 'test_val_0.h5',
+      'is_train': False,
      },
     ]
 
@@ -136,7 +138,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1,
         output_filepath = "{}_train_{}.h5".format(outfile_basestr, i)
         job_dict = {
             "file_list": job_files,
-            "output_filepath": output_filepath
+            "output_filepath": output_filepath,
+            "is_train": True,
         }
         jobs.append(job_dict)
 
@@ -144,7 +147,8 @@ def get_split(folder, outfile_basestr, n_train_files=1, n_val_files=1,
         output_filepath = "{}_val_{}.h5".format(outfile_basestr, i)
         job_dict = {
             "file_list": job_files,
-            "output_filepath": output_filepath
+            "output_filepath": output_filepath,
+            "is_train": False,
         }
         jobs.append(job_dict)
 
-- 
GitLab