[CHUNK_PREFILL] add dynamic_stride support#187
[CHUNK_PREFILL] add dynamic_stride support#187YizhouZ wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
c63c0fc to
e5b5629
Compare
There was a problem hiding this comment.
Pull request overview
Adds dynamic stride support for Xe2 chunk prefill / varlen flash-attn paths, enabling non-contiguous Q/K/V (with aligned strides) and extending tests to cover padded-stride layouts.
Changes:
- Pass actual tensor strides (Q/K/V/O) into chunk prefill kernel args and use them for layout/offset computation.
- Relax varlen FlashAttention API contiguity requirements to allow non-contiguous tensors with stride-alignment checks.
- Extend varlen paged-KV tests to exercise non-contiguous Q/K/V via padded head-size strides.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/flash_attn/test_flash_attn_varlen_func.py | Adds stride_pad parameterization and constructs non-contiguous Q/K/V to validate dynamic-stride behavior. |
| csrc/xpu/attn/xe_2/kernel/chunk_prefill_kernel.hpp | Switches var-len offset/layout computation to rely on runtime strides. |
| csrc/xpu/attn/xe_2/fmha_xe2.cpp | Extracts Q/K/V/O strides from PyTorch tensors and populates kernel args. |
| csrc/xpu/attn/xe_2/chunk_prefill.hpp | Extends args struct with stride fields and wires them into Cute/CUTLASS stride objects. |
| csrc/utils.h | Introduces a stride alignment check macro for Xe2 2D block load requirements. |
| csrc/flash_attn/flash_api.cpp | Replaces contiguity checks with stride-alignment checks for varlen inputs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| stride_Q = StrideQ{}; | ||
| get<0>(stride_Q) = args.q_stride_seq; | ||
| get<2>(stride_Q) = args.q_stride_heads; | ||
| get<3>(stride_Q) = args.q_stride_batch; | ||
|
|
||
| stride_K = StrideK{}; | ||
| get<0>(stride_K) = args.k_stride_seq; | ||
| get<2>(stride_K) = args.k_stride_heads; | ||
| get<3>(stride_K) = args.k_stride_batch; | ||
|
|
||
| stride_V = StrideV{}; | ||
| get<1>(stride_V) = args.v_stride_seq; | ||
| get<2>(stride_V) = args.v_stride_heads; | ||
| get<3>(stride_V) = args.v_stride_batch; | ||
|
|
||
| stride_O = StrideO{}; | ||
| get<0>(stride_O) = args.o_stride_seq; |
There was a problem hiding this comment.
The stride components for the head-size (unit-stride) dimension are never set. For Q/K/O in CUTLASS order (seq, head_size, heads, batch), get<1>(stride_*) should be 1 (since stride(-1) == 1 is required). For V in CUTLASS order (head_size, seq, heads, batch), get<0>(stride_V) should be 1. Leaving these at default can result in incorrect tensor indexing (e.g., collapsing the head-size dimension if the default is 0) once make_layout(shape_*, p.d*) is used.
| stride_Q = StrideQ{}; | |
| get<0>(stride_Q) = args.q_stride_seq; | |
| get<2>(stride_Q) = args.q_stride_heads; | |
| get<3>(stride_Q) = args.q_stride_batch; | |
| stride_K = StrideK{}; | |
| get<0>(stride_K) = args.k_stride_seq; | |
| get<2>(stride_K) = args.k_stride_heads; | |
| get<3>(stride_K) = args.k_stride_batch; | |
| stride_V = StrideV{}; | |
| get<1>(stride_V) = args.v_stride_seq; | |
| get<2>(stride_V) = args.v_stride_heads; | |
| get<3>(stride_V) = args.v_stride_batch; | |
| stride_O = StrideO{}; | |
| get<0>(stride_O) = args.o_stride_seq; | |
| stride_Q = StrideQ{}; | |
| // CUTLASS order for Q: (seq, head_size, heads, batch) | |
| get<0>(stride_Q) = args.q_stride_seq; | |
| get<1>(stride_Q) = 1; // unit stride along head_size | |
| get<2>(stride_Q) = args.q_stride_heads; | |
| get<3>(stride_Q) = args.q_stride_batch; | |
| stride_K = StrideK{}; | |
| // CUTLASS order for K: (seq, head_size, heads, batch) | |
| get<0>(stride_K) = args.k_stride_seq; | |
| get<1>(stride_K) = 1; // unit stride along head_size | |
| get<2>(stride_K) = args.k_stride_heads; | |
| get<3>(stride_K) = args.k_stride_batch; | |
| stride_V = StrideV{}; | |
| // CUTLASS order for V: (head_size, seq, heads, batch) | |
| get<0>(stride_V) = 1; // unit stride along head_size | |
| get<1>(stride_V) = args.v_stride_seq; | |
| get<2>(stride_V) = args.v_stride_heads; | |
| get<3>(stride_V) = args.v_stride_batch; | |
| stride_O = StrideO{}; | |
| // CUTLASS order for O: (seq, head_size, heads, batch) | |
| get<0>(stride_O) = args.o_stride_seq; | |
| get<1>(stride_O) = 1; // unit stride along head_size |
| for (int _d = 0; _d < (x).dim() - 1; ++_d) { \ | ||
| TORCH_CHECK( \ | ||
| (x).stride(_d) * (x).element_size() % 64 == 0, \ | ||
| #x " stride(", \ | ||
| _d, \ | ||
| ")=", \ | ||
| (x).stride(_d), \ | ||
| " is not 64-byte aligned (element_size=", \ |
There was a problem hiding this comment.
This macro expands to a raw for loop and is not statement-safe (e.g., if (cond) CHECK_STRIDE_ALIGNMENT(x); else ... will not compile/bind as intended). Wrap the body in do { ... } while (0) to make it safe in all contexts.
| for (int _d = 0; _d < (x).dim() - 1; ++_d) { \ | |
| TORCH_CHECK( \ | |
| (x).stride(_d) * (x).element_size() % 64 == 0, \ | |
| #x " stride(", \ | |
| _d, \ | |
| ")=", \ | |
| (x).stride(_d), \ | |
| " is not 64-byte aligned (element_size=", \ | |
| do { \ | |
| for (int _d = 0; _d < (x).dim() - 1; ++_d) { \ | |
| TORCH_CHECK((x).stride(_d) * (x).element_size() % 64 == 0, \ | |
| #x " stride(", _d, ")=", (x).stride(_d), \ | |
| " is not 64-byte aligned (element_size=", \ | |
| (x).element_size(), \ | |
| "). Xe2 2D block loads require 64-byte aligned base " \ | |
| "pointers."); \ | |
| } \ | |
| } while (0) |
| // Q/O strides in CUTLASS order: (seq, head_size=1, heads, batch) | ||
| int q_stride_seq = 0; | ||
| int q_stride_heads = 0; | ||
| int q_stride_batch = 0; | ||
| // K strides in CUTLASS order: (seq, head_size=1, heads, batch) | ||
| int k_stride_seq = 0; | ||
| int k_stride_heads = 0; | ||
| int k_stride_batch = 0; | ||
| // V strides in CUTLASS order: (head_size=1, seq, heads, batch) | ||
| int v_stride_seq = 0; | ||
| int v_stride_heads = 0; | ||
| int v_stride_batch = 0; | ||
| // O strides in CUTLASS order: (seq, head_size=1, heads, batch) | ||
| int o_stride_seq = 0; | ||
| int o_stride_heads = 0; | ||
| int o_stride_batch = 0; |
There was a problem hiding this comment.
PyTorch tensor strides are int64 (and can exceed 32-bit for large shapes). Storing them in int risks truncation/overflow, which can lead to incorrect addressing in the kernel when computing offsets. Consider using int64_t (or the stride element type expected by Cute/CUTLASS) for these fields and the downstream offset computations that multiply stride by cumulative lengths.
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
No description provided.