@@ -64,7 +64,8 @@ class CommonAttentionMetadata:
64
64
block_table_tensor : torch .Tensor
65
65
slot_mapping : torch .Tensor
66
66
67
- logits_indices : Optional [torch .Tensor ] = None
67
+ logits_indices_padded : Optional [torch .Tensor ] = None
68
+ num_logits_indices : Optional [int ] = None
68
69
69
70
causal : bool = True
70
71
@@ -537,7 +538,6 @@ def make_local_attention_virtual_batches(
537
538
max_query_len = seqlens_q_local .max (),
538
539
block_table_tensor = block_table_local ,
539
540
slot_mapping = common_attn_metadata .slot_mapping ,
540
- logits_indices = common_attn_metadata .logits_indices ,
541
541
causal = True ,
542
542
)
543
543
@@ -550,14 +550,14 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
550
550
# Skip computing fast prefill path
551
551
return common_attn_metadata
552
552
553
- if common_attn_metadata .logits_indices is None :
554
- # Logits_indices can be None if prompt_logprobs is
555
- # set for at least one request in the current iteration
556
- # fast prefill is not compatible with prompt_logprobs
557
- # so skip computing fast prefill path
553
+ if (common_attn_metadata .logits_indices_padded is None
554
+ or common_attn_metadata .num_logits_indices is None ):
558
555
return common_attn_metadata
559
556
560
- logits_indices = common_attn_metadata .logits_indices
557
+ logits_indices_padded = common_attn_metadata .logits_indices_padded
558
+ num_logits_indices = common_attn_metadata .num_logits_indices
559
+ # Get rid of CUDAGraph padding, if any
560
+ logits_indices = logits_indices_padded [:num_logits_indices ]
561
561
num_reqs = common_attn_metadata .num_reqs
562
562
query_start_loc = common_attn_metadata .query_start_loc
563
563
seq_lens = common_attn_metadata .seq_lens
@@ -600,7 +600,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
600
600
max_query_len = decode_max_query_len ,
601
601
block_table_tensor = common_attn_metadata .block_table_tensor ,
602
602
slot_mapping = common_attn_metadata .slot_mapping ,
603
- logits_indices = logits_indices ,
604
603
causal = True ,
605
604
)
606
605
return common_attn_metadata
@@ -611,6 +610,9 @@ def subclass_attention_metadata_builder(
611
610
builder_cls : type [AttentionMetadataBuilder [M ]],
612
611
build_preprocess_fn : Callable [[CommonAttentionMetadata ],
613
612
CommonAttentionMetadata ],
613
+ build_postprocess_fn : Optional [
614
+ Callable [[AttentionMetadataBuilder [M ], CommonAttentionMetadata , Any ],
615
+ Any ]] = None ,
614
616
) -> type [AttentionMetadataBuilder [M ]]:
615
617
"""
616
618
Return a new subclass of `builder_cls` whose .build(...) method
@@ -622,9 +624,13 @@ def build(self,
622
624
common_prefix_len : int ,
623
625
common_attn_metadata : CommonAttentionMetadata ,
624
626
fast_build : bool = False ):
625
- return builder_cls .build (self , common_prefix_len ,
626
- build_preprocess_fn (common_attn_metadata ),
627
- fast_build )
627
+ metadata = builder_cls .build (self , common_prefix_len ,
628
+ build_preprocess_fn (common_attn_metadata ),
629
+ fast_build )
630
+ if build_postprocess_fn is not None :
631
+ metadata = build_postprocess_fn (self , common_attn_metadata ,
632
+ metadata )
633
+ return metadata
628
634
629
635
Wrapped = type (
630
636
name ,
@@ -803,25 +809,25 @@ class KVSharingFastPrefillAttentionMetadata(Protocol):
803
809
804
810
805
811
def create_kv_sharing_fast_prefill_attn_metadata_subclass (
806
- attn_metadata_i : Any ,
807
- logits_indices_padded : torch . Tensor ,
808
- num_logits_indices : int ,
809
- ):
812
+ self : AttentionMetadataBuilder [ M ] ,
813
+ common_attn_metadata : CommonAttentionMetadata ,
814
+ metadata : Any ,
815
+ ) -> Any :
810
816
# Dynamically create a a dataclass type that inherits
811
817
# from attention metadata type but includes additional
812
818
# fields logits_indices_padded and num_logits_indices
813
819
# which are required for prefill truncation
814
820
fast_prefill_metadata_type = (
815
821
make_kv_sharing_fast_prefill_attention_metadata (
816
- metadata_cls = type (attn_metadata_i ), )) # type: ignore
822
+ metadata_cls = type (metadata ), )) # type: ignore
817
823
# Avoid deepcopy caused by dict.asdict
818
824
attn_metadata_fields = {}
819
- for field in fields (attn_metadata_i .__class__ ):
820
- attn_metadata_fields [field .name ] = getattr (attn_metadata_i , field .name )
825
+ for field in fields (metadata .__class__ ):
826
+ attn_metadata_fields [field .name ] = getattr (metadata , field .name )
821
827
attn_metadata_i = fast_prefill_metadata_type (
822
828
** attn_metadata_fields ,
823
- logits_indices_padded = logits_indices_padded ,
824
- num_logits_indices = num_logits_indices ,
829
+ logits_indices_padded = common_attn_metadata . logits_indices_padded ,
830
+ num_logits_indices = common_attn_metadata . num_logits_indices ,
825
831
)
826
832
return attn_metadata_i
827
833
@@ -832,14 +838,19 @@ def create_custom_attention_backend(
832
838
underlying_attn_backend : AttentionBackend ,
833
839
build_preprocess_fn : Callable [[CommonAttentionMetadata ],
834
840
CommonAttentionMetadata ],
841
+ build_postprocess_fn : Optional [
842
+ Callable [[AttentionMetadataBuilder [M ], CommonAttentionMetadata , Any ],
843
+ Any ]] = None ,
835
844
) -> type [AttentionBackend ]:
836
845
# Dynamically create a new attention backend that wraps the
837
846
# underlying attention backend but applies
838
847
# `build_preproces_fn` before calling `build(...)`
839
848
builder_cls = subclass_attention_metadata_builder (
840
849
name_prefix = prefix ,
841
850
builder_cls = underlying_attn_backend .get_builder_cls (),
842
- build_preprocess_fn = build_preprocess_fn )
851
+ build_preprocess_fn = build_preprocess_fn ,
852
+ build_postprocess_fn = build_postprocess_fn ,
853
+ )
843
854
attn_backend = subclass_attention_backend (
844
855
name_prefix = prefix ,
845
856
attention_backend_cls = underlying_attn_backend ,
0 commit comments