Add block_size 16/32 support for chunk prefill and fix paged decode#171
Add block_size 16/32 support for chunk prefill and fix paged decode#171baodii wants to merge 5 commits intovllm-project:mainfrom
Conversation
Add page-size-aware chunk prefill policies with smaller K-tiles to support block_size 16 and 32 without division-by-zero in get_paged_idx() where tiles_per_page = page_size / K_tile would yield 0 when K_tile > page_size: - fmha_utils.hpp: Add 7 new chunk prefill policy structs (5 for p16 with K_tile=16, 2 for p32 with K_tile=32 for head96/128) - chunk_prefill_extern.hpp: Add paged-only extern template declarations - chunk_prefill_configure.cmake: Add paged-only CMake generation loop - fmha_xe2.cpp: Route to p16/p32 policies based on block_size when paged Fix paged decode kernel issues for small block sizes: - flash_api.cpp: Cap num_kv_splits to block_size to prevent overflow of fixed-size arrays in ReduceSplitK (max_num_kv_splits = SGPerWG * sg_size = block_size for all decode policies) - chunk_prefill_epilogue.hpp: Initialize rA_max from tA_max in DecodeFwdEpilogue::reduce_A when ReduceK==1 (block_size=16) so that max_logits contains correct values for ReduceSplitK reduction Signed-off-by: baodii <di.bao@intel.com>
Add dispatch_by_page_size() and three dispatch_by_head_size_*() helpers
to chunk_prefill_utils.hpp, and replace the tangled head-size-outer /
page-size-inner dispatch in fmha_xe2.cpp with a single call.
Structure:
dispatch_by_page_size → {not paged, p16, p32, p64+}
→ dispatch_by_head_size_{default,p32,p16} → policy_dispatch_func
This mirrors the paged decode's dispatch_by_page_size pattern and makes
it easy to add new page sizes in one place.
Signed-off-by: baodii <di.bao@intel.com>
|
seems bmg ci failed due to host OOM. please try decrease MAX_JOBS in ut.yaml, bmg build job. |
Signed-off-by: baodii <di.bao@intel.com>
There was a problem hiding this comment.
Pull request overview
Adds page-size-aware dispatch and new policies to support paged chunk-prefill and paged decode with smaller block_size (16/32), and fixes decode reduction/split handling for small tiles.
Changes:
- Add page-size-specific chunk prefill policies + dispatch routing for paged KV with
block_size16/32. - Extend paged decode policy/config/dispatch to handle
page_size16/32. - Fix small-block paged decode issues by capping
num_kv_splitsand initializing reduction state forReduceK==1.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/flash_attn/test_flash_attn_varlen_func.py | Expands test matrix to include BLOCK_SIZES 16/32. |
| csrc/xpu/attn/xe_2/paged_decode_utils.hpp | Adds runtime dispatch cases for page_size 16/32. |
| csrc/xpu/attn/xe_2/paged_decode_extern.hpp | Adds extern policy list entries for p16/p32. |
| csrc/xpu/attn/xe_2/paged_decode_configure.cmake | Generates kernels for pagesize_list including 16/32. |
| csrc/xpu/attn/xe_2/paged_decode.hpp | Adds decode policy aliases for p16/p32. |
| csrc/xpu/attn/xe_2/fmha_xe2.cpp | Routes chunk prefill to page-size-aware dispatch. |
| csrc/xpu/attn/xe_2/fmha_utils.hpp | Adds chunk prefill policy structs for p16/p32 and decode policy specializations for _16/_32. |
| csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp | Fixes ReduceK==1 reduction init for split-K decode. |
| csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp | Introduces dispatch helpers for default/p16/p32 and top-level page-size routing. |
| csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp | Adds paged-only extern instantiations for new p16/p32 policies. |
| csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake | Generates paged-only instantiations for the new policies. |
| csrc/flash_attn/flash_api.cpp | Caps num_kv_splits by block_size to avoid fixed-array overflow. |
| .github/workflows/ut.yaml | Reduces MAX_JOBS for editable install in CI. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ReduceFragARow rA_max; | ||
| // Initialize rA_max from tA_max so that max_logits is correct | ||
| // when num_kv_splits > 1 (used by ReduceSplitK). | ||
| rA_max(0) = tA_max(0); |
There was a problem hiding this comment.
ReduceFragARow rA_max; is default-initialized and only element 0 is set. If ReduceFragARow contains more than one element, the remaining values are uninitialized and can corrupt ReduceSplitK reductions. Initialize the entire fragment deterministically (e.g., fill from tA_max for all lanes/elements, or value-initialize then assign all required indices).
| ReduceFragARow rA_max; | |
| // Initialize rA_max from tA_max so that max_logits is correct | |
| // when num_kv_splits > 1 (used by ReduceSplitK). | |
| rA_max(0) = tA_max(0); | |
| ReduceFragARow rA_max{}; | |
| // Initialize rA_max from tA_max so that max_logits is correct | |
| // when num_kv_splits > 1 (used by ReduceSplitK). | |
| for (int i = 0; i < cute::size(rA_max); ++i) { | |
| rA_max(i) = tA_max(0); | |
| } |
| } | ||
|
|
||
| // Dispatch by head size for non-paged or page_size >= 64 paths. | ||
| // Paged=false is passed as a bool arg so both paged and non-paged can use it. |
There was a problem hiding this comment.
The comment says Paged=false is passed as a bool arg, but this is a template parameter (template <bool IsPaged>), not a runtime bool argument. Update the comment to avoid confusing readers (e.g., clarify that paged-ness is a compile-time template parameter for selecting instantiations).
| // Paged=false is passed as a bool arg so both paged and non-paged can use it. | |
| // Paged-ness is selected via the compile-time template parameter IsPaged so | |
| // both paged and non-paged paths can share this implementation. |
| if (args.head_size <= HEAD_SIZE_LIMIT_0) { | ||
| policy_dispatch_func<chunk_policy_head64, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_1) { | ||
| policy_dispatch_func<chunk_policy_head96, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_2) { | ||
| policy_dispatch_func<chunk_policy_head128, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_3) { | ||
| policy_dispatch_func<chunk_policy_head192, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_4) { | ||
| policy_dispatch_func<chunk_policy_head256, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else { | ||
| TORCH_CHECK(false, "Unsupported head size for fmha"); | ||
| } |
There was a problem hiding this comment.
The head-size dispatch ladder is duplicated across dispatch_by_head_size_default, dispatch_by_head_size_p16, and dispatch_by_head_size_p32, increasing the chance of future inconsistencies (e.g., updating limits/policies in one path but not others). Consider factoring the common ladder into a single helper that takes the per-head policy types as template parameters (or a small traits mapping), so only the policy selection differs while the control flow stays in one place.
|
no major comments from my side. leave @YizhouZ to approve and merge. |
Signed-off-by: baodi <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
Add page-size-aware chunk prefill policies with smaller K-tiles to support block_size 16 and 32 without division-by-zero in get_paged_idx() where tiles_per_page = page_size / K_tile would yield 0 when K_tile > page_size:
Fix paged decode kernel issues for small block sizes: