Skip to content

[V1][Spec Decode] Async scheduling integration with spec decode #22262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--no-spec-decode", action="store_true")
parser.add_argument("--async-scheduling", action="store_true")
return parser.parse_args()


Expand Down Expand Up @@ -127,11 +129,12 @@ def main():
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
speculative_config=speculative_config if not args.no_spec_decode else None,
disable_log_stats=False,
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
async_scheduling=args.async_scheduling,
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
Expand Down
64 changes: 61 additions & 3 deletions tests/v1/core/test_async_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque
from typing import Optional

import pytest

Expand All @@ -12,16 +13,20 @@


def _make_model_runner_output(
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput:
scheduler_output: SchedulerOutput,
sampled_token_ids: Optional[list[list[int]]] = None,
spec_token_ids: Optional[list[list[int]]] = None) -> ModelRunnerOutput:
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
if not sampled_token_ids:
sampled_token_ids = [[i] for i in range(len(req_ids))]
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index={
req_id: i
for i, req_id in enumerate(req_ids)
},
sampled_token_ids=[[i] for i in range(len(req_ids))],
spec_token_ids=None,
sampled_token_ids=sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
Expand Down Expand Up @@ -55,6 +60,59 @@ def test_stop_by_max_tokens(max_tokens: int):
assert req1.num_output_tokens == max_tokens


def test_spec_decode():
max_tokens = 7
num_spec_tokens = 3
spec_token_ids = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
[-1, -2, -3]]
sampled_token_ids = [[0], [1, 2, 13], [4, 15], [16], [-1, -2]]
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens,
async_scheduling=True)
requests = create_requests(num_requests=1, max_tokens=max_tokens)
req = requests[0]

sched_outputs: deque[SchedulerOutput] = deque()
scheduler.add_request(req)
sched_outputs.append(scheduler.schedule())
sched_outputs.append(scheduler.schedule())

i = 0
while sched_outputs:
sched_output = sched_outputs.popleft()
# Overwrite with cached spec decode tokens as done in GPUModelRunner
if i > 0:
sched_output.scheduled_spec_decode_tokens[
req.request_id] = spec_token_ids[i - 1]
model_runner_output = _make_model_runner_output(
sched_output, [sampled_token_ids[i]], [spec_token_ids[i]])
engine_core_output = scheduler.update_from_output(
sched_output, model_runner_output)
# Validate spec decode stats
if engine_core_output:
assert engine_core_output[0].scheduler_stats
spec_decoding_stats = engine_core_output[
0].scheduler_stats.spec_decoding_stats
if i == 0:
# No spec decode stats for prefill round
assert spec_decoding_stats is None
else:
assert spec_decoding_stats
assert spec_decoding_stats.num_drafts == 1
assert spec_decoding_stats.num_draft_tokens == num_spec_tokens
assert spec_decoding_stats.num_accepted_tokens == len(
sampled_token_ids[i]) - 1
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
assert sched_output.num_scheduled_tokens[
req.request_id] == 1 + num_spec_tokens
sched_outputs.append(sched_output)
i += 1

assert scheduler.get_num_unfinished_requests() == 0
assert req.num_output_tokens == max_tokens
assert req.output_token_ids._x == [0, 1, 2, 13, 4, 15, 16]


def test_abort():
scheduler = create_scheduler(async_scheduling=True)
requests = create_requests(num_requests=10, max_tokens=20)
Expand Down
7 changes: 0 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,13 +1198,6 @@ def create_engine_config(
raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.")

# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if self.speculative_config is not None:
raise ValueError(
"Currently, speculative decoding is not supported with "
"async scheduling.")

parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
Expand Down
13 changes: 3 additions & 10 deletions vllm/v1/core/sched/async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ def _update_after_schedule(
super()._update_after_schedule(scheduler_output)
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
if (request.num_computed_tokens == request.num_tokens +
if (request.num_computed_tokens == request.num_tokens_with_spec +
request.num_output_placeholders):
# The request will generate a new token in this scheduling step.
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1
request.num_output_placeholders = 1 + self.num_spec_tokens

def _update_request_with_output(
self,
Expand All @@ -35,13 +34,7 @@ def _update_request_with_output(
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids)

# Update the number of output placeholders.
request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0

# Cache the new tokens. Preempted requests should be skipped.
if status_before_update == RequestStatus.RUNNING:
self.kv_cache_manager.cache_blocks(
request,
request.num_computed_tokens - request.num_output_placeholders)
self.kv_cache_manager.cache_blocks(request, request.num_tokens)
return new_token_ids, stopped
9 changes: 6 additions & 3 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,12 @@ def schedule(self) -> SchedulerOutput:
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]

num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if request.num_output_placeholders:
num_new_tokens = request.num_output_placeholders
else:
num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)

if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
Expand Down
35 changes: 34 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ def __init__(

self.reorder_batch_threshold: Optional[int] = None

# Cache spec token ids and num computed tokens from previous round,
# used when async scheduling and spec decoding are both enabled
self.cached_spec_token_ids: dict[str, list[int]] = {}
self.cached_num_computed_tokens: dict[str, int] = {}

def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
Expand Down Expand Up @@ -420,6 +425,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
self.cached_spec_token_ids.pop(req_id, None)
self.cached_num_computed_tokens.pop(req_id, None)

# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
Expand Down Expand Up @@ -488,6 +496,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
lora_request=new_req_data.lora_request,
)

# Cache computed tokens for new request with
# speculative decoding + async scheduling
if (self.speculative_config
and self.scheduler_config.async_scheduling):
self.cached_num_computed_tokens[req_id] = (
new_req_data.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
image_grid_thw = []
Expand Down Expand Up @@ -532,7 +548,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
if req_id in self.cached_spec_token_ids:
scheduler_output.scheduled_spec_decode_tokens[
req_id] = self.cached_spec_token_ids[req_id]
if req_id in self.cached_num_computed_tokens:
num_computed_tokens = self.cached_num_computed_tokens[req_id]
else:
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]

Expand Down Expand Up @@ -1763,6 +1785,17 @@ def execute_model(
spec_decode_metadata,
spec_decode_common_attn_metadata,
)
# Update cached request states for async scheduling
if self.scheduler_config.async_scheduling:
for idx, req_id in enumerate(self.input_batch.req_ids):
if req_id in self.cached_spec_token_ids:
# Update num computed tokens for running requests
num_rejected_tokens = max_gen_len - len(
valid_sampled_token_ids[idx])
self.cached_num_computed_tokens[
req_id] += scheduler_output.num_scheduled_tokens[
req_id] - num_rejected_tokens
self.cached_spec_token_ids[req_id] = spec_token_ids[idx]

self.eplb_step()

Expand Down