@@ -13,79 +13,63 @@ diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, dio
13
13
diopiConstTensorHandle_t batch2, double beta, double alpha) {
14
14
diopiDtype_t outDtype;
15
15
diopiGetTensorDtype (out, &outDtype);
16
- diopiDtype_t execType;
17
16
18
- // adjust the input's and output's data type
19
- if (outDtype == diopi_dtype_float64) {
20
- execType = diopi_dtype_float32;
21
- } else {
22
- execType = outDtype;
23
- }
24
-
25
- AscendTensor inputCopy (input);
26
- AscendTensor outputCopy (out);
27
- AscendTensor batch1Copy (batch1);
28
- AscendTensor batch2Copy (batch2);
29
- castTensor (ctx, outputCopy, execType);
30
- castTensor (ctx, batch1Copy, execType);
31
- castTensor (ctx, inputCopy, execType);
32
- castTensor (ctx, batch2Copy, execType);
17
+ AscendTensor inputAt (input);
18
+ AscendTensor outputAt (out);
19
+ AscendTensor batch1At (batch1);
20
+ AscendTensor batch2At (batch2);
33
21
34
22
// get the size of batch1 * batch2
35
- AscendTensor asBatch1 = AscendTensor (batch1Copy);
36
- AscendTensor asBatch2 = AscendTensor (batch2Copy);
37
- std::vector<int64_t > batch1Shape = asBatch1.shape ();
38
- std::vector<int64_t > batch2Shape = asBatch2.shape ();
23
+ std::vector<int64_t > batch1Shape = batch1At.shape ();
24
+ std::vector<int64_t > batch2Shape = batch2At.shape ();
39
25
std::vector<int64_t > vectorSizeBatchMatMulTensor = {batch1Shape[0 ], batch1Shape[1 ], batch2Shape[2 ]};
40
26
41
27
// init a tensor according to the size of batch1 * batch2 ;
42
28
diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize (vectorSizeBatchMatMulTensor);
43
- AscendTensor asBatchMatMulTensor ;
44
- makeTensor (ctx, asBatchMatMulTensor , &diopiSizeBatchMatMulTensor, execType , diopiDevice_t::diopi_device);
29
+ AscendTensor batchMatMulTensorAt ;
30
+ makeTensor (ctx, batchMatMulTensorAt , &diopiSizeBatchMatMulTensor, outDtype , diopiDevice_t::diopi_device);
45
31
46
32
// does batch1/batch2 need to transpose?
47
33
bool isSelfT = false ;
48
34
bool isMat2T = false ;
49
35
50
36
// do batch1 times batch2 -> BatchMatMulTensor
51
37
AclOpRunner<2 , 1 >(" BatchMatMul" , ctx)
52
- .addInput (batch1Copy )
53
- .addInput (batch2Copy )
54
- .addOutput (asBatchMatMulTensor )
38
+ .addInput (batch1At )
39
+ .addInput (batch2At )
40
+ .addOutput (batchMatMulTensorAt )
55
41
.setAttr (" adj_x1" , isSelfT)
56
42
.setAttr (" adj_x2" , isMat2T)
57
43
.run ();
58
44
59
45
// init memory based on the size of alphaMulTensor and betaMulTensor
60
46
AscendTensor alphaMulTensor;
61
47
AscendTensor betaMulTensor;
62
- makeTensorLike (ctx, alphaMulTensor, asBatchMatMulTensor, execType );
63
- makeTensorLike (ctx, betaMulTensor, inputCopy, execType );
48
+ makeTensorLike (ctx, alphaMulTensor, batchMatMulTensorAt, outDtype );
49
+ makeTensorLike (ctx, betaMulTensor, inputAt, outDtype );
64
50
65
51
diopiScalar_t alphaScalar;
66
- alphaScalar.stype = execType ;
52
+ alphaScalar.stype = outDtype ;
67
53
alphaScalar.fval = alpha;
68
54
diopiScalar_t betaScalar;
69
- betaScalar.stype = execType ;
55
+ betaScalar.stype = outDtype ;
70
56
betaScalar.fval = beta;
71
57
72
58
// transform ascendTensor to diopiTensorHandle_t
73
59
diopiTensorHandle_t diopiAlphaMulTensor = const_cast <diopiTensorHandle_t>(alphaMulTensor.tensorHandle ());
74
60
diopiTensorHandle_t diopiBateMulTensor = const_cast <diopiTensorHandle_t>(betaMulTensor.tensorHandle ());
75
- diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast <diopiTensorHandle_t>(asBatchMatMulTensor .tensorHandle ());
76
- diopiTensorHandle_t diopiInputCopy = const_cast <diopiTensorHandle_t>(inputCopy .tensorHandle ());
61
+ diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast <diopiTensorHandle_t>(batchMatMulTensorAt .tensorHandle ());
62
+ diopiTensorHandle_t diopiInput = const_cast <diopiTensorHandle_t>(inputAt .tensorHandle ());
77
63
78
64
// alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor
79
65
diopiMulScalar (ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar);
80
- diopiMulScalar (ctx, diopiBateMulTensor, diopiInputCopy , &betaScalar);
66
+ diopiMulScalar (ctx, diopiBateMulTensor, diopiInput , &betaScalar);
81
67
82
68
diopiScalar_t other;
83
69
other.fval = 1 ;
84
70
other.stype = outDtype;
85
- diopiTensorHandle_t diopiOutputCopy = const_cast <diopiTensorHandle_t>(outputCopy.tensorHandle ());
86
- diopiAdd (ctx, diopiOutputCopy, diopiAlphaMulTensor, diopiBateMulTensor, &other);
87
- diopiCastDtype (ctx, out, diopiOutputCopy);
88
-
71
+ diopiTensorHandle_t diopiOutput = const_cast <diopiTensorHandle_t>(outputAt.tensorHandle ());
72
+ diopiAdd (ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &other);
89
73
return diopiSuccess;
90
74
}
91
75
0 commit comments