Skip to content

Commit a822572

Browse files
committed
Cleaner code
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 1853ce1 commit a822572

File tree

3 files changed

+39
-49
lines changed

3 files changed

+39
-49
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class CommonAttentionMetadata:
6464
block_table_tensor: torch.Tensor
6565
slot_mapping: torch.Tensor
6666

67-
logits_indices: Optional[torch.Tensor] = None
67+
logits_indices_padded: Optional[torch.Tensor] = None
68+
num_logits_indices: Optional[int] = None
6869

6970
causal: bool = True
7071

@@ -534,7 +535,6 @@ def make_local_attention_virtual_batches(
534535
max_query_len=seqlens_q_local.max(),
535536
block_table_tensor=block_table_local,
536537
slot_mapping=common_attn_metadata.slot_mapping,
537-
logits_indices=common_attn_metadata.logits_indices,
538538
causal=True,
539539
)
540540

@@ -547,14 +547,14 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
547547
# Skip computing fast prefill path
548548
return common_attn_metadata
549549

550-
if common_attn_metadata.logits_indices is None:
551-
# Logits_indices can be None if prompt_logprobs is
552-
# set for at least one request in the current iteration
553-
# fast prefill is not compatible with prompt_logprobs
554-
# so skip computing fast prefill path
550+
if (common_attn_metadata.logits_indices_padded is None
551+
or common_attn_metadata.num_logits_indices is None):
555552
return common_attn_metadata
556553

