-
Notifications
You must be signed in to change notification settings - Fork 531
[main] add pd transfer for ascend scheduler #2753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
9138372
cbcc27c
5f721fc
914611a
919cc03
7095897
e76d630
6384861
e63792d
e22760b
47b7a97
6dbc172
5a24294
225f0bd
87161a0
b6952b8
c74d4c6
f3b3218
396c81f
cb13556
2bf91df
b5e0425
03949b1
30571f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import torch | ||
| from pytest_mock import MockerFixture | ||
| from vllm.config import SchedulerConfig, VllmConfig | ||
| from vllm.v1.sample.logits_processor import MinPLogitsProcessor | ||
|
|
||
| from tests.ut.base import PytestBase | ||
| from vllm_ascend.ops.min_p_logits_processor import \ | ||
| min_p_logits_processor_init_func | ||
|
|
||
|
|
||
| class TestMinPLogitsProcessorInitFunc(PytestBase): | ||
|
|
||
| def test_init_func_without_decode_max_num_seqs(self, | ||
| mocker: MockerFixture): | ||
| mock_min_p_logits_processor = mocker.MagicMock( | ||
| spec=MinPLogitsProcessor) | ||
| mock_min_p_logits_processor.min_p_cpu = None | ||
|
|
||
| mock_vllm_config = mocker.MagicMock(spec=VllmConfig) | ||
| mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) | ||
| mock_scheduler_config.decode_max_num_seqs = 0 | ||
| mock_scheduler_config.max_num_seqs = 128 | ||
| mock_vllm_config.scheduler_config = mock_scheduler_config | ||
| mocker.patch( | ||
| "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", | ||
| return_value=mock_vllm_config) | ||
| mocker.patch( | ||
| "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", | ||
| return_value=None) | ||
|
|
||
| min_p_logits_processor_init_func(mock_min_p_logits_processor, | ||
| mock_vllm_config, "cpu", True) | ||
|
|
||
| assert mock_min_p_logits_processor.min_p_cpu is None | ||
|
|
||
| def test_init_func_with_decode_max_num_seqs_and_npu( | ||
| self, mocker: MockerFixture): | ||
| mock_min_p_logits_processor = mocker.MagicMock( | ||
| spec=MinPLogitsProcessor) | ||
|
|
||
| mock_vllm_config = mocker.MagicMock(spec=VllmConfig) | ||
| mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) | ||
| mock_scheduler_config.decode_max_num_seqs = 256 | ||
| mock_scheduler_config.max_num_seqs = 128 | ||
| mock_vllm_config.scheduler_config = mock_scheduler_config | ||
| mocker.patch( | ||
| "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", | ||
| return_value=mock_vllm_config) | ||
| mocker.patch( | ||
| "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", | ||
| return_value=None) | ||
| # torch.zeros/torch.empty returns error on online ut machine, so mock it | ||
| mock_tensor = torch.zeros((256, ), | ||
| dtype=torch.float32, | ||
| pin_memory=False) | ||
| mocker.patch("torch.zeros", return_value=mock_tensor) | ||
| mock_empty_tensor = torch.empty((256, ), dtype=torch.float32) | ||
| mocker.patch("torch.empty", return_value=mock_empty_tensor) | ||
|
|
||
| min_p_logits_processor_init_func(mock_min_p_logits_processor, | ||
| mock_vllm_config, "npu", False) | ||
|
|
||
| assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 256 | ||
| assert mock_min_p_logits_processor.use_double_tensor is True | ||
|
|
||
| def test_init_func_with_decode_max_num_seqs_and_cpu( | ||
| self, mocker: MockerFixture): | ||
| mock_min_p_logits_processor = mocker.MagicMock( | ||
| spec=MinPLogitsProcessor) | ||
|
|
||
| mock_vllm_config = mocker.MagicMock(spec=VllmConfig) | ||
| mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) | ||
| mock_scheduler_config.max_num_seqs = 128 | ||
| mock_scheduler_config.decode_max_num_seqs = 256 | ||
| mock_vllm_config.scheduler_config = mock_scheduler_config | ||
| mocker.patch( | ||
| "vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config", | ||
| return_value=mock_vllm_config) | ||
| mocker.patch( | ||
| "vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func", | ||
| return_value=None) | ||
| # torch.zeros returns error on online ut machine, so mock it | ||
| mock_tensor = torch.zeros((256, ), | ||
| dtype=torch.float32, | ||
| pin_memory=False) | ||
| mocker.patch("torch.zeros", return_value=mock_tensor) | ||
|
|
||
| min_p_logits_processor_init_func(mock_min_p_logits_processor, | ||
| mock_vllm_config, "cpu", False) | ||
|
|
||
| assert mock_min_p_logits_processor.use_double_tensor is False |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -58,6 +58,15 @@ def __init__( | |
| self.scheduled_req_ids: set[str] = set() | ||
| self.running: list[Request] = [] | ||
|
|
||
| self.finished_prefill_reqs: deque[Request] = deque() | ||
| enable_pd_transfer = getattr(self.scheduler_config, | ||
| 'enable_pd_transfer', False) | ||
| decode_max_num_seqs = getattr(self.scheduler_config, | ||
| 'decode_max_num_seqs', 0) | ||
| self.phase = "" if not enable_pd_transfer else "prefill" | ||
| self.decode_max_num_running_reqs = max(self.max_num_running_reqs, | ||
| decode_max_num_seqs) | ||
|
|
||
| def schedule(self) -> SchedulerOutput: | ||
| if self.scheduler_config.chunked_prefill_enabled: | ||
| return super().schedule() | ||
|
|
@@ -85,9 +94,25 @@ def schedule(self) -> SchedulerOutput: | |
| # and put back at the head of the waiting queue later | ||
| skipped_waiting_requests: deque[Request] = deque() | ||
|
|
||
| if self.phase == "prefill": | ||
| remaining_running_reqs = [] | ||
| for request in self.running: | ||
| # move request has finished prefill to finished_prefill_reqs | ||
| if request.num_tokens > request.num_prompt_tokens: | ||
| self.finished_prefill_reqs.append(request) | ||
| else: | ||
| remaining_running_reqs.append(request) | ||
| self.running = remaining_running_reqs | ||
| # all request prefilled, change phase to decode | ||
| if not self.waiting and not self.running: | ||
| self.phase = "decode" | ||
|
Comment on lines
+107
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The phase transition from 'prefill' to 'decode' is currently one-way. Once the scheduler enters the 'decode' phase, it never returns to 'prefill'. If new requests arrive while the system is in the 'decode' phase, they will be prefilled and then immediately start decoding, which might not be the most efficient approach for the Ascend hardware this feature is targeting, as it breaks the strict batching of prefill operations. To improve performance for dynamic workloads, consider adding logic to allow the scheduler to switch back to the 'prefill' phase. For instance, you could add a check at the beginning of the if self.phase == "decode" and not self.running and self.waiting:
self.phase = "prefill"This would ensure that if the decoding queue is empty and new requests are waiting, the scheduler can switch back to the more efficient batch prefill mode. |
||
|
|
||
| # Schedule prefill requests first. | ||
| while self.waiting and token_budget > 0: | ||
| if len(self.running) == self.max_num_running_reqs: | ||
| if len(self.running) == (self.decode_max_num_running_reqs | ||
| if self.phase == "decode" else | ||
| self.max_num_running_reqs): | ||
|
|
||
| break | ||
|
|
||
| request = self.waiting[0] | ||
|
|
@@ -247,6 +272,13 @@ def skip_cur_request(): | |
| if skipped_waiting_requests: | ||
| self.waiting.extendleft(skipped_waiting_requests) | ||
|
|
||
| if self.phase == "decode": | ||
| while len( | ||
| self.running | ||
| ) < self.decode_max_num_running_reqs and self.finished_prefill_reqs: | ||
| request = self.finished_prefill_reqs.popleft() | ||
| self.running.append(request) | ||
|
|
||
| # If no prefill requests are scheduled, | ||
| # Schedule decode requests next. | ||
| if len(self.scheduled_req_ids) == 0: | ||
|
|
@@ -350,7 +382,9 @@ def skip_cur_request(): | |
| total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) | ||
| assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens | ||
| assert token_budget >= 0 | ||
| assert len(self.running) <= self.max_num_running_reqs | ||
| assert len( | ||
| self.running | ||
| ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs | ||
| assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( | ||
| scheduled_running_reqs) <= len(self.running) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import torch | ||
| from vllm.config import get_current_vllm_config | ||
| from vllm.v1.sample.logits_processor import MinPLogitsProcessor | ||
|
|
||
| original_min_p_logits_processor_init_func = MinPLogitsProcessor.__init__ | ||
|
|
||
|
|
||
| def min_p_logits_processor_init_func(self, *args, **kwargs): | ||
| original_min_p_logits_processor_init_func(self, *args, **kwargs) | ||
|
|
||
| vllm_config = get_current_vllm_config() | ||
| decode_max_num_seqs = getattr(vllm_config.scheduler_config, | ||
| 'decode_max_num_seqs', 0) | ||
| # reinit MinPLogitsProcessor if decode_max_num_seqs configured | ||
| if decode_max_num_seqs != 0: | ||
| device = args[1] | ||
| is_pin_memory = args[2] | ||
| max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, | ||
| decode_max_num_seqs) | ||
|
|
||
| self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), | ||
| dtype=torch.float32, | ||
| device="cpu", | ||
| pin_memory=is_pin_memory) | ||
| self.min_p_cpu = self.min_p_cpu_tensor.numpy() | ||
|
|
||
| self.use_double_tensor = torch.device(device).type != "cpu" | ||
|
|
||
| if self.use_double_tensor: | ||
| self.min_p_device = torch.empty((max_num_reqs, ), | ||
| dtype=torch.float32, | ||
| device=device) | ||
| else: | ||
| self.min_p_device = self.min_p_cpu_tensor | ||
| self.min_p = self.min_p_device[:0] | ||
|
|
||
|
|
||
| MinPLogitsProcessor.__init__ = min_p_logits_processor_init_func | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Manually setting the internal state
scheduler.phasemakes this test brittle and less representative of real usage. If the initialization logic inAscendScheduler.__init__changes, this test would not catch the regression. A better approach is to initialize the scheduler withenable_pd_transfer=Truein its configuration, which would correctly set the initial phase.To achieve this, you could modify the
create_schedulerhelper method to accept configuration overrides. For example:Then, the test can be updated to:
This change would make the test more robust and also serve to verify the scheduler's initialization logic.