From ca448e145fd45d5c83f8bb8ef553cf7f63c9b6cd Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Mon, 21 Jul 2025 18:29:53 +0800 Subject: [PATCH] feat: optimize forward metadata collection across dp ranks Signed-off-by: Jade Zheng --- vllm_ascend/torchair/torchair_worker.py | 10 ---- vllm_ascend/worker/model_runner_v1.py | 65 +++++++++++++++---------- vllm_ascend/worker/worker_v1.py | 29 ++++++----- 3 files changed, 56 insertions(+), 48 deletions(-) diff --git a/vllm_ascend/torchair/torchair_worker.py b/vllm_ascend/torchair/torchair_worker.py index dd426e352a..16d9128513 100644 --- a/vllm_ascend/torchair/torchair_worker.py +++ b/vllm_ascend/torchair/torchair_worker.py @@ -52,13 +52,3 @@ def determine_available_memory(self) -> int: self.model_runner.new_kv_cache_bytes = available_kv_cache_memory return available_kv_cache_memory - - def _get_max_num_tokens_and_with_prefill(self): - """Override _get_max_num_tokens_and_with_prefill to update max_num_tokens.""" - - max_num_tokens, with_prefill = super( - )._get_max_num_tokens_and_with_prefill() - if not with_prefill: - max_num_tokens = self.model_runner.select_torchair_padded_batch_size( - max_num_tokens) - return max_num_tokens, with_prefill diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e027b7cb37..f3fa14b435 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -30,9 +30,7 @@ import numpy.typing as npt import torch import torch._dynamo.cache_size -import torch.distributed as dist import torch.nn as nn -from torch.distributed import ReduceOp from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig @@ -562,16 +560,29 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_sampling_metadata() def _get_forward_metadata_across_dp( - self, total_num_scheduled_tokens: int, - with_prefill: bool) -> tuple[int, bool]: - forward_metadata = torch.tensor( - [total_num_scheduled_tokens, with_prefill], - device="cpu", - dtype=torch.int32) - dist.all_reduce(forward_metadata, - op=ReduceOp.MAX, - group=get_dp_group().cpu_group) - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + self, num_tokens: int, + with_prefill: bool) -> tuple[torch.Tensor, bool]: + local_forward_metadata = torch.tensor([num_tokens, with_prefill], + device="npu", + dtype=torch.int32).unsqueeze(0) + global_forward_metadata = get_dp_group().all_gather( + local_forward_metadata, dim=0) + num_tokens_across_dp = global_forward_metadata[:, 0].cpu() + with_prefill = bool(global_forward_metadata[:, 1].any()) + + if self.torchair_graph_enabled and not with_prefill: + max_num_tokens = int(num_tokens_across_dp.max().item()) + dummy_num_tokens = self.select_torchair_padded_batch_size( + max_num_tokens) + else: + dummy_num_tokens = 1 + + # If num_tokens is -1, this indicates a dummy batch and we need to reset + # num_tokens accordingly. + num_tokens = dummy_num_tokens if num_tokens == -1 else num_tokens + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + dummy_num_tokens) + return num_tokens, num_tokens_across_dp, with_prefill def get_eagle_atten_dict( self, @@ -1033,22 +1044,22 @@ def _process_reqs( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + num_tokens_across_dp = None if self.dp_size > 1: - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( - total_num_scheduled_tokens, with_prefill) + _, num_tokens_across_dp, with_prefill = \ + self._get_forward_metadata_across_dp(num_input_tokens, + with_prefill) + max_num_tokens = int(num_tokens_across_dp.max().item()) extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens extra_builder_kwargs['with_prefill_across_dp'] = with_prefill # Add graph_pad_size here if self.torchair_graph_enabled and not with_prefill: - if self.dp_size > 1: - padded_batch_size = self.select_torchair_padded_batch_size( - max_num_tokens) - else: - padded_batch_size = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) + max_num_tokens = (max_num_tokens + if self.dp_size > 1 else num_input_tokens) + padded_batch_size = self.select_torchair_padded_batch_size( + max_num_tokens) graph_pad_size = padded_batch_size - total_num_scheduled_tokens - extra_builder_kwargs['graph_pad_size'] = graph_pad_size if self.vllm_config.model_config.use_mla: @@ -1126,7 +1137,8 @@ def _process_reqs( # Run forward pass with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp): with ProfileExecuteDuration().capture_async("forward"): model_kwargs = {} if self.torchair_graph_enabled: @@ -1605,6 +1617,7 @@ def _dummy_run( num_tokens: int, is_compile: bool = False, with_prefill: bool = True, + num_tokens_across_dp: Optional[int] = None, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1646,9 +1659,11 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): if self.torchair_graph_enabled and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 73f2d0b29a..0af68378d0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -265,20 +265,23 @@ def list_loras(self) -> set[int]: def pin_lora(self, lora_id: int) -> bool: return self.model_runner.pin_lora(lora_id) - def _get_max_num_tokens_and_with_prefill(self): - max_num_tokens = 1 - with_prefill = False - if self.model_runner.dp_size > 1: - max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp( - max_num_tokens, with_prefill) - return max_num_tokens, with_prefill - def execute_dummy_batch(self) -> None: - max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill( - ) - self.model_runner._dummy_run(max_num_tokens, - is_compile=False, - with_prefill=with_prefill) + if self.runner.dp_size <= 1: + raise ValueError( + "Dummy batch execution should only be " + "performed with data parallelism enabled, but got " + f"dp_size={self.runner.dp_size}.") + + # Indicate to other data parallel (DP) ranks that this is a dummy run by + # using '-1' as the num_tokens flag. The actual batch size will be + # determined and set within the model runner after synchronization + # across DP ranks. + num_tokens, num_tokens_across_dp, with_prefill = \ + self.model_runner._get_forward_metadata_across_dp(-1, False) + self.runner._dummy_run(num_tokens, + is_compile=False, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment."""