Skip to content

Commit ad2c046

Browse files
authored
zq/fix token_attention (DeepLink-org#873)
fix token_attention
1 parent d14fe8a commit ad2c046

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

impl/ascend_npu/diopi_impl/functions_ext/token_attention_inference.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ diopiError_t diopiTokenAttentionInference(diopiContextHandle_t ctx, diopiTensorH
2525
for (int i = 0; i < batch; ++i) {
2626
int curSeqLen = bSeqLenAt[i].item<int>();
2727
int curSeqStartLoc = bStartLocAt[i].item<int>();
28-
at::Tensor kLoc = at::index_select(bLocAt[i], 0, acl_op::arange(maxInputLen - curSeqLen, maxInputLen, at::kLong, layout, device));
28+
at::Tensor kLoc = at::index_select(bLocAt[i], 0, acl_op::arange(maxInputLen - curSeqLen, maxInputLen, at::kInt, layout, device));
2929
at::Tensor key = at::index(kAt, {kLoc}).view({1, curSeqLen, head, dim}).transpose(1, 2);
30-
at::Tensor outLoc = acl_op::arange(curSeqStartLoc, curSeqStartLoc + curSeqLen, at::kLong, layout, device);
30+
at::Tensor outLoc = acl_op::arange(curSeqStartLoc, curSeqStartLoc + curSeqLen, at::kInt, layout, device);
3131
at::Tensor values =
3232
(at::matmul(at::index(qAt, {torch::scalar_to_tensor(i)}).toType(at::kFloat), key.transpose(2, 3).toType(at::kFloat)) / std::sqrt(dim))
3333
.view({head, curSeqLen})

impl/ascend_npu/diopi_impl/functions_ext/token_softmax_reducev.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ diopiError_t diopiTokenSoftmaxReduceVInference(diopiContextHandle_t ctx, diopiTe
2424
for (int i = 0; i < batch; ++i) {
2525
int curSeqLen = bSeqLenAt[i].item<int>();
2626
int curSeqStartLoc = bStartLocAt[i].item<int>();
27-
at::Tensor p = at::index(logicsAt, {at::Tensor(), acl_op::arange(curSeqStartLoc, curSeqStartLoc + curSeqLen, at::kLong, layout, device)})
27+
at::Tensor p = at::index(logicsAt, {at::Tensor(), acl_op::arange(curSeqStartLoc, curSeqStartLoc + curSeqLen, at::kInt, layout, device)})
2828
.softmax(-1)
2929
.reshape({head, 1, 1, curSeqLen})
3030
.transpose(0, 1);
31-
at::Tensor vLoc = bLocAt[i].index_select(0, acl_op::arange(maxInputLen - curSeqLen, maxInputLen, at::kLong, layout, device));
31+
at::Tensor vLoc = bLocAt[i].index_select(0, acl_op::arange(maxInputLen - curSeqLen, maxInputLen, at::kInt, layout, device));
3232
at::Tensor v = at::index(vAt, {vLoc}).view({1, curSeqLen, head, dim}).transpose(1, 2);
3333
at::Tensor values = at::matmul(p.toType(at::kFloat), v.toType(at::kFloat)).view({head, dim}).toType(dtype);
3434
at::index_put_(outAt, {torch::scalar_to_tensor(i)}, values);

0 commit comments

Comments
 (0)