From aeeaa6fa3997580c94964c2a39995fdc08a8d72a Mon Sep 17 00:00:00 2001
From: Stefan Reck <stefan.reck@fau.de>
Date: Fri, 28 Feb 2020 11:42:55 +0100
Subject: [PATCH] Improve concatenate - add shuffle filter - add option to conc
 individual files - remove hardcoded x/y/event_info - improve tests

---
 orcasong/tools/concatenate.py                 | 253 ++++++++++++------
 .../data_tools/shuffle/shuffle_h5.py          |   5 +-
 scraps/misc.py                                |  21 --
 tests/test_concatenate.py                     |  76 ++++--
 4 files changed, 221 insertions(+), 134 deletions(-)
 delete mode 100644 scraps/misc.py

diff --git a/orcasong/tools/concatenate.py b/orcasong/tools/concatenate.py
index 55bda1e..0cfcd0f 100644
--- a/orcasong/tools/concatenate.py
+++ b/orcasong/tools/concatenate.py
@@ -17,14 +17,26 @@ class FileConcatenator:
     input_files : list
         List that contains all filepaths of the input files.
     comptopts : dict
-        Options for compression. They are read from the first input file.
-        E.g. complib
+        Options for compression. They are read from the first input file,
+        but they can be updated as well during init.
     cumu_rows : np.array
         The cumulative number of rows (axis_0) of the specified
         input .h5 files (i.e. [0,100,200,300,...] if each file has 100 rows).
 
     """
-    def __init__(self, input_files):
+    def __init__(self, input_files, comptopts_update=None):
+        """
+        Check the files to concatenate.
+
+        Parameters
+        ----------
+        input_files : List
+            List that contains all filepaths of the input files.
+        comptopts_update : dict, optional
+            Overwrite the compression options that get read from the
+            first file. E.g. {'chunksize': 10} to get a chunksize of 10.
+
+        """
         self.input_files = input_files
         print(f"Checking {len(self.input_files)} files ...")
 
@@ -33,9 +45,11 @@ class FileConcatenator:
 
         # Get compression options from first file in the list
         self.comptopts = get_compopts(self.input_files[0])
-        print("\n".join([f"{k}:\t{v}" for k, v in self.comptopts.items()]))
+        if comptopts_update:
+            self.comptopts.update(comptopts_update)
+        print("\n".join([f"  {k}:\t{v}" for k, v in self.comptopts.items()]))
 
-        self._append_mc_index = False
+        self._modify_folder = False
 
     @classmethod
     def from_list(cls, list_file, n_files=None, **kwargs):
@@ -60,68 +74,102 @@ class FileConcatenator:
             input_files = input_files[:n_files]
         return cls(input_files, **kwargs)
 
-    def concatenate(self, output_filepath):
-        """ Concatenate input files and save output to given path. """
-        f_out = h5py.File(output_filepath, 'x')
-        start_time = time.time()
-        for n, input_file in enumerate(self.input_files):
-            print(f'Processing file {n+1}/{len(self.input_files)}: {input_file}')
-            f_in = h5py.File(input_file, 'r')
-
-            # create metadata
-            if n == 0 and 'format_version' in list(f_in.attrs.keys()):
-                f_out.attrs['format_version'] = f_in.attrs['format_version']
-
-            for folder_name in f_in:
-                if is_folder_ignored(folder_name):
-                    # we dont need datasets created by pytables anymore
-                    continue
-
-                folder_data = f_in[folder_name][()]
-                if n > 0 and folder_name in [
-                        'event_info', 'group_info', 'x_indices', 'y']:
-                    # we need to add the current number of the group_id / index in the file_output
-                    # to the group_ids / indices of the file that is to be appended
-                    column_name = 'group_id' if folder_name in [
-                        'event_info', 'group_info', 'y'] else 'index'
-                    # add 1 because the group_ids / indices start with 0
-                    folder_data[column_name] += np.amax(
-                        f_out[folder_name][column_name]) + 1
-
-                if self._append_mc_index and folder_name == "event_info":
-                    folder_data = self._modify_event_info(input_file, folder_data)
-
-                if n == 0:
-                    # first file; create the dummy dataset with no max shape
-                    print(f"\tCreating dataset '{folder_name}' with shape "
-                          f"{(self.cumu_rows[-1],) + folder_data.shape[1:]}")
-                    output_dataset = f_out.create_dataset(
-                        folder_name,
-                        data=folder_data,
-                        maxshape=(None,) + folder_data.shape[1:],
-                        chunks=(self.comptopts["chunksize"],) + folder_data.shape[1:],
-                        compression=self.comptopts["complib"],
-                        compression_opts=self.comptopts["complevel"],
-                    )
-                    output_dataset.resize(self.cumu_rows[-1], axis=0)
-
-                else:
-                    f_out[folder_name][self.cumu_rows[n]:self.cumu_rows[n + 1]] = folder_data
-            f_in.close()
-            f_out.flush()
-
-        elapsed_time = time.time() - start_time
-        # include the used filepaths in the file
-        f_out.create_dataset(
-            "used_files",
-            data=[n.encode("ascii", "ignore") for n in self.input_files]
-        )
-        f_out.close()
+    def concatenate(self, output_filepath, append_used_files=True):
+        """
+        Concatenate the input files.
+
+        Parameters
+        ----------
+        output_filepath : str
+            Path of the concatenated output file.
+        append_used_files : bool
+            If True (default), add a dataset called 'used_files' to the
+            output that contains the paths of the input_files.
+
+        """
+        print(f"Creating file {output_filepath}")
+        with h5py.File(output_filepath, 'x') as f_out:
+            start_time = time.time()
+            for input_file_nmbr, input_file in enumerate(self.input_files):
+                print(f'Processing file {input_file_nmbr+1}/'
+                      f'{len(self.input_files)}: {input_file}')
+                with h5py.File(input_file, 'r') as f_in:
+                    self._conc_file(f_in, f_out, input_file, input_file_nmbr)
+                f_out.flush()
+            elapsed_time = time.time() - start_time
+
+            if append_used_files:
+                # include the used filepaths in the file
+                print("Adding used files to output")
+                f_out.create_dataset(
+                    "used_files",
+                    data=[n.encode("ascii", "ignore") for n in self.input_files]
+                )
+
         print(f"\nConcatenation complete!"
               f"\nElapsed time: {elapsed_time/60:.2f} min "
               f"({elapsed_time/len(self.input_files):.2f} s per file)")
 
-    def _modify_event_info(self, input_file, folder_data):
+    def _conc_file(self, f_in, f_out, input_file, input_file_nmbr):
+        """ Conc one file to the output. """
+        # create metadata
+        if input_file_nmbr == 0 and 'format_version' in list(f_in.attrs.keys()):
+            f_out.attrs['format_version'] = f_in.attrs['format_version']
+
+        for folder_name in f_in:
+            if is_folder_ignored(folder_name):
+                # we dont need datasets created by pytables anymore
+                continue
+
+            folder_data = f_in[folder_name][()]
+
+            if input_file_nmbr > 0:
+                # we need to add the current number of the
+                # group_id / index in the file_output to the
+                # group_ids / indices of the file that is to be appended
+                try:
+                    if folder_name.endswith("_indices") and \
+                            "index" in folder_data.dtype.names:
+                        column_name = "index"
+                    elif "group_id" in folder_data.dtype.names:
+                        column_name = "group_id"
+                    else:
+                        column_name = None
+                except TypeError:
+                    column_name = None
+                if column_name is not None:
+                    # add 1 because the group_ids / indices start with 0
+                    folder_data[column_name] += \
+                        np.amax(f_out[folder_name][column_name]) + 1
+
+            if self._modify_folder:
+                data_mody = self._modify(
+                    input_file, folder_data, folder_name)
+                if data_mody is not None:
+                    folder_data = data_mody
+
+            if input_file_nmbr == 0:
+                # first file; create the dataset
+                dset_shape = (self.cumu_rows[-1],) + folder_data.shape[1:]
+                print(f"\tCreating dataset '{folder_name}' with shape "
+                      f"{dset_shape}")
+                output_dataset = f_out.create_dataset(
+                    folder_name,
+                    data=folder_data,
+                    maxshape=dset_shape,
+                    chunks=(self.comptopts["chunksize"],) + folder_data.shape[
+                                                            1:],
+                    compression=self.comptopts["complib"],
+                    compression_opts=self.comptopts["complevel"],
+                    shuffle=self.comptopts["shuffle"],
+                )
+                output_dataset.resize(self.cumu_rows[-1], axis=0)
+
+            else:
+                f_out[folder_name][
+                    self.cumu_rows[input_file_nmbr]:self.cumu_rows[input_file_nmbr + 1]] = folder_data
+
+    def _modify(self, input_file, folder_data, folder_name):
         raise NotImplementedError
 
     def _get_cumu_rows(self):
@@ -152,67 +200,72 @@ class FileConcatenator:
             raise OSError(
                 f"{len(errors)} error(s) during check of files! See above"
             )
+        print("Datasets:\t" + ", ".join(keys_stripped))
         return np.cumsum(rows_per_file)
 
 
-def _get_rows(file_name, keys_stripped):
+def _get_rows(file_name, target_datasets):
+    """ Get no of rows from a file and check if its good for conc'ing. """
     with h5py.File(file_name, 'r') as f:
