diff --git a/km3io/tools.py b/km3io/tools.py index 4cbe3bdd3c630b76d506d421514adf56016e4d96..5abeba7884f7232c079e661ffc62f3609a2a2efc 100644 --- a/km3io/tools.py +++ b/km3io/tools.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +import numba as nb +import numpy as np import uproot + # 110 MB based on the size of the largest basket found so far in km3net BASKET_CACHE_SIZE = 110 * 1024**2 BASKET_CACHE = uproot.cache.ThreadSafeArrayCache(BASKET_CACHE_SIZE) @@ -40,3 +43,45 @@ def to_num(value): except (ValueError, TypeError): pass return value + + +@nb.jit(nopython=True) +def unique(array, dtype=np.int64): + """Return the unique elements of an array with a given dtype. + + The performance is better for pre-sorted input arrays. + + """ + n = len(array) + out = np.empty(n, dtype) + last = array[0] + entry_idx = 0 + out[entry_idx] = last + for i in range(1, n): + current = array[i] + if current == last: # shortcut for sorted arrays + continue + already_present = False + for j in range(entry_idx + 1): + if current == out[j]: + already_present = True + break + if not already_present: + entry_idx += 1 + out[entry_idx] = current + last = current + return out[:entry_idx+1] + + +@nb.jit(nopython=True) +def uniquecount(array, dtype=np.int64): + """Count the number of unique elements in a jagged Awkward1 array.""" + n = len(array) + out = np.empty(n, dtype) + for i in range(n): + sub_array = array[i] + if len(sub_array) == 0: + out[i] = 0 + else: + out[i] = len(unique(sub_array)) + return out diff --git a/tests/test_tools.py b/tests/test_tools.py index 02467031fb09b01107931dda6165df70d9b2f3d8..9d0517a889ab0cb8e97c5eb252a8256d52d6eba3 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,7 +1,66 @@ #!/usr/bin/env python3 import unittest -from km3io.tools import to_num, cached_property, unfold_indices +import awkward1 as ak +import numpy as np +from km3io.tools import (to_num, cached_property, unfold_indices, unique, + uniquecount) + + +class TestUnique(unittest.TestCase): + def run_random_test_with_dtype(self, dtype): + max_range = 100 + for i in range(23): + low = np.random.randint(0, max_range) + high = np.random.randint(low + 1, low + 2 + np.random.randint(max_range)) + n = np.random.randint(max_range) + arr = np.random.randint(low, high, n).astype(dtype) + np_reference = np.sort(np.unique(arr)) + result = np.sort(unique(arr, dtype=dtype)) + try: + np.allclose(np_reference, result, atol=1e-1) + except ValueError: + print("low:", low) + print("high:", high) + print("n:", n) + print("arr =", list(arr)) + print("np.unique(arr) =", np_reference) + print("unique(arr) =", result) + assert False + + def test_unique_with_dtype_int8(self): + self.run_random_test_with_dtype(np.int8) + + def test_unique_with_dtype_int16(self): + self.run_random_test_with_dtype(np.int16) + + def test_unique_with_dtype_int32(self): + self.run_random_test_with_dtype(np.int32) + + def test_unique_with_dtype_int64(self): + self.run_random_test_with_dtype(np.int64) + + def test_unique_with_dtype_uint8(self): + self.run_random_test_with_dtype(np.uint8) + + def test_unique_with_dtype_uint16(self): + self.run_random_test_with_dtype(np.uint16) + + def test_unique_with_dtype_uint32(self): + self.run_random_test_with_dtype(np.uint32) + + def test_unique_with_dtype_uint64(self): + self.run_random_test_with_dtype(np.uint64) + + +class TestUniqueCount(unittest.TestCase): + def test_uniquecount(self): + arr = ak.Array([[1, 2, 3], [2, 2, 2], [3, 4, 5, 6, 6], [4, 4, 3, 1]]) + assert np.allclose([3, 1, 4, 3], uniquecount(arr)) + + def test_uniquecount_with_empty_subarrays(self): + arr = ak.Array([[1, 2, 3], [2, 2, 2], [], [4, 4, 3, 1]]) + assert np.allclose([3, 1, 0, 3], uniquecount(arr)) class TestToNum(unittest.TestCase):