1
+ from typing import Literal
2
+
1
3
import numpy as np
2
4
from numba .core .extending import overload
3
5
from numba .np .linalg import _copy_to_fortran_order , ensure_lapack
13
15
14
16
def _xgeqrf (A : np .ndarray , overwrite_a : bool , lwork : int ):
15
17
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
16
- (geqrf ,) = get_lapack_funcs (("geqrf" ,), (A ,))
18
+ # (geqrf,) = typing_cast(
19
+ # list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
20
+ # )
21
+ funcs = get_lapack_funcs (("geqrf" ,), (A ,))
22
+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
23
+ geqrf = funcs [0 ]
24
+
17
25
return geqrf (A , overwrite_a = overwrite_a , lwork = lwork )
18
26
19
27
@@ -61,7 +69,10 @@ def impl(A, overwrite_a, lwork):
61
69
62
70
def _xgeqp3 (A : np .ndarray , overwrite_a : bool , lwork : int ):
63
71
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
64
- (geqp3 ,) = get_lapack_funcs (("geqp3" ,), (A ,))
72
+ funcs = get_lapack_funcs (("geqp3" ,), (A ,))
73
+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
74
+ geqp3 = funcs [0 ]
75
+
65
76
return geqp3 (A , overwrite_a = overwrite_a , lwork = lwork )
66
77
67
78
@@ -111,7 +122,10 @@ def impl(A, overwrite_a, lwork):
111
122
112
123
def _xorgqr (A : np .ndarray , tau : np .ndarray , overwrite_a : bool , lwork : int ):
113
124
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
114
- (orgqr ,) = get_lapack_funcs (("orgqr" ,), (A ,))
125
+ funcs = get_lapack_funcs (("orgqr" ,), (A ,))
126
+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
127
+ orgqr = funcs [0 ]
128
+
115
129
return orgqr (A , tau , overwrite_a = overwrite_a , lwork = lwork )
116
130
117
131
@@ -160,7 +174,10 @@ def impl(A, tau, overwrite_a, lwork):
160
174
161
175
def _xungqr (A : np .ndarray , tau : np .ndarray , overwrite_a : bool , lwork : int ):
162
176
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
163
- (ungqr ,) = get_lapack_funcs (("ungqr" ,), (A ,))
177
+ funcs = get_lapack_funcs (("ungqr" ,), (A ,))
178
+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
179
+ ungqr = funcs [0 ]
180
+
164
181
return ungqr (A , tau , overwrite_a = overwrite_a , lwork = lwork )
165
182
166
183
@@ -209,8 +226,8 @@ def impl(A, tau, overwrite_a, lwork):
209
226
210
227
def _qr_full_pivot (
211
228
x : np .ndarray ,
212
- mode : str = "full" ,
213
- pivoting : bool = True ,
229
+ mode : Literal [ "full" , "economic" ] = "full" ,
230
+ pivoting : Literal [ True ] = True ,
214
231
overwrite_a : bool = False ,
215
232
check_finite : bool = False ,
216
233
lwork : int | None = None ,
@@ -234,8 +251,8 @@ def _qr_full_pivot(
234
251
235
252
def _qr_full_no_pivot (
236
253
x : np .ndarray ,
237
- mode : str = "full" ,
238
- pivoting : bool = False ,
254
+ mode : Literal [ "full" , "economic" ] = "full" ,
255
+ pivoting : Literal [ False ] = False ,
239
256
overwrite_a : bool = False ,
240
257
check_finite : bool = False ,
241
258
lwork : int | None = None ,
@@ -258,8 +275,8 @@ def _qr_full_no_pivot(
258
275
259
276
def _qr_r_pivot (
260
277
x : np .ndarray ,
261
- mode : str = "r" ,
262
- pivoting : bool = True ,
278
+ mode : Literal [ "r" , "raw" ] = "r" ,
279
+ pivoting : Literal [ True ] = True ,
263
280
overwrite_a : bool = False ,
264
281
check_finite : bool = False ,
265
282
lwork : int | None = None ,
@@ -282,8 +299,8 @@ def _qr_r_pivot(
282
299
283
300
def _qr_r_no_pivot (
284
301
x : np .ndarray ,
285
- mode : str = "r" ,
286
- pivoting : bool = False ,
302
+ mode : Literal [ "r" , "raw" ] = "r" ,
303
+ pivoting : Literal [ False ] = False ,
287
304
overwrite_a : bool = False ,
288
305
check_finite : bool = False ,
289
306
lwork : int | None = None ,
@@ -306,8 +323,8 @@ def _qr_r_no_pivot(
306
323
307
324
def _qr_raw_no_pivot (
308
325
x : np .ndarray ,
309
- mode : str = "raw" ,
310
- pivoting : bool = False ,
326
+ mode : Literal [ "raw" ] = "raw" ,
327
+ pivoting : Literal [ False ] = False ,
311
328
overwrite_a : bool = False ,
312
329
check_finite : bool = False ,
313
330
lwork : int | None = None ,
@@ -332,8 +349,8 @@ def _qr_raw_no_pivot(
332
349
333
350
def _qr_raw_pivot (
334
351
x : np .ndarray ,
335
- mode : str = "raw" ,
336
- pivoting : bool = True ,
352
+ mode : Literal [ "raw" ] = "raw" ,
353
+ pivoting : Literal [ True ] = True ,
337
354
overwrite_a : bool = False ,
338
355
check_finite : bool = False ,
339
356
lwork : int | None = None ,
0 commit comments