Skip to content

Commit f6149f3

Browse files
authored
[Model][3/N] Refactor sfa into mla and remove deepseek_v3_2.py (#3769)
This is the follow-up PR to PR #3189, which continues to refactor sfa into mla and finally remove deepseek_v3_2.py. This is the last PR of deepseek modeling refactoring. After this, all deepseek-related model codes are removed from vllm_ascend. FurtherMore, after this PR deepseek v3.2 can run chunk-prefill with correct accuracy. - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: whx-sjtu <[email protected]>
1 parent eff3e5f commit f6149f3

File tree

10 files changed

+751
-1935
lines changed

10 files changed

+751
-1935
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
if TYPE_CHECKING:
5353
from vllm.v1.core.sched.output import SchedulerOutput
5454

55+
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
56+
5557

5658
class AscendMLABackend(AttentionBackend):
5759

@@ -808,16 +810,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
808810
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
809811
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
810812

811-
# Currently mlapo only supports W8A8 quantization in MLA scenario
812-
# TODO(whx): modify this limitation when mlapo supports floating point
813-
if self.fused_qkv_a_proj is None or not isinstance(
814-
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
815-
None), AscendW8A8LinearMethod):
816-
self.enable_mlapo = False
817-
logger.warning_once(
818-
"Currently mlapo only supports W8A8 quantization in MLA scenario."
819-
"Some layers in your model are not quantized with W8A8,"
820-
"thus mlapo is disabled for these layers.")
813+
if self.enable_mlapo:
814+
# Currently mlapo only supports W8A8 quantization in MLA scenario
815+
# TODO(whx): modify this limitation when mlapo supports floating point
816+
if self.fused_qkv_a_proj is None or not isinstance(
817+
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
818+
None), AscendW8A8LinearMethod):
819+
self.enable_mlapo = False
820+
logger.warning_once(
821+
"Currently mlapo only supports W8A8 quantization in MLA scenario."
822+
"Some layers in your model are not quantized with W8A8,"
823+
"thus mlapo is disabled for these layers.")
821824
if self.enable_mlapo:
822825
self._process_weights_for_fused_mlapo(act_dtype)
823826

@@ -1282,12 +1285,13 @@ def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
12821285
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
12831286
attn_metadata, need_gather_q_kv):
12841287
# MLA Preprocess:
1285-
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
1286-
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
1287-
# 3. If need_gather_q_kv, perform all_gather.
1288-
# 4. Preprocess decode tokens, write kv cache and get:
1288+
# 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
1289+
# or
1290+
# Perform kv_a_proj_with_mqa to obtain kv_no_split
1291+
# 2. If need_gather_q_kv, perform all_gather.
1292+
# 3. Preprocess decode tokens, write kv cache and get:
12891293
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
1290-
# 5. Preprocess prefill tokens, write kv cache and get:
1294+
# 4. Preprocess prefill tokens, write kv cache and get:
12911295
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
12921296
has_decode = attn_metadata.num_decodes > 0
12931297
has_prefill = attn_metadata.num_prefills > 0

0 commit comments

Comments
 (0)