557-
logits_indices = common_attn_metadata.logits_indices
554+
logits_indices_padded = common_attn_metadata.logits_indices_padded
555+
num_logits_indices = common_attn_metadata.num_logits_indices
556+
# Get rid of CUDAGraph padding, if any
557+
logits_indices = logits_indices_padded[:num_logits_indices]
558558
num_reqs = common_attn_metadata.num_reqs
559559
query_start_loc = common_attn_metadata.query_start_loc
560560
seq_lens = common_attn_metadata.seq_lens
@@ -597,7 +597,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
597597
max_query_len=decode_max_query_len,
598598
block_table_tensor=common_attn_metadata.block_table_tensor,
599599
slot_mapping=common_attn_metadata.slot_mapping,
600-
logits_indices=logits_indices,
601600
causal=True,
602601
)
603602
return common_attn_metadata
@@ -608,6 +607,9 @@ def subclass_attention_metadata_builder(
608607
builder_cls: type[AttentionMetadataBuilder[M]],
609608
build_preprocess_fn: Callable[[CommonAttentionMetadata],
610609
CommonAttentionMetadata],
610+
build_postprocess_fn: Optional[
611+
Callable[[AttentionMetadataBuilder[M], CommonAttentionMetadata, Any],
612+
Any]] = None,
611613
) -> type[AttentionMetadataBuilder[M]]:
612614
"""
613615
Return a new subclass of `builder_cls` whose .build(...) method
@@ -619,9 +621,13 @@ def build(self,
619621
common_prefix_len: int,
620622
common_attn_metadata: CommonAttentionMetadata,
621623
fast_build: bool = False):
622-
return builder_cls.build(self, common_prefix_len,
623-
build_preprocess_fn(common_attn_metadata),
624-
fast_build)
624+
metadata = builder_cls.build(self, common_prefix_len,
625+
build_preprocess_fn(common_attn_metadata),
626+
fast_build)
627+
if build_postprocess_fn is not None:
628+
metadata = build_postprocess_fn(self, common_attn_metadata,
629+
metadata)
630+
return metadata
625631

626632
Wrapped = type(
627633
name,
@@ -800,25 +806,25 @@ class KVSharingFastPrefillAttentionMetadata(Protocol):
800806

801807

802808
def create_kv_sharing_fast_prefill_attn_metadata_subclass(
803-
attn_metadata_i: Any,
804-
logits_indices_padded: torch.Tensor,
805-
num_logits_indices: int,
806-
):
809+
self: AttentionMetadataBuilder[M],
810+
common_attn_metadata: CommonAttentionMetadata,
811+
metadata: Any,
812+
) -> Any:
807813
# Dynamically create a a dataclass type that inherits
808814
# from attention metadata type but includes additional
809815
# fields logits_indices_padded and num_logits_indices
810816
# which are required for prefill truncation
811817
fast_prefill_metadata_type = (
812818
make_kv_sharing_fast_prefill_attention_metadata(
813-
metadata_cls=type(attn_metadata_i), )) # type: ignore
819+
metadata_cls=type(metadata), )) # type: ignore
814820
# Avoid deepcopy caused by dict.asdict
815821
attn_metadata_fields = {}
816-
for field in fields(attn_metadata_i.__class__):
817-
attn_metadata_fields[field.name] = getattr(attn_metadata_i, field.name)
822+
for field in fields(metadata.__class__):
823+
attn_metadata_fields[field.name] = getattr(metadata, field.name)
818824
attn_metadata_i = fast_prefill_metadata_type(
819825
**attn_metadata_fields,
820-
logits_indices_padded=logits_indices_padded,
821-
num_logits_indices=num_logits_indices,
826+
logits_indices_padded=common_attn_metadata.logits_indices_padded,
827+
num_logits_indices=common_attn_metadata.num_logits_indices,
822828
)
823829
return attn_metadata_i
824830

@@ -829,14 +835,19 @@ def create_custom_attention_backend(
829835
underlying_attn_backend: AttentionBackend,
830836
build_preprocess_fn: Callable[[CommonAttentionMetadata],
831837
CommonAttentionMetadata],
838+
build_postprocess_fn: Optional[
839+
Callable[[AttentionMetadataBuilder[M], CommonAttentionMetadata, Any],
840+
Any]] = None,
832841
) -> type[AttentionBackend]:
833842
# Dynamically create a new attention backend that wraps the
834843
# underlying attention backend but applies
835844
# `build_preproces_fn` before calling `build(...)`
836845
builder_cls = subclass_attention_metadata_builder(
837846
name_prefix=prefix,
838847
builder_cls=underlying_attn_backend.get_builder_cls(),
839-
build_preprocess_fn=build_preprocess_fn)
848+
build_preprocess_fn=build_preprocess_fn,
849+
build_postprocess_fn=build_postprocess_fn,
850+
)
840851
attn_backend = subclass_attention_backend(
841852
name_prefix=prefix,
842853
attention_backend_cls=underlying_attn_backend,

vllm/v1/spec_decode/eagle.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,6 @@ def prepare_inputs(
609609
max_query_len=new_query_len_per_req.max().item(),
610610
block_table_tensor=common_attn_metadata.block_table_tensor,
611611
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
612-
logits_indices=common_attn_metadata.logits_indices,
613612
causal=True,
614613
)
615614

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -806,11 +806,11 @@ def _prepare_inputs(
806806

807807
if (self.cache_config.kv_sharing_fast_prefill
808808
and self.input_batch.num_prompt_logprobs):
809-
logger.warning_once(
809+
raise RuntimeError(
810810
"Encountered at least one request with prompt_logprobs set "
811811
"with --kv-sharing-fast-prefill enabled. Fast prefill doesn't "
812-
"produce correct logits for prompt tokens, so fast prefill "
813-
"will be disabled for scheduling rounds with prompt_logprobs.")
812+
"produce correct logits for prompt tokens. Please try again "
813+
"without the flag --kv-sharing-fast-prefill set.")
814814

815815
# Prepare the attention metadata for each KV cache group and make layers
816816
# in the same group share the same metadata.
@@ -837,6 +837,8 @@ def _prepare_inputs(
837837
max_query_len=max_num_scheduled_tokens,
838838
block_table_tensor=blk_table_tensor,
839839
slot_mapping=slot_mapping,
840+
logits_indices_padded=logits_indices_padded,
841+
num_logits_indices=logits_indices.size(0),
840842
causal=True,
841843
)
842844

@@ -857,34 +859,11 @@ def _prepare_inputs(
857859
builder,
858860
)
859861

860-
# If there is at least one request with prompt_logprobs set,
861-
# we cannot enable this optimization as the logits of prompt
862-
# tokens will no longer be valid when doing fast prefill.
863-
is_fast_prefill = (
864-
attn_group.layer_names[0]
865-
in self.kv_sharing_fast_prefill_eligible_layers
866-
and not self.input_batch.num_prompt_logprobs)
867-
if is_fast_prefill:
868-
# If logits_indices is set, builder.build(...) will
869-
# preprocess the common metadata to skip prefill tokens
870-
common_attn_metadata.logits_indices = logits_indices
871-
# TODO(sarckk): Enable cascade attention for fast prefill
872-
common_prefix_len = 0
873-
874862
attn_metadata_i = (builder.build(
875863
common_prefix_len=common_prefix_len,
876864
common_attn_metadata=common_attn_metadata,
877865
))
878866

879-
if is_fast_prefill:
880-
# Eligible layers need extra metadata for use in the model.
881-
attn_metadata_i = \
882-
create_kv_sharing_fast_prefill_attn_metadata_subclass(
883-
attn_metadata_i,
884-
logits_indices_padded,
885-
logits_indices.size(0),
886-
)
887-
888867
for layer_name in attn_group.layer_names:
889868
attn_metadata[layer_name] = attn_metadata_i
890869

@@ -2577,6 +2556,7 @@ def get_attn_backends_for_layers(
25772556
"FastPrefill",
25782557
attn_backend,
25792558
make_kv_sharing_fast_prefill_common_attn_metadata,
2559+
create_kv_sharing_fast_prefill_attn_metadata_subclass,
25802560
)
25812561

25822562
key = attn_backend.full_cls_name()

0 commit comments

Comments
 (0)