Skip to content

Support arbitrary KV cache strides in paged_decode for MLA#165

Open
baodii wants to merge 9 commits intovllm-project:mainfrom
baodii:mla-stride-support
Open

Support arbitrary KV cache strides in paged_decode for MLA#165
baodii wants to merge 9 commits intovllm-project:mainfrom
baodii:mla-stride-support

Conversation

@baodii
Copy link
Collaborator

@baodii baodii commented Feb 28, 2026

  • Remove CHECK_CONTIGUOUS for k/v in flash_api.cpp (stride(-1)==1 still enforced)
  • Add KV cache stride fields to paged_decode_args_t
  • Extract actual tensor strides in paged_decode_xe2.cpp
  • Use actual strides in DecodeKernelLauncher::initialize() instead of packed strides
  • Replace make_ordered_layout with make_layout using passed strides for K/V in kernel
  • Add test_decode_with_paged_kv_mla unit test with non-contiguous KV cache slices

- Remove CHECK_CONTIGUOUS for k/v in flash_api.cpp (stride(-1)==1 still enforced)
- Add KV cache stride fields to paged_decode_args_t
- Extract actual tensor strides in paged_decode_xe2.cpp
- Use actual strides in DecodeKernelLauncher::initialize() instead of packed strides
- Replace make_ordered_layout with make_layout using passed strides for K/V in kernel
- Add test_decode_with_paged_kv_mla unit test with non-contiguous KV cache slices

Signed-off-by: baodii <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
@jikunshang
Copy link
Collaborator

@baodii baodii changed the title [WIP] Support arbitrary KV cache strides in paged_decode for MLA Support arbitrary KV cache strides in paged_decode for MLA Mar 4, 2026
@jikunshang jikunshang mentioned this pull request Mar 9, 2026
8 tasks
@baodii baodii marked this pull request as ready for review March 9, 2026 06:19
Copilot AI review requested due to automatic review settings March 9, 2026 06:19
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for non-contiguous (but last-dim contiguous) paged KV cache layouts in XPU MLA decode by plumbing real tensor strides into the decode launcher/kernel and validating via a new unit test.

Changes:

  • Plumb K/V cache strides through paged_decode_args_t and extract real tensor strides at the call site.
  • Use passed-in strides in DecodeKernelLauncher::initialize() and in kernel layouts for K/V.
  • Add a unit test covering MLA-like non-contiguous KV cache slices.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/flash_attn/test_flash_attn_varlen_func.py Adds an MLA-like paged KV decode test using non-contiguous K/V views
csrc/xpu/attn/xe_2/paged_decode_xe2.cpp Populates decode args with K/V tensor strides
csrc/xpu/attn/xe_2/paged_decode.hpp Extends args for K/V strides; uses real strides in launcher initialization
csrc/xpu/attn/xe_2/kernel/paged_decode_kernel.hpp Switches K/V layouts to use runtime-provided strides
csrc/flash_attn/flash_api.cpp Removes contiguity requirement for K/V tensors (last-dim contiguous still enforced elsewhere)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +548 to +549
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trailing comma turns this into a tuple expression, so the failure message is never applied and the line has no effect beyond executing assert_close (and the second line becomes a dead expression). If you want a custom message, pass it via msg=... to torch.testing.assert_close (or remove the comma/backslash and just rely on assert_close raising with its own diff).

Suggested change
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
torch.testing.assert_close(
output,
ref_output,
atol=atol,
rtol=rtol,
msg=f"{torch.max(torch.abs(output - ref_output))}",
)

Copilot uses AI. Check for mistakes.
Comment on lines +476 to +550
torch.set_default_device("xpu")
torch.xpu.set_device("xpu:0")
torch.manual_seed(42)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
scale = head_size**-0.5

query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)

# MLA-like combined KV cache: K and V share a single buffer
combined_head_dim = head_size * combined_head_dim_factor
combined_kv_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
combined_head_dim,
dtype=dtype)

# Non-contiguous slices simulating MLA K/V extraction
key_cache = combined_kv_cache[..., :head_size]
value_cache = combined_kv_cache[..., head_size:2 * head_size]

assert not key_cache.is_contiguous(), "key_cache should be non-contiguous"
assert not value_cache.is_contiguous(), \
"value_cache should be non-contiguous"
assert key_cache.stride(-1) == 1
assert value_cache.stride(-1) == 1

cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
seq_k = torch.tensor(kv_lens, dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(query,
key_cache,
value_cache,
max_query_len,
cu_query_lens,
max_kv_len,
seqused_k=seq_k,
softmax_scale=scale,
causal=False,
block_table=block_tables,
window_size=(-1, -1))

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
casual=False,
is_paged=True,
sink=None,
window_size_left=-1,
window_size_right=-1)
atol, rtol = 1e-2, 1e-2
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
torch.xpu.empty_cache()
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.set_default_device(...) is global process state and can leak into other tests, causing order-dependent failures. Prefer allocating tensors with an explicit device="xpu" (and keeping torch.xpu.set_device(...) only if needed), or restore the prior default device at the end of the test.

