Skip to content

Commit c66b750

Browse files
committed
TYP: fix typing errors in numpy.linalg
1 parent 953d7c0 commit c66b750

File tree

1 file changed

+75
-22
lines changed

1 file changed

+75
-22
lines changed

array_api_compat/numpy/linalg.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,35 @@
1-
from numpy.linalg import * # noqa: F403
2-
from numpy.linalg import __all__ as linalg_all
3-
import numpy as _np
1+
# pyright: reportAttributeAccessIssue=false
2+
# pyright: reportUnknownArgumentType=false
3+
# pyright: reportUnknownMemberType=false
4+
# pyright: reportUnknownVariableType=false
5+
6+
from __future__ import annotations
7+
8+
import numpy as np
9+
10+
# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
11+
from numpy.linalg import (
12+
LinAlgError,
13+
cond,
14+
det,
15+
eig,
16+
eigvals,
17+
eigvalsh,
18+
inv,
19+
lstsq,
20+
matrix_power,
21+
multi_dot,
22+
norm,
23+
tensorinv,
24+
tensorsolve,
25+
)
426

5-
from ..common import _linalg
627
from .._internal import get_xp
28+
from ..common import _linalg
729

830
# These functions are in both the main and linalg namespaces
9-
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
10-
11-
import numpy as np
31+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
32+
from ._typing import Array
1233

1334
cross = get_xp(np)(_linalg.cross)
1435
outer = get_xp(np)(_linalg.outer)
@@ -38,19 +59,28 @@
3859
# To workaround this, the below is the code from np.linalg.solve except
3960
# only calling solve1 in the exactly 1D case.
4061

62+
4163
# This code is here instead of in common because it is numpy specific. Also
4264
# note that CuPy's solve() does not currently support broadcasting (see
4365
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
44-
def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
66+
def solve(x1: Array, x2: Array, /) -> Array:
4567
try:
4668
from numpy.linalg._linalg import (
47-
_makearray, _assert_stacked_2d, _assert_stacked_square,
48-
_commonType, isComplexType, _raise_linalgerror_singular
69+
_assert_stacked_2d,
70+
_assert_stacked_square,
71+
_commonType,
72+
_makearray,
73+
_raise_linalgerror_singular,
74+
isComplexType,
4975
)
5076
except ImportError:
5177
from numpy.linalg.linalg import (
52-
_makearray, _assert_stacked_2d, _assert_stacked_square,
53-
_commonType, isComplexType, _raise_linalgerror_singular
78+
_assert_stacked_2d,
79+
_assert_stacked_square,
80+
_commonType,
81+
_makearray,
82+
_raise_linalgerror_singular,
83+
isComplexType,
5484
)
5585
from numpy.linalg import _umath_linalg
5686

@@ -61,30 +91,53 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
6191
t, result_t = _commonType(x1, x2)
6292

6393
# This part is different from np.linalg.solve
94+
gufunc: np.ufunc
6495
if x2.ndim == 1:
6596
gufunc = _umath_linalg.solve1
6697
else:
6798
gufunc = _umath_linalg.solve
6899

69100
# This does nothing currently but is left in because it will be relevant
70101
# when complex dtype support is added to the spec in 2022.
71-
signature = 'DD->D' if isComplexType(t) else 'dd->d'
72-
with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
73-
over='ignore', divide='ignore', under='ignore'):
74-
r = gufunc(x1, x2, signature=signature)
102+
signature = "DD->D" if isComplexType(t) else "dd->d"
103+
with np.errstate(
104+
call=_raise_linalgerror_singular,
105+
invalid="call",
106+
over="ignore",
107+
divide="ignore",
108+
under="ignore",
109+
):
110+
r: Array = gufunc(x1, x2, signature=signature)
75111

76112
return wrap(r.astype(result_t, copy=False))
77113

114+
78115
# These functions are completely new here. If the library already has them
79116
# (i.e., numpy 2.0), use the library version instead of our wrapper.
80-
if hasattr(np.linalg, 'vector_norm'):
117+
if hasattr(np.linalg, "vector_norm"):
81118
vector_norm = np.linalg.vector_norm
82119
else:
83120
vector_norm = get_xp(np)(_linalg.vector_norm)
84121

85-
__all__ = linalg_all + _linalg.__all__ + ['solve']
86122

87-
del get_xp
88-
del np
89-
del linalg_all
90-
del _linalg
123+
__all__ = [
124+
"LinAlgError",
125+
"cond",
126+
"det",
127+
"eig",
128+
"eigvals",
129+
"eigvalsh",
130+
"inv",
131+
"lstsq",
132+
"matrix_power",
133+
"multi_dot",
134+
"norm",
135+
"tensorinv",
136+
"tensorsolve",
137+
]
138+
__all__ += _linalg.__all__
139+
__all__ += ["solve", "vector_norm"]
140+
141+
142+
def __dir__() -> list[str]:
143+
return __all__

0 commit comments

Comments
 (0)