Skip to content

Add block_size 16/32 support for chunk prefill and fix paged decode#171

Open
baodii wants to merge 5 commits intovllm-project:mainfrom
baodii:block-size-16-32-support
Open

Add block_size 16/32 support for chunk prefill and fix paged decode#171
baodii wants to merge 5 commits intovllm-project:mainfrom
baodii:block-size-16-32-support

Conversation

@baodii
Copy link
Collaborator

@baodii baodii commented Mar 4, 2026

Add page-size-aware chunk prefill policies with smaller K-tiles to support block_size 16 and 32 without division-by-zero in get_paged_idx() where tiles_per_page = page_size / K_tile would yield 0 when K_tile > page_size:

  • fmha_utils.hpp: Add 7 new chunk prefill policy structs (5 for p16 with K_tile=16, 2 for p32 with K_tile=32 for head96/128)
  • chunk_prefill_extern.hpp: Add paged-only extern template declarations
  • chunk_prefill_configure.cmake: Add paged-only CMake generation loop
  • fmha_xe2.cpp: Route to p16/p32 policies based on block_size when paged

Fix paged decode kernel issues for small block sizes:

  • flash_api.cpp: Cap num_kv_splits to block_size to prevent overflow of fixed-size arrays in ReduceSplitK (max_num_kv_splits = SGPerWG * sg_size = block_size for all decode policies)
  • chunk_prefill_epilogue.hpp: Initialize rA_max from tA_max in DecodeFwdEpilogue::reduce_A when ReduceK==1 (block_size=16) so that max_logits contains correct values for ReduceSplitK reduction

baodii added 2 commits March 4, 2026 00:00
Add page-size-aware chunk prefill policies with smaller K-tiles to
support block_size 16 and 32 without division-by-zero in get_paged_idx()
where tiles_per_page = page_size / K_tile would yield 0 when K_tile >
page_size:

- fmha_utils.hpp: Add 7 new chunk prefill policy structs (5 for p16
  with K_tile=16, 2 for p32 with K_tile=32 for head96/128)
- chunk_prefill_extern.hpp: Add paged-only extern template declarations
- chunk_prefill_configure.cmake: Add paged-only CMake generation loop
- fmha_xe2.cpp: Route to p16/p32 policies based on block_size when paged

Fix paged decode kernel issues for small block sizes:

- flash_api.cpp: Cap num_kv_splits to block_size to prevent overflow of
  fixed-size arrays in ReduceSplitK (max_num_kv_splits = SGPerWG *
  sg_size = block_size for all decode policies)
- chunk_prefill_epilogue.hpp: Initialize rA_max from tA_max in
  DecodeFwdEpilogue::reduce_A when ReduceK==1 (block_size=16) so that
  max_logits contains correct values for ReduceSplitK reduction

Signed-off-by: baodii <di.bao@intel.com>
Add dispatch_by_page_size() and three dispatch_by_head_size_*() helpers
to chunk_prefill_utils.hpp, and replace the tangled head-size-outer /
page-size-inner dispatch in fmha_xe2.cpp with a single call.

Structure:
  dispatch_by_page_size → {not paged, p16, p32, p64+}
    → dispatch_by_head_size_{default,p32,p16} → policy_dispatch_func

This mirrors the paged decode's dispatch_by_page_size pattern and makes
it easy to add new page sizes in one place.

Signed-off-by: baodii <di.bao@intel.com>
@jikunshang
Copy link
Collaborator

seems bmg ci failed due to host OOM. please try decrease MAX_JOBS in ut.yaml, bmg build job.

Signed-off-by: baodii <di.bao@intel.com>
@baodii baodii changed the title [WIP] Add block_size 16/32 support for chunk prefill and fix paged decode Add block_size 16/32 support for chunk prefill and fix paged decode Mar 5, 2026
@baodii baodii marked this pull request as ready for review March 5, 2026 06:24
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds page-size-aware dispatch and new policies to support paged chunk-prefill and paged decode with smaller block_size (16/32), and fixes decode reduction/split handling for small tiles.

Changes:

  • Add page-size-specific chunk prefill policies + dispatch routing for paged KV with block_size 16/32.
  • Extend paged decode policy/config/dispatch to handle page_size 16/32.
  • Fix small-block paged decode issues by capping num_kv_splits and initializing reduction state for ReduceK==1.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/flash_attn/test_flash_attn_varlen_func.py Expands test matrix to include BLOCK_SIZES 16/32.
