@@ -916,6 +916,22 @@ def sign(x: Array, /) -> Array:
916916 return out
917917
918918
919+ def round (x : Array , / , ** kwargs ) -> Array :
920+ # torch.round fails for complex inputs
921+ # https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845
922+ if x .dtype .is_complex :
923+ out = kwargs .pop ('out' , None )
924+ res_r = torch .round (x .real , ** kwargs )
925+ res_i = torch .round (x .imag , ** kwargs )
926+ res = res_r + 1j * res_i
927+ if out is not None :
928+ out .copy_ (res )
929+ return out
930+ return res
931+ else :
932+ return torch .round (x , ** kwargs )
933+
934+
919935def meshgrid (* arrays : Array , indexing : Literal ['xy' , 'ij' ] = 'xy' ) -> tuple [Array , ...]:
920936 # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
921937 # will be required to pass the indexing argument."
@@ -927,7 +943,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
927943 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
928944 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
929945 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
930- 'diff' , 'divide' ,
946+ 'diff' , 'divide' , 'round' ,
931947 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
932948 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
933949 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
0 commit comments