|
52 | 52 | if TYPE_CHECKING: |
53 | 53 | from vllm.v1.core.sched.output import SchedulerOutput |
54 | 54 |
|
| 55 | +MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 |
| 56 | + |
55 | 57 |
|
56 | 58 | class AscendMLABackend(AttentionBackend): |
57 | 59 |
|
@@ -808,16 +810,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase): |
808 | 810 | # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) |
809 | 811 | # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) |
810 | 812 |
|
811 | | - # Currently mlapo only supports W8A8 quantization in MLA scenario |
812 | | - # TODO(whx): modify this limitation when mlapo supports floating point |
813 | | - if self.fused_qkv_a_proj is None or not isinstance( |
814 | | - getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', |
815 | | - None), AscendW8A8LinearMethod): |
816 | | - self.enable_mlapo = False |
817 | | - logger.warning_once( |
818 | | - "Currently mlapo only supports W8A8 quantization in MLA scenario." |
819 | | - "Some layers in your model are not quantized with W8A8," |
820 | | - "thus mlapo is disabled for these layers.") |
| 813 | + if self.enable_mlapo: |
| 814 | + # Currently mlapo only supports W8A8 quantization in MLA scenario |
| 815 | + # TODO(whx): modify this limitation when mlapo supports floating point |
| 816 | + if self.fused_qkv_a_proj is None or not isinstance( |
| 817 | + getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', |
| 818 | + None), AscendW8A8LinearMethod): |
| 819 | + self.enable_mlapo = False |
| 820 | + logger.warning_once( |
| 821 | + "Currently mlapo only supports W8A8 quantization in MLA scenario." |
| 822 | + "Some layers in your model are not quantized with W8A8," |
| 823 | + "thus mlapo is disabled for these layers.") |
821 | 824 | if self.enable_mlapo: |
822 | 825 | self._process_weights_for_fused_mlapo(act_dtype) |
823 | 826 |
|
@@ -1282,12 +1285,13 @@ def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata): |
1282 | 1285 | def _mla_preprocess(self, layer_name, hidden_states, kv_cache, |
1283 | 1286 | attn_metadata, need_gather_q_kv): |
1284 | 1287 | # MLA Preprocess: |
1285 | | - # 1. Perform q_a_proj and q_a_layernorm to obtain q_c |
1286 | | - # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split |
1287 | | - # 3. If need_gather_q_kv, perform all_gather. |
1288 | | - # 4. Preprocess decode tokens, write kv cache and get: |
| 1288 | + # 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split |
| 1289 | + # or |
| 1290 | + # Perform kv_a_proj_with_mqa to obtain kv_no_split |
| 1291 | + # 2. If need_gather_q_kv, perform all_gather. |
| 1292 | + # 3. Preprocess decode tokens, write kv cache and get: |
1289 | 1293 | # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope |
1290 | | - # 5. Preprocess prefill tokens, write kv cache and get: |
| 1294 | + # 4. Preprocess prefill tokens, write kv cache and get: |
1291 | 1295 | # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value |
1292 | 1296 | has_decode = attn_metadata.num_decodes > 0 |
1293 | 1297 | has_prefill = attn_metadata.num_prefills > 0 |
|
0 commit comments