1
- from numpy .linalg import * # noqa: F403
2
- from numpy .linalg import __all__ as linalg_all
3
- import numpy as _np
1
+ # pyright: reportAttributeAccessIssue=false
2
+ # pyright: reportUnknownArgumentType=false
3
+ # pyright: reportUnknownMemberType=false
4
+ # pyright: reportUnknownVariableType=false
5
+
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+
10
+ # intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
11
+ from numpy .linalg import (
12
+ LinAlgError ,
13
+ cond ,
14
+ det ,
15
+ eig ,
16
+ eigvals ,
17
+ eigvalsh ,
18
+ inv ,
19
+ lstsq ,
20
+ matrix_power ,
21
+ multi_dot ,
22
+ norm ,
23
+ tensorinv ,
24
+ tensorsolve ,
25
+ )
4
26
5
- from ..common import _linalg
6
27
from .._internal import get_xp
28
+ from ..common import _linalg
7
29
8
30
# These functions are in both the main and linalg namespaces
9
- from ._aliases import matmul , matrix_transpose , tensordot , vecdot # noqa: F401
10
-
11
- import numpy as np
31
+ from ._aliases import matmul , matrix_transpose , tensordot , vecdot # noqa: F401
32
+ from ._typing import Array
12
33
13
34
cross = get_xp (np )(_linalg .cross )
14
35
outer = get_xp (np )(_linalg .outer )
38
59
# To workaround this, the below is the code from np.linalg.solve except
39
60
# only calling solve1 in the exactly 1D case.
40
61
62
+
41
63
# This code is here instead of in common because it is numpy specific. Also
42
64
# note that CuPy's solve() does not currently support broadcasting (see
43
65
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
44
- def solve (x1 : _np . ndarray , x2 : _np . ndarray , / ) -> _np . ndarray :
66
+ def solve (x1 : Array , x2 : Array , / ) -> Array :
45
67
try :
46
68
from numpy .linalg ._linalg import (
47
- _makearray , _assert_stacked_2d , _assert_stacked_square ,
48
- _commonType , isComplexType , _raise_linalgerror_singular
69
+ _assert_stacked_2d ,
70
+ _assert_stacked_square ,
71
+ _commonType ,
72
+ _makearray ,
73
+ _raise_linalgerror_singular ,
74
+ isComplexType ,
49
75
)
50
76
except ImportError :
51
77
from numpy .linalg .linalg import (
52
- _makearray , _assert_stacked_2d , _assert_stacked_square ,
53
- _commonType , isComplexType , _raise_linalgerror_singular
78
+ _assert_stacked_2d ,
79
+ _assert_stacked_square ,
80
+ _commonType ,
81
+ _makearray ,
82
+ _raise_linalgerror_singular ,
83
+ isComplexType ,
54
84
)
55
85
from numpy .linalg import _umath_linalg
56
86
@@ -61,30 +91,53 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
61
91
t , result_t = _commonType (x1 , x2 )
62
92
63
93
# This part is different from np.linalg.solve
94
+ gufunc : np .ufunc
64
95
if x2 .ndim == 1 :
65
96
gufunc = _umath_linalg .solve1
66
97
else :
67
98
gufunc = _umath_linalg .solve
68
99
69
100
# This does nothing currently but is left in because it will be relevant
70
101
# when complex dtype support is added to the spec in 2022.
71
- signature = 'DD->D' if isComplexType (t ) else 'dd->d'
72
- with _np .errstate (call = _raise_linalgerror_singular , invalid = 'call' ,
73
- over = 'ignore' , divide = 'ignore' , under = 'ignore' ):
74
- r = gufunc (x1 , x2 , signature = signature )
102
+ signature = "DD->D" if isComplexType (t ) else "dd->d"
103
+ with np .errstate (
104
+ call = _raise_linalgerror_singular ,
105
+ invalid = "call" ,
106
+ over = "ignore" ,
107
+ divide = "ignore" ,
108
+ under = "ignore" ,
109
+ ):
110
+ r : Array = gufunc (x1 , x2 , signature = signature )
75
111
76
112
return wrap (r .astype (result_t , copy = False ))
77
113
114
+
78
115
# These functions are completely new here. If the library already has them
79
116
# (i.e., numpy 2.0), use the library version instead of our wrapper.
80
- if hasattr (np .linalg , ' vector_norm' ):
117
+ if hasattr (np .linalg , " vector_norm" ):
81
118
vector_norm = np .linalg .vector_norm
82
119
else :
83
120
vector_norm = get_xp (np )(_linalg .vector_norm )
84
121
85
- __all__ = linalg_all + _linalg .__all__ + ['solve' ]
86
122
87
- del get_xp
88
- del np
89
- del linalg_all
90
- del _linalg
123
+ __all__ = [
124
+ "LinAlgError" ,
125
+ "cond" ,
126
+ "det" ,
127
+ "eig" ,
128
+ "eigvals" ,
129
+ "eigvalsh" ,
130
+ "inv" ,
131
+ "lstsq" ,
132
+ "matrix_power" ,
133
+ "multi_dot" ,
134
+ "norm" ,
135
+ "tensorinv" ,
136
+ "tensorsolve" ,
137
+ ]
138
+ __all__ += _linalg .__all__
139
+ __all__ += ["solve" , "vector_norm" ]
140
+
141
+
142
+ def __dir__ () -> list [str ]:
143
+ return __all__
0 commit comments