Skip to content

Commit b30f23a

Browse files
committed
support fp8 TRTLLM attn kernel
Signed-off-by: elvischenv <[email protected]>
1 parent 74333ae commit b30f23a

File tree

7 files changed

+363
-102
lines changed

7 files changed

+363
-102
lines changed

vllm/attention/backends/abstract.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,15 @@ def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
304304
"""
305305
return False
306306

307+
def inserted_input_quant_supported(self, dtype: torch.dtype, static: bool,
308+
group_shape: GroupShape):
309+
"""
310+
Does this attention implementation support inserted input quantization.
311+
This is used by the AttnFusionPass to insert input quantization
312+
that support it.
313+
"""
314+
return False
315+
307316

308317
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
309318

vllm/attention/backends/flashinfer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,10 +1114,14 @@ def forward(
11141114
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
11151115
# TODO: @pavanimajety Remove this once the switch happens
11161116
# inside flashinfer.
1117-
if not use_trtllm_attention(
1118-
num_decode_tokens, attn_metadata.max_decode_seq_len,
1119-
kv_cache_dtype, attn_metadata.num_qo_heads,
1120-
attn_metadata.num_kv_heads, attn_metadata.head_dim):
1117+
if not use_trtllm_attention(attn_metadata.num_qo_heads,
1118+
attn_metadata.num_kv_heads,
1119+
attn_metadata.head_dim,
1120+
window_left,
1121+
num_decode_tokens,
1122+
attn_metadata.max_decode_seq_len,
1123+
kv_cache_dtype,
1124+
is_prefill=False):
11211125
decode_meta.decode_wrapper.run(
11221126
decode_query,
11231127
kv_cache.permute(*stride_order),

vllm/attention/layer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,11 @@ def __init__(
126126
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
127127
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
128128

129-
# We also keep the float32 versions of k/v_scale for attention
129+
# We also keep the float32 versions of k/v/o_scale for attention
130130
# backends that don't support tensors (Flashinfer)
131131
self._k_scale_float = 1.0
132132
self._v_scale_float = 1.0
133+
self._o_scale_float = 1.0
133134

134135
self.use_mla = use_mla
135136
self.num_heads = num_heads
@@ -195,6 +196,9 @@ def __init__(
195196
self.layer_name = prefix
196197
self.attn_type = attn_type
197198

199+
self.enabled_fusion = compilation_config.pass_config.enable_attn_fusion
200+
self.fused_quant = False
201+
198202
if kv_sharing_target_layer_name is not None:
199203
validate_kv_sharing_target(
200204
prefix,
@@ -273,7 +277,13 @@ def forward(
273277
output=output)
274278
else:
275279
torch.ops.vllm.unified_attention_with_output(
276-
query, key, value, output, self.layer_name)
280+
query,
281+
key,
282+
value,
283+
output,
284+
self.layer_name,
285+
query_scale=(self._q_scale
286+
if self.enabled_fusion else None))
277287
return output.view(-1, hidden_size)
278288
else:
279289
if self.use_direct_call:
@@ -476,6 +486,7 @@ def unified_attention_with_output(
476486
value: torch.Tensor,
477487
output: torch.Tensor,
478488
layer_name: str,
489+
query_scale: Optional[torch.Tensor] = None,
479490
output_scale: Optional[torch.Tensor] = None,
480491
) -> None:
481492
wait_for_kv_layer_from_connector(layer_name)
@@ -503,6 +514,7 @@ def unified_attention_with_output_fake(
503514
value: torch.Tensor,
504515
output: torch.Tensor,
505516
layer_name: str,
517+
query_scale: Optional[torch.Tensor] = None,
506518
output_scale: Optional[torch.Tensor] = None,
507519
) -> None:
508520
return

0 commit comments

Comments
 (0)