diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index ee6768bce26c..02224cfe3ee8 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 + GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc/include) + ${flashmla_SOURCE_DIR}/csrc) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" diff --git a/csrc/cache.h b/csrc/cache.h index 0970b704be3a..fb0c353b9613 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, + std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 131dcb15cd7e..b3a985c2d5bb 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { // grid is launched with dimensions (batch, num_splits) -template -__global__ void gather_cache( - const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, +template +__global__ void gather_and_maybe_dequant_cache( + const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, // ENTRIES...] scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] @@ -634,6 +634,7 @@ __global__ void gather_cache( const int32_t block_size, const int32_t entry_size, const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const float* __restrict__ scale, const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per // batch @@ -675,10 +676,16 @@ __global__ void gather_cache( if (partial_block_size) full_blocks_end -= 1; } - auto copy_entry = [&](const scalar_t* __restrict__ _src, + auto copy_entry = [&](const cache_t* __restrict__ _src, scalar_t* __restrict__ _dst) { - for (int i = threadIdx.x; i < entry_size; i += blockDim.x) - _dst[i] = _src[i]; + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + _dst[i] = static_cast(_src[i]); + } else { + _dst[i] = + fp8::scaled_convert(_src[i], *scale); + } + } }; for (int pid = split_start; pid < full_blocks_end; ++pid) { @@ -705,25 +712,31 @@ __global__ void gather_cache( } // namespace vllm // Macro to dispatch the kernel based on the data type. -#define CALL_GATHER_CACHE(CPY_DTYPE) \ - vllm::gather_cache<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst.data_ptr()), \ - block_table.data_ptr(), cu_seq_lens.data_ptr(), \ - block_size, entry_size, block_table_stride, cache_block_stride, \ - cache_entry_stride, dst_entry_stride, seq_starts_ptr); +// SCALAR_T is the data type of the destination tensor. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ + vllm::gather_and_maybe_dequant_cache \ + <<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, \ + reinterpret_cast(scale.data_ptr()), seq_starts_ptr); // Gather sequences from the cache into the destination tensor. // - cu_seq_lens contains the cumulative sequence lengths for each batch // - block_table contains the cache block indices for each sequence // - Optionally, seq_starts (if provided) offsets the starting block index by // (seq_starts[bid] / page_size) -void gather_cache( +void gather_and_maybe_dequant_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size, + int64_t batch_size, const std::string& kv_cache_dtype, + torch::Tensor const& scale, std::optional seq_starts = std::nullopt) { at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -761,20 +774,8 @@ void gather_cache( dim3 grid(batch_size, num_splits); dim3 block(1024); - TORCH_CHECK(src_cache.dtype() == dst.dtype(), - "src_cache and dst must have the same dtype"); - - const int dtype_bits = src_cache.element_size() * 8; const int32_t* seq_starts_ptr = seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; - if (dtype_bits == 32) { - CALL_GATHER_CACHE(uint32_t); - } else if (dtype_bits == 16) { - CALL_GATHER_CACHE(uint16_t); - } else if (dtype_bits == 8) { - CALL_GATHER_CACHE(uint8_t); - } else { - TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); - } + DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7079671c2eb1..b253c7e3f5ff 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -676,11 +676,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); - // Gather cache blocks from src_cache to dst. + // Gather cache blocks from src_cache to dst, dequantizing from + // src_cache's dtype to dst's dtype if necessary. cache_ops.def( - "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " - "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); - cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); + "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " + " Tensor block_table, Tensor cu_seq_lens, " + " int batch_size, " + " str kv_cache_dtype, " + " Tensor scale, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, + &gather_and_maybe_dequant_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 8c3cc8cba9d9..cbf11da63cab 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -709,14 +709,15 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("max_seq_len", [512]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("kv_cache_dtype", - ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, - num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, + block_size, num_blocks, + max_seq_len, batch_size, dtype, + kv_cache_dtype, device): entry_size = kv_lora_rank + qk_rope_head_dim + scale = torch.tensor(0.1, dtype=torch.float32, device=device) src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) @@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, perm = torch.randperm(num_blocks, device=device) block_table[b, :] = perm - dst = torch.zeros((total_tokens, entry_size), - dtype=src_cache.dtype, - device=device) + dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device) expected_batches = [] for b in range(batch_size): @@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, gathered_rows = [] for i in range(tot - 1): - gathered_rows.append(src_cache[blocks[i]]) + block_data = src_cache[blocks[i]] + if kv_cache_dtype == "fp8": + dequantized_block = torch.empty_like(block_data, dtype=dtype) + ops.convert_fp8(dequantized_block, block_data, scale.item()) + gathered_rows.append(dequantized_block) + else: + gathered_rows.append(block_data) remaining = s - (tot - 1) * block_size - gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + last_block_data = src_cache[blocks[-1], :remaining, :] + if kv_cache_dtype == "fp8": + dequantized_last_block = torch.empty_like(last_block_data, + dtype=dtype) + ops.convert_fp8(dequantized_last_block, last_block_data, + scale.item()) + gathered_rows.append(dequantized_last_block) + else: + gathered_rows.append(last_block_data) batch_expected = torch.cat(gathered_rows, dim=0) expected_batches.append(batch_expected) expected = torch.cat(expected_batches, dim=0) opcheck( - torch.ops._C_cache_ops.gather_cache, - (src_cache, dst, block_table, cu_seq_lens, batch_size, None), + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, None), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, kv_cache_dtype, + scale, None) torch.testing.assert_close(dst, expected) diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 81841be58352..abcfe828d5ac 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -13,11 +13,17 @@ from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) - assert cos_diff < 1e-5 + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ if not is_flashmla_supported()[0] else "FlashMLA is supported" @@ -27,7 +33,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) -@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @pytest.mark.parametrize("h_q", [16, 32, 64, 128]) @pytest.mark.parametrize("h_kv", [1]) @pytest.mark.parametrize("d", [576]) @@ -35,20 +41,26 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("torch_dtype", + [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) @torch.inference_mode() def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen, dtype): + varlen, torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(dtype) + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}") + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv) + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None + def flash_mla(): return flash_mla_with_kvcache( q, @@ -81,6 +106,8 @@ def flash_mla(): tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def scaled_dot_product_attention(query, key, value, is_causal=False): @@ -104,29 +131,35 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): return attn_weight @ value, lse def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - ref_O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), is_causal=causal, ) - out[i] = ref_O.transpose(0, 1) - lse[i] = LSE + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i return out, lse out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + - b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " - f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0d556053f898..22e52c054e1f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1603,14 +1603,18 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) -def gather_cache(src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - seq_starts: Optional[torch.Tensor] = None) -> None: - torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, seq_starts) +def gather_and_maybe_dequant_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( + src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, + scale, seq_starts) def get_device_attribute(attribute: int, device: int) -> int: diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 8ff7f5674323..9d6ab7e3217b 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -837,8 +837,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.context_chunk_workspace_size // num_prefills_with_context # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel cannot + # handle `context_chunk_starts` that are not aligned to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) @@ -1082,6 +1082,7 @@ def _compute_prefill_context( q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): prefill_metadata = attn_metadata.prefill_metadata assert prefill_metadata is not None @@ -1103,12 +1104,14 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_tables, cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], batch_size=prefill_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.context_chunk_starts[i], ) @@ -1165,6 +1168,7 @@ def _forward_prefill( k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: prefill_metadata = attn_metadata.prefill_metadata @@ -1197,7 +1201,7 @@ def _forward_prefill( # ROCm flash_attn_varlen_func will return 3 objects instead of 2 suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1287,7 +1291,7 @@ def forward( if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: decode_q_nope, decode_q_pe = decode_q.split( diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 1af26dfc3daa..564042cf8eb1 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -67,6 +67,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -81,6 +83,8 @@ def flash_mla_with_kvcache( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + descale_q: (batch_size), torch.float32. Descaling factors for Q. + descale_k: (batch_size), torch.float32. Descaling factors for K. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -98,6 +102,8 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + descale_q, + descale_k, ) return out, softmax_lse diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6fc894827c4a..2dce722837c9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1404,10 +1404,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": supported = current_platform.is_kv_cache_dtype_supported( - self.kv_cache_dtype) + self.kv_cache_dtype, model_config) if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 321db8287c0f..392927b29092 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -495,16 +495,41 @@ def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") - will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + attention_backend = envs.VLLM_ATTENTION_BACKEND + supported = False - if cls.is_device_capability(100): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - supported = flash_attn_supports_fp8() + if model_config is not None and model_config.use_mla: + # Default to CutlassMLA for blackwell, + # FlashMLA otherwise + if attention_backend is None: + if cls.is_device_capability(100): + attention_backend = "CUTLASS_MLA" + else: + attention_backend = "FLASHMLA" + + # Only FlashMLA supports fp8 + if attention_backend == "FLASHMLA": + supported = True + else: + supported = (not fp8_attention) + else: + # Default to FlashAttention + if attention_backend is None: + attention_backend = "FLASH_ATTN_VLLM_V1" + + # All Blackwell backends support fp8 + if cls.is_device_capability(100): + supported = True + elif attention_backend == "FLASH_ATTN_VLLM_V1": + if fp8_attention: + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + else: + supported = True return supported diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 4017f1ca7eec..521e5d54ba00 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -562,7 +562,8 @@ def stateless_init_device_torch_dist_pg( raise RuntimeError(f"Unsupported torch distributed backend: {backend}") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3ede86e15855..317bc401a799 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -459,5 +459,6 @@ def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index ba06abd07f08..4324ce7dcb91 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -194,7 +194,8 @@ def validate_request( raise ValueError("Torch XLA does not support per-request seed.") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: return True diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f2610671f769..c8636469cf6d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -635,8 +635,9 @@ def build(self, if self.aot_schedule: # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size + # currently the `gather_and_maybe_dequant_cache` kernel + # cannot handle `context_chunk_starts` that are not aligned + # to page_size max_context_chunk = round_down(max_context_chunk, self.page_size) @@ -1009,6 +1010,7 @@ def _compute_prefill_context( q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill @@ -1021,12 +1023,14 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - ops.gather_cache( + ops.gather_and_maybe_dequant_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, + kv_cache_dtype=self.kv_cache_dtype, + scale=k_scale, seq_starts=prefill_metadata.chunked_context.starts[i], ) @@ -1077,6 +1081,7 @@ def _forward_prefill( k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, ) -> torch.Tensor: assert attn_metadata.prefill is not None @@ -1099,7 +1104,7 @@ def _forward_prefill( if has_context: suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1123,6 +1128,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, + layer: AttentionLayer, ) -> torch.Tensor: raise NotImplementedError @@ -1150,6 +1156,8 @@ def forward( # same expert outputs. return output.fill_(0) + fp8_attention = self.kv_cache_dtype.startswith("fp8") + num_actual_toks = attn_metadata.num_actual_tokens # Inputs and outputs may be padded for CUDA graphs @@ -1184,10 +1192,13 @@ def forward( scale=layer._k_scale, ) + if fp8_attention: + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + if has_prefill: output[num_decode_tokens:] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + attn_metadata, layer._k_scale) if has_decode: assert attn_metadata.decode is not None @@ -1200,7 +1211,21 @@ def forward( # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) + if fp8_attention: + ql_nope_shape = decode_ql_nope.shape + decode_ql_nope, _ = ops.scaled_fp8_quant( + decode_ql_nope.reshape([ + ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] + ]), layer._q_scale) + decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) + q_pe_shape = decode_q_pe.shape + decode_q_pe, _ = ops.scaled_fp8_quant( + decode_q_pe.reshape( + [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale) + decode_q_pe = decode_q_pe.reshape(q_pe_shape) + output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 6e1e5d6533da..7d6efc8830f6 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -7,7 +7,7 @@ import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -278,6 +278,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: if self._use_old_cutlass_mla: # TODO: Remove the old cutlass MLA kernel after more extensive diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 11674423400c..1c50144d4790 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,8 +6,7 @@ import torch -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -166,16 +165,13 @@ def __init__( "are not implemented for " "FlashMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA V1 with FP8 KV cache not yet supported") - def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -194,6 +190,8 @@ def _forward_decode( num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, + descale_q=layer._q_scale.reshape(1), + descale_k=layer._k_scale.reshape(1), ) return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 082c7e6f7c62..870cc600388e 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -7,6 +7,7 @@ import torch import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionLayer from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv @@ -221,6 +222,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 700fce68953e..f2974ed668d9 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -6,7 +6,7 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionType, +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention @@ -127,6 +127,7 @@ def _forward_decode( q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None