Skip to content

Commit d79f98d

Browse files
Liccoloffline0806
authored andcommitted
[main] add pd transfer for ascend scheduler (vllm-project#2753)
### What this PR does / why we need it? For offline scenarios, adjust the scheduling process to prioritize the prefill phase of all requests, then process the decode phase of all requests. ### How was this patch tested? ``` max_num_seqs=24, additional_config={ "ascend_scheduler_config":{ "enabled": True, "enable_pd_transfer": True, "decode_max_num_seqs": 24, "enable_chunked_prefill": False } }, ``` | input | output | num prompts | max_num_seqs | dp | tp | scheduler | tps | | ------ | ------ | ---------- | ---------------- | ---- | ---- | ---------------- | --------------- | | dapo-math-17K | 2K | 384 | 24 | 2 | 1 | v1 | 234.06 | | dapo-math-17K | 2K | 384 | 24 | 2 | 1 | pd transfer | 239.59(+2.4%) | | dapo-math-17K| 2K | 384 | 24 | 4 | 1 | v1 | 222.85 | | dapo-math-17K| 2K | 384 | 24 | 4 | 1 | pd transfer | 225.81(+1.3%) | - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@6fb2788 --------- Signed-off-by: CaranLic <[email protected]> Signed-off-by: offline0806 <[email protected]>
1 parent de4bd1a commit d79f98d

File tree

9 files changed

+216
-4
lines changed

9 files changed

+216
-4
lines changed

docs/source/user_guide/configuration/additional_config.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ The details of each config option are as follows:
5858
| Name | Type | Default | Description |
5959
| ---- | ---- | ------- | ----------- |
6060
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
61+
| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. |
62+
| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. |
6163

6264
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
6365

tests/ut/core/test_schedule_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,16 @@ def test_invalid_config_without_chunked_prefill(self):
165165
)
166166
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
167167
self.assertIn("max_model_len (4096)", str(context.exception))
168+
169+
def test_initialize_from_config_with_pd_transfer(self):
170+
ascend_config = AscendSchedulerConfig.initialize_from_config(
171+
self.basic_scheduler_config,
172+
AscendSchedulerConfig(
173+
enable_pd_transfer=True,
174+
decode_max_num_seqs=48,
175+
max_num_batched_tokens=4096,
176+
max_model_len=4096,
177+
),
178+
)
179+
self.assertEqual(ascend_config.enable_pd_transfer, True)
180+
self.assertEqual(ascend_config.decode_max_num_seqs, 48)

tests/ut/core/test_scheduler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,3 +705,34 @@ def test_memory_leak(self):
705705

706706
# Confirm no memory leak.
707707
self.assert_scheduler_empty(scheduler)
708+
709+
def test_scheduler_with_pd_transfer(self):
710+
scheduler = self.create_scheduler()
711+
scheduler.phase = "prefill"
712+
requests = create_requests(num_requests=32)
713+
for request in requests:
714+
scheduler.add_request(request)
715+
716+
# 1st iteration, move 16 requests from waiting to running for prefill
717+
scheduler_output = scheduler.schedule()
718+
model_runner_output = make_output(scheduler)
719+
scheduler.update_from_output(scheduler_output, model_runner_output)
720+
first_iter_prefilled_req_num = len(scheduler.running)
721+
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
722+
scheduler.max_num_running_reqs)
723+
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
724+
self.assertEqual(len(scheduler_output.finished_req_ids), 0)
725+
726+
# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
727+
# and move 16 requests from waiting to running for prefill
728+
scheduler_output = scheduler.schedule()
729+
model_runner_output = make_output(scheduler)
730+
scheduler.update_from_output(scheduler_output, model_runner_output)
731+
self.assertEqual(len(scheduler.finished_prefill_reqs),
732+
first_iter_prefilled_req_num)
733+
734+
# 3rd iteration, all requests prefilled, change scheduler phase to decode
735+
scheduler_output = scheduler.schedule()
736+
model_runner_output = make_output(scheduler)
737+
scheduler.update_from_output(scheduler_output, model_runner_output)
738+
self.assertEqual(scheduler.phase, "decode")
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
from pytest_mock import MockerFixture
3+
from vllm.config import SchedulerConfig, VllmConfig
4+
5+
from tests.ut.base import PytestBase
6+
from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor
7+
8+
9+
class TestMinPLogitsProcessorInitFunc(PytestBase):
10+
11+
def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture):
12+
device_cpu = torch.device("cpu")
13+
device_npu = torch.device("npu")
14+
is_pin_memory = False
15+
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
16+
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
17+
mock_scheduler_config.decode_max_num_seqs = 0
18+
mock_scheduler_config.max_num_seqs = 128
19+
mock_vllm_config.scheduler_config = mock_scheduler_config
20+
# torch.zeros/torch.empty returns error on online ut machine, so mock it
21+
mock_tensor = torch.zeros((256, ),
22+
dtype=torch.float32,
23+
pin_memory=False)
24+
mocker.patch("torch.zeros", return_value=mock_tensor)
25+
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
26+
mocker.patch("torch.empty", return_value=mock_empty_tensor)
27+
28+
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu,
29+
is_pin_memory)
30+
31+
assert processor_cpu.min_p is not None
32+
assert processor_cpu.use_double_tensor is False
33+
assert processor_cpu.min_p_cpu.shape[0] == 256
34+
35+
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu,
36+
is_pin_memory)
37+
38+
assert processor_cpu.min_p is not None
39+
assert processor_cpu.use_double_tensor is True
40+
assert processor_cpu.min_p_cpu.shape[0] == 256

