diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 81c5ef6e63..d6b626f085 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -58,6 +58,8 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | | `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine| +| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. | +| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. | ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well. diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index df36b52523..8f1422f1b8 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -165,3 +165,16 @@ def test_invalid_config_without_chunked_prefill(self): ) self.assertIn("max_num_batched_tokens (2048)", str(context.exception)) self.assertIn("max_model_len (4096)", str(context.exception)) + + def test_initialize_from_config_with_pd_transfer(self): + ascend_config = AscendSchedulerConfig.initialize_from_config( + self.basic_scheduler_config, + AscendSchedulerConfig( + enable_pd_transfer=True, + decode_max_num_seqs=48, + max_num_batched_tokens=4096, + max_model_len=4096, + ), + ) + self.assertEqual(ascend_config.enable_pd_transfer, True) + self.assertEqual(ascend_config.decode_max_num_seqs, 48) diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 1855c805bd..330115cc45 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -896,3 +896,34 @@ def test_memory_leak(self): # Confirm no memory leak. self.assert_scheduler_empty(scheduler) + + def test_scheduler_with_pd_transfer(self): + scheduler = self.create_scheduler() + scheduler.phase = "prefill" + requests = create_requests(num_requests=32) + for request in requests: + scheduler.add_request(request) + + # 1st iteration, move 16 requests from waiting to running for prefill + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + first_iter_prefilled_req_num = len(scheduler.running) + self.assertEqual(len(scheduler_output.scheduled_new_reqs), + scheduler.max_num_running_reqs) + self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(scheduler_output.finished_req_ids), 0) + + # 2nd iteration, move 16 prefilled requests to finished_prefill_reqs + # and move 16 requests from waiting to running for prefill + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + self.assertEqual(len(scheduler.finished_prefill_reqs), + first_iter_prefilled_req_num) + + # 3rd iteration, all requests prefilled, change scheduler phase to decode + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + self.assertEqual(scheduler.phase, "decode") diff --git a/tests/ut/sample/logits_processor/test_builtin.py b/tests/ut/sample/logits_processor/test_builtin.py new file mode 100644 index 0000000000..cecd18624d --- /dev/null +++ b/tests/ut/sample/logits_processor/test_builtin.py @@ -0,0 +1,40 @@ +import torch +from pytest_mock import MockerFixture +from vllm.config import SchedulerConfig, VllmConfig + +from tests.ut.base import PytestBase +from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor + + +class TestMinPLogitsProcessorInitFunc(PytestBase): + + def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture): + device_cpu = torch.device("cpu") + device_npu = torch.device("npu") + is_pin_memory = False + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) + mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig) + mock_scheduler_config.decode_max_num_seqs = 0 + mock_scheduler_config.max_num_seqs = 128 + mock_vllm_config.scheduler_config = mock_scheduler_config + # torch.zeros/torch.empty returns error on online ut machine, so mock it + mock_tensor = torch.zeros((256, ), + dtype=torch.float32, + pin_memory=False) + mocker.patch("torch.zeros", return_value=mock_tensor) + mock_empty_tensor = torch.empty((256, ), dtype=torch.float32) + mocker.patch("torch.empty", return_value=mock_empty_tensor) + + processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu, + is_pin_memory) + + assert processor_cpu.min_p is not None + assert processor_cpu.use_double_tensor is False + assert processor_cpu.min_p_cpu.shape[0] == 256 + + processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu, + is_pin_memory) + + assert processor_cpu.min_p is not None + assert processor_cpu.use_double_tensor is True + assert processor_cpu.min_p_cpu.shape[0] == 256 diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 4ee02e7ed4..422ca9aa3f 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -28,6 +28,8 @@ class AscendSchedulerConfig(SchedulerConfig): num_scheduler_steps: int = 1 scheduler_cls: Union[str, Type[object]] = ( "vllm_ascend.core.scheduler.AscendScheduler") + enable_pd_transfer: bool = False + decode_max_num_seqs: int = 0 @classmethod def initialize_from_config( @@ -45,6 +47,8 @@ def initialize_from_config( scheduler_config["num_scheduler_steps"] = 1 scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") + scheduler_config["enable_pd_transfer"] = False + scheduler_config["decode_max_num_seqs"] = 0 # Override params in original SchedulerConfig with params in ascend_scheduler_config for k, _ in scheduler_config.items(): if hasattr(ascend_scheduler_config, k): diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index f8c7f49355..965578155d 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -58,6 +58,15 @@ def __init__( self.scheduled_req_ids: set[str] = set() self.running: list[Request] = [] + self.finished_prefill_reqs: deque[Request] = deque() + enable_pd_transfer = getattr(self.scheduler_config, + 'enable_pd_transfer', False) + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.phase = "" if not enable_pd_transfer else "prefill" + self.decode_max_num_running_reqs = max(self.max_num_running_reqs, + decode_max_num_seqs) + def schedule(self) -> SchedulerOutput: if self.scheduler_config.chunked_prefill_enabled: return super().schedule() @@ -85,9 +94,25 @@ def schedule(self) -> SchedulerOutput: # and put back at the head of the waiting queue later skipped_waiting_requests: deque[Request] = deque() + if self.phase == "prefill": + remaining_running_reqs = [] + for request in self.running: + # move request has finished prefill to finished_prefill_reqs + if request.num_tokens > request.num_prompt_tokens: + self.finished_prefill_reqs.append(request) + else: + remaining_running_reqs.append(request) + self.running = remaining_running_reqs + # all request prefilled, change phase to decode + if not self.waiting and not self.running: + self.phase = "decode" + # Schedule prefill requests first. while self.waiting and token_budget > 0: - if len(self.running) == self.max_num_running_reqs: + if len(self.running) == (self.decode_max_num_running_reqs + if self.phase == "decode" else + self.max_num_running_reqs): + break request = self.waiting[0] @@ -247,6 +272,13 @@ def skip_cur_request(): if skipped_waiting_requests: self.waiting.extendleft(skipped_waiting_requests) + if self.phase == "decode": + while len( + self.running + ) < self.decode_max_num_running_reqs and self.finished_prefill_reqs: + request = self.finished_prefill_reqs.popleft() + self.running.append(request) + # If no prefill requests are scheduled, # Schedule decode requests next. if len(self.scheduled_req_ids) == 0: @@ -350,7 +382,9 @@ def skip_cur_request(): total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert token_budget >= 0 - assert len(self.running) <= self.max_num_running_reqs + assert len( + self.running + ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( scheduled_running_reqs) <= len(self.running) diff --git a/vllm_ascend/sample/logits_processor/__init__.py b/vllm_ascend/sample/logits_processor/__init__.py new file mode 100644 index 0000000000..5f810bfcd1 --- /dev/null +++ b/vllm_ascend/sample/logits_processor/__init__.py @@ -0,0 +1,50 @@ +import itertools +from collections.abc import Sequence +from typing import TYPE_CHECKING, Union + +import torch +from vllm.logger import init_logger +from vllm.v1.sample import logits_processor +from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, + MinTokensLogitsProcessor) +from vllm.v1.sample.logits_processor.interface import LogitsProcessor +from vllm.v1.sample.logits_processor.state import LogitsProcessors + +from vllm_ascend.sample.logits_processor.builtin import \ + AscendMinPLogitsProcessor + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + +# Error message when the user tries to initialize vLLM with a pooling model +# and custom logitsproces +STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" + " logits processors.") + +BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ + MinTokensLogitsProcessor, + LogitBiasLogitsProcessor, + AscendMinPLogitsProcessor, +] + + +def build_logitsprocs( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs( + custom_logitsprocs) + return LogitsProcessors( + ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) diff --git a/vllm_ascend/sample/logits_processor/builtin.py b/vllm_ascend/sample/logits_processor/builtin.py new file mode 100644 index 0000000000..f38d940240 --- /dev/null +++ b/vllm_ascend/sample/logits_processor/builtin.py @@ -0,0 +1,35 @@ +import torch +from vllm.config import VllmConfig +from vllm.v1.sample.logits_processor import MinPLogitsProcessor + + +class AscendMinPLogitsProcessor(MinPLogitsProcessor): + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + super().__init__(vllm_config, device, is_pin_memory) + + decode_max_num_seqs = getattr(vllm_config.scheduler_config, + 'decode_max_num_seqs', 0) + if decode_max_num_seqs != 0: + max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs, + decode_max_num_seqs) + + self.min_p_count: int = 0 + + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=is_pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + + self.use_double_tensor = torch.device(device).type != "cpu" + + if self.use_double_tensor: + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs, ), dtype=torch.float32, device=device) + else: + self.min_p_device = self.min_p_cpu_tensor + # Current slice of the device tensor + self.min_p: torch.Tensor = self.min_p_device[:0] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ba5b4396a2..f33503a3fe 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -66,7 +66,6 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -86,6 +85,7 @@ from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform +from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer @@ -178,7 +178,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, self.block_size) self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - self.max_num_reqs = self.scheduler_config.max_num_seqs + decode_max_num_seqs = getattr(self.scheduler_config, + 'decode_max_num_seqs', 0) + self.max_num_reqs = max(self.scheduler_config.max_num_seqs, + decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device