Skip to content

Commit c0a17bc

Browse files
Skip some tests on Gen12 (#120)
Co-authored-by: reazul.hoque <[email protected]> Co-authored-by: Sergey Pokhodenko <[email protected]>
1 parent d8345cf commit c0a17bc

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

numba_dppy/tests/skip_tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import dpctl
2+
3+
def is_gen12(device_type):
4+
with dpctl.device_context(device_type):
5+
q = dpctl.get_current_queue()
6+
device = q.get_sycl_device()
7+
name = device.get_device_name()
8+
if "Gen12" in name:
9+
return True
10+
11+
return False

numba_dppy/tests/test_numpy_math_functions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numba import njit
44
import dpctl
55
import unittest
6-
6+
from . import skip_tests
77

88
@unittest.skipUnless(dpctl.has_gpu_queues(), 'test only on GPU system')
99
class TestNumpy_math_functions(unittest.TestCase):
@@ -179,6 +179,7 @@ def f(a):
179179

180180
self.assertTrue(np.all(c == -input_arr))
181181

182+
@unittest.skipIf(skip_tests.is_gen12("opencl:gpu"), "Gen12 not supported")
182183
def test_sign(self):
183184
@njit
184185
def f(a):
@@ -221,6 +222,7 @@ def f(a):
221222
max_abs_err = c.sum() - d.sum()
222223
self.assertTrue(max_abs_err < 1e-5)
223224

225+
@unittest.skipIf(skip_tests.is_gen12("opencl:gpu"), "Gen12 not supported")
224226
def test_log(self):
225227
@njit
226228
def f(a):
@@ -236,6 +238,7 @@ def f(a):
236238
max_abs_err = c.sum() - d.sum()
237239
self.assertTrue(max_abs_err < 1e-5)
238240

241+
@unittest.skipIf(skip_tests.is_gen12("opencl:gpu"), "Gen12 not supported")
239242
def test_log10(self):
240243
@njit
241244
def f(a):
@@ -251,6 +254,7 @@ def f(a):
251254
max_abs_err = c.sum() - d.sum()
252255
self.assertTrue(max_abs_err < 1e-5)
253256

257+
@unittest.skipIf(skip_tests.is_gen12("opencl:gpu"), "Gen12 not supported")
254258
def test_expm1(self):
255259
@njit
256260
def f(a):

numba_dppy/tests/test_numpy_trigonomteric_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numba import njit
44
import dpctl
55
import unittest
6+
from . import skip_tests
67

78

89
@unittest.skipUnless(dpctl.has_gpu_queues(), 'test only on GPU system')
@@ -155,6 +156,7 @@ def f(a):
155156
max_abs_err = c.sum() - d.sum()
156157
self.assertTrue(max_abs_err < 1e-5)
157158

159+
@unittest.skipIf(skip_tests.is_gen12("opencl:gpu"), "Gen12 not supported")
158160
def test_arccosh(self):
159161
@njit
160162
def f(a):

0 commit comments

Comments
 (0)