Skip to content

Commit bcf331a

Browse files
committed
Cleaner code
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 271f14c commit bcf331a

File tree

3 files changed

+39
-49
lines changed

3 files changed

+39
-49
lines changed

vllm/v1/attention/backends/utils.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class CommonAttentionMetadata:
6464
block_table_tensor: torch.Tensor
6565
slot_mapping: torch.Tensor
6666

67-
logits_indices: Optional[torch.Tensor] = None
67+
logits_indices_padded: Optional[torch.Tensor] = None
68+
num_logits_indices: Optional[int] = None
6869

6970
causal: bool = True
7071

@@ -537,7 +538,6 @@ def make_local_attention_virtual_batches(
537538
max_query_len=seqlens_q_local.max(),
538539
block_table_tensor=block_table_local,
539540
slot_mapping=common_attn_metadata.slot_mapping,
540-
logits_indices=common_attn_metadata.logits_indices,
541541
causal=True,
542542
)
543543

@@ -550,14 +550,14 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
550550
# Skip computing fast prefill path
551551
return common_attn_metadata
552552

553-
if common_attn_metadata.logits_indices is None:
554-
# Logits_indices can be None if prompt_logprobs is
555-
# set for at least one request in the current iteration
556-
# fast prefill is not compatible with prompt_logprobs
557-
# so skip computing fast prefill path
553+
if (common_attn_metadata.logits_indices_padded is None
554+
or common_attn_metadata.num_logits_indices is None):
558555
return common_attn_metadata
559556

