From 069a26ad097d2c358ca44be0b5e4c8f07bd7752b Mon Sep 17 00:00:00 2001
From: Stefan Reck <stefan.reck@fau.de>
Date: Wed, 7 Jul 2021 16:57:47 +0200
Subject: [PATCH] added test for indexed shuffle2

---
 orcasong/tools/shuffle2.py |  4 ++--
 tests/test_concatenate.py  | 41 ++++++++++++++++++++------------------
 tests/test_postproc.py     | 35 ++++++++++++++++++++++++++++++++
 3 files changed, 59 insertions(+), 21 deletions(-)

diff --git a/orcasong/tools/shuffle2.py b/orcasong/tools/shuffle2.py
index e0c3f2d..dcd9ed0 100644
--- a/orcasong/tools/shuffle2.py
+++ b/orcasong/tools/shuffle2.py
@@ -145,9 +145,9 @@ def get_n_iterations(
         max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
     with h5py.File(input_file, "r") as f_in:
         dset_info = _get_largest_dset(f_in, datasets, max_ram)
-    n_iterations = int(
+    n_iterations = np.amax((1, int(
         np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
-    )
+    )))
     print(f"Largest dataset: {dset_info['name']}")
     print(f"Total chunks: {dset_info['n_chunks']}")
     print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}")
diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py
index 302878f..3a417a2 100644
--- a/tests/test_concatenate.py
+++ b/tests/test_concatenate.py
@@ -1,5 +1,5 @@
 import tempfile
-from unittest import TestCase
+import unittest
 import numpy as np
 import h5py
 import orcasong.tools.concatenate as conc
@@ -8,7 +8,7 @@ import os
 __author__ = 'Stefan Reck'
 
 
-class TestFileConcatenator(TestCase):
+class TestFileConcatenator(unittest.TestCase):
     """
     Test concatenation on pre-generated h5 files. They are in tests/data.
 
@@ -117,20 +117,27 @@ class TestFileConcatenator(TestCase):
                 )
 
 
-class TestConcatenateIndexed(TestCase):
-    @classmethod
-    def setUpClass(cls) -> None:
-        cls.infile = tempfile.NamedTemporaryFile()
-        with h5py.File(cls.infile, "w") as f:
-            cls.x = np.arange(20)
-            dset_x = f.create_dataset("x", data=cls.x, chunks=True)
-            dset_x.attrs.create("indexed", True)
-            cls.indices = np.array(
-                [(0, 5), (5, 12), (17, 3)],
-                dtype=[('index', '<i8'), ('n_items', '<i8')]
-            )
-            f.create_dataset("x_indices", data=cls.indices, chunks=True)
+class BaseTestClass:
+    class BaseIndexedFile(unittest.TestCase):
+        @classmethod
+        def setUpClass(cls) -> None:
+            cls.infile = tempfile.NamedTemporaryFile()
+            with h5py.File(cls.infile, "w") as f:
+                cls.x = np.arange(20)
+                dset_x = f.create_dataset("x", data=cls.x, chunks=True)
+                dset_x.attrs.create("indexed", True)
+                cls.indices = np.array(
+                    [(0, 5), (5, 12), (17, 3)],
+                    dtype=[('index', '<i8'), ('n_items', '<i8')]
+                )
+                f.create_dataset("x_indices", data=cls.indices, chunks=True)
+
+        @classmethod
+        def tearDownClass(cls) -> None:
+            cls.infile.close()
 
+
+class TestConcatenateIndexed(BaseTestClass.BaseIndexedFile):
     def setUp(self) -> None:
         self.outfile = "temp_out.h5"
         conc.concatenate([self.infile.name] * 2, outfile=self.outfile)
@@ -139,10 +146,6 @@ class TestConcatenateIndexed(TestCase):
         if os.path.exists(self.outfile):
             os.remove(self.outfile)
 
-    @classmethod
-    def tearDownClass(cls) -> None:
-        cls.infile.close()
-
     def test_check_x(self):
         with h5py.File(self.outfile) as f_out:
             np.testing.assert_array_equal(
diff --git a/tests/test_postproc.py b/tests/test_postproc.py
index dea566b..0efac09 100644
--- a/tests/test_postproc.py
+++ b/tests/test_postproc.py
@@ -4,6 +4,7 @@ import h5py
 import numpy as np
 import orcasong.tools.postproc as postproc
 import orcasong.tools.shuffle2 as shuffle2
+from .test_concatenate import BaseTestClass
 
 __author__ = 'Stefan Reck'
 
@@ -90,6 +91,40 @@ class TestShuffleV2(TestCase):
                 os.remove(fname)
 
 
+class TestShuffleIndexed(BaseTestClass.BaseIndexedFile):
+    def setUp(self) -> None:
+        self.outfile = "temp_out.h5"
+        shuffle2.h5shuffle2(
+            self.infile.name,
+            output_file=self.outfile,
+            datasets=("x",),
+            seed=2,
+        )
+
+    def tearDown(self) -> None:
+        if os.path.exists(self.outfile):
+            os.remove(self.outfile)
+
+    def test_check_x(self):
+        with h5py.File(self.outfile) as f_out:
+            np.testing.assert_array_equal(
+                f_out["x"],
+                np.concatenate([np.arange(17, 20), np.arange(5, 17), np.arange(0, 5)])
+            )
+
+    def test_check_x_indices_n_items(self):
+        with h5py.File(self.outfile) as f_out:
+            target_n_items = np.array([3, 12, 5])
+            np.testing.assert_array_equal(
+                f_out["x_indices"]["n_items"], target_n_items)
+
+    def test_check_x_indices_index(self):
+        with h5py.File(self.outfile) as f_out:
+            target_index = np.array([0, 3, 15])
+            np.testing.assert_array_equal(
+                f_out["x_indices"]["index"], target_index)
+
+
 def _make_shuffle_dummy_file(filepath):
     x = np.random.rand(22, 2)
     x[:, 0] = np.arange(22)
-- 
GitLab