diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 2308ba7a3..9a972d52c 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -76,12 +76,31 @@ custom_code_at_the_beginning: | return dipu_add__tensor(self, other, -alpha); +- schema: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + custom_code_at_the_beginning: | + at::Tensor out = UnaryOpInferrer().infer_out(self); + interface: diopiSubScalar(ctx, out, self, other, alpha) + - schema: "sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" - dummy_call_diopi: True + ins: [selfTmp] custom_code_at_the_beginning: | at::native::sub_check(self, other); - auto out = BinaryOpInferrer().infer_out(self, other); - return dipu_add_out(self, other, -alpha, out); + + if (is_scalar_on_cpu(other)) { + at::native::alpha_check(self.scalar_type(), alpha); + return dipu_sub_scalar(self, other.item(), alpha); + } + + at::Tensor selfTmp = self; + if (is_scalar_on_cpu(selfTmp)) { + selfTmp = nodispatch::empty({1}, self.options().device(other.device())); + dipu_fill__scalar(selfTmp, self.item()); + } + + at::native::alpha_check(selfTmp.scalar_type(), alpha); + at::Tensor out = BinaryOpInferrer().infer_out(selfTmp, other); + + interface: diopiSub(ctx, out, selfTmp, other, alpha) - schema: "div.Scalar(Tensor self, Scalar other) -> Tensor" custom_code_at_the_beginning: | @@ -1341,9 +1360,10 @@ - schema: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor custom_code_at_the_beginning: | + if (is_scalar_on_cpu(self)) { + return dipu_sub_scalar(other, self.item(), alpha); + } auto out = nodispatch::empty_like(self); - // NOLINTNEXTLINE(readability-suspicious-call-argument) - return dipu_sub_out(other, self, alpha, out); interface: diopiSub(ctx, out, other, self, alpha) - schema: "unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor out, Tensor indices, Tensor counts)" diff --git a/dipu/tests/python/unittests/test_rsub.py b/dipu/tests/python/unittests/test_rsub.py index cbe46ff6a..e734739e1 100644 --- a/dipu/tests/python/unittests/test_rsub.py +++ b/dipu/tests/python/unittests/test_rsub.py @@ -23,8 +23,15 @@ def test_rsub(self): self._test_rsub(torch.ones(4, 5) * 1.1, torch.ones(4, 5) * 5, alpha=4) def test_rsub_scalar(self): - self._test_rsub_scalar(torch.ones(4, 5), 10) - self._test_rsub_scalar(torch.ones(4, 5), 10, 2.5) + # from torch: + # For integral input tensors, argument alpha must not be a floating point number + # Boolean alpha only supported for Boolean results + self._test_rsub_scalar(torch.ones(4, 5), 10, alpha=1) + self.assertRaisesRegex( + RuntimeError, + r"For integral input tensors, argument alpha must not be a floating point number\.", + lambda: self._test_rsub_scalar(torch.ones(4, 5), 10, 2.5), + ) if __name__ == "__main__":