Skip to content

Commit 7b605f8

Browse files
committed
Update vllm_ascend/worker/model_runner_v1.py
Co-authored-by: Angazenn <[email protected]> Signed-off-by: Jade Zheng <[email protected]>
1 parent 98cf560 commit 7b605f8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,9 @@ def _get_forward_metadata_across_dp(
550550
with_prefill: bool) -> tuple[torch.Tensor, bool]:
551551
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
552552
device="npu",
553-
dtype=torch.int32)
553+
dtype=torch.int32).unsqueeze(0)
554554
global_forward_metadata = get_dp_group().all_gather(
555-
local_forward_metadata)
555+
local_forward_metadata, dim=0)
556556
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
557557
with_prefill = bool(global_forward_metadata[:, 1].any())
558558
return num_tokens_across_dp, with_prefill

0 commit comments

Comments
 (0)