Skip to content
Snippets Groups Projects
Commit 6f2d2946 authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Merge branch 'numba-optional' into 'master'

Allow disabling numba jitted functions

See merge request !18
parents ba8a313d e6de27dc
No related branches found
No related tags found
1 merge request!18Allow disabling numba jitted functions
Pipeline #8920 passed with warnings
......@@ -62,6 +62,14 @@ test-py3.8:
- make test
<<: *junit_definition
test-py3.8-no-numba:
image: docker.km3net.de/base/python:3.8
stage: test
script:
- *virtualenv_definition
- DISABLE_NUMBA=1 make test
<<: *junit_definition
code-style:
image: docker.km3net.de/base/python:3.7
stage: test
......
import os
import uproot
import numpy as np
import numba as nb
if os.getenv("DISABLE_NUMBA"):
print("Numba is disabled, DAQ helper functions will not work!")
# A hack to to get the @vectorize, @guvectorize and nb.types silently pass.
def dummy_decorator(*args, **kwargs):
def decorator(f):
def wrapper(*args, **kwargs):
return dummy_decorator(*args, **kwargs)
return wrapper
return decorator
vectorize = dummy_decorator
guvectorize = dummy_decorator
int8 = int16 = int32 = int64 = dummy_decorator
else:
from numba import vectorize, guvectorize, int8, int16, int32, int64
TIMESLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte]
SUMMARYSLICE_FRAME_BASKET_CACHE_SIZE = 523 * 1024**2 # [byte]
......@@ -16,12 +32,7 @@ RATE_FACTOR = np.log(MAXIMAL_RATE_HZ / MINIMAL_RATE_HZ) / 255
CHANNEL_BITS_TEMPLATE = np.zeros(31, dtype=bool)
@nb.vectorize([
nb.int32(nb.int8),
nb.int32(nb.int16),
nb.int32(nb.int32),
nb.int32(nb.int64)
])
@vectorize([int32(int8), int32(int16), int32(int32), int32(int64)])
def get_rate(value): #pragma: no cover
"""Return the rate in Hz from the short int value"""
if value == 0:
......@@ -30,10 +41,10 @@ def get_rate(value): #pragma: no cover
return MINIMAL_RATE_HZ * np.exp(value * RATE_FACTOR)
@nb.guvectorize("void(i8, b1[:], b1[:])",
"(), (n) -> (n)",
target="parallel",
nopython=True)
@guvectorize("void(i8, b1[:], b1[:])",
"(), (n) -> (n)",
target="parallel",
nopython=True)
def unpack_bits(value, bits_template, out): #pragma: no cover
"""Return a boolean array for a value's bit representation.
......
......@@ -175,6 +175,7 @@ class TestSummaryslices(unittest.TestCase):
def test_rates(self):
assert 3 == len(self.ss.rates)
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_fifo(self):
s = self.ss.slices[0]
dct_fifo_stat = {
......@@ -187,6 +188,7 @@ class TestSummaryslices(unittest.TestCase):
frame = s[s.dom_id == dom_id]
assert any(get_channel_flags(frame.fifo[0])) == fifo_status
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_has_udp_trailer(self):
s = self.ss.slices[0]
dct_udp_trailer = {
......@@ -206,6 +208,7 @@ class TestSummaryslices(unittest.TestCase):
frame = s[s.dom_id == dom_id]
assert has_udp_trailer(frame.fifo[0]) == udp_trailer
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_high_rate_veto(self):
s = self.ss.slices[0]
dct_high_rate_veto = {
......@@ -246,6 +249,7 @@ class TestSummaryslices(unittest.TestCase):
frame = s[s.dom_id == dom_id]
assert any(get_channel_flags(frame.hrv[0])) == high_rate_veto
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_max_sequence_number(self):
s = self.ss.slices[0]
dct_seq_numbers = {
......@@ -283,6 +287,7 @@ class TestSummaryslices(unittest.TestCase):
assert get_udp_max_sequence_number(
frame.dq_status[0]) == max_sequence_number
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_number_udp_packets(self):
s = self.ss.slices[0]
dct_n_packets = {
......@@ -312,6 +317,7 @@ class TestSummaryslices(unittest.TestCase):
frame = s[s.dom_id == dom_id]
assert get_number_udp_packets(frame.dq_status[0]) == n_udp_packets
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_hrv_flags(self):
s = self.ss.slices[0]
dct_hrv_flags = {
......@@ -347,6 +353,7 @@ class TestSummaryslices(unittest.TestCase):
for a, b in zip(get_channel_flags(frame.hrv[0]), hrv_flags)
])
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_fifo_flags(self):
s = self.ss.slices[0]
dct_fifo_flags = {
......@@ -399,14 +406,17 @@ class TestSummaryslices(unittest.TestCase):
class TestGetRate(unittest.TestCase):
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_zero(self):
assert 0 == get_rate(0)
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_some_values(self):
assert 2054 == get_rate(1)
assert 55987 == get_rate(123)
assert 1999999 == get_rate(255)
@unittest.skipIf(os.getenv("DISABLE_NUMBA"), reason="no numba")
def test_vectorized_input(self):
self.assertListEqual([2054], list(get_rate([1])))
self.assertListEqual([2054, 2111, 2169], list(get_rate([1, 2, 3])))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment