Skip to content

Commit 176513d

Browse files
authored
Merge pull request #399 from ev-br/test_all_2
BUG: cupy/linalg: add all names
2 parents 7c540f4 + 2ddef55 commit 176513d

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

array_api_compat/cupy/linalg.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
from cupy.linalg import * # noqa: F403
2-
# cupy.linalg doesn't have __all__. If it is added, replace this with
2+
3+
# https://github.com/cupy/cupy/issues/9749
4+
from cupy.linalg import lstsq # noqa: F401
5+
6+
# cupy.linalg doesn't have __all__ in cupy<14. If it is added, replace this with
37
#
48
# from cupy.linalg import __all__ as linalg_all
59
_n: dict[str, object] = {}
610
exec('from cupy.linalg import *', _n)
711
del _n['__builtins__']
8-
linalg_all = list(_n)
12+
linalg_all = list(_n) + ['lstsq']
913
del _n
1014

15+
try:
16+
# cupy 14 exports it, cupy 13 does not
17+
from cupy.linalg import annotations # noqa: F401
18+
linalg_all += ['annotations']
19+
except ImportError:
20+
pass
21+
22+
1123
from ..common import _linalg
1224
from .._internal import get_xp
1325

@@ -43,5 +55,8 @@
4355

4456
__all__ = linalg_all + _linalg.__all__
4557

58+
# cupy 13 does not have __all__, cupy 14 has it: remove duplicates
59+
__all__ = sorted(list(set(__all__)))
60+
4661
def __dir__() -> list[str]:
4762
return __all__

tests/test_all.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
# Manipulation Functions
141141
"broadcast_arrays",
142142
"broadcast_to",
143+
"broadcast_shapes",
143144
"concat",
144145
"expand_dims",
145146
"flip",
@@ -164,6 +165,7 @@
164165
"unique_counts",
165166
"unique_inverse",
166167
"unique_values",
168+
"isin",
167169
# Sorting Functions
168170
"argsort",
169171
"sort",
@@ -205,6 +207,8 @@
205207
"diagonal",
206208
"eigh",
207209
"eigvalsh",
210+
"eig",
211+
"eigvals",
208212
"inv",
209213
"matmul",
210214
"matrix_norm",
@@ -227,12 +231,14 @@
227231

228232
XFAILS = {
229233
("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [],
230-
("dask.array", ""): ["from_dlpack", "take_along_axis"],
234+
("dask.array", ""): ["from_dlpack", "take_along_axis", "broadcast_shapes"],
231235
("dask.array", "linalg"): [
232236
"cross",
233237
"det",
234238
"eigh",
235239
"eigvalsh",
240+
"eig",
241+
"eigvals",
236242
"matrix_power",
237243
"pinv",
238244
"slogdet",

0 commit comments

Comments
 (0)