From c840bc035ee31f1f32c8c36b6786345efc846774 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 7 Aug 2025 20:25:38 +0000 Subject: [PATCH 01/13] Pass layer to _forward_decode, add q and k descale for FlashMLA backend Signed-off-by: Matthew Bonanni --- vllm/attention/ops/flashmla.py | 6 ++++++ vllm/v1/attention/backends/mla/common.py | 3 ++- vllm/v1/attention/backends/mla/cutlass_mla.py | 4 +++- vllm/v1/attention/backends/mla/flashmla.py | 10 +++++----- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 ++ vllm/v1/attention/backends/mla/triton_mla.py | 4 +++- 6 files changed, 21 insertions(+), 8 deletions(-) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 1af26dfc3daa..a46cfb441dbe 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -67,6 +67,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -81,6 +83,8 @@ def flash_mla_with_kvcache( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. + descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -98,6 +102,8 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + descale_q, + descale_k, ) return out, softmax_lse diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index badff67656c2..9205f06f8a5e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1127,6 +1127,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, + layer: AttentionLayer, ) -> torch.Tensor: raise NotImplementedError @@ -1205,6 +1206,6 @@ def forward( decode_ql_nope = decode_ql_nope.transpose(0, 1) output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index b23a8f0a5e87..516e33f132c7 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -7,7 +7,8 @@ import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, + AttentionType, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -266,6 +267,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: if self._use_old_cutlass_mla: # TODO: Remove the old cutlass MLA kernel after more extensive diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2b0f52cf80bf..4985f7af0473 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,7 +6,8 @@ import torch -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, + AttentionType, is_quantized_kv_cache) from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, @@ -163,16 +164,13 @@ def __init__( "are not implemented for " "FlashMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA V1 with FP8 KV cache not yet supported") - def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -191,6 +189,8 @@ def _forward_decode( num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, + descale_q=layer._q_scale, + descale_k=layer._k_scale, ) return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 8b55e1a30199..ddf18b472214 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -7,6 +7,7 @@ import torch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv @@ -216,6 +217,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 700fce68953e..47a6ffe50d01 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -6,7 +6,8 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, + AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention @@ -127,6 +128,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None From 5854011a48529397bd61cc7f56ad51f3c3e248c2 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 14:53:33 +0000 Subject: [PATCH 02/13] Update tests Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_flashmla.py | 67 +++++++++++++++++------- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 21b08e45fd6f..0ff38d2966af 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -13,11 +13,14 @@ from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) - assert cos_diff < 1e-5 + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ if not is_flashmla_supported()[0] else "FlashMLA is supported" @@ -27,7 +30,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) -@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @pytest.mark.parametrize("h_q", [16, 32, 64, 128]) @pytest.mark.parametrize("h_kv", [1]) @pytest.mark.parametrize("d", [576]) @@ -35,21 +38,22 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen): - # TODO: parametrize using pytest - dtype = torch.bfloat16 + varlen, torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}") + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -72,6 +76,28 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv) + init_dtype = q.dtype + def prepare_fp8_input(): + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None + + if use_fp8: + nonlocal q, blocked_k, blocked_v + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + + return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k + + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() + if use_fp8: + q = q_fp8 + blocked_k = blocked_k_fp8 + blocked_v = blocked_v_fp8 + def flash_mla(): return flash_mla_with_kvcache( q, @@ -82,6 +108,8 @@ def flash_mla(): tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def scaled_dot_product_attention(query, key, value, is_causal=False): @@ -105,29 +133,32 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): return attn_weight @ value, lse def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - ref_O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + O, LSE = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), is_causal=causal, ) - out[i] = ref_O.transpose(0, 1) + out[i] = O.transpose(0, 1) lse[i] = LSE return out, lse out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + - b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " - f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print( + f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" + ) From cce4e0a07de53fa6a2bd365b945b7d9b21559262 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 18:24:53 +0000 Subject: [PATCH 03/13] Quantize Q and KV, fix shape error with scales Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 17 +++++++++++++++++ vllm/v1/attention/backends/mla/flashmla.py | 8 +++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 9205f06f8a5e..a751030afd5e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1155,6 +1155,8 @@ def forward( # same expert outputs. return output.fill_(0) + fp8_attention = self.kv_cache_dtype == "fp8" + num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs @@ -1205,6 +1207,21 @@ def forward( # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) + if fp8_attention: + kv_cache = kv_cache.view(torch.float8_e4m3fn) + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape([ + ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] + ]).contiguous(), layer._q_scale) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape([ + q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2] + ]).contiguous(), layer._q_scale) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 4985f7af0473..5980f476f3ed 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,9 +6,7 @@ import torch -from vllm.attention.backends.abstract import (AttentionLayer, - AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -189,8 +187,8 @@ def _forward_decode( num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, - descale_q=layer._q_scale, - descale_k=layer._k_scale, + descale_q=layer._q_scale.reshape(1), + descale_k=layer._k_scale.reshape(1), ) return self._v_up_proj(o) From 64febac5c1fc96a549bd0fe7bb6e5d5e09e70270 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 18:25:54 +0000 Subject: [PATCH 04/13] Update to reflect FP8 support in FLASHMLA backend Signed-off-by: Matthew Bonanni --- vllm/platforms/cuda.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index dd9356e399c9..2b6ac4d71d1c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -498,12 +498,16 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + will_use_flashmla = (envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") supported = False if cls.is_device_capability(100): supported = True elif fp8_attention and will_use_fa: from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 supported = flash_attn_supports_fp8() + elif fp8_attention and will_use_flashmla: + supported = True return supported From b548b10e9493dd2f4d2521a39ad9b01f7ca7aa6b Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 18:27:02 +0000 Subject: [PATCH 05/13] Update cmake Signed-off-by: Matthew Bonanni --- cmake/external_projects/flashmla.cmake | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index ee6768bce26c..3d050df2e24e 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc/include) + ${flashmla_SOURCE_DIR}/csrc) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" From 8ae24a35b4dfddd4f3589eec68bdb97f38181adf Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 18:47:12 +0000 Subject: [PATCH 06/13] Address comment Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_flashmla.py | 63 ++++++++++++------------ 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 0ff38d2966af..abcfe828d5ac 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -13,7 +13,10 @@ from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) @@ -38,12 +41,16 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) - @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("torch_dtype", + [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype): device = torch.device("cuda:0") - init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) @@ -77,26 +84,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, cache_seqlens, s_q * h_q // h_kv, h_kv) init_dtype = q.dtype - def prepare_fp8_input(): - q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None - - if use_fp8: - nonlocal q, blocked_k, blocked_v - fp8_dtype = torch.float8_e4m3fn - descale_q = torch.ones((1), dtype=torch.float32) - descale_k = torch.ones((1), dtype=torch.float32) - - q_fp8 = q.to(fp8_dtype) - blocked_k_fp8 = blocked_k.to(fp8_dtype) - blocked_v_fp8 = blocked_v.to(fp8_dtype) - - return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k - - q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() if use_fp8: - q = q_fp8 - blocked_k = blocked_k_fp8 - blocked_v = blocked_v_fp8 + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None def flash_mla(): return flash_mla_with_kvcache( @@ -134,21 +132,23 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): def ref_mla(): q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q - blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k - blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - O, LSE = scaled_dot_product_attention( + out_i, lse_i = scaled_dot_product_attention( q_[i].transpose(0, 1), blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), is_causal=causal, ) - out[i] = O.transpose(0, 1) - lse[i] = LSE + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i return out, lse out_flash, lse_flash = flash_mla() @@ -158,7 +158,8 @@ def ref_mla(): t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) - print( - f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" - ) + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") From 860f3e03f2003cfcb937aa4e971d0d6f8f2c84b4 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 18:57:11 +0000 Subject: [PATCH 07/13] Address pre-commit hooks Signed-off-by: Matthew Bonanni --- vllm/attention/ops/flashmla.py | 4 ++-- vllm/v1/attention/backends/mla/cutlass_mla.py | 3 +-- vllm/v1/attention/backends/mla/triton_mla.py | 3 +-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index a46cfb441dbe..564042cf8eb1 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -83,8 +83,8 @@ def flash_mla_with_kvcache( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. + descale_q: (batch_size), torch.float32. Descaling factors for Q. + descale_k: (batch_size), torch.float32. Descaling factors for K. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 516e33f132c7..146632e78c3e 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -7,8 +7,7 @@ import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionLayer, - AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 47a6ffe50d01..f2974ed668d9 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -6,8 +6,7 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, - AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention From dd7977dcf80f4ff3f63dcfd3e437e50a0d7fe600 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 11 Aug 2025 21:09:26 +0000 Subject: [PATCH 08/13] Dequant in chunked prefill Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a751030afd5e..efe3a23abc40 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -427,6 +427,9 @@ def __init__(self, self.page_size = self.kv_cache_spec.block_size if self.chunked_prefill_enabled: + workspace_dtype = self.model_config.dtype + if cache_config.cache_dtype.startswith("fp8"): + workspace_dtype = current_platform.fp8_dtype() self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request @@ -447,7 +450,7 @@ def __init__(self, self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, self.model_config.get_head_size()), - dtype=self.model_config.dtype, + dtype=workspace_dtype, device=device, ) @@ -1022,6 +1025,8 @@ def _compute_prefill_context( iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace + fp8_attention = self.kv_cache_dtype.startswith("fp8") + for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] @@ -1039,6 +1044,16 @@ def _compute_prefill_context( k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) + if fp8_attention: + target_dtype = self.kv_b_proj.weight.dtype + kv_c_normed_dequant = torch.empty_like(kv_c_normed, + dtype=target_dtype) + k_pe_dequant = torch.empty_like(k_pe, dtype=target_dtype) + ops.convert_fp8(kv_c_normed_dequant, kv_c_normed) + ops.convert_fp8(k_pe_dequant, k_pe) + kv_c_normed = kv_c_normed_dequant + k_pe = k_pe_dequant + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ @@ -1155,7 +1170,7 @@ def forward( # same expert outputs. return output.fill_(0) - fp8_attention = self.kv_cache_dtype == "fp8" + fp8_attention = self.kv_cache_dtype.startswith("fp8") num_actual_toks = attn_metadata.num_actual_tokens @@ -1191,6 +1206,9 @@ def forward( scale=layer._k_scale, ) + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, @@ -1208,7 +1226,6 @@ def forward( decode_ql_nope = decode_ql_nope.transpose(0, 1) if fp8_attention: - kv_cache = kv_cache.view(torch.float8_e4m3fn) ql_nope_shape = decode_ql_nope.shape decode_ql_nope, _ = ops.scaled_fp8_quant( decode_ql_nope.reshape([ From 8dfbf29f704f10a47974c53e012d3ad4e2846b63 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 12 Aug 2025 20:07:34 +0000 Subject: [PATCH 09/13] Dequant within gather_cache kernel Signed-off-by: Matthew Bonanni --- csrc/cache.h | 6 ++- csrc/cache_kernels.cu | 57 ++++++++++++------------ csrc/torch_bindings.cpp | 13 ++++-- tests/kernels/attention/test_cache.py | 25 ++++++----- vllm/_custom_ops.py | 20 +++++---- vllm/attention/backends/mla/common.py | 14 +++--- vllm/v1/attention/backends/mla/common.py | 32 +++++-------- 7 files changed, 88 insertions(+), 79 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 0970b704be3a..fb0c353b9613 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 131dcb15cd7e..b3a985c2d5bb 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template -__global__ void gather_cache( - const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, +template +__global__ void gather_and_maybe_dequant_cache( + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, // ENTRIES...] scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] @@ -634,6 +634,7 @@ __global__ void gather_cache( const int32_t block_size, const int32_t entry_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch @@ -675,10 +676,16 @@ __global__ void gather_cache( if (partial_block_size) full_blocks_end -= 1; } - auto copy_entry = [&](const scalar_t* __restrict__ _src, + auto copy_entry = [&](const cache_t* __restrict__ _src, scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) - _dst[i] = _src[i]; + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + _dst[i] = static_cast(_src[i]); + } else { + _dst[i] = + fp8::scaled_convert(_src[i], *scale); + } + } }; for (int pid = split_start; pid < full_blocks_end; ++pid) { @@ -705,25 +712,31 @@ __global__ void gather_cache( } // namespace vllm // Macro to dispatch the kernel based on the data type. -#define CALL_GATHER_CACHE(CPY_DTYPE) \ - vllm::gather_cache<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst.data_ptr()), \ - block_table.data_ptr(), cu_seq_lens.data_ptr(), \ - block_size, entry_size, block_table_stride, cache_block_stride, \ - cache_entry_stride, dst_entry_stride, seq_starts_ptr); +// SCALAR_T is the data type of the destination tensor. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, \ + reinterpret_cast(scale.data_ptr()), seq_starts_ptr); // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence // - Optionally, seq_starts (if provided) offsets the starting block index by // (seq_starts[bid] / page_size) -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, std::optional seq_starts = std::nullopt) { at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -761,20 +774,8 @@ void gather_cache( dim3 grid(batch_size, num_splits); dim3 block(1024); - TORCH_CHECK(src_cache.dtype() == dst.dtype(), - "src_cache and dst must have the same dtype"); - - const int dtype_bits = src_cache.element_size() * 8; const int32_t* seq_starts_ptr = seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; - if (dtype_bits == 32) { - CALL_GATHER_CACHE(uint32_t); - } else if (dtype_bits == 16) { - CALL_GATHER_CACHE(uint16_t); - } else if (dtype_bits == 8) { - CALL_GATHER_CACHE(uint8_t); - } else { - TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); - } + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 85b6abef00b0..6d9d7c16d126 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -703,11 +703,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); - // Gather cache blocks from src_cache to dst. + // Gather cache blocks from src_cache to dst, dequantizing from + // src_cache's dtype to dst's dtype if necessary. cache_ops.def( - "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " - "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); - cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); + "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " + " Tensor block_table, Tensor cu_seq_lens, " + " int batch_size, " + " str kv_cache_dtype, " + " Tensor scale, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, + &gather_and_maybe_dequant_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 789507615580..e87b6877d77f 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -709,14 +709,15 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("max_seq_len", [512]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("kv_cache_dtype", - ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, - num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, + block_size, num_blocks, + max_seq_len, batch_size, dtype, + kv_cache_dtype, device): entry_size = kv_lora_rank + qk_rope_head_dim + scale = torch.tensor(0.1, dtype=torch.float32, device=device) src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) @@ -742,10 +743,9 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, perm = torch.randperm(num_blocks, device=device) block_table[b, :] = perm - dst = torch.zeros((total_tokens, entry_size), - dtype=src_cache.dtype, - device=device) + dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device) + # TODO - do dequant here expected_batches = [] for b in range(batch_size): s = seq_len_tensor[b] @@ -765,12 +765,15 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, expected = torch.cat(expected_batches, dim=0) opcheck( - torch.ops._C_cache_ops.gather_cache, - (src_cache, dst, block_table, cu_seq_lens, batch_size, None), + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, None), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, kv_cache_dtype, + scale, None) torch.testing.assert_close(dst, expected) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 70605d3c5f52..9aeedfbdca1e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1672,14 +1672,18 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) -def gather_cache(src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - seq_starts: Optional[torch.Tensor] = None) -> None: - torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, seq_starts) +def gather_and_maybe_dequant_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( + src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, seq_starts) def get_device_attribute(attribute: int, device: int) -> int: diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 52c4a9e7da3d..fab422803bf8 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -922,8 +922,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.context_chunk_workspace_size // num_prefills_with_context # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel cannot + # handle `context_chunk_starts` that are not aligned to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) @@ -1167,6 +1167,7 @@ def _compute_prefill_context( q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): prefill_metadata = attn_metadata.prefill_metadata assert prefill_metadata is not None @@ -1188,12 +1189,14 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_tables, cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], batch_size=prefill_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.context_chunk_starts[i], ) @@ -1250,6 +1253,7 @@ def _forward_prefill( k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: prefill_metadata = attn_metadata.prefill_metadata @@ -1282,7 +1286,7 @@ def _forward_prefill( # ROCm flash_attn_varlen_func will return 3 objects instead of 2 suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1372,7 +1376,7 @@ def forward( if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: decode_q_nope, decode_q_pe = decode_q.split( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index efe3a23abc40..6f746e3bfd91 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -427,9 +427,6 @@ def __init__(self, self.page_size = self.kv_cache_spec.block_size if self.chunked_prefill_enabled: - workspace_dtype = self.model_config.dtype - if cache_config.cache_dtype.startswith("fp8"): - workspace_dtype = current_platform.fp8_dtype() self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request @@ -450,7 +447,7 @@ def __init__(self, self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, self.model_config.get_head_size()), - dtype=workspace_dtype, + dtype=self.model_config.dtype, device=device, ) @@ -638,8 +635,9 @@ def build(self, if self.aot_schedule: # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) @@ -1016,6 +1014,7 @@ def _compute_prefill_context( q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1025,17 +1024,17 @@ def _compute_prefill_context( iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace - fp8_attention = self.kv_cache_dtype.startswith("fp8") - for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.chunked_context.starts[i], ) @@ -1044,16 +1043,6 @@ def _compute_prefill_context( k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) - if fp8_attention: - target_dtype = self.kv_b_proj.weight.dtype - kv_c_normed_dequant = torch.empty_like(kv_c_normed, - dtype=target_dtype) - k_pe_dequant = torch.empty_like(k_pe, dtype=target_dtype) - ops.convert_fp8(kv_c_normed_dequant, kv_c_normed) - ops.convert_fp8(k_pe_dequant, k_pe) - kv_c_normed = kv_c_normed_dequant - k_pe = k_pe_dequant - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ @@ -1096,6 +1085,7 @@ def _forward_prefill( k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: assert attn_metadata.prefill is not None @@ -1118,7 +1108,7 @@ def _forward_prefill( if has_context: suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1212,7 +1202,7 @@ def forward( if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: assert attn_metadata.decode is not None From 926ba4d0fff3b5e52796f5557c595de7b08d397e Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 12 Aug 2025 20:15:53 +0000 Subject: [PATCH 10/13] Update test Signed-off-by: Matthew Bonanni --- tests/kernels/attention/test_cache.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index e87b6877d77f..c09635ed116d 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -745,7 +745,6 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device) - # TODO - do dequant here expected_batches = [] for b in range(batch_size): s = seq_len_tensor[b] @@ -756,9 +755,23 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, gathered_rows = [] for i in range(tot - 1): - gathered_rows.append(src_cache[blocks[i]]) + block_data = src_cache[blocks[i]] + if kv_cache_dtype == "fp8": + dequantized_block = torch.empty_like(block_data, dtype=dtype) + ops.convert_fp8(dequantized_block, block_data, scale.item()) + gathered_rows.append(dequantized_block) + else: + gathered_rows.append(block_data) remaining = s - (tot - 1) * block_size - gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + last_block_data = src_cache[blocks[-1], :remaining, :] + if kv_cache_dtype == "fp8": + dequantized_last_block = torch.empty_like(last_block_data, + dtype=dtype) + ops.convert_fp8(dequantized_last_block, last_block_data, + scale.item()) + gathered_rows.append(dequantized_last_block) + else: + gathered_rows.append(last_block_data) batch_expected = torch.cat(gathered_rows, dim=0) expected_batches.append(batch_expected) From 56e8135cba0f8470405f4d27a16a55d7dbb911aa Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 18 Aug 2025 17:12:23 +0000 Subject: [PATCH 11/13] Update GIT_TAG Signed-off-by: Matthew Bonanni --- cmake/external_projects/flashmla.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 3d050df2e24e..02224cfe3ee8 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 + GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" From 7b86ffbeed9d387f524bcfd263982c0e317bfa83 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 19 Aug 2025 18:57:34 +0000 Subject: [PATCH 12/13] Remove unnecessary contiguous() calls - tensors are already contiguous Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 98420a77ffdc..c8636469cf6d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1216,13 +1216,13 @@ def forward( decode_ql_nope, _ = ops.scaled_fp8_quant( decode_ql_nope.reshape([ ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] - ]).contiguous(), layer._q_scale) + ]), layer._q_scale) decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) q_pe_shape = decode_q_pe.shape decode_q_pe, _ = ops.scaled_fp8_quant( - decode_q_pe.reshape([ - q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2] - ]).contiguous(), layer._q_scale) + decode_q_pe.reshape( + [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale) decode_q_pe = decode_q_pe.reshape(q_pe_shape) output[:num_decode_tokens] = self._forward_decode( From f59cf508d5555d99435d6da8e0f1c8a53b7042da Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 19 Aug 2025 20:28:20 +0000 Subject: [PATCH 13/13] Update fp8 platform/backend support logic Signed-off-by: Matthew Bonanni --- vllm/engine/arg_utils.py | 3 +-- vllm/platforms/cuda.py | 45 +++++++++++++++++++++++++++---------- vllm/platforms/interface.py | 3 ++- vllm/platforms/rocm.py | 3 ++- vllm/platforms/tpu.py | 3 ++- 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6fc894827c4a..2dce722837c9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1404,10 +1404,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": supported = current_platform.is_kv_cache_dtype_supported( - self.kv_cache_dtype) + self.kv_cache_dtype, model_config) if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a6d7f8537c84..392927b29092 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -495,20 +495,41 @@ def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") - will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" - will_use_flashmla = (envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") + attention_backend = envs.VLLM_ATTENTION_BACKEND + supported = False - if cls.is_device_capability(100): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() - elif fp8_attention and will_use_flashmla: - supported = True + if model_config is not None and model_config.use_mla: + # Default to CutlassMLA for blackwell, + # FlashMLA otherwise + if attention_backend is None: + if cls.is_device_capability(100): + attention_backend = "CUTLASS_MLA" + else: + attention_backend = "FLASHMLA" + + # Only FlashMLA supports fp8 + if attention_backend == "FLASHMLA": + supported = True + else: + supported = (not fp8_attention) + else: + # Default to FlashAttention + if attention_backend is None: + attention_backend = "FLASH_ATTN_VLLM_V1" + + # All Blackwell backends support fp8 + if cls.is_device_capability(100): + supported = True + elif attention_backend == "FLASH_ATTN_VLLM_V1": + if fp8_attention: + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + else: + supported = True return supported diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4017f1ca7eec..521e5d54ba00 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -562,7 +562,8 @@ def stateless_init_device_torch_dist_pg( raise RuntimeError(f"Unsupported torch distributed backend: {backend}") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3ede86e15855..317bc401a799 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -459,5 +459,6 @@ def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index ba06abd07f08..4324ce7dcb91 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -194,7 +194,8 @@ def validate_request( raise ValueError("Torch XLA does not support per-request seed.") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True