Skip to content
This repository was archived by the owner on Feb 17, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,17 @@ class void:
# and this is why we let int32 be a subclass of int64; and similarly for float32 and float64
# the same logic applies when adding unsigned and signed values (uint + int -> int)

class complex128(void, int):
def __complex__(self) -> complex: ...

class complex64(complex128): ...

# this would be the correct definition, but it makes `int` conflict with `float`
# class float64(void, float): ...
class float64(void, int):
class float64(complex128):
def __float__(self) -> float: ...

class float32(float64): ...
class float32(float64, complex64): ...
class float16(float32): ...

floating = float64
Expand All @@ -70,6 +75,8 @@ integer = int64
_DType = TypeVar(
"_DType",
bool_,
complex64,
complex128,
float16,
float32,
float64,
Expand All @@ -87,6 +94,8 @@ _DType = TypeVar(
_DType2 = TypeVar(
"_DType2",
bool_,
complex64,
complex128,
float16,
float32,
float64,
Expand All @@ -110,7 +119,7 @@ _ScalarLike = Union[_DType, str, int, float]
_ConditionType = Union[ndarray[bool_], bool_, bool]
newaxis: None = ...

_AnyNum = Union[int, float, bool]
_AnyNum = Union[int, float, bool, complex]
# generic types that are only allowed to take on dtype values

_Float = TypeVar("_Float", float16, float32, float64)
Expand Down Expand Up @@ -390,6 +399,8 @@ class ndarray(Generic[_DType]):
def __radd__(self, value: ndarray[_DType]) -> ndarray[_DType]: ...
@overload
def __radd__(self, value: _DType) -> ndarray[_DType]: ...
@overload
def __radd__(self, value: float) -> ndarray[_DType2]: ...
def __rand__(self, value: object) -> ndarray[_DType]: ...
def __rdivmod__(self, value: object) -> Tuple[ndarray[_DType], ndarray[_DType]]: ...
def __rfloordiv__(self, value: object) -> ndarray[_DType]: ...
Expand Down Expand Up @@ -504,6 +515,8 @@ def array(object: _NestedList[int]) -> ndarray[int64]: ...
@overload
def array(object: _NestedList[float]) -> ndarray[float64]: ...
@overload
def array(object: _NestedList[complex]) -> ndarray[complex64]: ...
@overload
def array(object: _NestedList[str]) -> ndarray[str_]: ...
@overload
def array(object: str) -> ndarray[str_]: ...
Expand Down Expand Up @@ -774,10 +787,6 @@ def interp(
) -> ndarray: ...
def isin(element: Sequence[_DType], test_element: _DType) -> ndarray[_DType]: ...
@overload
def isnan(x: float64) -> bool: ...
@overload
def isnan(x: ndarray[_DType]) -> ndarray[bool_]: ...
@overload
def ix_(x: ndarray[_DType]) -> ndarray[_DType]: ...
@overload
def ix_(x1: ndarray[_DType], x2: ndarray[_DType]) -> Tuple[ndarray[_DType], ndarray[_DType]]: ...
Expand Down Expand Up @@ -993,6 +1002,92 @@ def set_printoptions(
*,
legacy: Any = ...,
) -> None: ...
def isscalar(element: Any) -> bool: ...
def diagonal(a: _ArrayLike, offset: int = ..., axis1: int = ..., axis2: int = ...) -> ndarray: ...
def allclose(
a: Union[_ArrayLike, _FloatLike],
b: Union[_ArrayLike, _FloatLike],
rtol: float = ...,
atol: float = ...,
equal_nan: bool = ...,
) -> bool: ...

#
# ufunc
#

# Backported from latest NumPy
class ufunc:
@property
def __name__(self) -> str: ...
def __call__(
self,
*args: Union[_FloatLike, _ArrayLike],
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
where: Optional[ndarray] = ...,
# The list should be a list of tuples of ints, but since we
# don't know the signature it would need to be
# Tuple[int, ...]. But, since List is invariant something like
# e.g. List[Tuple[int, int]] isn't a subtype of
# List[Tuple[int, ...]], so we can't type precisely here.
axes: List[Any] = ...,
axis: int = ...,
keepdims: bool = ...,
casting: Any = ...,
order: Any = ...,
dtype: Any = ...,
subok: bool = ...,
signature: Union[str, Tuple[str]] = ...,
# In reality this should be a length of list 3 containing an
# int, an int, and a callable, but there's no way to express
# that.
extobj: List[Union[int, Callable]] = ...,
) -> Any: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this Any return type. Is it possible to do an overload like we had before?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took this from NumPy stubs, I guess the return type from ufuncs is too broad, check this issue: numpy/numpy#17805

@property
def nin(self) -> int: ...
@property
def nout(self) -> int: ...
@property
def nargs(self) -> int: ...
@property
def ntypes(self) -> int: ...
@property
def types(self) -> List[str]: ...
# Broad return type because it has to encompass things like
#
# >>> np.logical_and.identity is True
# True
# >>> np.add.identity is 0
# True
# >>> np.sin.identity is None
# True
#
# and any user-defined ufuncs.
@property
def identity(self) -> Any: ...
# This is None for ufuncs and a string for gufuncs.
@property
def signature(self) -> Optional[str]: ...
# The next four methods will always exist, but they will just
# raise a ValueError ufuncs with that don't accept two input
# arguments and return one output argument. Because of that we
# can't type them very precisely.
@property
def reduce(self) -> Any: ...
@property
def accumulate(self) -> Any: ...
@property
def reduceat(self) -> Any: ...
@property
def outer(self) -> Any: ...
# Similarly at won't be defined for ufuncs that return multiple
# outputs, so we can't type it very precisely.
@property
def at(self) -> Any: ...

isfinite: ufunc
isinf: ufunc
isnan: ufunc

#
# Specific values
Expand Down
56 changes: 56 additions & 0 deletions tests/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
DType = TypeVar(
"DType",
np.bool_,
np.complex64,
np.complex128,
np.float32,
np.float64,
np.int8,
Expand Down Expand Up @@ -259,3 +261,57 @@ def test_interp() -> None:
def test_genfromtxt() -> None:
result = np.genfromtxt(["0.1, 0.2"], dtype=np.float64, delimiter=",")
assert list(result) == [0.1, 0.2]


def test_isfinite_isinf_isnan() -> None:
import math

assert np.isfinite(0.0)
assert not np.isfinite(np.inf)
assert np.isinf(np.inf)
assert not np.isfinite(math.inf)
assert not np.isfinite(np.nan)
assert np.isnan(np.nan)
assert np.all(np.isfinite([0.0, -np.inf]) == [True, False])
assert np.all(np.isfinite(np.array([0.0, np.nan])) == np.array([True, False]))
assert np.all(
np.isfinite(np.array([np.inf, np.nan], dtype=np.float32)) == np.array([False, False])
)
assert np.all(np.isnan([0.0, -np.inf]) == [False, False])
assert np.all(np.isinf([0.0, -np.inf]) == [False, True])


def test_diagonal() -> None:
assert np.all(np.diagonal([[1]]) == np.array([1]))
x = np.arange(12).reshape(3, 4)
assert np.all(np.diagonal(x) == np.array([0.0, 5.0, 10.0]))


def test_allclose() -> None:
assert np.allclose([1.0, 2.0], [1.0 + 1e-9, 2.0 + 1e-9])
assert np.allclose(np.array([1.0, 2.0]), np.array([1.0 + 1e-9, 2.0 + 1e-9]))
assert np.allclose(np.array([1.0, 1.0]), 1.0 + 1e-9)
assert np.allclose(1.0 + 1e-9, np.array([1.0, 1.0]))


def test_isscalar() -> None:
assert np.isscalar(1.0)
assert not np.isscalar([1.0])
assert not np.isscalar(np.array([1.0]))
assert not np.isscalar(np.array([]))
assert np.isscalar(np.array([1.0], dtype=np.float32)[0])


def test_newaxis() -> None:
x = np.array([1.0, 2.0])
assert x[np.newaxis, :].shape == (1, 2)


def test_sum_scalar_before() -> None:
x: np.ndarray[np.float64] = 273.15 + np.array([10, 20])
assert isinstance(x, np.ndarray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert the type of the ndarray?

Copy link
Author

@tadeu tadeu Dec 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done :)

assert_dtype(x, np.float64)

y: np.ndarray[np.complex128] = 10.0 + np.array([1j, 2j])
assert isinstance(y, np.ndarray)
assert_dtype(y, np.complex128)