Skip to content

Commit 8dfbf29

Browse files
Dequant within gather_cache kernel
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent dd7977d commit 8dfbf29

File tree

7 files changed

+88
-79
lines changed

7 files changed

+88
-79
lines changed

csrc/cache.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
4040
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
4141
const double scale, const std::string& kv_cache_dtype);
4242

43-
void gather_cache(
43+
void gather_and_maybe_dequant_cache(
4444
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
4545
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
4646
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
4747
torch::Tensor const& cu_seq_lens, // [BATCH+1]
48-
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
48+
int64_t batch_size, const std::string& kv_cache_dtype,
49+
torch::Tensor const& scale,
50+
std::optional<torch::Tensor> seq_starts = std::nullopt);

csrc/cache_kernels.cu

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -624,16 +624,17 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
624624
namespace vllm {
625625

626626
// grid is launched with dimensions (batch, num_splits)
627-
template <typename scalar_t>
628-
__global__ void gather_cache(
629-
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
627+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
628+
__global__ void gather_and_maybe_dequant_cache(
629+
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
630630
// ENTRIES...]
631631
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
632632
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
633633
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
634634
const int32_t block_size, const int32_t entry_size,
635635
const int64_t block_table_stride, const int64_t cache_block_stride,
636636
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
637+
const float* __restrict__ scale,
637638
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
638639
// batch
639640

@@ -675,10 +676,16 @@ __global__ void gather_cache(
675676
if (partial_block_size) full_blocks_end -= 1;
676677
}
677678

678-
auto copy_entry = [&](const scalar_t* __restrict__ _src,
679+
auto copy_entry = [&](const cache_t* __restrict__ _src,
679680
scalar_t* __restrict__ _dst) {
680-
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
681-
_dst[i] = _src[i];
681+
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
682+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
683+
_dst[i] = static_cast<scalar_t>(_src[i]);
684+
} else {
685+
_dst[i] =
686+
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
687+
}
688+
}
682689
};
683690

684691
for (int pid = split_start; pid < full_blocks_end; ++pid) {
@@ -705,25 +712,31 @@ __global__ void gather_cache(
705712
} // namespace vllm
706713

707714
// Macro to dispatch the kernel based on the data type.
708-
#define CALL_GATHER_CACHE(CPY_DTYPE) \
709-
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
710-
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
711-
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
712-
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
713-
block_size, entry_size, block_table_stride, cache_block_stride, \
714-
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
715+
// SCALAR_T is the data type of the destination tensor.
716+
// CACHE_T is the stored data type of kv-cache.
717+
// KV_DTYPE is the real data type of kv-cache.
718+
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
719+
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
720+
<<<grid, block, 0, stream>>>( \
721+
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
722+
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
723+
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
724+
block_size, entry_size, block_table_stride, cache_block_stride, \
725+
cache_entry_stride, dst_entry_stride, \
726+
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
715727

716728
// Gather sequences from the cache into the destination tensor.
717729
// - cu_seq_lens contains the cumulative sequence lengths for each batch
718730
// - block_table contains the cache block indices for each sequence
719731
// - Optionally, seq_starts (if provided) offsets the starting block index by
720732
// (seq_starts[bid] / page_size)
721-
void gather_cache(
733+
void gather_and_maybe_dequant_cache(
722734
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
723735
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
724736
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
725737
torch::Tensor const& cu_seq_lens, // [BATCH+1]
726-
int64_t batch_size,
738+
int64_t batch_size, const std::string& kv_cache_dtype,
739+
torch::Tensor const& scale,
727740
std::optional<torch::Tensor> seq_starts = std::nullopt) {
728741
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
729742
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -761,20 +774,8 @@ void gather_cache(
761774
dim3 grid(batch_size, num_splits);
762775
dim3 block(1024);
763776

764-
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
765-
"src_cache and dst must have the same dtype");
766-
767-
const int dtype_bits = src_cache.element_size() * 8;
768777
const int32_t* seq_starts_ptr =
769778
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
770779

771-
if (dtype_bits == 32) {
772-
CALL_GATHER_CACHE(uint32_t);
773-
} else if (dtype_bits == 16) {
774-
CALL_GATHER_CACHE(uint16_t);
775-
} else if (dtype_bits == 8) {
776-
CALL_GATHER_CACHE(uint8_t);
777-
} else {
778-
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
779-
}
780+
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
780781
}

csrc/torch_bindings.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,11 +703,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
703703
"str kv_cache_dtype) -> ()");
704704
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
705705

706-
// Gather cache blocks from src_cache to dst.
706+
// Gather cache blocks from src_cache to dst, dequantizing from
707+
// src_cache's dtype to dst's dtype if necessary.
707708
cache_ops.def(
708-
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
709-
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
710-
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
709+
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
710+
" Tensor block_table, Tensor cu_seq_lens, "
711+
" int batch_size, "
712+
" str kv_cache_dtype, "
713+
" Tensor scale, Tensor? seq_starts) -> ()");
714+
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
715+
&gather_and_maybe_dequant_cache);
711716
}
712717

713718
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {

tests/kernels/attention/test_cache.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -709,14 +709,15 @@ def test_swap_blocks_mla(
709709
@pytest.mark.parametrize("max_seq_len", [512])
710710
@pytest.mark.parametrize("batch_size", [8])
711711
@pytest.mark.parametrize("dtype", [torch.float32])
712-
@pytest.mark.parametrize("kv_cache_dtype",
713-
["auto"]) # You can also test "fp8" if needed.
712+
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
714713
@pytest.mark.parametrize("device", CUDA_DEVICES)
715714
@torch.inference_mode()
716-
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
717-
num_blocks, max_seq_len, batch_size, dtype,
718-
kv_cache_dtype, device):
715+
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
716+
block_size, num_blocks,
717+
max_seq_len, batch_size, dtype,
718+
kv_cache_dtype, device):
719719
entry_size = kv_lora_rank + qk_rope_head_dim
720+
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
720721
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
721722
kv_cache_dtype, device)
722723
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
@@ -742,10 +743,9 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
742743
perm = torch.randperm(num_blocks, device=device)
743744
block_table[b, :] = perm
744745

745-
dst = torch.zeros((total_tokens, entry_size),
746-
dtype=src_cache.dtype,
747-
device=device)
746+
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
748747

748+
# TODO - do dequant here
749749
expected_batches = []
750750
for b in range(batch_size):
751751
s = seq_len_tensor[b]
@@ -765,12 +765,15 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
765765
expected = torch.cat(expected_batches, dim=0)
766766

767767
opcheck(
768-
torch.ops._C_cache_ops.gather_cache,
769-
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
768+
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
769+
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
770+
scale, None),
770771
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
771772
)
772773

773-
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
774+
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
775+
cu_seq_lens, batch_size, kv_cache_dtype,
776+
scale, None)
774777
torch.testing.assert_close(dst, expected)
775778

