Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ The details of each config option are as follows:
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. |
| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. |

ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.

Expand Down
13 changes: 13 additions & 0 deletions tests/ut/core/test_schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,16 @@ def test_invalid_config_without_chunked_prefill(self):
)
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
self.assertIn("max_model_len (4096)", str(context.exception))

def test_initialize_from_config_with_pd_transfer(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_pd_transfer=True,
decode_max_num_seqs=48,
max_num_batched_tokens=4096,
max_model_len=4096,
),
)
self.assertEqual(ascend_config.enable_pd_transfer, True)
self.assertEqual(ascend_config.decode_max_num_seqs, 48)
31 changes: 31 additions & 0 deletions tests/ut/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,34 @@ def test_memory_leak(self):

# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)

def test_scheduler_with_pd_transfer(self):
scheduler = self.create_scheduler()
scheduler.phase = "prefill"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Manually setting the internal state scheduler.phase makes this test brittle and less representative of real usage. If the initialization logic in AscendScheduler.__init__ changes, this test would not catch the regression. A better approach is to initialize the scheduler with enable_pd_transfer=True in its configuration, which would correctly set the initial phase.

To achieve this, you could modify the create_scheduler helper method to accept configuration overrides. For example:

def create_scheduler(self, mock_compute_encoder_budget, scheduler_config_override: Optional[Dict[str, Any]] = None):
    # ... existing setup ...
    scheduler_config = SchedulerConfig(
        # ...
    )
    if scheduler_config_override:
        for key, value in scheduler_config_override.items():
            setattr(scheduler_config, key, value)
    # ... rest of the function ...

Then, the test can be updated to:

scheduler = self.create_scheduler(scheduler_config_override={"enable_pd_transfer": True})
self.assertEqual(scheduler.phase, "prefill")

This change would make the test more robust and also serve to verify the scheduler's initialization logic.

requests = create_requests(num_requests=32)
for request in requests:
scheduler.add_request(request)

# 1st iteration, move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
first_iter_prefilled_req_num = len(scheduler.running)
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
scheduler.max_num_running_reqs)
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(scheduler_output.finished_req_ids), 0)

# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
# and move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(len(scheduler.finished_prefill_reqs),
first_iter_prefilled_req_num)

# 3rd iteration, all requests prefilled, change scheduler phase to decode
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(scheduler.phase, "decode")
91 changes: 91 additions & 0 deletions tests/ut/ops/test_min_p_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from pytest_mock import MockerFixture
from vllm.config import SchedulerConfig, VllmConfig
from vllm.v1.sample.logits_processor import MinPLogitsProcessor

from tests.ut.base import PytestBase
from vllm_ascend.ops.min_p_logits_processor import \
min_p_logits_processor_init_func


class TestMinPLogitsProcessorInitFunc(PytestBase):

def test_init_func_without_decode_max_num_seqs(self,
mocker: MockerFixture):
mock_min_p_logits_processor = mocker.MagicMock(
spec=MinPLogitsProcessor)
mock_min_p_logits_processor.min_p_cpu = None

mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
mock_scheduler_config.decode_max_num_seqs = 0
mock_scheduler_config.max_num_seqs = 128
mock_vllm_config.scheduler_config = mock_scheduler_config
mocker.patch(
"vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config",
return_value=mock_vllm_config)
mocker.patch(
"vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func",
return_value=None)

min_p_logits_processor_init_func(mock_min_p_logits_processor,
mock_vllm_config, "cpu", True)

assert mock_min_p_logits_processor.min_p_cpu is None

def test_init_func_with_decode_max_num_seqs_and_npu(
self, mocker: MockerFixture):
mock_min_p_logits_processor = mocker.MagicMock(
spec=MinPLogitsProcessor)

mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
mock_scheduler_config.decode_max_num_seqs = 256
mock_scheduler_config.max_num_seqs = 128
mock_vllm_config.scheduler_config = mock_scheduler_config
mocker.patch(
"vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config",
return_value=mock_vllm_config)
mocker.patch(
"vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func",
return_value=None)
# torch.zeros/torch.empty returns error on online ut machine, so mock it
mock_tensor = torch.zeros((256, ),
dtype=torch.float32,
pin_memory=False)
mocker.patch("torch.zeros", return_value=mock_tensor)
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
mocker.patch("torch.empty", return_value=mock_empty_tensor)

min_p_logits_processor_init_func(mock_min_p_logits_processor,
mock_vllm_config, "npu", False)

assert mock_min_p_logits_processor.min_p_cpu.shape[0] == 256
assert mock_min_p_logits_processor.use_double_tensor is True

def test_init_func_with_decode_max_num_seqs_and_cpu(
self, mocker: MockerFixture):
mock_min_p_logits_processor = mocker.MagicMock(
spec=MinPLogitsProcessor)

mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
mock_scheduler_config.max_num_seqs = 128
mock_scheduler_config.decode_max_num_seqs = 256
mock_vllm_config.scheduler_config = mock_scheduler_config
mocker.patch(
"vllm_ascend.ops.min_p_logits_processor.get_current_vllm_config",
return_value=mock_vllm_config)
mocker.patch(
"vllm_ascend.ops.min_p_logits_processor.original_min_p_logits_processor_init_func",
return_value=None)
# torch.zeros returns error on online ut machine, so mock it
mock_tensor = torch.zeros((256, ),
dtype=torch.float32,
pin_memory=False)
mocker.patch("torch.zeros", return_value=mock_tensor)

min_p_logits_processor_init_func(mock_min_p_logits_processor,
mock_vllm_config, "cpu", False)

assert mock_min_p_logits_processor.use_double_tensor is False
4 changes: 4 additions & 0 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class AscendSchedulerConfig(SchedulerConfig):
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
enable_pd_transfer: bool = False
decode_max_num_seqs: int = 0

@classmethod
def initialize_from_config(
Expand All @@ -45,6 +47,8 @@ def initialize_from_config(
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
scheduler_config["enable_pd_transfer"] = False
scheduler_config["decode_max_num_seqs"] = 0
# Override params in original SchedulerConfig with params in ascend_scheduler_config
for k, _ in scheduler_config.items():
if hasattr(ascend_scheduler_config, k):
Expand Down
38 changes: 36 additions & 2 deletions vllm_ascend/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def __init__(
self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = []

self.finished_prefill_reqs: deque[Request] = deque()
enable_pd_transfer = getattr(self.scheduler_config,
'enable_pd_transfer', False)
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.phase = "" if not enable_pd_transfer else "prefill"
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
decode_max_num_seqs)

def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled:
return super().schedule()
Expand Down Expand Up @@ -85,9 +94,25 @@ def schedule(self) -> SchedulerOutput:
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()

if self.phase == "prefill":
remaining_running_reqs = []
for request in self.running:
# move request has finished prefill to finished_prefill_reqs
if request.num_tokens > request.num_prompt_tokens:
self.finished_prefill_reqs.append(request)
else:
remaining_running_reqs.append(request)
self.running = remaining_running_reqs
# all request prefilled, change phase to decode
if not self.waiting and not self.running:
self.phase = "decode"
Comment on lines +107 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The phase transition from 'prefill' to 'decode' is currently one-way. Once the scheduler enters the 'decode' phase, it never returns to 'prefill'. If new requests arrive while the system is in the 'decode' phase, they will be prefilled and then immediately start decoding, which might not be the most efficient approach for the Ascend hardware this feature is targeting, as it breaks the strict batching of prefill operations.

To improve performance for dynamic workloads, consider adding logic to allow the scheduler to switch back to the 'prefill' phase. For instance, you could add a check at the beginning of the schedule method:

if self.phase == "decode" and not self.running and self.waiting:
    self.phase = "prefill"

This would ensure that if the decoding queue is empty and new requests are waiting, the scheduler can switch back to the more efficient batch prefill mode.


# Schedule prefill requests first.
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
if len(self.running) == (self.decode_max_num_running_reqs
if self.phase == "decode" else
self.max_num_running_reqs):

break

request = self.waiting[0]
Expand Down Expand Up @@ -247,6 +272,13 @@ def skip_cur_request():
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)

if self.phase == "decode":
while len(
self.running
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
request = self.finished_prefill_reqs.popleft()
self.running.append(request)

# If no prefill requests are scheduled,
# Schedule decode requests next.
if len(self.scheduled_req_ids) == 0:
Expand Down Expand Up @@ -350,7 +382,9 @@ def skip_cur_request():
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert len(
self.running
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs) <= len(self.running)

Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.min_p_logits_processor # noqa
import vllm_ascend.ops.register_custom_ops # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
Expand Down
38 changes: 38 additions & 0 deletions vllm_ascend/ops/min_p_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from vllm.config import get_current_vllm_config
from vllm.v1.sample.logits_processor import MinPLogitsProcessor

original_min_p_logits_processor_init_func = MinPLogitsProcessor.__init__


def min_p_logits_processor_init_func(self, *args, **kwargs):
original_min_p_logits_processor_init_func(self, *args, **kwargs)

vllm_config = get_current_vllm_config()
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
'decode_max_num_seqs', 0)
# reinit MinPLogitsProcessor if decode_max_num_seqs configured
if decode_max_num_seqs != 0:
device = args[1]
is_pin_memory = args[2]
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
decode_max_num_seqs)

self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=is_pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()

self.use_double_tensor = torch.device(device).type != "cpu"

if self.use_double_tensor:
self.min_p_device = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
else:
self.min_p_device = self.min_p_cpu_tensor
self.min_p = self.min_p_device[:0]


MinPLogitsProcessor.__init__ = min_p_logits_processor_init_func
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not patch vLLM, you should create ascend build_logitsprocs in npu_input_batch instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change from patch MinPLogitsProcessor init func to redefine build_logitsprocs, code in 30571f1

5 changes: 4 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device
Expand Down
Loading