-        if not all(k in f.keys() for k in keys_stripped):
+        # check if all target datasets are in the file
+        if not all(k in f.keys() for k in target_datasets):
             raise KeyError(
                 f"File {file_name} does not have the "
                 f"keys of the first file! "
-                f"It has {f.keys()} First file: {keys_stripped}")
-        # length of each dataset
-        rows = [f[k].shape[0] for k in keys_stripped]
+                f"It has {f.keys()} First file: {target_datasets}"
+            )
+        # check if all target datasets in the file have the same length
+        rows = [f[k].shape[0] for k in target_datasets]
         if not all(row == rows[0] for row in rows):
             raise ValueError(
                 f"Datasets in file {file_name} have varying length! "
-                f"{dict(zip(keys_stripped, rows))}"
+                f"{dict(zip(target_datasets, rows))}"
             )
-        if not all(k in keys_stripped for k in strip_keys(list(f.keys()))):
+        # check if the file has additional datasets apart from the target keys
+        if not all(k in target_datasets for k in strip_keys(list(f.keys()))):
             warnings.warn(
                 f"Additional datasets found in file {file_name} compared "
                 f"to the first file, they wont be in the output! "
                 f"This file: {strip_keys(list(f.keys()))} "
-                f"First file {keys_stripped}"
+                f"First file {target_datasets}"
             )
     return rows[0]
 
 
 def strip_keys(f_keys):