csrc/xpu/attn/xe_2/paged_decode_utils.hpp Adds runtime dispatch cases for page_size 16/32.
csrc/xpu/attn/xe_2/paged_decode_extern.hpp Adds extern policy list entries for p16/p32.
csrc/xpu/attn/xe_2/paged_decode_configure.cmake Generates kernels for pagesize_list including 16/32.
csrc/xpu/attn/xe_2/paged_decode.hpp Adds decode policy aliases for p16/p32.
csrc/xpu/attn/xe_2/fmha_xe2.cpp Routes chunk prefill to page-size-aware dispatch.
csrc/xpu/attn/xe_2/fmha_utils.hpp Adds chunk prefill policy structs for p16/p32 and decode policy specializations for _16/_32.
csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp Fixes ReduceK==1 reduction init for split-K decode.
csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp Introduces dispatch helpers for default/p16/p32 and top-level page-size routing.
csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp Adds paged-only extern instantiations for new p16/p32 policies.
csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake Generates paged-only instantiations for the new policies.
csrc/flash_attn/flash_api.cpp Caps num_kv_splits by block_size to avoid fixed-array overflow.
.github/workflows/ut.yaml Reduces MAX_JOBS for editable install in CI.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 588 to +591
ReduceFragARow rA_max;
// Initialize rA_max from tA_max so that max_logits is correct
// when num_kv_splits > 1 (used by ReduceSplitK).
rA_max(0) = tA_max(0);
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReduceFragARow rA_max; is default-initialized and only element 0 is set. If ReduceFragARow contains more than one element, the remaining values are uninitialized and can corrupt ReduceSplitK reductions. Initialize the entire fragment deterministically (e.g., fill from tA_max for all lanes/elements, or value-initialize then assign all required indices).

Suggested change
ReduceFragARow rA_max;
// Initialize rA_max from tA_max so that max_logits is correct
// when num_kv_splits > 1 (used by ReduceSplitK).
rA_max(0) = tA_max(0);
ReduceFragARow rA_max{};
// Initialize rA_max from tA_max so that max_logits is correct
// when num_kv_splits > 1 (used by ReduceSplitK).
for (int i = 0; i < cute::size(rA_max); ++i) {
rA_max(i) = tA_max(0);
}

Copilot uses AI. Check for mistakes.
}

// Dispatch by head size for non-paged or page_size >= 64 paths.
// Paged=false is passed as a bool arg so both paged and non-paged can use it.
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says Paged=false is passed as a bool arg, but this is a template parameter (template <bool IsPaged>), not a runtime bool argument. Update the comment to avoid confusing readers (e.g., clarify that paged-ness is a compile-time template parameter for selecting instantiations).

Suggested change
// Paged=false is passed as a bool arg so both paged and non-paged can use it.
// Paged-ness is selected via the compile-time template parameter IsPaged so
// both paged and non-paged paths can share this implementation.

Copilot uses AI. Check for mistakes.
Comment on lines +39 to +56
if (args.head_size <= HEAD_SIZE_LIMIT_0) {
policy_dispatch_func<chunk_policy_head64, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_1) {
policy_dispatch_func<chunk_policy_head96, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_2) {
policy_dispatch_func<chunk_policy_head128, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_3) {
policy_dispatch_func<chunk_policy_head192, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_4) {
policy_dispatch_func<chunk_policy_head256, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else {
TORCH_CHECK(false, "Unsupported head size for fmha");
}
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head-size dispatch ladder is duplicated across dispatch_by_head_size_default, dispatch_by_head_size_p16, and dispatch_by_head_size_p32, increasing the chance of future inconsistencies (e.g., updating limits/policies in one path but not others). Consider factoring the common ladder into a single helper that takes the per-head policy types as template parameters (or a small traits mapping), so only the policy selection differs while the control flow stays in one place.

Copilot uses AI. Check for mistakes.
@jikunshang jikunshang mentioned this pull request Mar 5, 2026
3 tasks
@jikunshang
Copy link
Collaborator

no major comments from my side. leave @YizhouZ to approve and merge.

baodii added 2 commits March 11, 2026 16:43
Signed-off-by: baodi <di.bao@intel.com>
Signed-off-by: baodii <di.bao@intel.com>
@jikunshang jikunshang mentioned this pull request Mar 12, 2026
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants