Skip to content

Commit 9a6f640

Browse files
committed
(wip) async scheduling + spec decode
1 parent e9b639d commit 9a6f640

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def parse_args():
6868
parser.add_argument("--model-dir", type=str, default=None)
6969
parser.add_argument("--eagle-dir", type=str, default=None)
7070
parser.add_argument("--custom-mm-prompts", action="store_true")
71+
parser.add_argument("--no-spec-decode", action="store_true")
72+
parser.add_argument("--async-scheduling", action="store_true")
7173
return parser.parse_args()
7274

7375

@@ -127,11 +129,12 @@ def main():
127129
enable_chunked_prefill=args.enable_chunked_prefill,
128130
enforce_eager=args.enforce_eager,
129131
gpu_memory_utilization=0.8,
130-
speculative_config=speculative_config,
132+
speculative_config=speculative_config if not args.no_spec_decode else None,
131133
disable_log_stats=False,
132134
max_model_len=16384,
133135
limit_mm_per_prompt={"image": 5},
134136
disable_chunked_mm_input=True,
137+
async_scheduling=args.async_scheduling,
135138
)
136139

137140
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)

vllm/engine/arg_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,13 +1194,6 @@ def create_engine_config(
11941194
raise ValueError("Async scheduling is not supported with "
11951195
"pipeline-parallel-size > 1.")
11961196

1197-
# Currently, async scheduling does not support speculative decoding.
1198-
# TODO(woosuk): Support it.
1199-
if self.speculative_config is not None:
1200-
raise ValueError(
1201-
"Currently, speculative decoding is not supported with "
1202-
"async scheduling.")
1203-
12041197
parallel_config = ParallelConfig(
12051198
pipeline_parallel_size=self.pipeline_parallel_size,
12061199
tensor_parallel_size=self.tensor_parallel_size,

vllm/v1/core/sched/async_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def _update_after_schedule(
2323
if (request.num_computed_tokens == request.num_tokens_with_spec +
2424
request.num_output_placeholders):
2525
# The request will generate a new token in this scheduling step.
26-
request.num_output_placeholders = 1 + len(request.spec_token_ids)
26+
request.num_output_placeholders = 1 + len(
27+
request.spec_token_ids)
2728

2829
def _update_request_with_output(
2930
self,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ def __init__(
332332

333333
self.reorder_batch_threshold: Optional[int] = None
334334

335+
# Cache spec token ids and num rejected tokens from previous round,
336+
# used when async scheduling and spec decoding are both enabled
337+
self.cached_spec_token_ids = {}
338+
self.cached_num_rejected_tokens = {}
339+
335340
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
336341
"""
337342
Update the order of requests in the batch based on the attention
@@ -381,6 +386,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
381386
for req_id in scheduler_output.finished_req_ids:
382387
self.requests.pop(req_id, None)
383388
self.encoder_cache.pop(req_id, None)
389+
self.cached_spec_token_ids.pop(req_id, None)
390+
self.cached_num_rejected_tokens.pop(req_id, None)
384391
# Remove the finished requests from the persistent batch.
385392
# NOTE(woosuk): There could be an edge case where finished_req_ids and
386393
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -494,6 +501,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
494501
for i, req_id in enumerate(req_data.req_ids):
495502
req_state = self.requests[req_id]
496503
num_computed_tokens = req_data.num_computed_tokens[i]
504+
if req_id in self.cached_num_rejected_tokens:
505+
num_computed_tokens -= self.cached_num_rejected_tokens[req_id]
497506
new_block_ids = req_data.new_block_ids[i]
498507
resumed_from_preemption = req_data.resumed_from_preemption[i]
499508

@@ -554,8 +563,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
554563
self.input_batch.num_tokens[req_index] = end_token_index
555564

556565
# Add spec_token_ids to token_ids_cpu.
557-
spec_token_ids = (
558-
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
566+
if req_id in self.cached_spec_token_ids:
567+
spec_token_ids = self.cached_spec_token_ids[req_id]
568+
else:
569+
spec_token_ids = (
570+
scheduler_output.scheduled_spec_decode_tokens.get(
571+
req_id, ()))
559572
if spec_token_ids:
560573
num_spec_tokens = len(spec_token_ids)
561574
start_index = self.input_batch.num_tokens_no_spec[req_index]
@@ -1743,6 +1756,13 @@ def execute_model(
17431756

17441757
self.eplb_step()
17451758

1759+
if self.speculative_config and self.scheduler_config.async_scheduling:
1760+
assert spec_token_ids
1761+
for idx, req_id in enumerate(self.input_batch.req_ids):
1762+
self.cached_spec_token_ids[req_id] = spec_token_ids[idx]
1763+
self.cached_num_rejected_tokens[req_id] = max_gen_len - len(
1764+
valid_sampled_token_ids[idx])
1765+
17461766
return ModelRunnerOutput(
17471767
req_ids=self.input_batch.req_ids,
17481768
req_id_to_index=self.input_batch.req_id_to_index,

0 commit comments

Comments
 (0)