Skip to content

Commit e9b639d

Browse files
committed
refactor async_scheduler to keep num_output_placeholders constant
1 parent c09efff commit e9b639d

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

vllm/v1/core/sched/async_scheduler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ def _update_after_schedule(
2020
super()._update_after_schedule(scheduler_output)
2121
for req_id in scheduler_output.num_scheduled_tokens:
2222
request = self.requests[req_id]
23-
if (request.num_computed_tokens == request.num_tokens +
23+
if (request.num_computed_tokens == request.num_tokens_with_spec +
2424
request.num_output_placeholders):
2525
# The request will generate a new token in this scheduling step.
26-
# TODO(woosuk): Support speculative decoding.
27-
request.num_output_placeholders += 1
26+
request.num_output_placeholders = 1 + len(request.spec_token_ids)
2827

2928
def _update_request_with_output(
3029
self,
@@ -35,10 +34,6 @@ def _update_request_with_output(
3534
new_token_ids, stopped = super()._update_request_with_output(
3635
request, new_token_ids)
3736

38-
# Update the number of output placeholders.
39-
request.num_output_placeholders -= len(new_token_ids)
40-
assert request.num_output_placeholders >= 0
41-
4237
# Cache the new tokens. Preempted requests should be skipped.
4338
if status_before_update == RequestStatus.RUNNING:
4439
self.kv_cache_manager.cache_blocks(

0 commit comments

Comments
 (0)