Skip to content

Commit ad1b488

Browse files
committed
further fixes
Signed-off-by: qizixi <[email protected]>
1 parent 82deff1 commit ad1b488

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

vllm/v1/core/sched/async_scheduler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def _update_after_schedule(
2323
if (request.num_computed_tokens == request.num_tokens_with_spec +
2424
request.num_output_placeholders):
2525
# The request will generate a new token in this scheduling step.
26-
request.num_output_placeholders = 1 + len(
27-
request.spec_token_ids)
26+
request.num_output_placeholders = 1 + self.num_spec_tokens
2827

2928
def _update_request_with_output(
3029
self,
@@ -37,7 +36,5 @@ def _update_request_with_output(
3736

3837
# Cache the new tokens. Preempted requests should be skipped.
3938
if status_before_update == RequestStatus.RUNNING:
40-
self.kv_cache_manager.cache_blocks(
41-
request,
42-
request.num_computed_tokens - request.num_output_placeholders)
39+
self.kv_cache_manager.cache_blocks(request, request.num_tokens)
4340
return new_token_ids, stopped

vllm/v1/core/sched/scheduler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,12 @@ def schedule(self) -> SchedulerOutput:
204204
while req_index < len(self.running) and token_budget > 0:
205205
request = self.running[req_index]
206206

207-
num_new_tokens = (request.num_tokens_with_spec +
208-
request.num_output_placeholders -
209-
request.num_computed_tokens)
207+
if request.num_output_placeholders:
208+
num_new_tokens = request.num_output_placeholders
209+
else:
210+
num_new_tokens = (request.num_tokens_with_spec -
211+
request.num_computed_tokens)
212+
210213
if (0 < self.scheduler_config.long_prefill_token_threshold <
211214
num_new_tokens):
212215
num_new_tokens = (

vllm/v1/worker/gpu_model_runner.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,10 @@ def __init__(
332332

333333
self.reorder_batch_threshold: Optional[int] = None
334334

335-
# Cache spec token ids and num rejected tokens from previous round,
335+
# Cache spec token ids and num computed tokens from previous round,
336336
# used when async scheduling and spec decoding are both enabled
337-
self.cached_spec_token_ids = {}
338-
self.cached_num_rejected_tokens = {}
337+
self.cached_spec_token_ids: dict[str, list[int]] = {}
338+
self.cached_num_computed_tokens: dict[str, int] = {}
339339

340340
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
341341
"""
@@ -387,7 +387,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
387387
self.requests.pop(req_id, None)
388388
self.encoder_cache.pop(req_id, None)
389389
self.cached_spec_token_ids.pop(req_id, None)
390-
self.cached_num_rejected_tokens.pop(req_id, None)
390+
self.cached_num_computed_tokens.pop(req_id, None)
391+
391392
# Remove the finished requests from the persistent batch.
392393
# NOTE(woosuk): There could be an edge case where finished_req_ids and
393394
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -500,9 +501,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
500501
req_data = scheduler_output.scheduled_cached_reqs
501502
for i, req_id in enumerate(req_data.req_ids):
502503
req_state = self.requests[req_id]
503-
num_computed_tokens = req_data.num_computed_tokens[i]
504-
if req_id in self.cached_num_rejected_tokens:
505-
num_computed_tokens -= self.cached_num_rejected_tokens[req_id]
504+
if req_id in self.cached_spec_token_ids:
505+
scheduler_output.scheduled_spec_decode_tokens[
506+
req_id] = self.cached_spec_token_ids[req_id]
507+
if req_id in self.cached_num_computed_tokens:
508+
num_computed_tokens = self.cached_num_computed_tokens[req_id]
509+
else:
510+
num_computed_tokens = req_data.num_computed_tokens[i]
506511
new_block_ids = req_data.new_block_ids[i]
507512
resumed_from_preemption = req_data.resumed_from_preemption[i]
508513

@@ -563,12 +568,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
563568
self.input_batch.num_tokens[req_index] = end_token_index
564569

565570
# Add spec_token_ids to token_ids_cpu.
566-
if req_id in self.cached_spec_token_ids:
567-
spec_token_ids = self.cached_spec_token_ids[req_id]
568-
else:
569-
spec_token_ids = (
570-
scheduler_output.scheduled_spec_decode_tokens.get(
571-
req_id, ()))
571+
spec_token_ids = (
572+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
572573
if spec_token_ids:
573574
num_spec_tokens = len(spec_token_ids)
574575
start_index = self.input_batch.num_tokens_no_spec[req_index]
@@ -1760,8 +1761,16 @@ def execute_model(
17601761
assert spec_token_ids
17611762
for idx, req_id in enumerate(self.input_batch.req_ids):
17621763
self.cached_spec_token_ids[req_id] = spec_token_ids[idx]
1763-
self.cached_num_rejected_tokens[req_id] = max_gen_len - len(
1764+
num_rejected_tokens = max_gen_len - len(
17641765
valid_sampled_token_ids[idx])
1766+
if req_id not in self.cached_num_computed_tokens:
1767+
self.cached_num_computed_tokens[
1768+
req_id] = scheduler_output.num_scheduled_tokens[
1769+
req_id] - num_rejected_tokens
1770+
else:
1771+
self.cached_num_computed_tokens[
1772+
req_id] += scheduler_output.num_scheduled_tokens[
1773+
req_id] - num_rejected_tokens
17651774

17661775
return ModelRunnerOutput(
17671776
req_ids=self.input_batch.req_ids,

0 commit comments

Comments
 (0)