@@ -912,6 +912,22 @@ def sign(x: Array, /) -> Array:
912912 return out
913913
914914
915+ def round (x : Array , / , ** kwargs ) -> Array :
916+ # torch.round fails for complex inputs
917+ # https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845
918+ if x .dtype .is_complex :
919+ out = kwargs .pop ('out' , None )
920+ if out is not None :
921+ o_r , o_i = out .real , out .imag
922+ else :
923+ o_r , o_i = None , None
924+ res_r = torch .round (x .real , ** kwargs , out = o_r )
925+ res_i = torch .round (x .imag , ** kwargs , out = o_i )
926+ return res_r + 1j * res_i
927+ else :
928+ return torch .round (x , ** kwargs )
929+
930+
915931def meshgrid (* arrays : Array , indexing : Literal ['xy' , 'ij' ] = 'xy' ) -> tuple [Array , ...]:
916932 # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
917933 # will be required to pass the indexing argument."
@@ -923,7 +939,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra
923939 'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
924940 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
925941 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
926- 'diff' , 'divide' ,
942+ 'diff' , 'divide' , 'round' ,
927943 'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
928944 'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
929945 'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
0 commit comments