|
12 | 12 |
|
13 | 13 |
|
14 | 14 | def _make_model_runner_output(
|
15 |
| - scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput: |
| 15 | + scheduler_output: SchedulerOutput, |
| 16 | + sampled_token_ids: list[list[int]] | None = None, |
| 17 | + spec_token_ids: list[list[int]] | None = None) -> ModelRunnerOutput: |
16 | 18 | req_ids = list(scheduler_output.num_scheduled_tokens.keys())
|
| 19 | + if not sampled_token_ids: |
| 20 | + sampled_token_ids = [[i] for i in range(len(req_ids))] |
17 | 21 | return ModelRunnerOutput(
|
18 | 22 | req_ids=req_ids,
|
19 | 23 | req_id_to_index={
|
20 | 24 | req_id: i
|
21 | 25 | for i, req_id in enumerate(req_ids)
|
22 | 26 | },
|
23 |
| - sampled_token_ids=[[i] for i in range(len(req_ids))], |
24 |
| - spec_token_ids=None, |
| 27 | + sampled_token_ids=sampled_token_ids, |
| 28 | + spec_token_ids=spec_token_ids, |
25 | 29 | logprobs=None,
|
26 | 30 | prompt_logprobs_dict={},
|
27 | 31 | pooler_output=[],
|
@@ -55,6 +59,59 @@ def test_stop_by_max_tokens(max_tokens: int):
|
55 | 59 | assert req1.num_output_tokens == max_tokens
|
56 | 60 |
|
57 | 61 |
|
| 62 | +def test_spec_decode(): |
| 63 | + max_tokens = 7 |
| 64 | + num_spec_tokens = 3 |
| 65 | + spec_token_ids = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], |
| 66 | + [-1, -2, -3]] |
| 67 | + sampled_token_ids = [[0], [1, 2, 13], [4, 15], [16], [-1, -2]] |
| 68 | + scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens, |
| 69 | + async_scheduling=True) |
| 70 | + requests = create_requests(num_requests=1, max_tokens=max_tokens) |
| 71 | + req = requests[0] |
| 72 | + |
| 73 | + sched_outputs: deque[SchedulerOutput] = deque() |
| 74 | + scheduler.add_request(req) |
| 75 | + sched_outputs.append(scheduler.schedule()) |
| 76 | + sched_outputs.append(scheduler.schedule()) |
| 77 | + |
| 78 | + i = 0 |
| 79 | + while sched_outputs: |
| 80 | + sched_output = sched_outputs.popleft() |
| 81 | + # Overwrite with cached spec decode tokens as done in GPUModelRunner |
| 82 | + if i > 0: |
| 83 | + sched_output.scheduled_spec_decode_tokens[ |
| 84 | + req.request_id] = spec_token_ids[i - 1] |
| 85 | + model_runner_output = _make_model_runner_output( |
| 86 | + sched_output, [sampled_token_ids[i]], [spec_token_ids[i]]) |
| 87 | + engine_core_output = scheduler.update_from_output( |
| 88 | + sched_output, model_runner_output) |
| 89 | + # Validate spec decode stats |
| 90 | + if engine_core_output: |
| 91 | + assert engine_core_output[0].scheduler_stats |
| 92 | + spec_decoding_stats = engine_core_output[ |
| 93 | + 0].scheduler_stats.spec_decoding_stats |
| 94 | + if i == 0: |
| 95 | + # No spec decode stats for prefill round |
| 96 | + assert spec_decoding_stats is None |
| 97 | + else: |
| 98 | + assert spec_decoding_stats |
| 99 | + assert spec_decoding_stats.num_drafts == 1 |
| 100 | + assert spec_decoding_stats.num_draft_tokens == num_spec_tokens |
| 101 | + assert spec_decoding_stats.num_accepted_tokens == len( |
| 102 | + sampled_token_ids[i]) - 1 |
| 103 | + sched_output = scheduler.schedule() |
| 104 | + if sched_output.num_scheduled_tokens: |
| 105 | + assert sched_output.num_scheduled_tokens[ |
| 106 | + req.request_id] == 1 + num_spec_tokens |
| 107 | + sched_outputs.append(sched_output) |
| 108 | + i += 1 |
| 109 | + |
| 110 | + assert scheduler.get_num_unfinished_requests() == 0 |
| 111 | + assert req.num_output_tokens == max_tokens |
| 112 | + assert req.output_token_ids._x == [0, 1, 2, 13, 4, 15, 16] |
| 113 | + |
| 114 | + |
58 | 115 | def test_abort():
|
59 | 116 | scheduler = create_scheduler(async_scheduling=True)
|
60 | 117 | requests = create_requests(num_requests=10, max_tokens=20)
|
|
0 commit comments