Skip to content

Commit 3b9ddec

Browse files
committed
fix issue and add unit test
Signed-off-by: qizixi <[email protected]>
1 parent ad1b488 commit 3b9ddec

File tree

2 files changed

+79
-18
lines changed

2 files changed

+79
-18
lines changed

tests/v1/core/test_async_scheduler.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@
1212

1313

1414
def _make_model_runner_output(
15-
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput:
15+
scheduler_output: SchedulerOutput,
16+
sampled_token_ids: list[list[int]] | None = None,
17+
spec_token_ids: list[list[int]] | None = None) -> ModelRunnerOutput:
1618
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
19+
if not sampled_token_ids:
20+
sampled_token_ids = [[i] for i in range(len(req_ids))]
1721
return ModelRunnerOutput(
1822
req_ids=req_ids,
1923
req_id_to_index={
2024
req_id: i
2125
for i, req_id in enumerate(req_ids)
2226
},
23-
sampled_token_ids=[[i] for i in range(len(req_ids))],
24-
spec_token_ids=None,
27+
sampled_token_ids=sampled_token_ids,
28+
spec_token_ids=spec_token_ids,
2529
logprobs=None,
2630
prompt_logprobs_dict={},
2731
pooler_output=[],
@@ -55,6 +59,59 @@ def test_stop_by_max_tokens(max_tokens: int):
5559
assert req1.num_output_tokens == max_tokens
5660

5761

62+
def test_spec_decode():
63+
max_tokens = 7
64+
num_spec_tokens = 3
65+
spec_token_ids = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
66+
[-1, -2, -3]]
67+
sampled_token_ids = [[0], [1, 2, 13], [4, 15], [16], [-1, -2]]
68+
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens,
69+
async_scheduling=True)
70+
requests = create_requests(num_requests=1, max_tokens=max_tokens)
71+
req = requests[0]
72+
73+
sched_outputs: deque[SchedulerOutput] = deque()
74+
scheduler.add_request(req)
75+
sched_outputs.append(scheduler.schedule())
76+
sched_outputs.append(scheduler.schedule())
77+
78+
i = 0
79+
while sched_outputs:
80+
sched_output = sched_outputs.popleft()
81+
# Overwrite with cached spec decode tokens as done in GPUModelRunner
82+
if i > 0:
83+
sched_output.scheduled_spec_decode_tokens[
84+
req.request_id] = spec_token_ids[i - 1]
85+
model_runner_output = _make_model_runner_output(
86+
sched_output, [sampled_token_ids[i]], [spec_token_ids[i]])
87+
engine_core_output = scheduler.update_from_output(
88+
sched_output, model_runner_output)
89+
# Validate spec decode stats
90+
if engine_core_output:
91+
assert engine_core_output[0].scheduler_stats
92+
spec_decoding_stats = engine_core_output[
93+
0].scheduler_stats.spec_decoding_stats
94+
if i == 0:
95+
# No spec decode stats for prefill round
96+
assert spec_decoding_stats is None
97+
else:
98+
assert spec_decoding_stats
99+
assert spec_decoding_stats.num_drafts == 1
100+
assert spec_decoding_stats.num_draft_tokens == num_spec_tokens
101+
assert spec_decoding_stats.num_accepted_tokens == len(
102+
sampled_token_ids[i]) - 1
103+
sched_output = scheduler.schedule()
104+
if sched_output.num_scheduled_tokens:
105+
assert sched_output.num_scheduled_tokens[
106+
req.request_id] == 1 + num_spec_tokens
107+
sched_outputs.append(sched_output)
108+
i += 1
109+
110+
assert scheduler.get_num_unfinished_requests() == 0
111+
assert req.num_output_tokens == max_tokens
112+
assert req.output_token_ids._x == [0, 1, 2, 13, 4, 15, 16]
113+
114+
58115
def test_abort():
59116
scheduler = create_scheduler(async_scheduling=True)
60117
requests = create_requests(num_requests=10, max_tokens=20)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
457457
lora_request=new_req_data.lora_request,
458458
)
459459

460+
# Cache computed tokens for new request with
461+
# speculative decoding + async scheduling
462+
if (self.speculative_config
463+
and self.scheduler_config.async_scheduling):
464+
self.cached_num_computed_tokens[req_id] = (
465+
new_req_data.num_computed_tokens +
466+
scheduler_output.num_scheduled_tokens[req_id])
467+
460468
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
461469
if self.uses_mrope:
462470
image_grid_thw = []
@@ -1754,24 +1762,20 @@ def execute_model(
17541762
spec_decode_metadata,
17551763
spec_decode_common_attn_metadata,
17561764
)
1765+
# Update cached request states for async scheduling
1766+
if self.scheduler_config.async_scheduling:
1767+
for idx, req_id in enumerate(self.input_batch.req_ids):
1768+
if req_id in self.cached_spec_token_ids:
1769+
# Update num computed tokens for running requests
1770+
num_rejected_tokens = max_gen_len - len(
1771+
valid_sampled_token_ids[idx])
1772+
self.cached_num_computed_tokens[
1773+
req_id] += scheduler_output.num_scheduled_tokens[
1774+
req_id] - num_rejected_tokens
1775+
self.cached_spec_token_ids[req_id] = spec_token_ids[idx]
17571776

17581777
self.eplb_step()
17591778

1760-
if self.speculative_config and self.scheduler_config.async_scheduling:
1761-
assert spec_token_ids
1762-
for idx, req_id in enumerate(self.input_batch.req_ids):
1763-
self.cached_spec_token_ids[req_id] = spec_token_ids[idx]
1764-
num_rejected_tokens = max_gen_len - len(
1765-
valid_sampled_token_ids[idx])
1766-
if req_id not in self.cached_num_computed_tokens:
1767-
self.cached_num_computed_tokens[
1768-
req_id] = scheduler_output.num_scheduled_tokens[
1769-
req_id] - num_rejected_tokens
1770-
else:
1771-
self.cached_num_computed_tokens[
1772-
req_id] += scheduler_output.num_scheduled_tokens[
1773-
req_id] - num_rejected_tokens
1774-
17751779
return ModelRunnerOutput(
17761780
req_ids=self.input_batch.req_ids,
17771781
req_id_to_index=self.input_batch.req_id_to_index,

0 commit comments

Comments
 (0)