diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891eca..e6acd9cfdf2c 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -68,6 +68,8 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--no-spec-decode", action="store_true") + parser.add_argument("--async-scheduling", action="store_true") return parser.parse_args() @@ -127,11 +129,12 @@ def main(): enable_chunked_prefill=args.enable_chunked_prefill, enforce_eager=args.enforce_eager, gpu_memory_utilization=0.8, - speculative_config=speculative_config, + speculative_config=speculative_config if not args.no_spec_decode else None, disable_log_stats=False, max_model_len=16384, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, + async_scheduling=args.async_scheduling, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 3ccefbd81cab..f2e5818478e9 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import deque +from typing import Optional import pytest @@ -12,16 +13,20 @@ def _make_model_runner_output( - scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput: + scheduler_output: SchedulerOutput, + sampled_token_ids: Optional[list[list[int]]] = None, + spec_token_ids: Optional[list[list[int]]] = None) -> ModelRunnerOutput: req_ids = list(scheduler_output.num_scheduled_tokens.keys()) + if not sampled_token_ids: + sampled_token_ids = [[i] for i in range(len(req_ids))] return ModelRunnerOutput( req_ids=req_ids, req_id_to_index={ req_id: i for i, req_id in enumerate(req_ids) }, - sampled_token_ids=[[i] for i in range(len(req_ids))], - spec_token_ids=None, + sampled_token_ids=sampled_token_ids, + spec_token_ids=spec_token_ids, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -55,6 +60,59 @@ def test_stop_by_max_tokens(max_tokens: int): assert req1.num_output_tokens == max_tokens +def test_spec_decode(): + max_tokens = 7 + num_spec_tokens = 3 + spec_token_ids = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [-1, -2, -3]] + sampled_token_ids = [[0], [1, 2, 13], [4, 15], [16], [-1, -2]] + scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens, + async_scheduling=True) + requests = create_requests(num_requests=1, max_tokens=max_tokens) + req = requests[0] + + sched_outputs: deque[SchedulerOutput] = deque() + scheduler.add_request(req) + sched_outputs.append(scheduler.schedule()) + sched_outputs.append(scheduler.schedule()) + + i = 0 + while sched_outputs: + sched_output = sched_outputs.popleft() + # Overwrite with cached spec decode tokens as done in GPUModelRunner + if i > 0: + sched_output.scheduled_spec_decode_tokens[ + req.request_id] = spec_token_ids[i - 1] + model_runner_output = _make_model_runner_output( + sched_output, [sampled_token_ids[i]], [spec_token_ids[i]]) + engine_core_output = scheduler.update_from_output( + sched_output, model_runner_output) + # Validate spec decode stats + if engine_core_output: + assert engine_core_output[0].scheduler_stats + spec_decoding_stats = engine_core_output[ + 0].scheduler_stats.spec_decoding_stats + if i == 0: + # No spec decode stats for prefill round + assert spec_decoding_stats is None + else: + assert spec_decoding_stats + assert spec_decoding_stats.num_drafts == 1 + assert spec_decoding_stats.num_draft_tokens == num_spec_tokens + assert spec_decoding_stats.num_accepted_tokens == len( + sampled_token_ids[i]) - 1 + sched_output = scheduler.schedule() + if sched_output.num_scheduled_tokens: + assert sched_output.num_scheduled_tokens[ + req.request_id] == 1 + num_spec_tokens + sched_outputs.append(sched_output) + i += 1 + + assert scheduler.get_num_unfinished_requests() == 0 + assert req.num_output_tokens == max_tokens + assert req.output_token_ids._x == [0, 1, 2, 13, 4, 15, 16] + + def test_abort(): scheduler = create_scheduler(async_scheduling=True) requests = create_requests(num_requests=10, max_tokens=20) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 41a6da709bec..16c8dc0b262f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1198,13 +1198,6 @@ def create_engine_config( raise ValueError("Async scheduling is not supported with " "pipeline-parallel-size > 1.") - # Currently, async scheduling does not support speculative decoding. - # TODO(woosuk): Support it. - if self.speculative_config is not None: - raise ValueError( - "Currently, speculative decoding is not supported with " - "async scheduling.") - parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 74ff6261732c..4087f7ed4263 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -20,11 +20,10 @@ def _update_after_schedule( super()._update_after_schedule(scheduler_output) for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] - if (request.num_computed_tokens == request.num_tokens + + if (request.num_computed_tokens == request.num_tokens_with_spec + request.num_output_placeholders): # The request will generate a new token in this scheduling step. - # TODO(woosuk): Support speculative decoding. - request.num_output_placeholders += 1 + request.num_output_placeholders = 1 + self.num_spec_tokens def _update_request_with_output( self, @@ -35,13 +34,7 @@ def _update_request_with_output( new_token_ids, stopped = super()._update_request_with_output( request, new_token_ids) - # Update the number of output placeholders. - request.num_output_placeholders -= len(new_token_ids) - assert request.num_output_placeholders >= 0 - # Cache the new tokens. Preempted requests should be skipped. if status_before_update == RequestStatus.RUNNING: - self.kv_cache_manager.cache_blocks( - request, - request.num_computed_tokens - request.num_output_placeholders) + self.kv_cache_manager.cache_blocks(request, request.num_tokens) return new_token_ids, stopped diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dcb9f4dd36f5..6dbcdaeea8c5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -204,9 +204,12 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) + if request.num_output_placeholders: + num_new_tokens = request.num_output_placeholders + else: + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens): num_new_tokens = ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 045a06d9278d..2375e4fd6a72 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -336,6 +336,11 @@ def __init__( self.reorder_batch_threshold: Optional[int] = None + # Cache spec token ids and num computed tokens from previous round, + # used when async scheduling and spec decoding are both enabled + self.cached_spec_token_ids: dict[str, list[int]] = {} + self.cached_num_computed_tokens: dict[str, int] = {} + def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() num_reqs = self.input_batch.num_reqs @@ -420,6 +425,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) + self.cached_spec_token_ids.pop(req_id, None) + self.cached_num_computed_tokens.pop(req_id, None) + # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -488,6 +496,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: lora_request=new_req_data.lora_request, ) + # Cache computed tokens for new request with + # speculative decoding + async scheduling + if (self.speculative_config + and self.scheduler_config.async_scheduling): + self.cached_num_computed_tokens[req_id] = ( + new_req_data.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: image_grid_thw = [] @@ -532,7 +548,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_data = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] + if req_id in self.cached_spec_token_ids: + scheduler_output.scheduled_spec_decode_tokens[ + req_id] = self.cached_spec_token_ids[req_id] + if req_id in self.cached_num_computed_tokens: + num_computed_tokens = self.cached_num_computed_tokens[req_id] + else: + num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] @@ -1763,6 +1785,17 @@ def execute_model( spec_decode_metadata, spec_decode_common_attn_metadata, ) + # Update cached request states for async scheduling + if self.scheduler_config.async_scheduling: + for idx, req_id in enumerate(self.input_batch.req_ids): + if req_id in self.cached_spec_token_ids: + # Update num computed tokens for running requests + num_rejected_tokens = max_gen_len - len( + valid_sampled_token_ids[idx]) + self.cached_num_computed_tokens[ + req_id] += scheduler_output.num_scheduled_tokens[ + req_id] - num_rejected_tokens + self.cached_spec_token_ids[req_id] = spec_token_ids[idx] self.eplb_step()