We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 98cf560 commit 7b605f8Copy full SHA for 7b605f8
vllm_ascend/worker/model_runner_v1.py
@@ -550,9 +550,9 @@ def _get_forward_metadata_across_dp(
550
with_prefill: bool) -> tuple[torch.Tensor, bool]:
551
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
552
device="npu",
553
- dtype=torch.int32)
+ dtype=torch.int32).unsqueeze(0)
554
global_forward_metadata = get_dp_group().all_gather(
555
- local_forward_metadata)
+ local_forward_metadata, dim=0)
556
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
557
with_prefill = bool(global_forward_metadata[:, 1].any())
558
return num_tokens_across_dp, with_prefill
0 commit comments