|
12 | 12 | AttentionMetadata, |
13 | 13 | MLAAttentionImpl) |
14 | 14 | from vllm.config import VllmConfig, get_current_vllm_config |
15 | | - |
16 | | -# isort: off |
17 | 15 | from vllm.distributed import (get_dcp_group, |
18 | 16 | get_decode_context_model_parallel_rank, |
19 | 17 | get_decode_context_model_parallel_world_size, |
|
35 | 33 | split_decodes_and_prefills, |
36 | 34 | trans_rope_weight, transdata, |
37 | 35 | wait_for_kv_layer_from_connector) |
38 | | -from vllm_ascend.compilation.acl_graph import get_graph_params |
| 36 | +from vllm_ascend.compilation.acl_graph import (get_graph_params, |
| 37 | + update_graph_params_workspaces) |
39 | 38 | from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch |
40 | 39 | from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod |
41 | 40 | from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, |
42 | | - is_enable_nz, prefill_context_parallel_enable) |
| 41 | + is_enable_nz, prefill_context_parallel_enable, |
| 42 | + weak_ref_tensors) |
43 | 43 | from vllm_ascend.worker.npu_input_batch import InputBatch |
44 | 44 |
|
| 45 | +# isort: off |
45 | 46 | if prefill_context_parallel_enable(): |
46 | 47 | from vllm.distributed import (get_pcp_group, |
47 | 48 | get_prefill_context_model_parallel_rank, |
48 | 49 | get_prefill_context_model_parallel_world_size |
49 | 50 | ) |
50 | | -# isort:on |
| 51 | +# isort: on |
51 | 52 | if TYPE_CHECKING: |
52 | 53 | from vllm.v1.core.sched.output import SchedulerOutput |
53 | 54 |
|
@@ -743,7 +744,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): |
743 | 744 | getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', |
744 | 745 | None), AscendW8A8LinearMethod): |
745 | 746 | self.enable_mlapo = False |
746 | | - logger.warning( |
| 747 | + logger.warning_once( |
747 | 748 | "Currently mlapo only supports W8A8 quantization in MLA scenario." |
748 | 749 | "Some layers in your model are not quantized with W8A8," |
749 | 750 | "thus mlapo is disabled for these layers.") |
@@ -1115,19 +1116,22 @@ def _forward_decode( |
1115 | 1116 | if workspace is None: |
1116 | 1117 | workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( |
1117 | 1118 | q_nope, k_nope, k_nope, **common_kwargs) |
1118 | | - graph_params.workspaces[num_tokens] = workspace |
| 1119 | + update_graph_params_workspaces(num_tokens, |
| 1120 | + weak_ref_tensors(workspace)) |
1119 | 1121 |
|
1120 | 1122 | attn_output = torch.empty_like(q_nope) |
1121 | 1123 | softmax_lse = torch.empty(num_tokens, |
1122 | 1124 | dtype=q_nope.dtype, |
1123 | 1125 | device=q_nope.device) |
1124 | 1126 |
|
1125 | 1127 | graph_params.attn_params[num_tokens].append( |
1126 | | - (q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads, |
1127 | | - input_layout, spec_attn_mask, sparse_mode, self.scale, |
1128 | | - decode_meta.block_table, block_size, |
1129 | | - decode_meta.seq_lens_list, actual_seq_lengths, workspace, |
1130 | | - attn_output, softmax_lse)) |
| 1128 | + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), |
| 1129 | + weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), |
| 1130 | + self.num_heads, self.num_kv_heads, input_layout, |
| 1131 | + weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None |
| 1132 | + else None, sparse_mode, self.scale, decode_meta.block_table, |
| 1133 | + block_size, decode_meta.seq_lens_list, actual_seq_lengths, |
| 1134 | + weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) |
1131 | 1135 |
|
1132 | 1136 | torch.npu.graph_task_group_begin(stream) |
1133 | 1137 | torch_npu.npu_fused_infer_attention_score.out( |
|
0 commit comments