Skip to content

Commit 7088994

Browse files
committed
fix scipy-stubs squigglies
1 parent 3fd0ab3 commit 7088994

File tree

5 files changed

+99
-65
lines changed

5 files changed

+99
-65
lines changed

pytensor/link/numba/dispatch/linalg/decomposition/lu.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable
2-
from typing import cast as typing_cast
2+
from typing import Literal
33

44
import numpy as np
55
from numba import njit as numba_njit
@@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
3737

3838
def _lu_1(
3939
a: np.ndarray,
40-
permute_l: bool,
40+
permute_l: Literal[True],
4141
check_finite: bool,
42-
p_indices: bool,
42+
p_indices: Literal[False],
4343
overwrite_a: bool,
4444
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
4545
"""
@@ -48,23 +48,20 @@ def _lu_1(
4848
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
4949
array of row swaps, such that L[perm] @ U = A.
5050
"""
51-
return typing_cast(
52-
tuple[np.ndarray, np.ndarray, np.ndarray],
53-
linalg.lu(
54-
a,
55-
permute_l=permute_l,
56-
check_finite=check_finite,
57-
p_indices=p_indices,
58-
overwrite_a=overwrite_a,
59-
),
51+
return linalg.lu(
52+
a,
53+
permute_l=permute_l,
54+
check_finite=check_finite,
55+
p_indices=p_indices,
56+
overwrite_a=overwrite_a,
6057
)
6158

6259

6360
def _lu_2(
6461
a: np.ndarray,
65-
permute_l: bool,
62+
permute_l: Literal[False],
6663
check_finite: bool,
67-
p_indices: bool,
64+
p_indices: Literal[True],
6865
overwrite_a: bool,
6966
) -> tuple[np.ndarray, np.ndarray]:
7067
"""
@@ -73,23 +70,20 @@ def _lu_2(
7370
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
7471
permuted L matrix, PL = P @ L.
7572
"""
76-
return typing_cast(
77-
tuple[np.ndarray, np.ndarray],
78-
linalg.lu(
79-
a,
80-
permute_l=permute_l,
81-
check_finite=check_finite,
82-
p_indices=p_indices,
83-
overwrite_a=overwrite_a,
84-
),
73+
return linalg.lu(
74+
a,
75+
permute_l=permute_l,
76+
check_finite=check_finite,
77+
p_indices=p_indices,
78+
overwrite_a=overwrite_a,
8579
)
8680

8781

8882
def _lu_3(
8983
a: np.ndarray,
90-
permute_l: bool,
84+
permute_l: Literal[False],
9185
check_finite: bool,
92-
p_indices: bool,
86+
p_indices: Literal[False],
9387
overwrite_a: bool,
9488
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
9589
"""
@@ -98,15 +92,12 @@ def _lu_3(
9892
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
9993
matrix, P @ L @ U = A.
10094
"""
101-
return typing_cast(
102-
tuple[np.ndarray, np.ndarray, np.ndarray],
103-
linalg.lu(
104-
a,
105-
permute_l=permute_l,
106-
check_finite=check_finite,
107-
p_indices=p_indices,
108-
overwrite_a=overwrite_a,
109-
),
95+
return linalg.lu(
96+
a,
97+
permute_l=permute_l,
98+
check_finite=check_finite,
99+
p_indices=p_indices,
100+
overwrite_a=overwrite_a,
110101
)
111102

112103

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
from typing import cast as typing_cast
23

34
import numpy as np
45
from numba.core.extending import overload
@@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
2122
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
2223
returns an info code with diagnostic information.
2324
"""
24-
(getrf,) = linalg.get_lapack_funcs("getrf", (A,))
25-
A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a)
25+
funcs = linalg.get_lapack_funcs("getrf", (A,))
26+
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
27+
getrf = funcs[0]
28+
29+
A_copy, ipiv, info = typing_cast(
30+
tuple[np.ndarray, np.ndarray, int], getrf(A, overwrite_a=overwrite_a)
31+
)
2632

2733
return A_copy, ipiv, info
2834

pytensor/link/numba/dispatch/linalg/decomposition/qr.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
import numpy as np
24
from numba.core.extending import overload
35
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
@@ -13,7 +15,13 @@
1315

1416
def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int):
1517
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
16-
(geqrf,) = get_lapack_funcs(("geqrf",), (A,))
18+
# (geqrf,) = typing_cast(
19+
# list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
20+
# )
21+
funcs = get_lapack_funcs(("geqrf",), (A,))
22+
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
23+
geqrf = funcs[0]
24+
1725
return geqrf(A, overwrite_a=overwrite_a, lwork=lwork)
1826

1927

@@ -61,7 +69,10 @@ def impl(A, overwrite_a, lwork):
6169

6270
def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int):
6371
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
64-
(geqp3,) = get_lapack_funcs(("geqp3",), (A,))
72+
funcs = get_lapack_funcs(("geqp3",), (A,))
73+
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
74+
geqp3 = funcs[0]
75+
6576
return geqp3(A, overwrite_a=overwrite_a, lwork=lwork)
6677

6778

@@ -111,7 +122,10 @@ def impl(A, overwrite_a, lwork):
111122

112123
def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
113124
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
114-
(orgqr,) = get_lapack_funcs(("orgqr",), (A,))
125+
funcs = get_lapack_funcs(("orgqr",), (A,))
126+
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
127+
orgqr = funcs[0]
128+
115129
return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
116130

117131

@@ -160,7 +174,10 @@ def impl(A, tau, overwrite_a, lwork):
160174

161175
def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int):
162176
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
163-
(ungqr,) = get_lapack_funcs(("ungqr",), (A,))
177+
funcs = get_lapack_funcs(("ungqr",), (A,))
178+
assert isinstance(funcs, list) # narrows `funcs: list[F] | F` to `funcs: list[F]`
179+
ungqr = funcs[0]
180+
164181
return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork)
165182

166183

@@ -209,8 +226,8 @@ def impl(A, tau, overwrite_a, lwork):
209226

210227
def _qr_full_pivot(
211228
x: np.ndarray,
212-
mode: str = "full",
213-
pivoting: bool = True,
229+
mode: Literal["full", "economic"] = "full",
230+
pivoting: Literal[True] = True,
214231
overwrite_a: bool = False,
215232
check_finite: bool = False,
216233
lwork: int | None = None,
@@ -234,8 +251,8 @@ def _qr_full_pivot(
234251

235252
def _qr_full_no_pivot(
236253
x: np.ndarray,
237-
mode: str = "full",
238-
pivoting: bool = False,
254+
mode: Literal["full", "economic"] = "full",
255+
pivoting: Literal[False] = False,
239256
overwrite_a: bool = False,
240257
check_finite: bool = False,
241258
lwork: int | None = None,
@@ -258,8 +275,8 @@ def _qr_full_no_pivot(
258275

259276
def _qr_r_pivot(
260277
x: np.ndarray,
261-
mode: str = "r",
262-
pivoting: bool = True,
278+
mode: Literal["r", "raw"] = "r",
279+
pivoting: Literal[True] = True,
263280
overwrite_a: bool = False,
264281
check_finite: bool = False,
265282
lwork: int | None = None,
@@ -282,8 +299,8 @@ def _qr_r_pivot(
282299

283300
def _qr_r_no_pivot(
284301
x: np.ndarray,
285-
mode: str = "r",
286-
pivoting: bool = False,
302+
mode: Literal["r", "raw"] = "r",
303+
pivoting: Literal[False] = False,
287304
overwrite_a: bool = False,
288305
check_finite: bool = False,
289306
lwork: int | None = None,
@@ -306,8 +323,8 @@ def _qr_r_no_pivot(
306323

307324
def _qr_raw_no_pivot(
308325
x: np.ndarray,
309-
mode: str = "raw",
310-
pivoting: bool = False,
326+
mode: Literal["raw"] = "raw",
327+
pivoting: Literal[False] = False,
311328
overwrite_a: bool = False,
312329
check_finite: bool = False,
313330
lwork: int | None = None,
@@ -332,8 +349,8 @@ def _qr_raw_no_pivot(
332349

333350
def _qr_raw_pivot(
334351
x: np.ndarray,
335-
mode: str = "raw",
336-
pivoting: bool = True,
352+
mode: Literal["raw"] = "raw",
353+
pivoting: Literal[True] = True,
337354
overwrite_a: bool = False,
338355
check_finite: bool = False,
339356
lwork: int | None = None,

pytensor/link/numba/dispatch/linalg/solve/lu_solve.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Callable
2+
from typing import Literal, TypeAlias
23

34
import numpy as np
45
from numba.core.extending import overload
@@ -20,8 +21,15 @@
2021
)
2122

2223

24+
_Trans: TypeAlias = Literal[0, 1, 2]
25+
26+
2327
def _getrs(
24-
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
28+
LU: np.ndarray,
29+
B: np.ndarray,
30+
IPIV: np.ndarray,
31+
trans: _Trans | bool, # mypy does not realize that `bool <: Literal[0, 1]`
32+
overwrite_b: bool,
2533
) -> tuple[np.ndarray, int]:
2634
"""
2735
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
@@ -31,8 +39,10 @@ def _getrs(
3139

3240
@overload(_getrs)
3341
def getrs_impl(
34-
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
35-
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, int, bool], tuple[np.ndarray, int]]:
42+
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: _Trans, overwrite_b: bool
43+
) -> Callable[
44+
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
45+
]:
3646
ensure_lapack()
3747
_check_scipy_linalg_matrix(LU, "getrs")
3848
_check_scipy_linalg_matrix(B, "getrs")
@@ -41,7 +51,11 @@ def getrs_impl(
4151
numba_getrs = _LAPACK().numba_xgetrs(dtype)
4252

4353
def impl(
44-
LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool
54+
LU: np.ndarray,
55+
B: np.ndarray,
56+
IPIV: np.ndarray,
57+
trans: _Trans,
58+
overwrite_b: bool,
4559
) -> tuple[np.ndarray, int]:
4660
_N = np.int32(LU.shape[-1])
4761
_solve_check_input_shapes(LU, B)
@@ -89,7 +103,7 @@ def impl(
89103
def _lu_solve(
90104
lu_and_piv: tuple[np.ndarray, np.ndarray],
91105
b: np.ndarray,
92-
trans: int,
106+
trans: _Trans,
93107
overwrite_b: bool,
94108
check_finite: bool,
95109
):
@@ -105,10 +119,10 @@ def _lu_solve(
105119
def lu_solve_impl(
106120
lu_and_piv: tuple[np.ndarray, np.ndarray],
107121
b: np.ndarray,
108-
trans: int,
122+
trans: _Trans,
109123
overwrite_b: bool,
110124
check_finite: bool,
111-
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, bool, bool, bool], np.ndarray]:
125+
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
112126
ensure_lapack()
113127
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
114128
_check_scipy_linalg_matrix(b, "lu_solve")
@@ -117,7 +131,7 @@ def impl(
117131
lu: np.ndarray,
118132
piv: np.ndarray,
119133
b: np.ndarray,
120-
trans: int,
134+
trans: _Trans,
121135
overwrite_b: bool,
122136
check_finite: bool,
123137
) -> np.ndarray:

pytensor/tensor/conv/abstract_conv.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@
66
import sys
77
import warnings
88
from math import gcd
9+
from typing import TYPE_CHECKING
910

1011
import numpy as np
1112
from numpy.exceptions import ComplexWarning
1213

1314

14-
try:
15-
from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve
16-
from scipy.signal.sigtools import _convolve2d
17-
except ImportError:
18-
from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve
15+
if TYPE_CHECKING:
16+
# https://github.com/scipy/scipy-stubs/issues/851
17+
from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve # type: ignore[attr-defined]
1918
from scipy.signal._sigtools import _convolve2d
19+
else:
20+
try:
21+
from scipy.signal.signaltools import _bvalfromboundary, _valfrommode, convolve
22+
from scipy.signal.sigtools import _convolve2d
23+
except ImportError:
24+
from scipy.signal._signaltools import _bvalfromboundary, _valfrommode, convolve
25+
from scipy.signal._sigtools import _convolve2d
2026

2127
import pytensor
2228
from pytensor import tensor as pt

0 commit comments

Comments
 (0)