Skip to content

Commit 43b6743

Browse files
MengqingCaowhx-sjtu
andcommitted
[BugFix] Fix a bug of running chunked-prefill with torchair. (#1378)
This PR fixes a bug of running chunked-prefill with torchair. Co-authored-by: whx-sjtu <[email protected]> Signed-off-by: whx-sjtu <[email protected]> Signed-off-by: MengqingCao <[email protected]>
1 parent f9dfde0 commit 43b6743

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,11 +1051,10 @@ def forward(
10511051
]
10521052
num_actual_toks = attn_metadata.num_actual_tokens
10531053
if k_pe is None and not self.running_in_graph:
1054-
if not self.torchair_graph_enabled:
1055-
kv_c, k_pe = self.kv_a_proj_with_mqa(
1056-
hidden_states_or_kv_c_normed)[0].split(
1057-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1058-
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1054+
kv_c, k_pe = self.kv_a_proj_with_mqa(
1055+
hidden_states_or_kv_c_normed)[0].split(
1056+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1057+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
10591058
else:
10601059
kv_c_normed = hidden_states_or_kv_c_normed
10611060
assert attn_metadata.num_decodes is not None and \
@@ -1074,12 +1073,13 @@ def forward(
10741073
if not self.running_in_graph:
10751074
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
10761075
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
1077-
if not self.torchair_graph_enabled:
1078-
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1079-
k_pe = k_pe[:num_actual_toks, ...]
1080-
k_pe = k_pe.unsqueeze(1)
1081-
decode_k_pe = k_pe[:num_decode_tokens]
1082-
prefill_k_pe = k_pe[num_decode_tokens:]
1076+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1077+
prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:]
1078+
# if not self.torchair_graph_enabled:
1079+
k_pe = k_pe[:num_actual_toks, ...]
1080+
k_pe = k_pe.unsqueeze(1)
1081+
decode_k_pe = k_pe[:num_decode_tokens]
1082+
prefill_k_pe = k_pe[num_decode_tokens:]
10831083
else:
10841084
decode_hs_or_q_c = hidden_states_or_q_c
10851085
if has_decode:
@@ -1153,11 +1153,11 @@ def forward(
11531153

11541154
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
11551155
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
1156-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1157-
attn_metadata.slot_mapping)
1156+
prefill_hs, cos, sin, kv_cache,
1157+
attn_metadata.slot_mapping[num_decode_tokens:])
11581158

11591159
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
1160-
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
1160+
prefill_k_c_normed = prefill_k_nope
11611161
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
11621162
-1)
11631163
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)

0 commit comments

Comments
 (0)