@@ -4,69 +4,91 @@ from typing_extensions import TypeVar
44import numpy as np
55import optype as op
66import optype .numpy as onp
7+ import optype .numpy .compat as npc
78
89from scipy ._typing import Falsy , Truthy
910
1011__all__ = ["diagsvd" , "null_space" , "orth" , "subspace_angles" , "svd" , "svdvals" ]
1112
13+ _T = TypeVar ("_T" )
14+ _Tuple3 : TypeAlias = tuple [_T , _T , _T ]
15+
1216_Float : TypeAlias = np .float32 | np .float64
1317_FloatND : TypeAlias = onp .ArrayND [_Float ]
1418
1519_Complex : TypeAlias = np .complex64 | np .complex128
16- _ComplexND : TypeAlias = onp .ArrayND [_Complex ]
1720
1821_LapackDriver : TypeAlias = Literal ["gesdd" , "gesvd" ]
1922
20- _FloatSVD : TypeAlias = tuple [_FloatND , _FloatND , _FloatND ]
21- _ComplexSVD : TypeAlias = tuple [_ComplexND , _FloatND , _ComplexND ]
22-
2323_RealT = TypeVar ("_RealT" , bound = np .bool_ | np .integer [Any ] | np .floating [Any ])
2424_InexactT = TypeVar ("_InexactT" , bound = _Float | _Complex )
2525
26+ _as_f32 : TypeAlias = np .float32 | np .float16 # noqa: PYI042
27+ _as_f64 : TypeAlias = np .longdouble | np .float64 | npc .integer | np .bool_ # noqa: PYI042
28+
2629###
2730
28- @overload
31+ @overload # nd float64
2932def svd (
30- a : onp .ToFloatND ,
33+ a : onp .ToArrayND [ float , _as_f64 ] ,
3134 full_matrices : onp .ToBool = True ,
3235 compute_uv : Truthy = True ,
3336 overwrite_a : onp .ToBool = False ,
3437 check_finite : onp .ToBool = True ,
3538 lapack_driver : _LapackDriver = "gesdd" ,
36- ) -> _FloatSVD : ...
37- @overload
39+ ) -> _Tuple3 [onp .ArrayND [np .float64 ]]: ...
40+ @overload # nd float32
41+ def svd (
42+ a : onp .ToArrayND [_as_f32 , _as_f32 ],
43+ full_matrices : onp .ToBool = True ,
44+ compute_uv : Truthy = True ,
45+ overwrite_a : onp .ToBool = False ,
46+ check_finite : onp .ToBool = True ,
47+ lapack_driver : _LapackDriver = "gesdd" ,
48+ ) -> _Tuple3 [onp .ArrayND [np .float32 ]]: ...
49+ @overload # nd complex128
3850def svd (
39- a : onp .ToComplexND ,
51+ a : onp .ToArrayND [ op . JustComplex , np . complex128 | np . clongdouble ] ,
4052 full_matrices : onp .ToBool = True ,
4153 compute_uv : Truthy = True ,
4254 overwrite_a : onp .ToBool = False ,
4355 check_finite : onp .ToBool = True ,
4456 lapack_driver : _LapackDriver = "gesdd" ,
45- ) -> _FloatSVD | _ComplexSVD : ...
46- @overload # complex, compute_uv: {False}
57+ ) -> tuple [ onp . ArrayND [ np . complex128 ], onp . ArrayND [ np . float64 ], onp . ArrayND [ np . complex128 ]] : ...
58+ @overload # nd complex64
4759def svd (
48- a : onp .ToComplexND ,
49- full_matrices : onp .ToBool ,
60+ a : onp .ToArrayND [np .complex64 , np .complex64 ],
61+ full_matrices : onp .ToBool = True ,
62+ compute_uv : Truthy = True ,
63+ overwrite_a : onp .ToBool = False ,
64+ check_finite : onp .ToBool = True ,
65+ lapack_driver : _LapackDriver = "gesdd" ,
66+ ) -> tuple [onp .ArrayND [np .complex64 ], onp .ArrayND [np .float32 ], onp .ArrayND [np .complex64 ]]: ...
67+ @overload # nd float64 | complex128, compute_uv=False (keyword)
68+ def svd (
69+ a : onp .ToArrayND [complex , _as_f64 | np .complex128 | np .clongdouble ],
70+ full_matrices : onp .ToBool = True ,
71+ * ,
5072 compute_uv : Falsy ,
5173 overwrite_a : onp .ToBool = False ,
5274 check_finite : onp .ToBool = True ,
5375 lapack_driver : _LapackDriver = "gesdd" ,
54- ) -> _FloatND : ...
55- @overload # complex, * , compute_uv: { False}
76+ ) -> onp . ArrayND [ np . float64 ] : ...
77+ @overload # nd float32 | complex64 , compute_uv= False (keyword)
5678def svd (
57- a : onp .ToComplexND ,
79+ a : onp .ToArrayND [ _as_f32 , _as_f32 | np . complex64 ] ,
5880 full_matrices : onp .ToBool = True ,
5981 * ,
6082 compute_uv : Falsy ,
6183 overwrite_a : onp .ToBool = False ,
6284 check_finite : onp .ToBool = True ,
6385 lapack_driver : _LapackDriver = "gesdd" ,
64- ) -> _FloatND : ...
86+ ) -> onp . ArrayND [ np . float32 ] : ...
6587
6688#
6789def svdvals (a : onp .ToComplexND , overwrite_a : onp .ToBool = False , check_finite : onp .ToBool = True ) -> _FloatND : ...
6890
69- # beware the overlapping overloads for bool <: int (<: float)
91+ #
7092@overload
7193def diagsvd (s : onp .SequenceND [_RealT ] | onp .CanArrayND [_RealT ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [_RealT ]: ...
7294@overload
0 commit comments