Skip to content
46 changes: 46 additions & 0 deletions tests/ut/worker/test_model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm_ascend.utils import AscendSocVersion
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
Expand Down Expand Up @@ -104,3 +105,48 @@ def test_select_moe_comm_method_unsupported_soc():
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):

NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)


@patch('vllm_ascend.worker.model_runner_v1.torch_npu')
@patch('vllm_ascend.worker.model_runner_v1.torch')
def test_init_creates_transfer_event_and_pinned_memory(mock_torch,
mock_torch_npu):
"""Test that initialization creates transfer event and pinned CPU memory."""
# This is a simplified test focusing only on the new attributes
# We mock the entire __init__ process and only test the specific lines we added

# Mock torch.empty to return a mock tensor
mock_pinned_tensor = MagicMock()
mock_torch.empty.return_value = mock_pinned_tensor

# Mock torch_npu.npu.Event - 需要设置嵌套的 mock 结构
mock_event = MagicMock()
mock_torch_npu.npu.Event.return_value = mock_event

# Create a runner instance using __new__ to bypass __init__
runner = NPUModelRunner.__new__(NPUModelRunner)

# Manually set the attributes we need for our test
runner.max_model_len = 2048

# Test the specific lines from the commit
runner.transfer_event = mock_torch_npu.npu.Event()
runner.sampled_token_ids_pinned_cpu = mock_torch.empty(
(runner.max_model_len, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)

# Verify max_model_len is set
assert runner.max_model_len == 2048

# Verify transfer_event is created
assert runner.transfer_event == mock_event
mock_torch_npu.npu.Event.assert_called_once()

# Verify pinned CPU memory is created with correct parameters
assert runner.sampled_token_ids_pinned_cpu == mock_pinned_tensor
mock_torch.empty.assert_called_with((2048, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)
24 changes: 23 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.block_size = vllm_config.cache_config.block_size
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
Expand Down Expand Up @@ -401,6 +402,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
self.transfer_event = torch_npu.npu.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_model_len, 1),
dtype=torch.int64,
device="cpu",
pin_memory=True)

# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
Expand Down Expand Up @@ -1906,7 +1913,7 @@ def execute_model(
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
valid_sampled_token_ids = self._to_list(sampled_token_ids)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
Expand Down Expand Up @@ -3054,3 +3061,18 @@ def get_supported_pooling_tasks(self):

def _build_drafter_prepare_inputs_torchair_param(self):
return False

def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
# This is a short term mitigation for issue mentioned in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rewrite the comment to ascend case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

# https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which
# would block other copy ops from other cuda streams.
# A cuda event sync would avoid such a situation. Since
# this is in the critical path of every single model
# forward loop, this has caused perf issue for a disagg
# setup.
pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]]
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
Loading