@@ -215,6 +215,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
215
215
self ._cascade_wrapper = None # Wrapper for cascade attention
216
216
217
217
# Global hyperparameters shared by all attention layers
218
+ # TODO: discard this for trtllm-gen backend
218
219
self .global_hyperparameters = infer_global_hyperparameters (
219
220
get_per_layer_parameters (vllm_config , layer_names , FlashInferImpl ))
220
221
@@ -523,16 +524,12 @@ def build(self,
523
524
head_dim = self .kv_cache_spec .head_size
524
525
525
526
# currently prefill trtllm attention does not support fp8 kv cache
526
- # trtllm may not support sliding window
527
- prefill_use_trtllm = (self .global_hyperparameters .window_left == - 1
528
- and not cache_dtype .startswith ("fp8" )
529
- and use_trtllm_attention (
527
+ prefill_use_trtllm = use_trtllm_attention (
530
528
num_prefill_tokens , max_seq_len , cache_dtype ,
531
- num_qo_heads , num_kv_heads , head_dim ))
532
- decode_use_trtllm = (self .global_hyperparameters .window_left == - 1
533
- and use_trtllm_attention (
529
+ num_qo_heads , num_kv_heads , head_dim )
530
+ decode_use_trtllm = use_trtllm_attention (
534
531
num_decode_tokens , max_seq_len , cache_dtype ,
535
- num_qo_heads , num_kv_heads , head_dim ))
532
+ num_qo_heads , num_kv_heads , head_dim )
536
533
537
534
attn_metadata = FlashInferMetadata (
538
535
num_actual_tokens = num_actual_tokens ,
@@ -793,6 +790,8 @@ def forward(
793
790
batch_size = attn_metadata .num_prefills ,
794
791
cum_seq_lens_q = attn_metadata .qo_indptr_gpu ,
795
792
cum_seq_lens_kv = attn_metadata .paged_kv_indptr_gpu ,
793
+ window_left = window_left ,
794
+ sinks = self .sinks ,
796
795
out = output [num_decode_tokens :],
797
796
)
798
797
@@ -839,6 +838,8 @@ def forward(
839
838
max_seq_len = attn_metadata .max_seq_len ,
840
839
bmm1_scale = layer ._k_scale_float * self .scale ,
841
840
bmm2_scale = layer ._v_scale_float ,
841
+ window_left = window_left ,
842
+ sinks = self .sinks ,
842
843
out = output [:num_decode_tokens ],
843
844
)
844
845
return output_padded
0 commit comments