vllm_ascend/core/schedule_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class AscendSchedulerConfig(SchedulerConfig):
2828
num_scheduler_steps: int = 1
2929
scheduler_cls: Union[str, Type[object]] = (
3030
"vllm_ascend.core.scheduler.AscendScheduler")
31+
enable_pd_transfer: bool = False
32+
decode_max_num_seqs: int = 0
3133

3234
@classmethod
3335
def initialize_from_config(
@@ -45,6 +47,8 @@ def initialize_from_config(
4547
scheduler_config["num_scheduler_steps"] = 1
4648
scheduler_config["scheduler_cls"] = (
4749
"vllm_ascend.core.scheduler.AscendScheduler")
50+
scheduler_config["enable_pd_transfer"] = False
51+
scheduler_config["decode_max_num_seqs"] = 0
4852
# Override params in original SchedulerConfig with params in ascend_scheduler_config
4953
for k, _ in scheduler_config.items():
5054
if hasattr(ascend_scheduler_config, k):

vllm_ascend/core/scheduler.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def __init__(
5252
self.scheduled_req_ids: set[str] = set()
5353
self.running: list[Request] = []
5454

55+
self.finished_prefill_reqs: deque[Request] = deque()
56+
enable_pd_transfer = getattr(self.scheduler_config,
57+
'enable_pd_transfer', False)
58+
decode_max_num_seqs = getattr(self.scheduler_config,
59+
'decode_max_num_seqs', 0)
60+
self.phase = "" if not enable_pd_transfer else "prefill"
61+
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
62+
decode_max_num_seqs)
63+
5564
def schedule(self) -> SchedulerOutput:
5665
if self.scheduler_config.chunked_prefill_enabled:
5766
return super().schedule()
@@ -76,9 +85,25 @@ def schedule(self) -> SchedulerOutput:
7685
# and put back at the head of the waiting queue later
7786
skipped_waiting_requests: deque[Request] = deque()
7887

88+
if self.phase == "prefill":
89+
remaining_running_reqs = []
90+
for request in self.running:
91+
# move request has finished prefill to finished_prefill_reqs
92+
if request.num_tokens > request.num_prompt_tokens:
93+
self.finished_prefill_reqs.append(request)
94+
else:
95+
remaining_running_reqs.append(request)
96+
self.running = remaining_running_reqs
97+
# all request prefilled, change phase to decode
98+
if not self.waiting and not self.running:
99+
self.phase = "decode"
100+
79101
# Schedule prefill requests first.
80102
while self.waiting and token_budget > 0:
81-
if len(self.running) == self.max_num_running_reqs:
103+
if len(self.running) == (self.decode_max_num_running_reqs
104+
if self.phase == "decode" else
105+
self.max_num_running_reqs):
106+
82107
break
83108

84109
request = self.waiting[0]
@@ -235,6 +260,13 @@ def skip_cur_request():
235260
if skipped_waiting_requests:
236261
self.waiting.extendleft(skipped_waiting_requests)
237262

263+
if self.phase == "decode":
264+
while len(
265+
self.running
266+
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
267+
request = self.finished_prefill_reqs.popleft()
268+
self.running.append(request)
269+
238270
# If no prefill requests are scheduled,
239271
# Schedule decode requests next.
240272
if len(self.scheduled_req_ids) == 0:
@@ -334,7 +366,9 @@ def skip_cur_request():
334366
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
335367
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
336368
assert token_budget >= 0
337-
assert len(self.running) <= self.max_num_running_reqs
369+
assert len(
370+
self.running
371+
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
338372
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
339373
scheduled_running_reqs) <= len(self.running)
340374

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import itertools
2+
from collections.abc import Sequence
3+
from typing import TYPE_CHECKING, Union
4+
5+
import torch
6+
from vllm.logger import init_logger
7+
from vllm.v1.sample import logits_processor
8+
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
9+
MinTokensLogitsProcessor)
10+
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
11+
from vllm.v1.sample.logits_processor.state import LogitsProcessors
12+
13+
from vllm_ascend.sample.logits_processor.builtin import \
14+
AscendMinPLogitsProcessor
15+
16+
if TYPE_CHECKING:
17+
from vllm.config import VllmConfig
18+
19+
logger = init_logger(__name__)
20+
21+
# Error message when the user tries to initialize vLLM with a pooling model
22+
# and custom logitsproces
23+
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
24+
" logits processors.")
25+
26+
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
27+
MinTokensLogitsProcessor,
28+
LogitBiasLogitsProcessor,
29+
AscendMinPLogitsProcessor,
30+
]
31+
32+
33+
def build_logitsprocs(
34+
vllm_config: "VllmConfig",
35+
device: torch.device,
36+
is_pin_memory: bool,
37+
is_pooling_model: bool,
38+
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
39+
) -> LogitsProcessors:
40+
if is_pooling_model:
41+
if custom_logitsprocs:
42+
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
43+
logger.debug("Skipping logits processor loading because pooling models"
44+
" do not support logits processors.")
45+
return LogitsProcessors()
46+
custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs(
47+
custom_logitsprocs)
48+
return LogitsProcessors(
49+
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
50+
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
from vllm.config import VllmConfig
3+
from vllm.v1.sample.logits_processor import MinPLogitsProcessor
4+
5+
6+
class AscendMinPLogitsProcessor(MinPLogitsProcessor):
7+
8+
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
9+
is_pin_memory: bool):
10+
super().__init__(vllm_config, device, is_pin_memory)
11+
12+
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
13+
'decode_max_num_seqs', 0)
14+
if decode_max_num_seqs != 0:
15+
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
16+
decode_max_num_seqs)
17+
18+
self.min_p_count: int = 0
19+
20+
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
21+
dtype=torch.float32,
22+
device="cpu",
23+
pin_memory=is_pin_memory)
24+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
25+
26+
self.use_double_tensor = torch.device(device).type != "cpu"
27+
28+
if self.use_double_tensor:
29+
# Pre-allocated device tensor
30+
self.min_p_device: torch.Tensor = torch.empty(
31+
(max_num_reqs, ), dtype=torch.float32, device=device)
32+
else:
33+
self.min_p_device = self.min_p_cpu_tensor
34+
# Current slice of the device tensor
35+
self.min_p: torch.Tensor = self.min_p_device[:0]

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
6868
LogprobsTensors, ModelRunnerOutput)
6969
from vllm.v1.pool.metadata import PoolingMetadata
70-
from vllm.v1.sample.logits_processor import build_logitsprocs
7170
from vllm.v1.sample.metadata import SamplingMetadata
7271
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
7372
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -92,6 +91,7 @@
9291
from vllm_ascend.eplb.utils import model_register
9392
from vllm_ascend.multistream.ms_split import compute_split_seq_index
9493
from vllm_ascend.platform import NPUPlatform
94+
from vllm_ascend.sample.logits_processor import build_logitsprocs
9595
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
9696
from vllm_ascend.spec_decode import get_spec_decode_method
9797
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
@@ -179,7 +179,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
179179
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
180180
self.block_size)
181181
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
182-
self.max_num_reqs = self.scheduler_config.max_num_seqs
182+
decode_max_num_seqs = getattr(self.scheduler_config,
183+
'decode_max_num_seqs', 0)
184+
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
185+
decode_max_num_seqs)
183186
self.dp_size = vllm_config.parallel_config.data_parallel_size
184187
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
185188
self.device = device

0 commit comments

Comments
 (0)