@@ -24,11 +24,11 @@ diopiError_t diopiTokenSoftmaxReduceVInference(diopiContextHandle_t ctx, diopiTe
24
24
for (int i = 0 ; i < batch; ++i) {
25
25
int curSeqLen = bSeqLenAt[i].item <int >();
26
26
int curSeqStartLoc = bStartLocAt[i].item <int >();
27
- at::Tensor p = at::index (logicsAt, {at::Tensor (), acl_op::arange (curSeqStartLoc, curSeqStartLoc + curSeqLen, at::kLong , layout, device)})
27
+ at::Tensor p = at::index (logicsAt, {at::Tensor (), acl_op::arange (curSeqStartLoc, curSeqStartLoc + curSeqLen, at::kInt , layout, device)})
28
28
.softmax (-1 )
29
29
.reshape ({head, 1 , 1 , curSeqLen})
30
30
.transpose (0 , 1 );
31
- at::Tensor vLoc = bLocAt[i].index_select (0 , acl_op::arange (maxInputLen - curSeqLen, maxInputLen, at::kLong , layout, device));
31
+ at::Tensor vLoc = bLocAt[i].index_select (0 , acl_op::arange (maxInputLen - curSeqLen, maxInputLen, at::kInt , layout, device));
32
32
at::Tensor v = at::index (vAt, {vLoc}).view ({1 , curSeqLen, head, dim}).transpose (1 , 2 );
33
33
at::Tensor values = at::matmul (p.toType (at::kFloat ), v.toType (at::kFloat )).view ({head, dim}).toType (dtype);
34
34
at::index_put_ (outAt, {torch::scalar_to_tensor (i)}, values);
0 commit comments