560-
logits_indices = common_attn_metadata.logits_indices
557+
logits_indices_padded = common_attn_metadata.logits_indices_padded
558+
num_logits_indices = common_attn_metadata.num_logits_indices
559+
# Get rid of CUDAGraph padding, if any
560+
logits_indices = logits_indices_padded[:num_logits_indices]
561561
num_reqs = common_attn_metadata.num_reqs
562562
query_start_loc = common_attn_metadata.query_start_loc
563563
seq_lens = common_attn_metadata.seq_lens
@@ -600,7 +600,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
600600
max_query_len=decode_max_query_len,
601601
block_table_tensor=common_attn_metadata.block_table_tensor,
602602
slot_mapping=common_attn_metadata.slot_mapping,
603-
logits_indices=logits_indices,
604603
causal=True,
605604
)
606605
return common_attn_metadata
@@ -611,6 +610,9 @@ def subclass_attention_metadata_builder(
611610
builder_cls: type[AttentionMetadataBuilder[M]],
612611
build_preprocess_fn: Callable[[CommonAttentionMetadata],
613612
CommonAttentionMetadata],
613+
build_postprocess_fn: Optional[
614+
Callable[[AttentionMetadataBuilder[M], CommonAttentionMetadata, Any],
615+
Any]] = None,
614616
) -> type[AttentionMetadataBuilder[M]]:
615617
"""
616618
Return a new subclass of `builder_cls` whose .build(...) method
@@ -622,9 +624,13 @@ def build(self,
622624
common_prefix_len: int,
623625
common_attn_metadata: CommonAttentionMetadata,
624626
fast_build: bool = False):
625-
return builder_cls.build(self, common_prefix_len,
626-
build_preprocess_fn(common_attn_metadata),
627-
fast_build)
627+
metadata = builder_cls.build(self, common_prefix_len,
628+
build_preprocess_fn(common_attn_metadata),
629+
fast_build)
630+
if build_postprocess_fn is not None:
631+
metadata = build_postprocess_fn(self, common_attn_metadata,
632+
metadata)
633+
return metadata
628634

629635
Wrapped = type(
630636
name,
@@ -803,25 +809,25 @@ class KVSharingFastPrefillAttentionMetadata(Protocol):
803809

804810

805811
def create_kv_sharing_fast_prefill_attn_metadata_subclass(
806-
attn_metadata_i: Any,
807-
logits_indices_padded: torch.Tensor,
808-
num_logits_indices: int,
809-
):
812+
self: AttentionMetadataBuilder[M],
813+
common_attn_metadata: CommonAttentionMetadata,
814+
metadata: Any,
815+
) -> Any:
810816
# Dynamically create a a dataclass type that inherits
811817
# from attention metadata type but includes additional
812818
# fields logits_indices_padded and num_logits_indices
813819
# which are required for prefill truncation
814820
fast_prefill_metadata_type = (
815821
make_kv_sharing_fast_prefill_attention_metadata(
816-
metadata_cls=type(attn_metadata_i), )) # type: ignore
822+
metadata_cls=type(metadata), )) # type: ignore
817823
# Avoid deepcopy caused by dict.asdict
818824
attn_metadata_fields = {}
819-
for field in fields(attn_metadata_i.__class__):
820-
attn_metadata_fields[field.name] = getattr(attn_metadata_i, field.name)
825+
for field in fields(metadata.__class__):
826+
attn_metadata_fields[field.name] = getattr(metadata, field.name)
821827
attn_metadata_i = fast_prefill_metadata_type(
822828
**attn_metadata_fields,
823-
logits_indices_padded=logits_indices_padded,
824-
num_logits_indices=num_logits_indices,
829+
logits_indices_padded=common_attn_metadata.logits_indices_padded,
830+
num_logits_indices=common_attn_metadata.num_logits_indices,
825831
)
826832
return attn_metadata_i
827833

@@ -832,14 +838,19 @@ def create_custom_attention_backend(
832838
underlying_attn_backend: AttentionBackend,
833839
build_preprocess_fn: Callable[[CommonAttentionMetadata],
834840
CommonAttentionMetadata],
841+
build_postprocess_fn: Optional[
842+
Callable[[AttentionMetadataBuilder[M], CommonAttentionMetadata, Any],
843+
Any]] = None,
835844
) -> type[AttentionBackend]:
836845
# Dynamically create a new attention backend that wraps the
837846
# underlying attention backend but applies
838847
# `build_preproces_fn` before calling `build(...)`
839848
builder_cls = subclass_attention_metadata_builder(
840849
name_prefix=prefix,
841850
builder_cls=underlying_attn_backend.get_builder_cls(),
842-
build_preprocess_fn=build_preprocess_fn)
851+
build_preprocess_fn=build_preprocess_fn,
852+
build_postprocess_fn=build_postprocess_fn,
853+
)
843854
attn_backend = subclass_attention_backend(
844855
name_prefix=prefix,
845856
attention_backend_cls=underlying_attn_backend,

vllm/v1/spec_decode/eagle.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,6 @@ def prepare_inputs(
584584
max_query_len=new_query_len_per_req.max().item(),
585585
block_table_tensor=common_attn_metadata.block_table_tensor,
586586
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
587-
logits_indices=common_attn_metadata.logits_indices,
588587
causal=True,
589588
)
590589

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -868,11 +868,11 @@ def _prepare_inputs(
868868

869869
if (self.cache_config.kv_sharing_fast_prefill
870870
and self.input_batch.num_prompt_logprobs):
871-
logger.warning_once(
871+
raise RuntimeError(
872872
"Encountered at least one request with prompt_logprobs set "
873873
"with --kv-sharing-fast-prefill enabled. Fast prefill doesn't "
874-
"produce correct logits for prompt tokens, so fast prefill "
875-
"will be disabled for scheduling rounds with prompt_logprobs.")
874+
"produce correct logits for prompt tokens. Please try again "
875+
"without the flag --kv-sharing-fast-prefill set.")
876876

877877
# Prepare the attention metadata for each KV cache group and make layers
878878
# in the same group share the same metadata.
@@ -898,6 +898,8 @@ def _prepare_inputs(
898898
max_query_len=max_num_scheduled_tokens,
899899
block_table_tensor=blk_table_tensor,
900900
slot_mapping=slot_mapping,
901+
logits_indices_padded=logits_indices_padded,
902+
num_logits_indices=logits_indices.size(0),
901903
causal=True,
902904
)
903905

@@ -918,34 +920,11 @@ def _prepare_inputs(
918920
builder,
919921
)
920922

921-
# If there is at least one request with prompt_logprobs set,
922-
# we cannot enable this optimization as the logits of prompt
923-
# tokens will no longer be valid when doing fast prefill.
924-
is_fast_prefill = (
925-
attn_group.layer_names[0]
926-
in self.kv_sharing_fast_prefill_eligible_layers
927-
and not self.input_batch.num_prompt_logprobs)
928-
if is_fast_prefill:
929-
# If logits_indices is set, builder.build(...) will
930-
# preprocess the common metadata to skip prefill tokens
931-
common_attn_metadata.logits_indices = logits_indices
932-
# TODO(sarckk): Enable cascade attention for fast prefill
933-
common_prefix_len = 0
934-
935923
attn_metadata_i = (builder.build(
936924
common_prefix_len=common_prefix_len,
937925
common_attn_metadata=common_attn_metadata,
938926
))
939927

940-
if is_fast_prefill:
941-
# Eligible layers need extra metadata for use in the model.
942-
attn_metadata_i = \
943-
create_kv_sharing_fast_prefill_attn_metadata_subclass(
944-
attn_metadata_i,
945-
logits_indices_padded,
946-
logits_indices.size(0),
947-
)
948-
949928
for layer_name in attn_group.layer_names:
950929
attn_metadata[layer_name] = attn_metadata_i
951930

@@ -2790,6 +2769,7 @@ def get_attn_backends_for_layers(
27902769
"FastPrefill",
27912770
attn_backend,
27922771
make_kv_sharing_fast_prefill_common_attn_metadata,
2772+
create_kv_sharing_fast_prefill_attn_metadata_subclass,
27932773
)
27942774

27952775
key = attn_backend.full_cls_name()

0 commit comments

Comments
 (0)