Skip to content

[CHUNK_PREFILL] add dynamic_stride support#187

Open
YizhouZ wants to merge 2 commits intovllm-project:mainfrom
YizhouZ:yizhou/dynamic_stride
Open

[CHUNK_PREFILL] add dynamic_stride support#187
YizhouZ wants to merge 2 commits intovllm-project:mainfrom
YizhouZ:yizhou/dynamic_stride

Conversation

@YizhouZ
Copy link
Collaborator

@YizhouZ YizhouZ commented Mar 10, 2026

No description provided.

Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Copilot AI review requested due to automatic review settings March 10, 2026 08:43
@YizhouZ YizhouZ force-pushed the yizhou/dynamic_stride branch from c63c0fc to e5b5629 Compare March 10, 2026 08:49
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds dynamic stride support for Xe2 chunk prefill / varlen flash-attn paths, enabling non-contiguous Q/K/V (with aligned strides) and extending tests to cover padded-stride layouts.

Changes:

  • Pass actual tensor strides (Q/K/V/O) into chunk prefill kernel args and use them for layout/offset computation.
  • Relax varlen FlashAttention API contiguity requirements to allow non-contiguous tensors with stride-alignment checks.
  • Extend varlen paged-KV tests to exercise non-contiguous Q/K/V via padded head-size strides.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/flash_attn/test_flash_attn_varlen_func.py Adds stride_pad parameterization and constructs non-contiguous Q/K/V to validate dynamic-stride behavior.
csrc/xpu/attn/xe_2/kernel/chunk_prefill_kernel.hpp Switches var-len offset/layout computation to rely on runtime strides.
csrc/xpu/attn/xe_2/fmha_xe2.cpp Extracts Q/K/V/O strides from PyTorch tensors and populates kernel args.
csrc/xpu/attn/xe_2/chunk_prefill.hpp Extends args struct with stride fields and wires them into Cute/CUTLASS stride objects.
csrc/utils.h Introduces a stride alignment check macro for Xe2 2D block load requirements.
csrc/flash_attn/flash_api.cpp Replaces contiguity checks with stride-alignment checks for varlen inputs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +130 to +146
stride_Q = StrideQ{};
get<0>(stride_Q) = args.q_stride_seq;
get<2>(stride_Q) = args.q_stride_heads;
get<3>(stride_Q) = args.q_stride_batch;

stride_K = StrideK{};
get<0>(stride_K) = args.k_stride_seq;
get<2>(stride_K) = args.k_stride_heads;
get<3>(stride_K) = args.k_stride_batch;

stride_V = StrideV{};
get<1>(stride_V) = args.v_stride_seq;
get<2>(stride_V) = args.v_stride_heads;
get<3>(stride_V) = args.v_stride_batch;

stride_O = StrideO{};
get<0>(stride_O) = args.o_stride_seq;
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stride components for the head-size (unit-stride) dimension are never set. For Q/K/O in CUTLASS order (seq, head_size, heads, batch), get<1>(stride_*) should be 1 (since stride(-1) == 1 is required). For V in CUTLASS order (head_size, seq, heads, batch), get<0>(stride_V) should be 1. Leaving these at default can result in incorrect tensor indexing (e.g., collapsing the head-size dimension if the default is 0) once make_layout(shape_*, p.d*) is used.

Suggested change
stride_Q = StrideQ{};
get<0>(stride_Q) = args.q_stride_seq;
get<2>(stride_Q) = args.q_stride_heads;
get<3>(stride_Q) = args.q_stride_batch;
stride_K = StrideK{};
get<0>(stride_K) = args.k_stride_seq;
get<2>(stride_K) = args.k_stride_heads;
get<3>(stride_K) = args.k_stride_batch;
stride_V = StrideV{};
get<1>(stride_V) = args.v_stride_seq;
get<2>(stride_V) = args.v_stride_heads;
get<3>(stride_V) = args.v_stride_batch;
stride_O = StrideO{};
get<0>(stride_O) = args.o_stride_seq;
stride_Q = StrideQ{};
// CUTLASS order for Q: (seq, head_size, heads, batch)
get<0>(stride_Q) = args.q_stride_seq;
get<1>(stride_Q) = 1; // unit stride along head_size
get<2>(stride_Q) = args.q_stride_heads;
get<3>(stride_Q) = args.q_stride_batch;
stride_K = StrideK{};
// CUTLASS order for K: (seq, head_size, heads, batch)
get<0>(stride_K) = args.k_stride_seq;
get<1>(stride_K) = 1; // unit stride along head_size
get<2>(stride_K) = args.k_stride_heads;
get<3>(stride_K) = args.k_stride_batch;
stride_V = StrideV{};
// CUTLASS order for V: (head_size, seq, heads, batch)
get<0>(stride_V) = 1; // unit stride along head_size
get<1>(stride_V) = args.v_stride_seq;
get<2>(stride_V) = args.v_stride_heads;
get<3>(stride_V) = args.v_stride_batch;
stride_O = StrideO{};
// CUTLASS order for O: (seq, head_size, heads, batch)
get<0>(stride_O) = args.o_stride_seq;
get<1>(stride_O) = 1; // unit stride along head_size

Copilot uses AI. Check for mistakes.
Comment on lines +14 to +21
for (int _d = 0; _d < (x).dim() - 1; ++_d) { \
TORCH_CHECK( \
(x).stride(_d) * (x).element_size() % 64 == 0, \
#x " stride(", \
_d, \
")=", \
(x).stride(_d), \
" is not 64-byte aligned (element_size=", \
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This macro expands to a raw for loop and is not statement-safe (e.g., if (cond) CHECK_STRIDE_ALIGNMENT(x); else ... will not compile/bind as intended). Wrap the body in do { ... } while (0) to make it safe in all contexts.

Suggested change
for (int _d = 0; _d < (x).dim() - 1; ++_d) { \
TORCH_CHECK( \
(x).stride(_d) * (x).element_size() % 64 == 0, \
#x " stride(", \
_d, \
")=", \
(x).stride(_d), \
" is not 64-byte aligned (element_size=", \
do { \
for (int _d = 0; _d < (x).dim() - 1; ++_d) { \
TORCH_CHECK((x).stride(_d) * (x).element_size() % 64 == 0, \
#x " stride(", _d, ")=", (x).stride(_d), \
" is not 64-byte aligned (element_size=", \
(x).element_size(), \
"). Xe2 2D block loads require 64-byte aligned base " \
"pointers."); \
} \
} while (0)

Copilot uses AI. Check for mistakes.
Comment on lines +56 to +71
// Q/O strides in CUTLASS order: (seq, head_size=1, heads, batch)
int q_stride_seq = 0;
int q_stride_heads = 0;
int q_stride_batch = 0;
// K strides in CUTLASS order: (seq, head_size=1, heads, batch)
int k_stride_seq = 0;
int k_stride_heads = 0;
int k_stride_batch = 0;
// V strides in CUTLASS order: (head_size=1, seq, heads, batch)
int v_stride_seq = 0;
int v_stride_heads = 0;
int v_stride_batch = 0;
// O strides in CUTLASS order: (seq, head_size=1, heads, batch)
int o_stride_seq = 0;
int o_stride_heads = 0;
int o_stride_batch = 0;
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch tensor strides are int64 (and can exceed 32-bit for large shapes). Storing them in int risks truncation/overflow, which can lead to incorrect addressing in the kernel when computing offsets. Consider using int64_t (or the stride element type expected by Cute/CUTLASS) for these fields and the downstream offset computations that multiply stride by cumulative lengths.

Copilot uses AI. Check for mistakes.
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants