|
8 | 8 | from pytensor.tensor.basic import AllocEmpty |
9 | 9 | from pytensor.tensor.blas import Ger |
10 | 10 | from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv |
11 | | -from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector |
| 11 | +from pytensor.tensor.type import ( |
| 12 | + dmatrix, |
| 13 | + dscalar, |
| 14 | + dvector, |
| 15 | + matrix, |
| 16 | + scalar, |
| 17 | + tensor, |
| 18 | + vector, |
| 19 | +) |
12 | 20 | from tests import unittest_tools |
13 | 21 | from tests.tensor.test_blas import BaseGemv, TestBlasStrides |
14 | 22 | from tests.unittest_tools import OptimizationTestMixin |
@@ -143,19 +151,21 @@ def setup_method(self): |
143 | 151 | def test_nan_beta_0(self, inplace): |
144 | 152 | mode = self.mode.including() |
145 | 153 | mode.check_isfinite = False |
| 154 | + beta = self.a.type("beta") |
146 | 155 | f = pytensor.function( |
147 | | - [self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a], |
148 | | - self.a * self.y + pt.dot(self.A, self.x), |
| 156 | + [self.A, self.x, pytensor.In(self.y, mutable=inplace), beta], |
| 157 | + beta * self.y + pt.dot(self.A, self.x), |
149 | 158 | mode=mode, |
150 | 159 | ) |
151 | 160 | [node] = f.maker.fgraph.apply_nodes |
152 | 161 | assert isinstance(node.op, CGemv) and node.op.inplace == inplace |
153 | | - for rows in (3, 1): |
154 | | - Aval = np.ones((rows, 1), dtype=self.dtype) |
155 | | - xval = np.ones((1,), dtype=self.dtype) |
156 | | - yval = np.full((rows,), np.nan, dtype=self.dtype) |
157 | | - zval = f(Aval, xval, yval, 0) |
158 | | - assert not np.isnan(zval).any() |
| 162 | + for rows in (3, 1, 0): |
| 163 | + for cols in (1, 0): |
| 164 | + Aval = np.ones((rows, cols), dtype=self.dtype) |
| 165 | + xval = np.ones((cols,), dtype=self.dtype) |
| 166 | + yval = np.full((rows,), np.nan, dtype=self.dtype) |
| 167 | + zval = f(Aval, xval, yval, beta=0) |
| 168 | + assert not np.isnan(zval).any(), f"{rows=}, {cols=}" |
159 | 169 |
|
160 | 170 | def test_optimizations_vm(self): |
161 | 171 | skip_if_blas_ldflags_empty() |
@@ -294,6 +304,26 @@ def test_multiple_inplace(self): |
294 | 304 | == 2 |
295 | 305 | ) |
296 | 306 |
|
| 307 | + def test_empty_A(self): |
| 308 | + A = dmatrix("A") |
| 309 | + x = dvector("x") |
| 310 | + y = dvector("y") |
| 311 | + alpha = 1.0 |
| 312 | + beta = dscalar("beta") |
| 313 | + gemv = CGemv(inplace=True)(y, alpha, A, x, beta) |
| 314 | + fn = pytensor.function( |
| 315 | + [A, x, y, beta], |
| 316 | + gemv, |
| 317 | + accept_inplace=True, |
| 318 | + ) |
| 319 | + test_A = np.empty((10, 0)) |
| 320 | + test_x = np.empty((0,)) |
| 321 | + test_y = np.random.random((10,)) |
| 322 | + for test_beta in [0.0, 1.0, 2.0]: |
| 323 | + out = fn(test_A, test_x, test_y.copy(), test_beta) |
| 324 | + expected = test_beta * test_y |
| 325 | + np.testing.assert_allclose(out, expected) |
| 326 | + |
297 | 327 |
|
298 | 328 | class TestCGemvFloat32(BaseGemv, OptimizationTestMixin): |
299 | 329 | mode = mode_blas_opt |
|
0 commit comments