Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 9 additions & 69 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import contextlib
import inspect
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from unittest import mock

import numba
import numpy as np
Expand Down Expand Up @@ -108,73 +106,15 @@ def compare_shape_dtype(x, y):
def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
"""Evaluate the Numba implementation in pure Python for coverage purposes."""

def py_tuple_setitem(t, i, v):
ll = list(t)
ll[i] = v
return tuple(ll)

def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x

def njit_noop(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
else:
return lambda x: x

def vectorize_noop(*args, **kwargs):
def wrap(fn):
# `numba.vectorize` allows an `out` positional argument. We need
# to account for that
sig = inspect.signature(fn)
nparams = len(sig.parameters)

def inner_vec(*args):
if len(args) > nparams:
# An `out` argument has been specified for an in-place
# operation
out = args[-1]
out[...] = np.vectorize(fn)(*args[:nparams])
return out
else:
return np.vectorize(fn)(*args)

return inner_vec

if len(args) == 1 and callable(args[0]):
return wrap(args[0], **kwargs)
else:
return wrap

mocks = [
mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem),
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch(
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
),
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
]

with contextlib.ExitStack() as stack:
for ctx in mocks:
stack.enter_context(ctx)

aesara_numba_fn = function(
fn_inputs,
fn_outputs,
mode=mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
numba.config.DISABLE_JIT = True
aesara_numba_fn = function(
fn_inputs,
fn_outputs,
mode=mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
numba.config.DISABLE_JIT = False


def compare_numba_and_py(
Expand Down