+    """ Remove unwanted keys from list. """
+    return [x for x in f_keys if not is_folder_ignored(x)]
+
+
+def is_folder_ignored(folder_name):
     """
+    Defines datasets which should be ignored during concat.
+
     Remove pytables folders starting with '_i_', because the shape
     of its first axis does not correspond to the number of events
     in the file. All other folders normally have an axis_0 shape
     that is equal to the number of events in the file.
     Also remove bin_stats.
-    """
-    return [x for x in f_keys if not is_folder_ignored(x)]
 
-
-def is_folder_ignored(folder_name):
-    """
-    Defines pytable folders which should be ignored during concat.
     """
     return '_i_' in folder_name or "bin_stats" in folder_name
 
 
 def get_compopts(file):
     """
-    Extract the following compression options:
+    Get the following compression options from a h5 file as a dict:
 
     complib : str
         Specifies the compression library that should be used for saving
         the concatenated output files.
-        It's read from the first input file.
     complevel : None/int
         Specifies the compression level that should be used for saving
         the concatenated output files.
         A compression level is only available for gzip compression, not lzf!
-        It's read from the first input file.
     chunksize : None/int
         Specifies the chunksize for axis_0 in the concatenated output files.
-        It's read from the first input file.
+    shuffle : bool
+        Enable shuffle filter for chunks.
 
     """
     with h5py.File(file, 'r') as f:
@@ -224,24 +277,46 @@ def get_compopts(file):
         else:
             comptopts["complevel"] = dset.compression_opts
         comptopts["chunksize"] = dset.chunks[0]
+        comptopts["shuffle"] = dset.shuffle
     return comptopts
 
 
