Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,28 @@
custom_code_at_the_beginning: |
return dipu_add__tensor(self, other, -alpha);

- schema: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
custom_code_at_the_beginning: |
at::Tensor out = UnaryOpInferrer().infer_out(self);
interface: diopiSubScalar(ctx, out, self, other, alpha)

- schema: "sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
dummy_call_diopi: True
ins: [selfTmp]
custom_code_at_the_beginning: |
at::native::sub_check(self, other);
auto out = BinaryOpInferrer().infer_out(self, other);
return dipu_add_out(self, other, -alpha, out);
if (is_scalar_on_cpu(other)) {
return dipu_sub_scalar(self, other.item(), alpha);
}

at::Tensor selfTmp = self;
if (is_scalar_on_cpu(selfTmp)) {
selfTmp = nodispatch::empty({1}, self.options().device(other.device()));
dipu_fill__scalar(selfTmp, self.item());
}

at::native::sub_check(selfTmp, other);
at::Tensor out = BinaryOpInferrer().infer_out(selfTmp, other);

interface: diopiSub(ctx, out, selfTmp, other, alpha)

- schema: "div.Scalar(Tensor self, Scalar other) -> Tensor"
custom_code_at_the_beginning: |
Expand Down