Skip to content

Commit d63541d

Browse files
committed
BUG: torch: work around torch.round not supporting complex inputs
1 parent a88067a commit d63541d

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,22 @@ def sign(x: Array, /) -> Array:
912912
return out
913913

914914

915+
def round(x: Array, /, **kwargs) -> Array:
916+
# torch.round fails for complex inputs
917+
# https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845
918+
if x.dtype.is_complex:
919+
out = kwargs.pop('out', None)
920+
if out is not None:
921+
o_r, o_i = out.real, out.imag
922+
else:
923+
o_r, o_i = None, None
924+
res_r = torch.round(x.real, **kwargs, out=o_r)
925+
res_i = torch.round(x.imag, **kwargs, out=o_i)
926+
return res_r + 1j*res_i
927+
else:
928+
return torch.round(x, **kwargs)
929+
930+
915931
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]:
916932
# torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
917933
# will be required to pass the indexing argument."
@@ -923,7 +939,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
923939
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
924940
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
925941
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
926-
'diff', 'divide',
942+
'diff', 'divide', 'round',
927943
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
928944
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
929945
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',

tests/test_torch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,11 @@ def test_argsort_stable():
152152

153153
t = xp.zeros(50) # should be >16
154154
assert xp.all(xp.argsort(t) == xp.arange(50))
155+
156+
157+
def test_round():
158+
"""Verify the out= argument of xp.round with complex inputs."""
159+
x = torch.as_tensor([1.23456786]*3) + 3.456789j
160+
o = torch.empty(3, dtype=torch.complex64)
161+
r = xp.round(x, decimals=1, out=o)
162+
assert xp.all(r == o)

torch-xfails.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ array_api_tests/test_statistical_functions.py::test_var
130130

131131

132132
# These functions do not yet support complex numbers
133-
array_api_tests/test_operators_and_elementwise_functions.py::test_round
134133
array_api_tests/test_set_functions.py::test_unique_counts
135134
array_api_tests/test_set_functions.py::test_unique_values
136135

0 commit comments

Comments
 (0)