Skip to content

Commit 0b251a6

Browse files
authored
Merge pull request #408 from ev-br/torch_round
BUG: torch: work around torch.round not supporting complex inputs
2 parents ac7e997 + 4078863 commit 0b251a6

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,22 @@ def sign(x: Array, /) -> Array:
916916
return out
917917

918918

919+
def round(x: Array, /, **kwargs) -> Array:
920+
# torch.round fails for complex inputs
921+
# https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845
922+
if x.dtype.is_complex:
923+
out = kwargs.pop('out', None)
924+
res_r = torch.round(x.real, **kwargs)
925+
res_i = torch.round(x.imag, **kwargs)
926+
res = res_r + 1j*res_i
927+
if out is not None:
928+
out.copy_(res)
929+
return out
930+
return res
931+
else:
932+
return torch.round(x, **kwargs)
933+
934+
919935
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]:
920936
# torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
921937
# will be required to pass the indexing argument."
@@ -927,7 +943,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
927943
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
928944
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
929945
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
930-
'diff', 'divide',
946+
'diff', 'divide', 'round',
931947
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
932948
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
933949
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',

tests/test_torch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,12 @@ def test_argsort_stable():
152152

153153
t = xp.zeros(50) # should be >16
154154
assert xp.all(xp.argsort(t) == xp.arange(50))
155+
156+
157+
def test_round():
158+
"""Verify the out= argument of xp.round with complex inputs."""
159+
x = torch.as_tensor([1.23456786]*3) + 3.456789j
160+
o = torch.empty(3, dtype=torch.complex64)
161+
r = xp.round(x, decimals=1, out=o)
162+
assert xp.all(r == o)
163+
assert r is o

torch-xfails.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ array_api_tests/test_statistical_functions.py::test_var
130130

131131

132132
# These functions do not yet support complex numbers
133-
array_api_tests/test_operators_and_elementwise_functions.py::test_round
134133
array_api_tests/test_set_functions.py::test_unique_counts
135134
array_api_tests/test_set_functions.py::test_unique_values
136135

0 commit comments

Comments
 (0)