@@ -64,7 +64,8 @@ class CommonAttentionMetadata:
64
64
block_table_tensor : torch .Tensor
65
65
slot_mapping : torch .Tensor
66
66
67
- logits_indices : Optional [torch .Tensor ] = None
67
+ logits_indices_padded : Optional [torch .Tensor ] = None
68
+ num_logits_indices : Optional [int ] = None
68
69
69
70
causal : bool = True
70
71
@@ -534,7 +535,6 @@ def make_local_attention_virtual_batches(
534
535
max_query_len = seqlens_q_local .max (),
535
536
block_table_tensor = block_table_local ,
536
537
slot_mapping = common_attn_metadata .slot_mapping ,
537
- logits_indices = common_attn_metadata .logits_indices ,
538
538
causal = True ,
539
539
)
540
540
@@ -547,14 +547,14 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
547
547
# Skip computing fast prefill path
548
548
return common_attn_metadata
549
549
550
- if common_attn_metadata .logits_indices is None :
551
- # Logits_indices can be None if prompt_logprobs is
552
- # set for at least one request in the current iteration
553
- # fast prefill is not compatible with prompt_logprobs
554
- # so skip computing fast prefill path
550
+ if (common_attn_metadata .logits_indices_padded is None
551
+ or common_attn_metadata .num_logits_indices is None ):
555
552
return common_attn_metadata
556
553
557
- logits_indices = common_attn_metadata .logits_indices
554
+ logits_indices_padded = common_attn_metadata .logits_indices_padded
555
+ num_logits_indices = common_attn_metadata .num_logits_indices
556
+ # Get rid of CUDAGraph padding, if any
557
+ logits_indices = logits_indices_padded [:num_logits_indices ]
558
558
num_reqs = common_attn_metadata .num_reqs
559
559
query_start_loc = common_attn_metadata .query_start_loc
560
560
seq_lens = common_attn_metadata .seq_lens
@@ -597,7 +597,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
597
597
max_query_len = decode_max_query_len ,
598
598
block_table_tensor = common_attn_metadata .block_table_tensor ,
599
599
slot_mapping = common_attn_metadata .slot_mapping ,
600
- logits_indices = logits_indices ,
601
600
causal = True ,
602
601
)
603
602
return common_attn_metadata
@@ -608,6 +607,9 @@ def subclass_attention_metadata_builder(
608
607
builder_cls : type [AttentionMetadataBuilder [M ]],
609
608
build_preprocess_fn : Callable [[CommonAttentionMetadata ],
610
609
CommonAttentionMetadata ],
610
+ build_postprocess_fn : Optional [
611
+ Callable [[AttentionMetadataBuilder [M ], CommonAttentionMetadata , Any ],
612
+ Any ]] = None ,
611
613
) -> type [AttentionMetadataBuilder [M ]]:
612
614
"""
613
615
Return a new subclass of `builder_cls` whose .build(...) method
@@ -619,9 +621,13 @@ def build(self,
619
621
common_prefix_len : int ,
620
622
common_attn_metadata : CommonAttentionMetadata ,
621
623
fast_build : bool = False ):
622
- return builder_cls .build (self , common_prefix_len ,
623
- build_preprocess_fn (common_attn_metadata ),
624
- fast_build )
624
+ metadata = builder_cls .build (self , common_prefix_len ,
625
+ build_preprocess_fn (common_attn_metadata ),
626
+ fast_build )
627
+ if build_postprocess_fn is not None :
628
+ metadata = build_postprocess_fn (self , common_attn_metadata ,
629
+ metadata )
630
+ return metadata
625
631
626
632
Wrapped = type (
627
633
name ,
@@ -800,25 +806,25 @@ class KVSharingFastPrefillAttentionMetadata(Protocol):
800
806
801
807
802
808
def create_kv_sharing_fast_prefill_attn_metadata_subclass (
803
- attn_metadata_i : Any ,
804
- logits_indices_padded : torch . Tensor ,
805
- num_logits_indices : int ,
806
- ):
809
+ self : AttentionMetadataBuilder [ M ] ,
810
+ common_attn_metadata : CommonAttentionMetadata ,
811
+ metadata : Any ,
812
+ ) -> Any :
807
813
# Dynamically create a a dataclass type that inherits
808
814
# from attention metadata type but includes additional
809
815
# fields logits_indices_padded and num_logits_indices
810
816
# which are required for prefill truncation
811
817
fast_prefill_metadata_type = (
812
818
make_kv_sharing_fast_prefill_attention_metadata (
813
- metadata_cls = type (attn_metadata_i ), )) # type: ignore
819
+ metadata_cls = type (metadata ), )) # type: ignore
814
820
# Avoid deepcopy caused by dict.asdict
815
821
attn_metadata_fields = {}
816
- for field in fields (attn_metadata_i .__class__ ):
817
- attn_metadata_fields [field .name ] = getattr (attn_metadata_i , field .name )
822
+ for field in fields (metadata .__class__ ):
823
+ attn_metadata_fields [field .name ] = getattr (metadata , field .name )
818
824
attn_metadata_i = fast_prefill_metadata_type (
819
825
** attn_metadata_fields ,
820
- logits_indices_padded = logits_indices_padded ,
821
- num_logits_indices = num_logits_indices ,
826
+ logits_indices_padded = common_attn_metadata . logits_indices_padded ,
827
+ num_logits_indices = common_attn_metadata . num_logits_indices ,
822
828
)
823
829
return attn_metadata_i
824
830
@@ -829,14 +835,19 @@ def create_custom_attention_backend(
829
835
underlying_attn_backend : AttentionBackend ,
830
836
build_preprocess_fn : Callable [[CommonAttentionMetadata ],
831
837
CommonAttentionMetadata ],
838
+ build_postprocess_fn : Optional [
839
+ Callable [[AttentionMetadataBuilder [M ], CommonAttentionMetadata , Any ],
840
+ Any ]] = None ,
832
841
) -> type [AttentionBackend ]:
833
842
# Dynamically create a new attention backend that wraps the
834
843
# underlying attention backend but applies
835
844
# `build_preproces_fn` before calling `build(...)`
836
845
builder_cls = subclass_attention_metadata_builder (
837
846
name_prefix = prefix ,
838
847
builder_cls = underlying_attn_backend .get_builder_cls (),
839
- build_preprocess_fn = build_preprocess_fn )
848
+ build_preprocess_fn = build_preprocess_fn ,
849
+ build_postprocess_fn = build_postprocess_fn ,
850
+ )
840
851
attn_backend = subclass_attention_backend (
841
852
name_prefix = prefix ,
842
853
attention_backend_cls = underlying_attn_backend ,
0 commit comments