-def main():
+def get_parser():
     parser = argparse.ArgumentParser(
-        description='Concatenate many small h5 files to a single large one. '
+        description='Concatenate many small h5 files to a single large one '
+                    'in a km3pipe compatible format. This is intended for '
+                    'files that get generated by orcasong, i.e. all datsets '
+                    'should have the same length, with one row per '
+                    'blob. '
                     'Compression options and the datasets to be created in '
                     'the new file will be read from the first input file.')
     parser.add_argument(
-        'list_file', type=str, help='A txt list of files to concatenate. '
-                                    'One absolute filepath per line. ')
+        'file', type=str, nargs="*",
+        help="Define the files to concatenate. If it's one argument: A txt list "
+             "with pathes of h5 files to concatenate (one path per line). "
+             "If it's multiple arguments: "
+             "The pathes of h5 files to concatenate.")
     parser.add_argument(
-        'output_filepath', type=str, help='The absoulte filepath of the output '
-                                          '.h5 file that will be created. ')
+        '--outfile', type=str, default="concatenated.h5",
+        help='The absoulte filepath of the output .h5 file that will be created. ')
+    parser.add_argument(
+        '--no_used_files', action='store_true',
+        help="Per default, the paths of the input files are added "
+             "as their own datagroup in the output file. Use this flag to "
+             "disable. ")
+    return parser
+
+
+def main():
+    parser = get_parser()
     parsed_args = parser.parse_args()
 
-    fc = FileConcatenator.from_list(parsed_args.list_file)
-    fc.concatenate(parsed_args.output_filepath)
+    if len(parsed_args.file) == 1:
+        fc = FileConcatenator.from_list(parsed_args.file[0])
+    else:
+        fc = FileConcatenator(input_files=parsed_args.file)
+    fc.concatenate(parsed_args.outfile,
+                   append_used_files=not parsed_args.no_used_files)
 
 
 if __name__ == '__main__':
diff --git a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
index 9103dd8..4e0a93c 100644
--- a/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
+++ b/orcasong_contrib/data_tools/shuffle/shuffle_h5.py
@@ -22,7 +22,7 @@ import h5py
 import km3pipe as kp
 import km3modules as km
 from orcasong_contrib.data_tools.concatenate.concatenate_h5 import get_f_compression_and_chunking
-from orcasong.modules import EventSkipper
+import orcasong.modules as os_modules
 
 # from memory_profiler import profile # for memory profiling, call with @profile; myfunc()
 
