Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cmake/external_projects/flashmla.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand All @@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)

set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/include)
${flashmla_SOURCE_DIR}/csrc)

set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"
Expand Down
6 changes: 4 additions & 2 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);

void gather_cache(
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
int64_t batch_size, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt);
57 changes: 29 additions & 28 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -624,16 +624,17 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {

// grid is launched with dimensions (batch, num_splits)
template <typename scalar_t>
__global__ void gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void gather_and_maybe_dequant_cache(
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch

Expand Down Expand Up @@ -675,10 +676,16 @@ __global__ void gather_cache(
if (partial_block_size) full_blocks_end -= 1;
}

auto copy_entry = [&](const scalar_t* __restrict__ _src,
auto copy_entry = [&](const cache_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
_dst[i] = static_cast<scalar_t>(_src[i]);
} else {
_dst[i] =
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
}
}
};

for (int pid = split_start; pid < full_blocks_end; ++pid) {
Expand All @@ -705,25 +712,31 @@ __global__ void gather_cache(
} // namespace vllm

// Macro to dispatch the kernel based on the data type.
#define CALL_GATHER_CACHE(CPY_DTYPE) \
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, \
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);

// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_cache(
void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
int64_t batch_size, const std::string& kv_cache_dtype,
torch::Tensor const& scale,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -761,20 +774,8 @@ void gather_cache(
dim3 grid(batch_size, num_splits);
dim3 block(1024);

TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");

const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;

if (dtype_bits == 32) {
CALL_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
}
13 changes: 9 additions & 4 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,11 +676,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);

// Gather cache blocks from src_cache to dst.
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
cache_ops.def(
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" int batch_size, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
&gather_and_maybe_dequant_cache);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
Expand Down
42 changes: 29 additions & 13 deletions tests/kernels/attention/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,14 +709,15 @@ def test_swap_blocks_mla(
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
block_size, num_blocks,
max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
Expand All @@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm

dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)

expected_batches = []
for b in range(batch_size):
Expand All @@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,

gathered_rows = []
for i in range(tot - 1):
gathered_rows.append(src_cache[blocks[i]])
block_data = src_cache[blocks[i]]
if kv_cache_dtype == "fp8":
dequantized_block = torch.empty_like(block_data, dtype=dtype)
ops.convert_fp8(dequantized_block, block_data, scale.item())
gathered_rows.append(dequantized_block)
else:
gathered_rows.append(block_data)
remaining = s - (tot - 1) * block_size
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
last_block_data = src_cache[blocks[-1], :remaining, :]
if kv_cache_dtype == "fp8":
dequantized_last_block = torch.empty_like(last_block_data,
dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data,
scale.item())
gathered_rows.append(dequantized_last_block)
else:
gathered_rows.append(last_block_data)

batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)

opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)

ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, kv_cache_dtype,
scale, None)
torch.testing.assert_close(dst, expected)


Expand Down
69 changes: 51 additions & 18 deletions tests/kernels/attention/test_flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
from vllm.triton_utils import triton


def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
def cal_diff(x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
assert cos_diff < 1e-5
if (use_fp8):
assert cos_diff < 1e-4
else:
assert cos_diff < 1e-5

FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported"
Expand All @@ -27,28 +33,34 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
@pytest.mark.parametrize("h_kv", [1])
@pytest.mark.parametrize("d", [576])
@pytest.mark.parametrize("dv", [512])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, dtype):
varlen, torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)

print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")

use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
Expand All @@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv)

init_dtype = q.dtype
if use_fp8:
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)

q = q.to(fp8_dtype)
blocked_k = blocked_k.to(fp8_dtype)
blocked_v = blocked_v.to(fp8_dtype)
else:
descale_q = None
descale_k = None

def flash_mla():
return flash_mla_with_kvcache(
q,
Expand All @@ -81,6 +106,8 @@ def flash_mla():
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)

def scaled_dot_product_attention(query, key, value, is_causal=False):
Expand All @@ -104,29 +131,35 @@ def scaled_dot_product_attention(query, key, value, is_causal=False):
return attn_weight @ value, lse

def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
ref_O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
out_i, lse_i = scaled_dot_product_attention(
q_[i].transpose(0, 1),
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
is_causal=causal,
)
out[i] = ref_O.transpose(0, 1)
lse[i] = LSE
out[i] = out_i.transpose(0, 1)
lse[i] = lse_i
return out, lse

out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse")

t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
bytes = (total_seqlens * h_kv * d +
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
f"{bytes / 10 ** 6 / t:.0f} GB/s")
Loading