|
4 | 4 | import aesara |
5 | 5 | import aesara.tensor as at |
6 | 6 | from aesara.configdefaults import config |
| 7 | +from aesara.tensor.basic import Tri |
7 | 8 | from aesara.tensor.blockwise import ( |
8 | 9 | Blockwise, |
9 | 10 | _calculate_shapes, |
|
14 | 15 | ) |
15 | 16 | from aesara.tensor.math import Dot |
16 | 17 | from aesara.tensor.nlinalg import Det |
17 | | -from aesara.tensor.slinalg import Cholesky, Solve |
| 18 | +from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular |
18 | 19 | from aesara.tensor.type import TensorType |
19 | 20 | from tests import unittest_tools as utt |
20 | 21 | from tests.unittest_tools import check_infer_shape, verify_grad |
@@ -100,6 +101,12 @@ def test_parse_input_dimensions(args, arg_vals, input_core_dims, output_core_dim |
100 | 101 | (np.zeros((5, 3, 3)),), |
101 | 102 | lambda x: np.linalg.det(x), |
102 | 103 | ), |
| 104 | + ( |
| 105 | + Tri(), |
| 106 | + (at.scalar(), at.scalar(), at.scalar()), |
| 107 | + (3, 4, 0), |
| 108 | + lambda n, m, k: np.tri(n, m, k), |
| 109 | + ), |
103 | 110 | ], |
104 | 111 | ) |
105 | 112 | def test_Blockwise_perform(op, args, arg_vals, np_fn): |
@@ -261,3 +268,25 @@ def test_Blockwise_get_output_info(): |
261 | 268 | out_dtype, output_shapes, inputs = blk_op.get_output_info(a, b, c) |
262 | 269 |
|
263 | 270 | assert out_dtype == ["float64"] |
| 271 | + |
| 272 | + |
| 273 | +@pytest.mark.parametrize( |
| 274 | + "a_shape, b_shape", |
| 275 | + [ |
| 276 | + ( |
| 277 | + (3, 3), |
| 278 | + (3, 1), |
| 279 | + ) |
| 280 | + ], |
| 281 | +) |
| 282 | +def test_blockwise_SolveTriangular_grad(a_shape, b_shape): |
| 283 | + rng = np.random.default_rng(utt.fetch_seed()) |
| 284 | + A_val = (rng.normal(size=a_shape) * 0.5 + np.eye(a_shape[-1])).astype(config.floatX) |
| 285 | + b_val = rng.normal(size=b_shape).astype(config.floatX) |
| 286 | + |
| 287 | + eps = None |
| 288 | + if config.floatX == "float64": |
| 289 | + eps = 2e-8 |
| 290 | + |
| 291 | + solve_op = Blockwise(SolveTriangular()) |
| 292 | + verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) |
0 commit comments