Skip to content

Commit 2b277cd

Browse files
committed
optimize
1 parent 9d95d97 commit 2b277cd

File tree

2 files changed

+9
-24
lines changed

2 files changed

+9
-24
lines changed

csrc/extensions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,8 @@ void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k
439439
);
440440
}
441441

442-
void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin) {
443-
callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin);
442+
void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {
443+
callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin, dim);
444444
}
445445

446446
void extMatmulAllReduce(at::Tensor& out, const at::Tensor& x1,

deeplink_ext/patch_lightllm.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,9 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
7171
ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
7272
return out
7373

74-
def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len):
75-
batch, head, dim = b_start_loc.shape[0], q.shape[1], q.shape[2]
76-
numKeyValueHeads = k.shape[1]
77-
assert k.shape[1] == v.shape[1]
74+
def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
75+
batch = b_start_loc.shape[0]
7876
scale = 1 / math.sqrt(dim)
79-
8077
mask_key_str = str(batch) + ":" + str(max_input_len)
8178
if mask_key_str not in mask_cache:
8279
mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
@@ -86,11 +83,7 @@ def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
8683
print(f"cache mask in context attention, batch:seqLen={mask_key_str}")
8784

8885
mask = mask_cache[mask_key_str]
89-
ext.prompt_flash_attention(
90-
out.view(batch, max_input_len, head*dim),
91-
q.view(batch, max_input_len, head*dim),
92-
k.view(batch, max_input_len, numKeyValueHeads*dim),
93-
v.view(batch, max_input_len, numKeyValueHeads*dim),
86+
ext.prompt_flash_attention(out, q, k, v,
9487
None, mask, b_seq_len, head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
9588
return out
9689

@@ -101,18 +94,10 @@ def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
10194

10295

10396
def patch_paged_token_attention_inference():
104-
def paged_token_attention(q, k_cache, v_cache, out, kv_head_num, b_seq_len, block_table:torch.Tensor, block_size):
105-
# numKeyValueHeads = k_cache.shape[1]
106-
# assert k_cache.shape[1] == v_cache.shape[1]
107-
batch, head, dim = q.shape
108-
kv_cache_len = k_cache.shape[0]
109-
ext.paged_attention(out.view(batch, 1, head*dim),
110-
q.view(batch, 1, head*dim),
111-
k_cache.view(kv_cache_len, 1, kv_head_num*dim),
112-
v_cache.view(kv_cache_len, 1, kv_head_num*dim),
113-
None, None,
114-
b_seq_len, block_table, head, kv_head_num,
115-
1.0 / math.sqrt(dim), "BSH", block_size, 0,
97+
def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
98+
ext.paged_attention(out, q, k_cache, v_cache, None, None,
99+
b_seq_len, block_table, q_head_num, kv_head_num,
100+
1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
116101
None, None, None, None, None, None, None, None
117102
)
118103
return out

0 commit comments

Comments
 (0)