776779

vllm/_custom_ops.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,14 +1672,18 @@ def convert_fp8(output: torch.Tensor,
16721672
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
16731673

16741674

1675-
def gather_cache(src_cache: torch.Tensor,
1676-
dst: torch.Tensor,
1677-
block_table: torch.Tensor,
1678-
cu_seq_lens: torch.Tensor,
1679-
batch_size: int,
1680-
seq_starts: Optional[torch.Tensor] = None) -> None:
1681-
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
1682-
cu_seq_lens, batch_size, seq_starts)
1675+
def gather_and_maybe_dequant_cache(
1676+
src_cache: torch.Tensor,
1677+
dst: torch.Tensor,
1678+
block_table: torch.Tensor,
1679+
cu_seq_lens: torch.Tensor,
1680+
batch_size: int,
1681+
kv_cache_dtype: str,
1682+
scale: torch.Tensor,
1683+
seq_starts: Optional[torch.Tensor] = None) -> None:
1684+
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache(
1685+
src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
1686+
scale, seq_starts)
16831687

16841688

16851689
def get_device_attribute(attribute: int, device: int) -> int:

vllm/attention/backends/mla/common.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -922,8 +922,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
922922
self.context_chunk_workspace_size // num_prefills_with_context
923923

924924
# align max_context_chunk to page_size by rounding down,
925-
# currently the `gather_cache` kernel cannot handle
926-
# `context_chunk_starts` that are not aligned to page_size
925+
# currently the `gather_and_maybe_dequant_cache` kernel cannot
926+
# handle `context_chunk_starts` that are not aligned to page_size
927927
max_context_chunk = round_down(max_context_chunk, self.page_size)
928928
assert max_context_chunk > 0
929929
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
@@ -1167,6 +1167,7 @@ def _compute_prefill_context(
11671167
q: torch.Tensor,
11681168
kv_c_and_k_pe_cache: torch.Tensor,
11691169
attn_metadata: MLACommonMetadata,
1170+
k_scale: torch.Tensor,
11701171
):
11711172
prefill_metadata = attn_metadata.prefill_metadata
11721173
assert prefill_metadata is not None
@@ -1188,12 +1189,14 @@ def _compute_prefill_context(
11881189
for i in range(iters):
11891190
toks = prefill_metadata.context_chunk_seq_tot[i]
11901191

1191-
ops.gather_cache(
1192+
ops.gather_and_maybe_dequant_cache(
11921193
src_cache=kv_c_and_k_pe_cache,
11931194
dst=workspace,
11941195
block_table=prefill_metadata.block_tables,
11951196
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
11961197
batch_size=prefill_metadata.num_prefills,
1198+
kv_cache_dtype=self.kv_cache_dtype,
1199+
scale=k_scale,
11971200
seq_starts=prefill_metadata.context_chunk_starts[i],
11981201
)
11991202

@@ -1250,6 +1253,7 @@ def _forward_prefill(
12501253
k_pe: torch.Tensor,
12511254
kv_c_and_k_pe_cache: torch.Tensor,
12521255
attn_metadata: MLACommonMetadata,
1256+
k_scale: torch.Tensor,
12531257
) -> torch.Tensor:
12541258

12551259
prefill_metadata = attn_metadata.prefill_metadata
@@ -1282,7 +1286,7 @@ def _forward_prefill(
12821286
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
12831287
suffix_output, suffix_lse = output
12841288
context_output, context_lse = self._compute_prefill_context( \
1285-
q, kv_c_and_k_pe_cache, attn_metadata)
1289+
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
12861290

12871291
output = torch.empty_like(suffix_output)
12881292
merge_attn_states(
@@ -1372,7 +1376,7 @@ def forward(
13721376
if has_prefill:
13731377
output[:num_prefill_tokens] = self._forward_prefill(
13741378
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
1375-
attn_metadata)
1379+
attn_metadata, layer._k_scale)
13761380

13771381
if has_decode:
13781382
decode_q_nope, decode_q_pe = decode_q.split(

vllm/v1/attention/backends/mla/common.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,6 @@ def __init__(self,
427427
self.page_size = self.kv_cache_spec.block_size
428428

429429
if self.chunked_prefill_enabled:
430-
workspace_dtype = self.model_config.dtype
431-
if cache_config.cache_dtype.startswith("fp8"):
432-
workspace_dtype = current_platform.fp8_dtype()
433430
self.chunked_prefill_workspace_size = min(
434431
# Max sure there is enough for 8 full length request or at least
435432
# 4 pages of cache per request
@@ -450,7 +447,7 @@ def __init__(self,
450447
self.chunked_prefill_workspace = torch.empty(
451448
(self.chunked_prefill_workspace_size,
452449
self.model_config.get_head_size()),
453-
dtype=workspace_dtype,
450+
dtype=self.model_config.dtype,
454451
device=device,
455452
)
456453

@@ -638,8 +635,9 @@ def build(self,
638635

639636
if self.aot_schedule:
640637
# align max_context_chunk to page_size by rounding down,
641-
# currently the `gather_cache` kernel cannot handle
642-
# `context_chunk_starts` that are not aligned to page_size
638+
# currently the `gather_and_maybe_dequant_cache` kernel
639+
# cannot handle `context_chunk_starts` that are not aligned
640+
# to page_size
643641
max_context_chunk = round_down(max_context_chunk,
644642
self.page_size)
645643

@@ -1016,6 +1014,7 @@ def _compute_prefill_context(
10161014
q: torch.Tensor,
10171015
kv_c_and_k_pe_cache: torch.Tensor,
10181016
attn_metadata: MLACommonMetadata,
1017+
k_scale: torch.Tensor,
10191018
):
10201019
assert attn_metadata.prefill is not None
10211020
prefill_metadata = attn_metadata.prefill
@@ -1025,17 +1024,17 @@ def _compute_prefill_context(
10251024
iters = len(prefill_metadata.chunked_context.seq_tot)
10261025
workspace = prefill_metadata.chunked_context.workspace
10271026

1028-
fp8_attention = self.kv_cache_dtype.startswith("fp8")
1029-
10301027
for i in range(iters):
10311028
toks = prefill_metadata.chunked_context.seq_tot[i]
10321029

1033-
ops.gather_cache(
1030+
ops.gather_and_maybe_dequant_cache(
10341031
src_cache=kv_c_and_k_pe_cache,
10351032
dst=workspace,
10361033
block_table=prefill_metadata.block_table,
10371034
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
10381035
batch_size=attn_metadata.num_prefills,
1036+
kv_cache_dtype=self.kv_cache_dtype,
1037+
scale=k_scale,
10391038
seq_starts=prefill_metadata.chunked_context.starts[i],
10401039
)
10411040

@@ -1044,16 +1043,6 @@ def _compute_prefill_context(
10441043
k_pe = workspace[:toks]\
10451044
[..., self.kv_lora_rank:].unsqueeze(1)
10461045

1047-
if fp8_attention:
1048-
target_dtype = self.kv_b_proj.weight.dtype
1049-
kv_c_normed_dequant = torch.empty_like(kv_c_normed,
1050-
dtype=target_dtype)
1051-
k_pe_dequant = torch.empty_like(k_pe, dtype=target_dtype)
1052-
ops.convert_fp8(kv_c_normed_dequant, kv_c_normed)
1053-
ops.convert_fp8(k_pe_dequant, k_pe)
1054-
kv_c_normed = kv_c_normed_dequant
1055-
k_pe = k_pe_dequant
1056-
10571046
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
10581047
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
10591048
k_nope, v = kv_nope\
@@ -1096,6 +1085,7 @@ def _forward_prefill(
10961085
k_pe: torch.Tensor,
10971086
kv_c_and_k_pe_cache: torch.Tensor,
10981087
attn_metadata: MLACommonMetadata,
1088+
k_scale: torch.Tensor,
10991089
) -> torch.Tensor:
11001090
assert attn_metadata.prefill is not None
11011091

@@ -1118,7 +1108,7 @@ def _forward_prefill(
11181108
if has_context:
11191109
suffix_output, suffix_lse = output
11201110
context_output, context_lse = self._compute_prefill_context( \
1121-
q, kv_c_and_k_pe_cache, attn_metadata)
1111+
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
11221112

11231113
output = torch.empty_like(suffix_output)
11241114
merge_attn_states(
@@ -1212,7 +1202,7 @@ def forward(
12121202
if has_prefill:
12131203
output[num_decode_tokens:] = self._forward_prefill(
12141204
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
1215-
attn_metadata)
1205+
attn_metadata, layer._k_scale)
12161206

12171207
if has_decode:
12181208
assert attn_metadata.decode is not None

0 commit comments

Comments
 (0)