Skip to content

Commit aafb99a

Browse files
authored
[Core] Small simplification in GPUModelRunner._update_states() (vllm-project#26508)
Signed-off-by: Nick Hill <[email protected]>
1 parent 757fa4a commit aafb99a

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
708708
# Update the cached states.
709709

710710
req_state.num_computed_tokens = num_computed_tokens
711+
req_index = self.input_batch.req_id_to_index.get(req_id)
711712

712713
if not is_last_rank:
713714
# When using PP, the scheduler sends the sampled tokens back,
@@ -728,19 +729,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
728729
# Some output tokens were discarded due to a sync-KV-load
729730
# failure. Align the cached state.
730731
del req_state.output_token_ids[num_output_tokens:]
731-
732-
req_index = self.input_batch.req_id_to_index.get(req_id)
733732
if req_index is not None:
734-
old_end_idx = self.input_batch.num_tokens_no_spec[req_index]
735733
end_idx = (
736734
self.input_batch.num_prompt_tokens[req_index]
737735
+ num_output_tokens
738736
)
739737
self.input_batch.num_tokens[req_index] = end_idx
740738
self.input_batch.num_tokens_no_spec[req_index] = end_idx
741-
self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = (
742-
False
743-
)
744739

745740
# Update the block IDs.
746741
if not resumed_from_preemption:
@@ -749,12 +744,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
749744
for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
750745
block_ids.extend(new_ids)
751746
else:
747+
assert req_index is None
752748
assert new_block_ids is not None
753749
# The request is resumed from preemption.
754750
# Replace the existing block IDs with the new ones.
755751
req_state.block_ids = new_block_ids
756752

757-
req_index = self.input_batch.req_id_to_index.get(req_id)
758753
if req_index is None:
759754
# The request is not in the persistent batch.
760755
# The request was either preempted and resumed later, or was not

0 commit comments

Comments
 (0)