Skip to content
Snippets Groups Projects
Commit 069a26ad authored by Stefan Reck's avatar Stefan Reck
Browse files

added test for indexed shuffle2

parent 28e8302a
No related branches found
Tags v1.0.2
1 merge request!23Jagged graph
...@@ -145,9 +145,9 @@ def get_n_iterations( ...@@ -145,9 +145,9 @@ def get_n_iterations(
max_ram = get_max_ram(max_ram_fraction=max_ram_fraction) max_ram = get_max_ram(max_ram_fraction=max_ram_fraction)
with h5py.File(input_file, "r") as f_in: with h5py.File(input_file, "r") as f_in:
dset_info = _get_largest_dset(f_in, datasets, max_ram) dset_info = _get_largest_dset(f_in, datasets, max_ram)
n_iterations = int( n_iterations = np.amax((1, int(
np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"])) np.ceil(np.log(dset_info["n_chunks"]) / np.log(dset_info["chunks_per_batch"]))
) )))
print(f"Largest dataset: {dset_info['name']}") print(f"Largest dataset: {dset_info['name']}")
print(f"Total chunks: {dset_info['n_chunks']}") print(f"Total chunks: {dset_info['n_chunks']}")
print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}") print(f"Max. chunks per batch: {dset_info['chunks_per_batch']}")
......
import tempfile import tempfile
from unittest import TestCase import unittest
import numpy as np import numpy as np
import h5py import h5py
import orcasong.tools.concatenate as conc import orcasong.tools.concatenate as conc
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
__author__ = 'Stefan Reck' __author__ = 'Stefan Reck'
class TestFileConcatenator(TestCase): class TestFileConcatenator(unittest.TestCase):
""" """
Test concatenation on pre-generated h5 files. They are in tests/data. Test concatenation on pre-generated h5 files. They are in tests/data.
...@@ -117,20 +117,27 @@ class TestFileConcatenator(TestCase): ...@@ -117,20 +117,27 @@ class TestFileConcatenator(TestCase):
) )
class TestConcatenateIndexed(TestCase): class BaseTestClass:
@classmethod class BaseIndexedFile(unittest.TestCase):
def setUpClass(cls) -> None: @classmethod
cls.infile = tempfile.NamedTemporaryFile() def setUpClass(cls) -> None:
with h5py.File(cls.infile, "w") as f: cls.infile = tempfile.NamedTemporaryFile()
cls.x = np.arange(20) with h5py.File(cls.infile, "w") as f:
dset_x = f.create_dataset("x", data=cls.x, chunks=True) cls.x = np.arange(20)
dset_x.attrs.create("indexed", True) dset_x = f.create_dataset("x", data=cls.x, chunks=True)
cls.indices = np.array( dset_x.attrs.create("indexed", True)
[(0, 5), (5, 12), (17, 3)], cls.indices = np.array(
dtype=[('index', '<i8'), ('n_items', '<i8')] [(0, 5), (5, 12), (17, 3)],
) dtype=[('index', '<i8'), ('n_items', '<i8')]
f.create_dataset("x_indices", data=cls.indices, chunks=True) )
f.create_dataset("x_indices", data=cls.indices, chunks=True)
@classmethod
def tearDownClass(cls) -> None:
cls.infile.close()
class TestConcatenateIndexed(BaseTestClass.BaseIndexedFile):
def setUp(self) -> None: def setUp(self) -> None:
self.outfile = "temp_out.h5" self.outfile = "temp_out.h5"
conc.concatenate([self.infile.name] * 2, outfile=self.outfile) conc.concatenate([self.infile.name] * 2, outfile=self.outfile)
...@@ -139,10 +146,6 @@ class TestConcatenateIndexed(TestCase): ...@@ -139,10 +146,6 @@ class TestConcatenateIndexed(TestCase):
if os.path.exists(self.outfile): if os.path.exists(self.outfile):
os.remove(self.outfile) os.remove(self.outfile)
@classmethod
def tearDownClass(cls) -> None:
cls.infile.close()
def test_check_x(self): def test_check_x(self):
with h5py.File(self.outfile) as f_out: with h5py.File(self.outfile) as f_out:
np.testing.assert_array_equal( np.testing.assert_array_equal(
......
...@@ -4,6 +4,7 @@ import h5py ...@@ -4,6 +4,7 @@ import h5py
import numpy as np import numpy as np
import orcasong.tools.postproc as postproc import orcasong.tools.postproc as postproc
import orcasong.tools.shuffle2 as shuffle2 import orcasong.tools.shuffle2 as shuffle2
from .test_concatenate import BaseTestClass
__author__ = 'Stefan Reck' __author__ = 'Stefan Reck'
...@@ -90,6 +91,40 @@ class TestShuffleV2(TestCase): ...@@ -90,6 +91,40 @@ class TestShuffleV2(TestCase):
os.remove(fname) os.remove(fname)
class TestShuffleIndexed(BaseTestClass.BaseIndexedFile):
def setUp(self) -> None:
self.outfile = "temp_out.h5"
shuffle2.h5shuffle2(
self.infile.name,
output_file=self.outfile,
datasets=("x",),
seed=2,
)
def tearDown(self) -> None:
if os.path.exists(self.outfile):
os.remove(self.outfile)
def test_check_x(self):
with h5py.File(self.outfile) as f_out:
np.testing.assert_array_equal(
f_out["x"],
np.concatenate([np.arange(17, 20), np.arange(5, 17), np.arange(0, 5)])
)
def test_check_x_indices_n_items(self):
with h5py.File(self.outfile) as f_out:
target_n_items = np.array([3, 12, 5])
np.testing.assert_array_equal(
f_out["x_indices"]["n_items"], target_n_items)
def test_check_x_indices_index(self):
with h5py.File(self.outfile) as f_out:
target_index = np.array([0, 3, 15])
np.testing.assert_array_equal(
f_out["x_indices"]["index"], target_index)
def _make_shuffle_dummy_file(filepath): def _make_shuffle_dummy_file(filepath):
x = np.random.rand(22, 2) x = np.random.rand(22, 2)
x[:, 0] = np.arange(22) x[:, 0] = np.arange(22)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment