@@ -332,10 +332,10 @@ def __init__(
332
332
333
333
self .reorder_batch_threshold : Optional [int ] = None
334
334
335
- # Cache spec token ids and num rejected tokens from previous round,
335
+ # Cache spec token ids and num computed tokens from previous round,
336
336
# used when async scheduling and spec decoding are both enabled
337
- self .cached_spec_token_ids = {}
338
- self .cached_num_rejected_tokens = {}
337
+ self .cached_spec_token_ids : dict [ str , list [ int ]] = {}
338
+ self .cached_num_computed_tokens : dict [ str , int ] = {}
339
339
340
340
def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> None :
341
341
"""
@@ -387,7 +387,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
387
387
self .requests .pop (req_id , None )
388
388
self .encoder_cache .pop (req_id , None )
389
389
self .cached_spec_token_ids .pop (req_id , None )
390
- self .cached_num_rejected_tokens .pop (req_id , None )
390
+ self .cached_num_computed_tokens .pop (req_id , None )
391
+
391
392
# Remove the finished requests from the persistent batch.
392
393
# NOTE(woosuk): There could be an edge case where finished_req_ids and
393
394
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -500,9 +501,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
500
501
req_data = scheduler_output .scheduled_cached_reqs
501
502
for i , req_id in enumerate (req_data .req_ids ):
502
503
req_state = self .requests [req_id ]
503
- num_computed_tokens = req_data .num_computed_tokens [i ]
504
- if req_id in self .cached_num_rejected_tokens :
505
- num_computed_tokens -= self .cached_num_rejected_tokens [req_id ]
504
+ if req_id in self .cached_spec_token_ids :
505
+ scheduler_output .scheduled_spec_decode_tokens [
506
+ req_id ] = self .cached_spec_token_ids [req_id ]
507
+ if req_id in self .cached_num_computed_tokens :
508
+ num_computed_tokens = self .cached_num_computed_tokens [req_id ]
509
+ else :
510
+ num_computed_tokens = req_data .num_computed_tokens [i ]
506
511
new_block_ids = req_data .new_block_ids [i ]
507
512
resumed_from_preemption = req_data .resumed_from_preemption [i ]
508
513
@@ -563,12 +568,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
563
568
self .input_batch .num_tokens [req_index ] = end_token_index
564
569
565
570
# Add spec_token_ids to token_ids_cpu.
566
- if req_id in self .cached_spec_token_ids :
567
- spec_token_ids = self .cached_spec_token_ids [req_id ]
568
- else :
569
- spec_token_ids = (
570
- scheduler_output .scheduled_spec_decode_tokens .get (
571
- req_id , ()))
571
+ spec_token_ids = (
572
+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , ()))
572
573
if spec_token_ids :
573
574
num_spec_tokens = len (spec_token_ids )
574
575
start_index = self .input_batch .num_tokens_no_spec [req_index ]
@@ -1760,8 +1761,16 @@ def execute_model(
1760
1761
assert spec_token_ids
1761
1762
for idx , req_id in enumerate (self .input_batch .req_ids ):
1762
1763
self .cached_spec_token_ids [req_id ] = spec_token_ids [idx ]
1763
- self . cached_num_rejected_tokens [ req_id ] = max_gen_len - len (
1764
+ num_rejected_tokens = max_gen_len - len (
1764
1765
valid_sampled_token_ids [idx ])
1766
+ if req_id not in self .cached_num_computed_tokens :
1767
+ self .cached_num_computed_tokens [
1768
+ req_id ] = scheduler_output .num_scheduled_tokens [
1769
+ req_id ] - num_rejected_tokens
1770
+ else :
1771
+ self .cached_num_computed_tokens [
1772
+ req_id ] += scheduler_output .num_scheduled_tokens [
1773
+ req_id ] - num_rejected_tokens
1765
1774
1766
1775
return ModelRunnerOutput (
1767
1776
req_ids = self .input_batch .req_ids ,
0 commit comments