Skip to content

Commit 3ed3497

Browse files
committed
add test for Blockwise SolveTriangular
1 parent 684914d commit 3ed3497

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

aesara/tensor/slinalg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ class SolveTriangular(SolveBase):
294294
"trans",
295295
"unit_diagonal",
296296
"check_finite",
297-
"gufunc_sig",
298297
)
299298

300299
def __init__(

tests/tensor/test_blockwise.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import aesara
55
import aesara.tensor as at
66
from aesara.configdefaults import config
7+
from aesara.tensor.basic import Tri
78
from aesara.tensor.blockwise import (
89
Blockwise,
910
_calculate_shapes,
@@ -14,7 +15,7 @@
1415
)
1516
from aesara.tensor.math import Dot
1617
from aesara.tensor.nlinalg import Det
17-
from aesara.tensor.slinalg import Cholesky, Solve
18+
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
1819
from aesara.tensor.type import TensorType
1920
from tests import unittest_tools as utt
2021
from tests.unittest_tools import check_infer_shape, verify_grad
@@ -100,6 +101,12 @@ def test_parse_input_dimensions(args, arg_vals, input_core_dims, output_core_dim
100101
(np.zeros((5, 3, 3)),),
101102
lambda x: np.linalg.det(x),
102103
),
104+
(
105+
Tri(),
106+
(at.scalar(), at.scalar(), at.scalar()),
107+
(3, 4, 0),
108+
lambda n, m, k: np.tri(n, m, k),
109+
),
103110
],
104111
)
105112
def test_Blockwise_perform(op, args, arg_vals, np_fn):
@@ -261,3 +268,25 @@ def test_Blockwise_get_output_info():
261268
out_dtype, output_shapes, inputs = blk_op.get_output_info(a, b, c)
262269

263270
assert out_dtype == ["float64"]
271+
272+
273+
@pytest.mark.parametrize(
274+
"a_shape, b_shape",
275+
[
276+
(
277+
(3, 3),
278+
(3, 1),
279+
)
280+
],
281+
)
282+
def test_blockwise_SolveTriangular_grad(a_shape, b_shape):
283+
rng = np.random.default_rng(utt.fetch_seed())
284+
A_val = (rng.normal(size=a_shape) * 0.5 + np.eye(a_shape[-1])).astype(config.floatX)
285+
b_val = rng.normal(size=b_shape).astype(config.floatX)
286+
287+
eps = None
288+
if config.floatX == "float64":
289+
eps = 2e-8
290+
291+
solve_op = Blockwise(SolveTriangular())
292+
verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)

0 commit comments

Comments
 (0)