@@ -708,6 +708,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
708708 # Update the cached states.
709709
710710 req_state .num_computed_tokens = num_computed_tokens
711+ req_index = self .input_batch .req_id_to_index .get (req_id )
711712
712713 if not is_last_rank :
713714 # When using PP, the scheduler sends the sampled tokens back,
@@ -728,19 +729,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
728729 # Some output tokens were discarded due to a sync-KV-load
729730 # failure. Align the cached state.
730731 del req_state .output_token_ids [num_output_tokens :]
731-
732- req_index = self .input_batch .req_id_to_index .get (req_id )
733732 if req_index is not None :
734- old_end_idx = self .input_batch .num_tokens_no_spec [req_index ]
735733 end_idx = (
736734 self .input_batch .num_prompt_tokens [req_index ]
737735 + num_output_tokens
738736 )
739737 self .input_batch .num_tokens [req_index ] = end_idx
740738 self .input_batch .num_tokens_no_spec [req_index ] = end_idx
741- self .input_batch .is_token_ids [req_index , end_idx :old_end_idx ] = (
742- False
743- )
744739
745740 # Update the block IDs.
746741 if not resumed_from_preemption :
@@ -749,12 +744,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
749744 for block_ids , new_ids in zip (req_state .block_ids , new_block_ids ):
750745 block_ids .extend (new_ids )
751746 else :
747+ assert req_index is None
752748 assert new_block_ids is not None
753749 # The request is resumed from preemption.
754750 # Replace the existing block IDs with the new ones.
755751 req_state .block_ids = new_block_ids
756752
757- req_index = self .input_batch .req_id_to_index .get (req_id )
758753 if req_index is None :
759754 # The request is not in the persistent batch.
760755 # The request was either preempted and resumed later, or was not
0 commit comments