diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index f50fe56e77..f214b77d03 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1051,11 +1051,10 @@ def forward( ] num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: - if not self.torchair_graph_enabled: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else: kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ @@ -1074,12 +1073,13 @@ def forward( if not self.running_in_graph: hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - if not self.torchair_graph_enabled: - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] + # if not self.torchair_graph_enabled: + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c if has_decode: @@ -1153,11 +1153,11 @@ def forward( prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + prefill_hs, cos, sin, kv_cache, + attn_metadata.slot_mapping[num_decode_tokens:]) kv_c_normed = prefill_k_nope[:num_actual_toks, ...] - prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] + prefill_k_c_normed = prefill_k_nope prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)