diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index 9ad336c19c..71d33f0c82 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -31,6 +31,7 @@ def _deepseek_torchair_test_fixture( additional_config: Dict, *, tensor_parallel_size=2, + use_v1_schduler=False, ): example_prompts = [ "Hello, my name is", @@ -38,14 +39,14 @@ def _deepseek_torchair_test_fixture( "The capital of France is", "The future of AI is", ] - - # torchair is only work without chunked-prefill now - kwargs = { - "ascend_scheduler_config": { - "enabled": True, - }, - "refresh": True, - } + kwargs = {} + if not use_v1_schduler: + kwargs = { + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } additional_config.update(**kwargs) with VllmRunner( @@ -95,6 +96,15 @@ def test_e2e_deepseekv3_with_torchair_ms_mla(): _deepseek_torchair_test_fixture(additional_config) +def test_e2e_deepseekv3_with_torchair_v1scheduler(): + additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + } + _deepseek_torchair_test_fixture(additional_config, use_v1_schduler=True) + + def _pangu_torchair_test_fixture( additional_config: Dict, *, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 4e247562cf..7771632300 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1079,11 +1079,10 @@ def forward( ] num_actual_toks = attn_metadata.num_actual_tokens if k_pe is None and not self.running_in_graph: - if not self.torchair_graph_enabled: - kv_c, k_pe = self.kv_a_proj_with_mqa( - hidden_states_or_kv_c_normed)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + kv_c, k_pe = self.kv_a_proj_with_mqa( + hidden_states_or_kv_c_normed)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else: kv_c_normed = hidden_states_or_kv_c_normed assert attn_metadata.num_decodes is not None and \ @@ -1102,12 +1101,13 @@ def forward( if not self.running_in_graph: hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] - if not self.torchair_graph_enabled: - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] - k_pe = k_pe[:num_actual_toks, ...] - k_pe = k_pe.unsqueeze(1) - decode_k_pe = k_pe[:num_decode_tokens] - prefill_k_pe = k_pe[num_decode_tokens:] + decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + prefill_hs = hidden_states_or_kv_c_normed[num_decode_tokens:] + # if not self.torchair_graph_enabled: + k_pe = k_pe[:num_actual_toks, ...] + k_pe = k_pe.unsqueeze(1) + decode_k_pe = k_pe[:num_decode_tokens] + prefill_k_pe = k_pe[num_decode_tokens:] else: decode_hs_or_q_c = hidden_states_or_q_c if has_decode: @@ -1167,11 +1167,11 @@ def forward( prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + prefill_hs, cos, sin, kv_cache, + attn_metadata.slot_mapping[num_decode_tokens:]) kv_c_normed = prefill_k_nope[:num_actual_toks, ...] - prefill_k_c_normed = prefill_k_nope[num_decode_tokens:] + prefill_k_c_normed = prefill_k_nope prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads, -1) prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)