@@ -332,6 +332,11 @@ def __init__(
332
332
333
333
self .reorder_batch_threshold : Optional [int ] = None
334
334
335
+ # Cache spec token ids and num rejected tokens from previous round,
336
+ # used when async scheduling and spec decoding are both enabled
337
+ self .cached_spec_token_ids = {}
338
+ self .cached_num_rejected_tokens = {}
339
+
335
340
def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> None :
336
341
"""
337
342
Update the order of requests in the batch based on the attention
@@ -381,6 +386,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
381
386
for req_id in scheduler_output .finished_req_ids :
382
387
self .requests .pop (req_id , None )
383
388
self .encoder_cache .pop (req_id , None )
389
+ self .cached_spec_token_ids .pop (req_id , None )
390
+ self .cached_num_rejected_tokens .pop (req_id , None )
384
391
# Remove the finished requests from the persistent batch.
385
392
# NOTE(woosuk): There could be an edge case where finished_req_ids and
386
393
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -494,6 +501,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
494
501
for i , req_id in enumerate (req_data .req_ids ):
495
502
req_state = self .requests [req_id ]
496
503
num_computed_tokens = req_data .num_computed_tokens [i ]
504
+ if req_id in self .cached_num_rejected_tokens :
505
+ num_computed_tokens -= self .cached_num_rejected_tokens [req_id ]
497
506
new_block_ids = req_data .new_block_ids [i ]
498
507
resumed_from_preemption = req_data .resumed_from_preemption [i ]
499
508
@@ -554,8 +563,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
554
563
self .input_batch .num_tokens [req_index ] = end_token_index
555
564
556
565
# Add spec_token_ids to token_ids_cpu.
557
- spec_token_ids = (
558
- scheduler_output .scheduled_spec_decode_tokens .get (req_id , ()))
566
+ if req_id in self .cached_spec_token_ids :
567
+ spec_token_ids = self .cached_spec_token_ids [req_id ]
568
+ else :
569
+ spec_token_ids = (
570
+ scheduler_output .scheduled_spec_decode_tokens .get (
571
+ req_id , ()))
559
572
if spec_token_ids :
560
573
num_spec_tokens = len (spec_token_ids )
561
574
start_index = self .input_batch .num_tokens_no_spec [req_index ]
@@ -1743,6 +1756,13 @@ def execute_model(
1743
1756
1744
1757
self .eplb_step ()
1745
1758
1759
+ if self .speculative_config and self .scheduler_config .async_scheduling :
1760
+ assert spec_token_ids
1761
+ for idx , req_id in enumerate (self .input_batch .req_ids ):
1762
+ self .cached_spec_token_ids [req_id ] = spec_token_ids [idx ]
1763
+ self .cached_num_rejected_tokens [req_id ] = max_gen_len - len (
1764
+ valid_sampled_token_ids [idx ])
1765
+
1746
1766
return ModelRunnerOutput (
1747
1767
req_ids = self .input_batch .req_ids ,
1748
1768
req_id_to_index = self .input_batch .req_id_to_index ,
0 commit comments