Skip to content

Commit 0d5a442

Browse files
committed
Fall back to full prefill if prompt_logprobs set
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent a6514dd commit 0d5a442

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class CommonAttentionMetadata:
6969

7070
logits_indices_padded: Optional[torch.Tensor] = None
7171
num_logits_indices: Optional[int] = None
72+
prompt_logprobs: Optional[bool] = None
7273

7374
causal: bool = True
7475

@@ -836,13 +837,25 @@ def build(self,
836837
common_prefix_len: int,
837838
common_attn_metadata: CommonAttentionMetadata,
838839
fast_build: bool = False) -> AttentionMetadata:
839-
new_common_attn_metadata =\
840-
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
840+
# Either not set (None) or prompt_logprobs is False
841+
if not common_attn_metadata.prompt_logprobs:
842+
# Fast prefill path
843+
new_common_attn_metadata =\
844+
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
845+
metadata = super(self.__class__,
846+
self).build(common_prefix_len,
847+
new_common_attn_metadata, fast_build)
848+
return create_kv_sharing_fast_prefill_attn_metadata_subclass(
849+
metadata, common_attn_metadata)
850+
851+
# Default path:
852+
# Either --kv-sharing-fast-prefill is not set or at least one request
853+
# in the current scheduling round requests logprobs for prompt tokens
854+
# which is not compatible with fast prefill
841855
metadata = super(self.__class__,
842-
self).build(common_prefix_len,
843-
new_common_attn_metadata, fast_build)
844-
return create_kv_sharing_fast_prefill_attn_metadata_subclass(
845-
metadata, common_attn_metadata)
856+
self).build(common_prefix_len, common_attn_metadata,
857+
fast_build)
858+
return metadata
846859

847860
# Dynamically create a new attention backend that wraps the
848861
# underlying attention backend but applies

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,11 +867,11 @@ def _prepare_inputs(
867867

868868
if (self.cache_config.kv_sharing_fast_prefill
869869
and self.input_batch.num_prompt_logprobs):
870-
raise RuntimeError(
870+
logger.warning(
871871
"Encountered at least one request with prompt_logprobs set "
872872
"with --kv-sharing-fast-prefill enabled. Fast prefill doesn't "
873-
"produce correct logits for prompt tokens. Please try again "
874-
"without the flag --kv-sharing-fast-prefill set.")
873+
"produce correct logits for prompt tokens, so fast prefill will"
874+
" be disabled for this iteration.")
875875

876876
# Prepare the attention metadata for each KV cache group and make layers
877877
# in the same group share the same metadata.
@@ -900,6 +900,7 @@ def _prepare_inputs(
900900
slot_mapping=slot_mapping,
901901
logits_indices_padded=logits_indices_padded,
902902
num_logits_indices=logits_indices.size(0),
903+
prompt_logprobs=len(self.input_batch.num_prompt_logprobs) > 0,
903904
causal=True,
904905
)
905906

0 commit comments

Comments
 (0)