Skip to content

Commit 255f453

Browse files
committed
🚚 move HasArrayNamespace
Signed-off-by: nstarman <[email protected]>
1 parent 6d54ce7 commit 255f453

File tree

10 files changed

+161
-62
lines changed

10 files changed

+161
-62
lines changed

‎.github/workflows/ci.yml

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ jobs:
8888
python-version: "3.11"
8989
activate-environment: true
9090

91-
- name: get major numpy version
92-
id: numpy-major
91+
- name: get major.minor numpy version
92+
id: numpy-version
9393
run: |
94-
version=$(echo ${{ matrix.numpy-version }} | cut -c 1)
95-
echo "::set-output name=version::$version"
94+
version="${{ matrix.numpy-version }}"
95+
major=$(echo "$version" | cut -d. -f1)
96+
minor=$(echo "$version" | cut -d. -f2)
97+
98+
echo "major=$major" >> $GITHUB_OUTPUT
99+
echo "minor=$minor" >> $GITHUB_OUTPUT
96100
97101
- name: install deps
98102
run: |
@@ -101,10 +105,29 @@ jobs:
101105
102106
# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
103107
- name: mypy
104-
run: >
105-
uv run --no-sync --active
106-
mypy --tb --no-incremental --cache-dir=/dev/null
107-
tests/integration/test_numpy${{ steps.numpy-major.outputs.version }}.pyi
108+
run: |
109+
major="${{ steps.numpy-version.outputs.major }}"
110+
minor="${{ steps.numpy-version.outputs.minor }}"
111+
112+
# Directory containing versioned test files
113+
prefix="tests/integration"
114+
files=""
115+
116+
# Find all test files matching the current major version
117+
for path in $(find "$prefix" -name "test_numpy${major}p*.pyi"); do
118+
# Extract file name
119+
fname=$(basename "$path")
120+
# Parse the minor version from the filename
121+
fminor=$(echo "$fname" | sed -E "s/test_numpy${major}p([0-9]+)\.pyi/\1/")
122+
# Include files where minor version ≤ NumPy's minor
123+
if [ "$fminor" -le "$minor" ]; then
124+
files="$files $path"
125+
fi
126+
done
127+
128+
uv run --no-sync --active \
129+
mypy --tb --no-incremental --cache-dir=/dev/null \
130+
$files
108131
109132
# TODO: (based)pyright
110133

‎pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,16 @@ ignore = [
121121
"D107", # Missing docstring in __init__
122122
"D203", # 1 blank line required before class docstring
123123
"D213", # Multi-line docstring summary should start at the second line
124+
"D401", # First line should be in imperative mood
124125
"FBT", # flake8-boolean-trap
125126
"FIX", # flake8-fixme
126127
"ISC001", # Conflicts with formatter
128+
"PYI041", # Use `float` instead of `int | float`
127129
]
128130

131+
[tool.ruff.lint.pydocstyle]
132+
convention = "google"
133+
129134
[tool.ruff.lint.pylint]
130135
allow-dunder-method-names = [
131136
"__array_api_version__",

‎src/array_api_typing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
"__version_tuple__",
77
)
88

9-
from ._namespace import HasArrayNamespace
9+
from ._array import HasArrayNamespace
1010
from ._version import version as __version__, version_tuple as __version_tuple__

‎src/array_api_typing/_array.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
__all__ = ("HasArrayNamespace",)
2+
3+
from types import ModuleType
4+
from typing import Literal, Protocol
5+
from typing_extensions import TypeVar
6+
7+
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
8+
9+
10+
class HasArrayNamespace(Protocol[NamespaceT_co]):
11+
"""Protocol for classes that have an `__array_namespace__` method.
12+
13+
This `Protocol` is intended for use in static typing to ensure that an
14+
object has an `__array_namespace__` method that returns a namespace for
15+
array operations. This `Protocol` should not be used at runtime, for type
16+
checking or as a base class.
17+
18+
Example:
19+
>>> import array_api_typing as xpt
20+
>>>
21+
>>> class MyArray:
22+
... def __array_namespace__(self):
23+
... return object()
24+
>>>
25+
>>> x = MyArray()
26+
>>> def has_array_namespace(x: xpt.HasArrayNamespace) -> bool:
27+
... return hasattr(x, "__array_namespace__")
28+
>>> has_array_namespace(x)
29+
True
30+
31+
"""
32+
33+
def __array_namespace__(
34+
self, /, *, api_version: Literal["2021.12"] | None = None
35+
) -> NamespaceT_co:
36+
"""Returns an object that has all the array API functions on it.
37+
38+
Args:
39+
api_version: string representing the version of the array API
40+
specification to be returned, in 'YYYY.MM' form, for example,
41+
'2020.10'. If it is `None`, it should return the namespace
42+
corresponding to latest version of the array API specification.
43+
If the given version is invalid or not implemented for the given
44+
module, an error should be raised. Default: `None`.
45+
46+
Returns:
47+
NamespaceT_co: An object representing the array API namespace. It
48+
should have every top-level function defined in the
49+
specification as an attribute. It may contain other public names
50+
as well, but it is recommended to only include those names that
51+
are part of the specification.
52+
53+
"""
54+
...

‎src/array_api_typing/_namespace.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

‎tests/integration/test_numpy1.pyi

Lines changed: 0 additions & 12 deletions
This file was deleted.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from types import ModuleType
4+
from typing import TypeAlias
5+
6+
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
7+
8+
import array_api_typing as xpt
9+
10+
# Define NDArrays against which we can test the protocols
11+
# Note that `np.array_api` doesn't support boolean arrays.
12+
nparr = np.eye(2)
13+
nparr_i32 = np.asarray([1], dtype=np.int32)
14+
nparr_f32 = np.asarray([1.0], dtype=np.float32)
15+
16+
# =========================================================
17+
# `xpt.HasArrayNamespace`
18+
19+
_: xpt.HasArrayNamespace[ModuleType] = nparr
20+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
21+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
22+
23+
# Check `__array_namespace__` method
24+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
25+
ns: ModuleType = a_ns.__array_namespace__()
26+
27+
# Incorrect values are caught when using `__array_namespace__` and
28+
# backpropagated to the type of `a_ns`
29+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught

‎tests/integration/test_numpy2.pyi

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from types import ModuleType
4+
from typing import Any, TypeAlias
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
9+
import array_api_typing as xpt
10+
11+
# DType aliases
12+
F32: TypeAlias = np.float32
13+
I32: TypeAlias = np.int32
14+
15+
# Define NDArrays against which we can test the protocols
16+
nparr: npt.NDArray[Any]
17+
nparr_i32: npt.NDArray[I32]
18+
nparr_f32: npt.NDArray[F32]
19+
nparr_b: npt.NDArray[np.bool_]
20+
21+
# =========================================================
22+
# `xpt.HasArrayNamespace`
23+
24+
# Check assignment
25+
_: xpt.HasArrayNamespace[ModuleType] = nparr
26+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
27+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
28+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
29+
30+
# Check `__array_namespace__` method
31+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
32+
ns: ModuleType = a_ns.__array_namespace__()
33+
34+
# Incorrect values are caught when using `__array_namespace__` and
35+
# backpropagated to the type of `a_ns`
36+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from test_numpy2p0 import nparr
2+
3+
import array_api_typing as xpt
4+
5+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]

0 commit comments

Comments
 (0)