Suggested change
torch.set_default_device("xpu")
torch.xpu.set_device("xpu:0")
torch.manual_seed(42)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
# MLA-like combined KV cache: K and V share a single buffer
combined_head_dim = head_size * combined_head_dim_factor
combined_kv_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
combined_head_dim,
dtype=dtype)
# Non-contiguous slices simulating MLA K/V extraction
key_cache = combined_kv_cache[..., :head_size]
value_cache = combined_kv_cache[..., head_size:2 * head_size]
assert not key_cache.is_contiguous(), "key_cache should be non-contiguous"
assert not value_cache.is_contiguous(), \
"value_cache should be non-contiguous"
assert key_cache.stride(-1) == 1
assert value_cache.stride(-1) == 1
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
seq_k = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = flash_attn_varlen_func(query,
key_cache,
value_cache,
max_query_len,
cu_query_lens,
max_kv_len,
seqused_k=seq_k,
softmax_scale=scale,
causal=False,
block_table=block_tables,
window_size=(-1, -1))
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
casual=False,
is_paged=True,
sink=None,
window_size_left=-1,
window_size_right=-1)
atol, rtol = 1e-2, 1e-2
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
torch.xpu.empty_cache()
prev_default_device = torch.get_default_device()
prev_xpu_device = torch.xpu.current_device()
try:
torch.set_default_device("xpu")
torch.xpu.set_device("xpu:0")
torch.manual_seed(42)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
# MLA-like combined KV cache: K and V share a single buffer
combined_head_dim = head_size * combined_head_dim_factor
combined_kv_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
combined_head_dim,
dtype=dtype)
# Non-contiguous slices simulating MLA K/V extraction
key_cache = combined_kv_cache[..., :head_size]
value_cache = combined_kv_cache[..., head_size:2 * head_size]
assert not key_cache.is_contiguous(), \
"key_cache should be non-contiguous"
assert not value_cache.is_contiguous(), \
"value_cache should be non-contiguous"
assert key_cache.stride(-1) == 1
assert value_cache.stride(-1) == 1
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32)
seq_k = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = flash_attn_varlen_func(query,
key_cache,
value_cache,
max_query_len,
cu_query_lens,
max_kv_len,
seqused_k=seq_k,
softmax_scale=scale,
causal=False,
block_table=block_tables,
window_size=(-1, -1))
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
casual=False,
is_paged=True,
sink=None,
window_size_left=-1,
window_size_right=-1)
atol, rtol = 1e-2, 1e-2
torch.testing.assert_close(output,
ref_output,
atol=atol,
rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
torch.xpu.empty_cache()
finally:
torch.set_default_device(prev_default_device)
torch.xpu.set_device(prev_xpu_device)

Copilot uses AI. Check for mistakes.
Comment on lines +161 to +170
stride_K = StrideK{
static_cast<int>(args.k_stride_seq),
_1{},
static_cast<int>(args.k_stride_heads),
static_cast<int>(args.k_stride_page)};
stride_V = StrideV{
_1{},
static_cast<int>(args.v_stride_seq),
static_cast<int>(args.v_stride_heads),
static_cast<int>(args.v_stride_page)};
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These int64_t stride values are narrowed to int, which can overflow for large allocations / large strides and silently produce incorrect addressing. Consider keeping stride types as 64-bit (if Cute/Cutlass stride types allow), or add a TORCH_CHECK range validation before casting and fail fast when strides exceed std::numeric_limits<int>::max().

Copilot uses AI. Check for mistakes.
Comment on lines +152 to +158
// KV cache strides
key_cache.stride(0),
key_cache.stride(1),
key_cache.stride(2),
value_cache.stride(0),
value_cache.stride(1),
value_cache.stride(2)};
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downstream stride construction hard-codes the head-dimension stride to 1 (_1{}), so this path implicitly requires key_cache.stride(-1) == 1 and value_cache.stride(-1) == 1. To prevent silent misbehavior if this code is called outside the guarded API path, add an explicit TORCH_CHECK here (or immediately before kernel launch) enforcing last-dim contiguity for both K and V.

Copilot uses AI. Check for mistakes.
Signed-off-by: baodi <di.bao@intel.com>
@baodii baodii requested a review from YizhouZ March 11, 2026 07:28
baodii and others added 5 commits March 11, 2026 17:37
Signed-off-by: baodii <di.bao@intel.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: baodi <di.bao@intel.com>
Signed-off-by: baodi <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants