Skip to content

Commit 3d84367

Browse files
authored
[ascend]Zq/fix baddbmm (DeepLink-org#912)
* fix hardtanh * fix baddbmm
1 parent cd36478 commit 3d84367

File tree

2 files changed

+20
-66
lines changed

2 files changed

+20
-66
lines changed

impl/ascend/device_configs.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,6 @@
1818
),
1919
),
2020

21-
'baddbmm': dict(
22-
name=['baddbmm'],
23-
atol=1e-2,
24-
rtol=1e-2,
25-
# temp for 910B
26-
tensor_para=dict(
27-
args=[
28-
{
29-
"ins": ['input'],
30-
"dtype": [Skip(np.float16),Skip(np.float32),Skip(np.float64),],
31-
},
32-
]
33-
),
34-
),
35-
36-
'baddbmm_without_inplace': dict(
37-
name=['baddbmm'],
38-
atol=1e-2,
39-
rtol=1e-2,
40-
# temp for 910B
41-
tensor_para=dict(
42-
args=[
43-
{
44-
"ins": ['input'],
45-
"dtype": [Skip(np.float16),Skip(np.float32),Skip(np.float64),],
46-
},
47-
]
48-
),
49-
),
50-
5121
'conv_2d': dict(
5222
name=['conv2d'],
5323
atol=1e-1,

impl/ascend/functions/baddbmm.cpp

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,79 +13,63 @@ diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio
1313
diopiConstTensorHandle_t batch2, double beta, double alpha) {
1414
diopiDtype_t outDtype;
1515
diopiGetTensorDtype(out, &outDtype);
16-
diopiDtype_t execType;
1716

18-
// adjust the input's and output's data type
19-
if (outDtype == diopi_dtype_float64) {
20-
execType = diopi_dtype_float32;
21-
} else {
22-
execType = outDtype;
23-
}
24-
25-
AscendTensor inputCopy(input);
26-
AscendTensor outputCopy(out);
27-
AscendTensor batch1Copy(batch1);
28-
AscendTensor batch2Copy(batch2);
29-
castTensor(ctx, outputCopy, execType);
30-
castTensor(ctx, batch1Copy, execType);
31-
castTensor(ctx, inputCopy, execType);
32-
castTensor(ctx, batch2Copy, execType);
17+
AscendTensor inputAt(input);
18+
AscendTensor outputAt(out);
19+
AscendTensor batch1At(batch1);
20+
AscendTensor batch2At(batch2);
3321

3422
// get the size of batch1 * batch2
35-
AscendTensor asBatch1 = AscendTensor(batch1Copy);
36-
AscendTensor asBatch2 = AscendTensor(batch2Copy);
37-
std::vector<int64_t> batch1Shape = asBatch1.shape();
38-
std::vector<int64_t> batch2Shape = asBatch2.shape();
23+
std::vector<int64_t> batch1Shape = batch1At.shape();
24+
std::vector<int64_t> batch2Shape = batch2At.shape();
3925
std::vector<int64_t> vectorSizeBatchMatMulTensor = {batch1Shape[0], batch1Shape[1], batch2Shape[2]};
4026

4127
// init a tensor according to the size of batch1 * batch2 ;
4228
diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize(vectorSizeBatchMatMulTensor);
43-
AscendTensor asBatchMatMulTensor;
44-
makeTensor(ctx, asBatchMatMulTensor, &diopiSizeBatchMatMulTensor, execType, diopiDevice_t::diopi_device);
29+
AscendTensor batchMatMulTensorAt;
30+
makeTensor(ctx, batchMatMulTensorAt, &diopiSizeBatchMatMulTensor, outDtype, diopiDevice_t::diopi_device);
4531

4632
// does batch1/batch2 need to transpose?
4733
bool isSelfT = false;
4834
bool isMat2T = false;
4935

5036
// do batch1 times batch2 -> BatchMatMulTensor
5137
AclOpRunner<2, 1>("BatchMatMul", ctx)
52-
.addInput(batch1Copy)
53-
.addInput(batch2Copy)
54-
.addOutput(asBatchMatMulTensor)
38+
.addInput(batch1At)
39+
.addInput(batch2At)
40+
.addOutput(batchMatMulTensorAt)
5541
.setAttr("adj_x1", isSelfT)
5642
.setAttr("adj_x2", isMat2T)
5743
.run();
5844

5945
// init memory based on the size of alphaMulTensor and betaMulTensor
6046
AscendTensor alphaMulTensor;
6147
AscendTensor betaMulTensor;
62-
makeTensorLike(ctx, alphaMulTensor, asBatchMatMulTensor, execType);
63-
makeTensorLike(ctx, betaMulTensor, inputCopy, execType);
48+
makeTensorLike(ctx, alphaMulTensor, batchMatMulTensorAt, outDtype);
49+
makeTensorLike(ctx, betaMulTensor, inputAt, outDtype);
6450

6551
diopiScalar_t alphaScalar;
66-
alphaScalar.stype = execType;
52+
alphaScalar.stype = outDtype;
6753
alphaScalar.fval = alpha;
6854
diopiScalar_t betaScalar;
69-
betaScalar.stype = execType;
55+
betaScalar.stype = outDtype;
7056
betaScalar.fval = beta;
7157

7258
// transform ascendTensor to diopiTensorHandle_t
7359
diopiTensorHandle_t diopiAlphaMulTensor = const_cast<diopiTensorHandle_t>(alphaMulTensor.tensorHandle());
7460
diopiTensorHandle_t diopiBateMulTensor = const_cast<diopiTensorHandle_t>(betaMulTensor.tensorHandle());
75-
diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast<diopiTensorHandle_t>(asBatchMatMulTensor.tensorHandle());
76-
diopiTensorHandle_t diopiInputCopy = const_cast<diopiTensorHandle_t>(inputCopy.tensorHandle());
61+
diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast<diopiTensorHandle_t>(batchMatMulTensorAt.tensorHandle());
62+
diopiTensorHandle_t diopiInput = const_cast<diopiTensorHandle_t>(inputAt.tensorHandle());
7763

7864
// alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor
7965
diopiMulScalar(ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar);
80-
diopiMulScalar(ctx, diopiBateMulTensor, diopiInputCopy, &betaScalar);
66+
diopiMulScalar(ctx, diopiBateMulTensor, diopiInput, &betaScalar);
8167

8268
diopiScalar_t other;
8369
other.fval = 1;
8470
other.stype = outDtype;
85-
diopiTensorHandle_t diopiOutputCopy = const_cast<diopiTensorHandle_t>(outputCopy.tensorHandle());
86-
diopiAdd(ctx, diopiOutputCopy, diopiAlphaMulTensor, diopiBateMulTensor, &other);
87-
diopiCastDtype(ctx, out, diopiOutputCopy);
88-
71+
diopiTensorHandle_t diopiOutput = const_cast<diopiTensorHandle_t>(outputAt.tensorHandle());
72+
diopiAdd(ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &other);
8973
return diopiSuccess;
9074
}
9175

0 commit comments

Comments
 (0)