diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..715182a1 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -241,6 +241,20 @@ def sort( ) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values + +# Wrap torch.argsort to set stable=True by default +def argsort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: + return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs) + + def _normalize_axes(axis, ndim): axes = [] if ndim == 0 and axis: diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 76342980..f11b3eb5 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Literal -import torch +import torch # noqa: F401 import torch.fft from ._typing import Array