@@ -11,6 +11,7 @@ __all__ = ["log_softmax", "logsumexp", "softmax"]
1111_InexactT = TypeVar ("_InexactT" , bound = npc .inexact )
1212_FloatingT = TypeVar ("_FloatingT" , bound = npc .floating )
1313_CFloatingT = TypeVar ("_CFloatingT" , bound = npc .complexfloating )
14+ _InexactOrArrayT = TypeVar ("_InexactOrArrayT" , bound = npc .inexact | onp .ArrayND [npc .inexact ])
1415
1516###
1617
@@ -98,7 +99,7 @@ def logsumexp(
9899 keepdims : bool = False ,
99100 return_sign : Falsy = False ,
100101) -> onp .ArrayND [np .float64 | Any ] | Any : ...
101- @overload # ccomplex fallback, return_sign=False
102+ @overload # complex fallback, return_sign=False
102103def logsumexp (
103104 a : onp .ToComplex | onp .ToComplexND ,
104105 axis : AnyShape | None = None ,
@@ -223,7 +224,7 @@ def logsumexp(
223224 * ,
224225 return_sign : Truthy ,
225226) -> tuple [onp .ArrayND [np .float64 | Any ] | Any , onp .ArrayND [np .float64 | Any ] | Any ]: ...
226- @overload # ccomplex fallback, return_sign=True
227+ @overload # complex fallback, return_sign=True
227228def logsumexp (
228229 a : onp .ToComplex | onp .ToComplexND ,
229230 axis : AnyShape | None = None ,
@@ -233,22 +234,46 @@ def logsumexp(
233234 return_sign : Truthy ,
234235) -> tuple [onp .ArrayND [np .float64 | Any ] | Any , onp .ArrayND [np .complex128 | Any ] | Any ]: ...
235236
236- #
237- @overload
238- def softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 : ...
239- @overload
240- def softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
241- @overload
242- def softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .float64 | np .complex128 : ...
243- @overload
244- def softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | np .complex128 ]: ...
237+ # NOTE: keep in sync with `log_softmax`
238+ @overload # T
239+ def softmax (x : _InexactOrArrayT , axis : AnyShape | None = None ) -> _InexactOrArrayT : ... # type: ignore[overload-overlap]
240+ @overload # 0d +float64
241+ def softmax (x : onp .ToInt | onp .ToJustFloat64 , axis : AnyShape | None = None ) -> np .float64 : ...
242+ @overload # 0d ~complex128
243+ def softmax (x : onp .ToJustComplex128 , axis : AnyShape | None = None ) -> np .complex128 : ...
244+ @overload # nd T@inexact
245+ def softmax (x : onp .ToArrayND [_InexactT , _InexactT ], axis : AnyShape | None = None ) -> onp .ArrayND [_InexactT ]: ...
246+ @overload # nd +float64
247+ def softmax (x : onp .ToIntND | onp .ToJustFloat64_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
248+ @overload # nd ~complex128
249+ def softmax (x : onp .ToJustComplex128_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 ]: ...
250+ @overload # 0d float fallback
251+ def softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 | Any : ...
252+ @overload # 0d complex fallback
253+ def softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .complex128 | Any : ...
254+ @overload # nd float fallback
255+ def softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | Any ]: ...
256+ @overload # nd complex fallback
257+ def softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 | Any ]: ...
245258
246- #
247- @overload
248- def log_softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 : ...
249- @overload
250- def log_softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
251- @overload
252- def log_softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .float64 | np .complex128 : ...
253- @overload
254- def log_softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | np .complex128 ]: ...
259+ # NOTE: keep in sync with `softmax`
260+ @overload # T
261+ def log_softmax (x : _InexactOrArrayT , axis : AnyShape | None = None ) -> _InexactOrArrayT : ... # type: ignore[overload-overlap]
262+ @overload # 0d +float64
263+ def log_softmax (x : onp .ToInt | onp .ToJustFloat64 , axis : AnyShape | None = None ) -> np .float64 : ...
264+ @overload # 0d ~complex128
265+ def log_softmax (x : onp .ToJustComplex128 , axis : AnyShape | None = None ) -> np .complex128 : ...
266+ @overload # nd T@inexact
267+ def log_softmax (x : onp .ToArrayND [_InexactT , _InexactT ], axis : AnyShape | None = None ) -> onp .ArrayND [_InexactT ]: ...
268+ @overload # nd +float64
269+ def log_softmax (x : onp .ToIntND | onp .ToJustFloat64_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
270+ @overload # nd ~complex128
271+ def log_softmax (x : onp .ToJustComplex128_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 ]: ...
272+ @overload # 0d float fallback
273+ def log_softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 | Any : ...
274+ @overload # 0d complex fallback
275+ def log_softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .complex128 | Any : ...
276+ @overload # nd float fallback
277+ def log_softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | Any ]: ...
278+ @overload # nd complex fallback
279+ def log_softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 | Any ]: ...
0 commit comments