@@ -69,6 +69,7 @@ class CommonAttentionMetadata:
69
69
70
70
logits_indices_padded : Optional [torch .Tensor ] = None
71
71
num_logits_indices : Optional [int ] = None
72
+ prompt_logprobs : Optional [bool ] = None
72
73
73
74
causal : bool = True
74
75
@@ -836,13 +837,25 @@ def build(self,
836
837
common_prefix_len : int ,
837
838
common_attn_metadata : CommonAttentionMetadata ,
838
839
fast_build : bool = False ) -> AttentionMetadata :
839
- new_common_attn_metadata = \
840
- make_kv_sharing_fast_prefill_common_attn_metadata (common_attn_metadata )
840
+ # Either not set (None) or prompt_logprobs is False
841
+ if not common_attn_metadata .prompt_logprobs :
842
+ # Fast prefill path
843
+ new_common_attn_metadata = \
844
+ make_kv_sharing_fast_prefill_common_attn_metadata (common_attn_metadata )
845
+ metadata = super (self .__class__ ,
846
+ self ).build (common_prefix_len ,
847
+ new_common_attn_metadata , fast_build )
848
+ return create_kv_sharing_fast_prefill_attn_metadata_subclass (
849
+ metadata , common_attn_metadata )
850
+
851
+ # Default path:
852
+ # Either --kv-sharing-fast-prefill is not set or at least one request
853
+ # in the current scheduling round requests logprobs for prompt tokens
854
+ # which is not compatible with fast prefill
841
855
metadata = super (self .__class__ ,
842
- self ).build (common_prefix_len ,
843
- new_common_attn_metadata , fast_build )
844
- return create_kv_sharing_fast_prefill_attn_metadata_subclass (
845
- metadata , common_attn_metadata )
856
+ self ).build (common_prefix_len , common_attn_metadata ,
857
+ fast_build )
858
+ return metadata
846
859
847
860
# Dynamically create a new attention backend that wraps the
848
861
# underlying attention backend but applies
0 commit comments