Support arbitrary KV cache strides in paged_decode for MLA#165
Support arbitrary KV cache strides in paged_decode for MLA#165baodii wants to merge 9 commits intovllm-project:mainfrom
Conversation
baodii
commented
Feb 28, 2026
- Remove CHECK_CONTIGUOUS for k/v in flash_api.cpp (stride(-1)==1 still enforced)
- Add KV cache stride fields to paged_decode_args_t
- Extract actual tensor strides in paged_decode_xe2.cpp
- Use actual strides in DecodeKernelLauncher::initialize() instead of packed strides
- Replace make_ordered_layout with make_layout using passed strides for K/V in kernel
- Add test_decode_with_paged_kv_mla unit test with non-contiguous KV cache slices
- Remove CHECK_CONTIGUOUS for k/v in flash_api.cpp (stride(-1)==1 still enforced) - Add KV cache stride fields to paged_decode_args_t - Extract actual tensor strides in paged_decode_xe2.cpp - Use actual strides in DecodeKernelLauncher::initialize() instead of packed strides - Replace make_ordered_layout with make_layout using passed strides for K/V in kernel - Add test_decode_with_paged_kv_mla unit test with non-contiguous KV cache slices Signed-off-by: baodii <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
|
qq: we will share same interface with current attention(follow this https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/flashattn_mla.py#L334) , instead of a new interface like https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/ops/flashmla.py#L140 https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/cutlass_mla.py#L227 |
There was a problem hiding this comment.
Pull request overview
Adds support for non-contiguous (but last-dim contiguous) paged KV cache layouts in XPU MLA decode by plumbing real tensor strides into the decode launcher/kernel and validating via a new unit test.
Changes:
- Plumb K/V cache strides through
paged_decode_args_tand extract real tensor strides at the call site. - Use passed-in strides in
DecodeKernelLauncher::initialize()and in kernel layouts for K/V. - Add a unit test covering MLA-like non-contiguous KV cache slices.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/flash_attn/test_flash_attn_varlen_func.py | Adds an MLA-like paged KV decode test using non-contiguous K/V views |
| csrc/xpu/attn/xe_2/paged_decode_xe2.cpp | Populates decode args with K/V tensor strides |
| csrc/xpu/attn/xe_2/paged_decode.hpp | Extends args for K/V strides; uses real strides in launcher initialization |
| csrc/xpu/attn/xe_2/kernel/paged_decode_kernel.hpp | Switches K/V layouts to use runtime-provided strides |
| csrc/flash_attn/flash_api.cpp | Removes contiguity requirement for K/V tensors (last-dim contiguous still enforced elsewhere) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ | ||
| f"{torch.max(torch.abs(output - ref_output))}" |
There was a problem hiding this comment.
The trailing comma turns this into a tuple expression, so the failure message is never applied and the line has no effect beyond executing assert_close (and the second line becomes a dead expression). If you want a custom message, pass it via msg=... to torch.testing.assert_close (or remove the comma/backslash and just rely on assert_close raising with its own diff).
| torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ | |
| f"{torch.max(torch.abs(output - ref_output))}" | |
| torch.testing.assert_close( | |
| output, | |
| ref_output, | |
| atol=atol, | |
| rtol=rtol, | |
| msg=f"{torch.max(torch.abs(output - ref_output))}", | |
| ) |
| torch.set_default_device("xpu") | ||
| torch.xpu.set_device("xpu:0") | ||
| torch.manual_seed(42) | ||
| num_seqs = len(seq_lens) | ||
| query_lens = [x[0] for x in seq_lens] | ||
| kv_lens = [x[1] for x in seq_lens] | ||
| num_query_heads = num_heads[0] | ||
| num_kv_heads = num_heads[1] | ||
| assert num_query_heads % num_kv_heads == 0 | ||
| max_query_len = max(query_lens) | ||
| max_kv_len = max(kv_lens) | ||
| scale = head_size**-0.5 | ||
|
|
||
| query = torch.randn(sum(query_lens), | ||
| num_query_heads, | ||
| head_size, | ||
| dtype=dtype) | ||
|
|
||
| # MLA-like combined KV cache: K and V share a single buffer | ||
| combined_head_dim = head_size * combined_head_dim_factor | ||
| combined_kv_cache = torch.randn(num_blocks, | ||
| block_size, | ||
| num_kv_heads, | ||
| combined_head_dim, | ||
| dtype=dtype) | ||
|
|
||
| # Non-contiguous slices simulating MLA K/V extraction | ||
| key_cache = combined_kv_cache[..., :head_size] | ||
| value_cache = combined_kv_cache[..., head_size:2 * head_size] | ||
|
|
||
| assert not key_cache.is_contiguous(), "key_cache should be non-contiguous" | ||
| assert not value_cache.is_contiguous(), \ | ||
| "value_cache should be non-contiguous" | ||
| assert key_cache.stride(-1) == 1 | ||
| assert value_cache.stride(-1) == 1 | ||
|
|
||
| cu_query_lens = torch.tensor([0] + query_lens, | ||
| dtype=torch.int32).cumsum(dim=0, | ||
| dtype=torch.int32) | ||
| seq_k = torch.tensor(kv_lens, dtype=torch.int32) | ||
|
|
||
| max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size | ||
| block_tables = torch.randint(0, | ||
| num_blocks, | ||
| (num_seqs, max_num_blocks_per_seq), | ||
| dtype=torch.int32) | ||
|
|
||
| output = flash_attn_varlen_func(query, | ||
| key_cache, | ||
| value_cache, | ||
| max_query_len, | ||
| cu_query_lens, | ||
| max_kv_len, | ||
| seqused_k=seq_k, | ||
| softmax_scale=scale, | ||
| causal=False, | ||
| block_table=block_tables, | ||
| window_size=(-1, -1)) | ||
|
|
||
| ref_output = ref_paged_attn(query=query, | ||
| key_cache=key_cache, | ||
| value_cache=value_cache, | ||
| query_lens=query_lens, | ||
| kv_lens=kv_lens, | ||
| block_tables=block_tables, | ||
| scale=scale, | ||
| casual=False, | ||
| is_paged=True, | ||
| sink=None, | ||
| window_size_left=-1, | ||
| window_size_right=-1) | ||
| atol, rtol = 1e-2, 1e-2 | ||
| torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ | ||
| f"{torch.max(torch.abs(output - ref_output))}" | ||
| torch.xpu.empty_cache() |
There was a problem hiding this comment.
torch.set_default_device(...) is global process state and can leak into other tests, causing order-dependent failures. Prefer allocating tensors with an explicit device="xpu" (and keeping torch.xpu.set_device(...) only if needed), or restore the prior default device at the end of the test.
| torch.set_default_device("xpu") | |
| torch.xpu.set_device("xpu:0") | |
| torch.manual_seed(42) | |
| num_seqs = len(seq_lens) | |
| query_lens = [x[0] for x in seq_lens] | |
| kv_lens = [x[1] for x in seq_lens] | |
| num_query_heads = num_heads[0] | |
| num_kv_heads = num_heads[1] | |
| assert num_query_heads % num_kv_heads == 0 | |
| max_query_len = max(query_lens) | |
| max_kv_len = max(kv_lens) | |
| scale = head_size**-0.5 | |
| query = torch.randn(sum(query_lens), | |
| num_query_heads, | |
| head_size, | |
| dtype=dtype) | |
| # MLA-like combined KV cache: K and V share a single buffer | |
| combined_head_dim = head_size * combined_head_dim_factor | |
| combined_kv_cache = torch.randn(num_blocks, | |
| block_size, | |
| num_kv_heads, | |
| combined_head_dim, | |
| dtype=dtype) | |
| # Non-contiguous slices simulating MLA K/V extraction | |
| key_cache = combined_kv_cache[..., :head_size] | |
| value_cache = combined_kv_cache[..., head_size:2 * head_size] | |
| assert not key_cache.is_contiguous(), "key_cache should be non-contiguous" | |
| assert not value_cache.is_contiguous(), \ | |
| "value_cache should be non-contiguous" | |
| assert key_cache.stride(-1) == 1 | |
| assert value_cache.stride(-1) == 1 | |
| cu_query_lens = torch.tensor([0] + query_lens, | |
| dtype=torch.int32).cumsum(dim=0, | |
| dtype=torch.int32) | |
| seq_k = torch.tensor(kv_lens, dtype=torch.int32) | |
| max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size | |
| block_tables = torch.randint(0, | |
| num_blocks, | |
| (num_seqs, max_num_blocks_per_seq), | |
| dtype=torch.int32) | |
| output = flash_attn_varlen_func(query, | |
| key_cache, | |
| value_cache, | |
| max_query_len, | |
| cu_query_lens, | |
| max_kv_len, | |
| seqused_k=seq_k, | |
| softmax_scale=scale, | |
| causal=False, | |
| block_table=block_tables, | |
| window_size=(-1, -1)) | |
| ref_output = ref_paged_attn(query=query, | |
| key_cache=key_cache, | |
| value_cache=value_cache, | |
| query_lens=query_lens, | |
| kv_lens=kv_lens, | |
| block_tables=block_tables, | |
| scale=scale, | |
| casual=False, | |
| is_paged=True, | |
| sink=None, | |
| window_size_left=-1, | |
| window_size_right=-1) | |
| atol, rtol = 1e-2, 1e-2 | |
| torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ | |
| f"{torch.max(torch.abs(output - ref_output))}" | |
| torch.xpu.empty_cache() | |
| prev_default_device = torch.get_default_device() | |
| prev_xpu_device = torch.xpu.current_device() | |
| try: | |
| torch.set_default_device("xpu") | |
| torch.xpu.set_device("xpu:0") | |
| torch.manual_seed(42) | |
| num_seqs = len(seq_lens) | |
| query_lens = [x[0] for x in seq_lens] | |
| kv_lens = [x[1] for x in seq_lens] | |
| num_query_heads = num_heads[0] | |
| num_kv_heads = num_heads[1] | |
| assert num_query_heads % num_kv_heads == 0 | |
| max_query_len = max(query_lens) | |
| max_kv_len = max(kv_lens) | |
| scale = head_size**-0.5 | |
| query = torch.randn(sum(query_lens), | |
| num_query_heads, | |
| head_size, | |
| dtype=dtype) | |
| # MLA-like combined KV cache: K and V share a single buffer | |
| combined_head_dim = head_size * combined_head_dim_factor | |
| combined_kv_cache = torch.randn(num_blocks, | |
| block_size, | |
| num_kv_heads, | |
| combined_head_dim, | |
| dtype=dtype) | |
| # Non-contiguous slices simulating MLA K/V extraction | |
| key_cache = combined_kv_cache[..., :head_size] | |
| value_cache = combined_kv_cache[..., head_size:2 * head_size] | |
| assert not key_cache.is_contiguous(), \ | |
| "key_cache should be non-contiguous" | |
| assert not value_cache.is_contiguous(), \ | |
| "value_cache should be non-contiguous" | |
| assert key_cache.stride(-1) == 1 | |
| assert value_cache.stride(-1) == 1 | |
| cu_query_lens = torch.tensor([0] + query_lens, | |
| dtype=torch.int32).cumsum( | |
| dim=0, dtype=torch.int32) | |
| seq_k = torch.tensor(kv_lens, dtype=torch.int32) | |
| max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size | |
| block_tables = torch.randint(0, | |
| num_blocks, | |
| (num_seqs, max_num_blocks_per_seq), | |
| dtype=torch.int32) | |
| output = flash_attn_varlen_func(query, | |
| key_cache, | |
| value_cache, | |
| max_query_len, | |
| cu_query_lens, | |
| max_kv_len, | |
| seqused_k=seq_k, | |
| softmax_scale=scale, | |
| causal=False, | |
| block_table=block_tables, | |
| window_size=(-1, -1)) | |
| ref_output = ref_paged_attn(query=query, | |
| key_cache=key_cache, | |
| value_cache=value_cache, | |
| query_lens=query_lens, | |
| kv_lens=kv_lens, | |
| block_tables=block_tables, | |
| scale=scale, | |
| casual=False, | |
| is_paged=True, | |
| sink=None, | |
| window_size_left=-1, | |
| window_size_right=-1) | |
| atol, rtol = 1e-2, 1e-2 | |
| torch.testing.assert_close(output, | |
| ref_output, | |
| atol=atol, | |
| rtol=rtol), \ | |
| f"{torch.max(torch.abs(output - ref_output))}" | |
| torch.xpu.empty_cache() | |
| finally: | |
| torch.set_default_device(prev_default_device) | |
| torch.xpu.set_device(prev_xpu_device) |
| stride_K = StrideK{ | ||
| static_cast<int>(args.k_stride_seq), | ||
| _1{}, | ||
| static_cast<int>(args.k_stride_heads), | ||
| static_cast<int>(args.k_stride_page)}; | ||
| stride_V = StrideV{ | ||
| _1{}, | ||
| static_cast<int>(args.v_stride_seq), | ||
| static_cast<int>(args.v_stride_heads), | ||
| static_cast<int>(args.v_stride_page)}; |
There was a problem hiding this comment.
These int64_t stride values are narrowed to int, which can overflow for large allocations / large strides and silently produce incorrect addressing. Consider keeping stride types as 64-bit (if Cute/Cutlass stride types allow), or add a TORCH_CHECK range validation before casting and fail fast when strides exceed std::numeric_limits<int>::max().
| // KV cache strides | ||
| key_cache.stride(0), | ||
| key_cache.stride(1), | ||
| key_cache.stride(2), | ||
| value_cache.stride(0), | ||
| value_cache.stride(1), | ||
| value_cache.stride(2)}; |
There was a problem hiding this comment.
The downstream stride construction hard-codes the head-dimension stride to 1 (_1{}), so this path implicitly requires key_cache.stride(-1) == 1 and value_cache.stride(-1) == 1. To prevent silent misbehavior if this code is called outside the guarded API path, add an explicit TORCH_CHECK here (or immediately before kernel launch) enforcing last-dim contiguity for both K and V.
Signed-off-by: baodi <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: baodi <di.bao@intel.com>
Signed-off-by: baodi <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>