Skip to content

Commit e8c9bd3

Browse files
committed
Fix CGemV with empty A
1 parent abedb7f commit e8c9bd3

File tree

2 files changed

+63
-24
lines changed

2 files changed

+63
-24
lines changed

pytensor/tensor/blas_c.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -417,20 +417,6 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
417417
}
418418
}
419419
420-
if (%(must_initialize_y)d && dbeta == 0)
421-
{
422-
// Most BLAS implementations of GEMV ignore y=nan when beta=0
423-
// PyTensor considers that the correct behavior,
424-
// and even exploits it to avoid copying or initializing outputs.
425-
// By deciding to exploit this, however, it becomes our responsibility
426-
// to ensure the behavior even in the rare cases BLAS deviates,
427-
// or users will get errors, even for graphs that had no nan to begin with.
428-
PyObject *zero = PyFloat_FromDouble(0.);
429-
if (zero == NULL) %(fail)s;
430-
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
431-
Py_DECREF(zero);
432-
}
433-
434420
{
435421
int NA0 = PyArray_DIMS(%(A)s)[0];
436422
int NA1 = PyArray_DIMS(%(A)s)[1];
@@ -439,6 +425,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
439425
{
440426
// Non-empty A matrix
441427
428+
if (%(must_initialize_y)d && dbeta == 0)
429+
{
430+
// Most BLAS implementations of GEMV ignore y=nan when beta=0
431+
// PyTensor considers that the correct behavior,
432+
// and even exploits it to avoid copying or initializing outputs.
433+
// By deciding to exploit this, however, it becomes our responsibility
434+
// to ensure the behavior even in the rare cases BLAS deviates,
435+
// or users will get errors, even for graphs that had no nan to begin with.
436+
PyArray_FILLWBYTE(%(z)s, 0);
437+
}
438+
442439
/* In the case where A is actually a row or column matrix,
443440
* the strides corresponding to the dummy dimension don't matter,
444441
* but BLAS requires these to be no smaller than the number of elements in the array.
@@ -567,6 +564,18 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
567564
"A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;");
568565
%(fail)s
569566
}
567+
} else
568+
{
569+
// Empty A matrix, just scale y by beta
570+
if (dbeta != 1.0)
571+
{
572+
npy_intp Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
573+
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
574+
for (npy_intp i = 0; i < NA0; ++i)
575+
{
576+
z_data[i * Sz] = (dbeta == 0.0) ? 0 : z_data[i * Sz] * dbeta;
577+
}
578+
}
570579
}
571580
}
572581
"""
@@ -598,7 +607,7 @@ def c_code(self, node, name, inp, out, sub):
598607
return code
599608

600609
def c_code_cache_version(self):
601-
return (17, blas_header_version(), must_initialize_y_gemv())
610+
return (18, blas_header_version(), must_initialize_y_gemv())
602611

603612

604613
cgemv_inplace = CGemv(inplace=True)

tests/tensor/test_blas_c.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
from pytensor.tensor.basic import AllocEmpty
99
from pytensor.tensor.blas import Ger
1010
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
11-
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
11+
from pytensor.tensor.type import (
12+
dmatrix,
13+
dscalar,
14+
dvector,
15+
matrix,
16+
scalar,
17+
tensor,
18+
vector,
19+
)
1220
from tests import unittest_tools
1321
from tests.tensor.test_blas import BaseGemv, TestBlasStrides
1422
from tests.unittest_tools import OptimizationTestMixin
@@ -143,19 +151,21 @@ def setup_method(self):
143151
def test_nan_beta_0(self, inplace):
144152
mode = self.mode.including()
145153
mode.check_isfinite = False
154+
beta = self.a.type("beta")
146155
f = pytensor.function(
147-
[self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a],
148-
self.a * self.y + pt.dot(self.A, self.x),
156+
[self.A, self.x, pytensor.In(self.y, mutable=inplace), beta],
157+
beta * self.y + pt.dot(self.A, self.x),
149158
mode=mode,
150159
)
151160
[node] = f.maker.fgraph.apply_nodes
152161
assert isinstance(node.op, CGemv) and node.op.inplace == inplace
153-
for rows in (3, 1):
154-
Aval = np.ones((rows, 1), dtype=self.dtype)
155-
xval = np.ones((1,), dtype=self.dtype)
156-
yval = np.full((rows,), np.nan, dtype=self.dtype)
157-
zval = f(Aval, xval, yval, 0)
158-
assert not np.isnan(zval).any()
162+
for rows in (3, 1, 0):
163+
for cols in (1, 0):
164+
Aval = np.ones((rows, cols), dtype=self.dtype)
165+
xval = np.ones((cols,), dtype=self.dtype)
166+
yval = np.full((rows,), np.nan, dtype=self.dtype)
167+
zval = f(Aval, xval, yval, beta=0)
168+
assert not np.isnan(zval).any(), f"{rows=}, {cols=}"
159169

160170
def test_optimizations_vm(self):
161171
skip_if_blas_ldflags_empty()
@@ -294,6 +304,26 @@ def test_multiple_inplace(self):
294304
== 2
295305
)
296306

307+
def test_empty_A(self):
308+
A = dmatrix("A")
309+
x = dvector("x")
310+
y = dvector("y")
311+
alpha = 1.0
312+
beta = dscalar("beta")
313+
gemv = CGemv(inplace=True)(y, alpha, A, x, beta)
314+
fn = pytensor.function(
315+
[A, x, y, beta],
316+
gemv,
317+
accept_inplace=True,
318+
)
319+
test_A = np.empty((10, 0))
320+
test_x = np.empty((0,))
321+
test_y = np.random.random((10,))
322+
for test_beta in [0.0, 1.0, 2.0]:
323+
out = fn(test_A, test_x, test_y.copy(), test_beta)
324+
expected = test_beta * test_y
325+
np.testing.assert_allclose(out, expected)
326+
297327

298328
class TestCGemvFloat32(BaseGemv, OptimizationTestMixin):
299329
mode = mode_blas_opt

0 commit comments

Comments
 (0)