|
1 | 1 | import contextlib |
2 | | -import inspect |
3 | 2 | from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union |
4 | | -from unittest import mock |
5 | 3 |
|
6 | 4 | import numba |
7 | 5 | import numpy as np |
@@ -108,73 +106,15 @@ def compare_shape_dtype(x, y): |
108 | 106 | def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): |
109 | 107 | """Evaluate the Numba implementation in pure Python for coverage purposes.""" |
110 | 108 |
|
111 | | - def py_tuple_setitem(t, i, v): |
112 | | - ll = list(t) |
113 | | - ll[i] = v |
114 | | - return tuple(ll) |
115 | | - |
116 | | - def py_to_scalar(x): |
117 | | - if isinstance(x, np.ndarray): |
118 | | - return x.item() |
119 | | - else: |
120 | | - return x |
121 | | - |
122 | | - def njit_noop(*args, **kwargs): |
123 | | - if len(args) == 1 and callable(args[0]): |
124 | | - return args[0] |
125 | | - else: |
126 | | - return lambda x: x |
127 | | - |
128 | | - def vectorize_noop(*args, **kwargs): |
129 | | - def wrap(fn): |
130 | | - # `numba.vectorize` allows an `out` positional argument. We need |
131 | | - # to account for that |
132 | | - sig = inspect.signature(fn) |
133 | | - nparams = len(sig.parameters) |
134 | | - |
135 | | - def inner_vec(*args): |
136 | | - if len(args) > nparams: |
137 | | - # An `out` argument has been specified for an in-place |
138 | | - # operation |
139 | | - out = args[-1] |
140 | | - out[...] = np.vectorize(fn)(*args[:nparams]) |
141 | | - return out |
142 | | - else: |
143 | | - return np.vectorize(fn)(*args) |
144 | | - |
145 | | - return inner_vec |
146 | | - |
147 | | - if len(args) == 1 and callable(args[0]): |
148 | | - return wrap(args[0], **kwargs) |
149 | | - else: |
150 | | - return wrap |
151 | | - |
152 | | - mocks = [ |
153 | | - mock.patch("numba.njit", njit_noop), |
154 | | - mock.patch("numba.vectorize", vectorize_noop), |
155 | | - mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem), |
156 | | - mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop), |
157 | | - mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop), |
158 | | - mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x), |
159 | | - mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar), |
160 | | - mock.patch( |
161 | | - "aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", |
162 | | - lambda dtype: dtype, |
163 | | - ), |
164 | | - mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)), |
165 | | - ] |
166 | | - |
167 | | - with contextlib.ExitStack() as stack: |
168 | | - for ctx in mocks: |
169 | | - stack.enter_context(ctx) |
170 | | - |
171 | | - aesara_numba_fn = function( |
172 | | - fn_inputs, |
173 | | - fn_outputs, |
174 | | - mode=mode, |
175 | | - accept_inplace=True, |
176 | | - ) |
177 | | - _ = aesara_numba_fn(*inputs) |
| 109 | + numba.config.DISABLE_JIT = True |
| 110 | + aesara_numba_fn = function( |
| 111 | + fn_inputs, |
| 112 | + fn_outputs, |
| 113 | + mode=mode, |
| 114 | + accept_inplace=True, |
| 115 | + ) |
| 116 | + _ = aesara_numba_fn(*inputs) |
| 117 | + numba.config.DISABLE_JIT = False |
178 | 118 |
|
179 | 119 |
|
180 | 120 | def compare_numba_and_py( |
|
0 commit comments