@@ -220,9 +220,8 @@ def shuffle_h5(filepath_input, tool=False, seed=42, delete=False, chunksize=None
         pipe.attach(km.common.StatusBar, every=200)
         pipe.attach(km.common.MemoryObserver, every=200)
         pipe.attach(kp.io.hdf5.HDF5Pump, filename=filepath_input, shuffle=shuffle, reset_index=True)
-
         if event_skipper is not None:
-            pipe.attach(EventSkipper, event_skipper=event_skipper)
+            pipe.attach(os_modules.EventSkipper, event_skipper=event_skipper)
 
         pipe.attach(kp.io.hdf5.HDF5Sink, filename=filepath_output, complib=complib, complevel=complevel, chunksize=chunksize, flush_frequency=1000)
         pipe.drain()
diff --git a/scraps/misc.py b/scraps/misc.py
deleted file mode 100644
index b3bd005..0000000
--- a/scraps/misc.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import h5py
-import numpy as np
-
-
-def create_dummy_file(filepath, columns=10, val_array=1, val_recarray=(1, 3)):
-    """ Create a dummy h5 file with an array and a recarray in it. """
-    with h5py.File(filepath, "w") as f:
-        f.create_dataset(
-            "numpy_array",
-            data=np.ones(shape=(columns, 7, 3))*val_array,
-            chunks=(5, 7, 3),
-            compression="gzip",
-            compression_opts=1
-        )
-        f.create_dataset(
-            "rec_array",
-            data=np.array([val_recarray] * columns, dtype=[('x', '<f8'), ('y', '<i8')]),
-            chunks=(5,),
-            compression="gzip",
-            compression_opts=1
-        )
diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py
index a8a6257..7e6ca38 100644
--- a/tests/test_concatenate.py
+++ b/tests/test_concatenate.py
@@ -1,6 +1,5 @@
 import tempfile
 from unittest import TestCase
-import os
 import numpy as np
 import h5py
 import orcasong.tools.concatenate as conc
@@ -12,26 +11,30 @@ class TestFileConcatenator(TestCase):
     """
     Test concatenation on pre-generated h5 files. They are in test/data.
 
-    create_dummy_file(
-        "dummy_file_1.h5", columns=10, val_array=1, val_recarray=(1, 3)
-    )
-    create_dummy_file(
-        "dummy_file_2.h5", columns=15, val_array=2, val_recarray=(4, 5)
-    )
-
     """
-    def setUp(self):
-        # the files to test on
-        data_dir = os.path.join(os.path.dirname(__file__), "data")
-        self.dummy_files = (
-            os.path.join(data_dir, "dummy_file_1.h5"),  # 10 columns
-            os.path.join(data_dir, "dummy_file_2.h5"),  # 15 columns
+    @classmethod
+    def setUpClass(cls):
+        cls.dummy_file_1 = tempfile.NamedTemporaryFile()
+        _create_dummy_file(
+            cls.dummy_file_1, columns=10, val_array=1, val_recarray=(1, 3)
+        )
+        cls.dummy_file_2 = tempfile.NamedTemporaryFile()
+        _create_dummy_file(
+            cls.dummy_file_2, columns=15, val_array=2, val_recarray=(4, 5)
+        )
+        cls.dummy_files = (
+            cls.dummy_file_1.name,
+            cls.dummy_file_2.name,
         )
-        # their compression opts
-        self.compt_opts = {
-            'complib': 'gzip', 'complevel': 1, 'chunksize': 5
+        cls.compt_opts = {
+            'complib': 'gzip', 'complevel': 1, 'chunksize': 5, "shuffle": False
         }
 
+    @classmethod
+    def tearDownClass(cls):
+        cls.dummy_file_1.close()
+        cls.dummy_file_2.close()
+
     def test_from_list(self):
         with tempfile.NamedTemporaryFile("w+") as tf:
             tf.writelines([f + "\n" for f in self.dummy_files])
@@ -47,6 +50,12 @@ class TestFileConcatenator(TestCase):
         fc = conc.FileConcatenator(self.dummy_files)
         self.assertDictEqual(fc.comptopts, self.compt_opts)
 
+    def test_fc_get_comptopts_updates(self):
+        fc = conc.FileConcatenator(self.dummy_files, comptopts_update={'chunksize': 1})
+        target_compt_opts = dict(self.compt_opts)
+        target_compt_opts["chunksize"] = 1
+        self.assertDictEqual(fc.comptopts, target_compt_opts)
+
     def test_get_cumu_rows(self):
         fc = conc.FileConcatenator(self.dummy_files)
         np.testing.assert_array_equal(fc.cumu_rows, [0, 10, 25])
@@ -73,18 +82,43 @@ class TestFileConcatenator(TestCase):
                     f["numpy_array"][()]
                 )
 
-    def test_concatenate_recarray(self):
+    def test_concatenate_recarray_with_groupid(self):
         fc = conc.FileConcatenator(self.dummy_files)
         with tempfile.TemporaryFile() as tf:
             fc.concatenate(tf)
             with h5py.File(tf) as f:
                 target = np.array(
-                    [(1, 3)] * 25,
-                    dtype=[('x', '<f8'), ('y', '<i8')]
+                    [(1, 3, 1)] * 25,
+                    dtype=[('x', '<f8'), ('y', '<i8'), ("group_id", "<i8")]
                 )
                 target["x"][10:] = 4.
                 target["y"][10:] = 5.
+                target["group_id"] = np.arange(25)
                 np.testing.assert_array_equal(
                     target,
                     f["rec_array"][()]
-                )
\ No newline at end of file
+                )
+
+
+def _create_dummy_file(filepath, columns=10, val_array=1, val_recarray=(1, 3)):
+    """ Create a dummy h5 file with an array and a recarray in it. """
+    with h5py.File(filepath, "w") as f:
+        f.create_dataset(
+            "numpy_array",
+            data=np.ones(shape=(columns, 7, 3))*val_array,
+            chunks=(5, 7, 3),
+            compression="gzip",
+            compression_opts=1
+        )
+        rec_array = np.array(
+            [val_recarray + (1, )] * columns,
+            dtype=[('x', '<f8'), ('y', '<i8'), ("group_id", "<i8")]
+        )
+        rec_array["group_id"] = np.arange(columns)
+        f.create_dataset(
+            "rec_array",
+            data=rec_array,
+            chunks=(5,),
+            compression="gzip",
+            compression_opts=1
+        )
-- 
GitLab