Skip to content

Commit 8ab8111

Browse files
authored
[Fix] Prevent memory leak in MLA decode graph (#3743)
### What this PR does / why we need it? The cache for MLA decode graph parameters was holding strong references to tensors, preventing them from being garbage collected and leading to increased memory usage. This change wraps the cached tensors in weak references, allowing them to be deallocated when no longer in use and reducing overall memory pressure. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None. - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@c9461e0 --------- Signed-off-by: Yizhou Liu <[email protected]>
1 parent afc5818 commit 8ab8111

File tree

4 files changed

+29
-19
lines changed

4 files changed

+29
-19
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def _forward_decode_only(
562562
block_table=attn_metadata.block_tables,
563563
context_lens=attn_metadata.seq_lens,
564564
out=output)
565-
update_graph_params_workspaces(num_tokens, workspace)
565+
update_graph_params_workspaces(
566+
num_tokens, weak_ref_tensors(workspace))
566567

567568
# Handle graph capturing mode
568569
stream = torch_npu.npu.current_stream()
@@ -578,7 +579,7 @@ def _forward_decode_only(
578579
self.num_kv_heads,
579580
self.num_heads,
580581
self.scale,
581-
weak_ref_tensors(attn_metadata.block_tables),
582+
attn_metadata.block_tables,
582583
attn_metadata.seq_lens,
583584
weak_ref_tensors(output),
584585
))

vllm_ascend/attention/mla_v1.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
AttentionMetadata,
1313
MLAAttentionImpl)
1414
from vllm.config import VllmConfig, get_current_vllm_config
15-
16-
# isort: off
1715
from vllm.distributed import (get_dcp_group,
1816
get_decode_context_model_parallel_rank,
1917
get_decode_context_model_parallel_world_size,
@@ -35,19 +33,22 @@
3533
split_decodes_and_prefills,
3634
trans_rope_weight, transdata,
3735
wait_for_kv_layer_from_connector)
38-
from vllm_ascend.compilation.acl_graph import get_graph_params
36+
from vllm_ascend.compilation.acl_graph import (get_graph_params,
37+
update_graph_params_workspaces)
3938
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
4039
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
4140
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
42-
is_enable_nz, prefill_context_parallel_enable)
41+
is_enable_nz, prefill_context_parallel_enable,
42+
weak_ref_tensors)
4343
from vllm_ascend.worker.npu_input_batch import InputBatch
4444

45+
# isort: off
4546
if prefill_context_parallel_enable():
4647
from vllm.distributed import (get_pcp_group,
4748
get_prefill_context_model_parallel_rank,
4849
get_prefill_context_model_parallel_world_size
4950
)
50-
# isort:on
51+
# isort: on
5152
if TYPE_CHECKING:
5253
from vllm.v1.core.sched.output import SchedulerOutput
5354

@@ -743,7 +744,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
743744
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
744745
None), AscendW8A8LinearMethod):
745746
self.enable_mlapo = False
746-
logger.warning(
747+
logger.warning_once(
747748
"Currently mlapo only supports W8A8 quantization in MLA scenario."
748749
"Some layers in your model are not quantized with W8A8,"
749750
"thus mlapo is disabled for these layers.")
@@ -1115,19 +1116,22 @@ def _forward_decode(
11151116
if workspace is None:
11161117
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
11171118
q_nope, k_nope, k_nope, **common_kwargs)
1118-
graph_params.workspaces[num_tokens] = workspace
1119+
update_graph_params_workspaces(num_tokens,
1120+
weak_ref_tensors(workspace))
11191121

11201122
attn_output = torch.empty_like(q_nope)
11211123
softmax_lse = torch.empty(num_tokens,
11221124
dtype=q_nope.dtype,
11231125
device=q_nope.device)
11241126

11251127
graph_params.attn_params[num_tokens].append(
1126-
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
1127-
input_layout, spec_attn_mask, sparse_mode, self.scale,
1128-
decode_meta.block_table, block_size,
1129-
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
1130-
attn_output, softmax_lse))
1128+
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
1129+
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
1130+
self.num_heads, self.num_kv_heads, input_layout,
1131+
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
1132+
else None, sparse_mode, self.scale, decode_meta.block_table,
1133+
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
1134+
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
11311135

11321136
torch.npu.graph_task_group_begin(stream)
11331137
torch_npu.npu_fused_infer_attention_score.out(

vllm_ascend/compilation/acl_graph.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
212212
seq_lens,
213213
output,
214214
) = param
215-
# block_table = forward_context.attn_metadata[key].block_tables
216215
seq_lens = forward_context.attn_metadata[key].seq_lens
217216
torch_npu_check = version_check()
218217

@@ -258,8 +257,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
258257
):
259258
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
260259
spec_attn_mask, sparse_mode, scale, block_table, block_size,
261-
seq_lens_list, actual_seq_lengths, workspace, attn_output,
262-
softmax_lse) = param
260+
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
263261
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
264262
if speculative_config and speculative_config.method == "deepseek_mtp":
265263
actual_seq_lengths = forward_context.attn_metadata[
@@ -295,7 +293,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
295293
block_size=block_size,
296294
actual_seq_lengths_kv=seq_lens_list,
297295
actual_seq_lengths=actual_seq_lengths,
298-
workspace=workspace,
296+
workspace=graph_params.workspaces.get(runtime_shape),
299297
out=[attn_output, softmax_lse])
300298
torch.npu.graph_task_update_end(update_stream)
301299

@@ -329,7 +327,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
329327
)
330328

331329

332-
def update_graph_params_workspaces(num_tokens: int, workspace: int):
330+
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
333331
global _graph_params
334332
if _graph_params is not None:
335333
_graph_params.workspaces[num_tokens] = workspace

vllm_ascend/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,13 @@ def weak_ref_tensors(
697697
"""
698698
Convenience function to create weak references to tensors,
699699
for single tensor, list of tensors or tuple of tensors.
700+
701+
This function should be used in the following scenario:
702+
When a tensor is created during graph capture, and it's held by a method
703+
that's not part of the graph, we don't really need to store it, but we
704+
**do need** its buffer pointer. If we don't handle this, it cannot
705+
be garbage collected, leading to a memory leak. To avoid this,
706+
we should create a weak reference to the tensor.
700707
"""
701708
if isinstance(tensors, torch.Tensor):
702709
return weak_ref_tensor(tensors)

0 commit comments

Comments
 (0)