-
Notifications
You must be signed in to change notification settings - Fork 31
Add block_size 16/32 support for chunk prefill and fix paged decode #171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5ea8477
6b7d92a
f0b9f98
7b5cdf4
f7d6a91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,117 @@ void policy_dispatch_func( | |
| } | ||
| } | ||
|
|
||
| // Dispatch by head size for non-paged or page_size >= 64 paths. | ||
| // Paged=false is passed as a bool arg so both paged and non-paged can use it. | ||
| template <bool IsPaged> | ||
| void dispatch_by_head_size_default( | ||
| sycl::queue& queue, | ||
| CutlassQKType& cuQKType, | ||
| const chunk_prefill_args_t& args, | ||
| bool is_causal, | ||
| bool is_local, | ||
| bool is_sink) { | ||
| if (args.head_size <= HEAD_SIZE_LIMIT_0) { | ||
| policy_dispatch_func<chunk_policy_head64, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_1) { | ||
| policy_dispatch_func<chunk_policy_head96, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_2) { | ||
| policy_dispatch_func<chunk_policy_head128, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_3) { | ||
| policy_dispatch_func<chunk_policy_head192, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_4) { | ||
| policy_dispatch_func<chunk_policy_head256, IsPaged>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else { | ||
| TORCH_CHECK(false, "Unsupported head size for fmha"); | ||
| } | ||
|
Comment on lines
+39
to
+56
|
||
| } | ||
|
|
||
| // Dispatch by head size for paged KV with page_size=32. | ||
| // head96/128 need a p32 policy (K-tile=32); others fall back to default. | ||
| inline void dispatch_by_head_size_p32( | ||
| sycl::queue& queue, | ||
| CutlassQKType& cuQKType, | ||
| const chunk_prefill_args_t& args, | ||
| bool is_causal, | ||
| bool is_local, | ||
| bool is_sink) { | ||
| if (args.head_size <= HEAD_SIZE_LIMIT_0) { | ||
| policy_dispatch_func<chunk_policy_head64, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_1) { | ||
| policy_dispatch_func<chunk_policy_head96_p32, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_2) { | ||
| policy_dispatch_func<chunk_policy_head128_p32, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_3) { | ||
| policy_dispatch_func<chunk_policy_head192, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_4) { | ||
| policy_dispatch_func<chunk_policy_head256, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else { | ||
| TORCH_CHECK(false, "Unsupported head size for fmha"); | ||
| } | ||
| } | ||
|
|
||
| // Dispatch by head size for paged KV with page_size=16 (K-tile=16 for all). | ||
| inline void dispatch_by_head_size_p16( | ||
| sycl::queue& queue, | ||
| CutlassQKType& cuQKType, | ||
| const chunk_prefill_args_t& args, | ||
| bool is_causal, | ||
| bool is_local, | ||
| bool is_sink) { | ||
| if (args.head_size <= HEAD_SIZE_LIMIT_0) { | ||
| policy_dispatch_func<chunk_policy_head64_p16, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_1) { | ||
| policy_dispatch_func<chunk_policy_head96_p16, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_2) { | ||
| policy_dispatch_func<chunk_policy_head128_p16, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_3) { | ||
| policy_dispatch_func<chunk_policy_head192_p16, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.head_size <= HEAD_SIZE_LIMIT_4) { | ||
| policy_dispatch_func<chunk_policy_head256_p16, true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else { | ||
| TORCH_CHECK(false, "Unsupported head size for fmha"); | ||
| } | ||
| } | ||
|
|
||
| // Top-level dispatch: select head-size dispatch table by page size. | ||
| inline void dispatch_by_page_size( | ||
| sycl::queue& queue, | ||
| CutlassQKType& cuQKType, | ||
| const chunk_prefill_args_t& args, | ||
| bool is_paged, | ||
| bool is_causal, | ||
| bool is_local, | ||
| bool is_sink) { | ||
| if (!is_paged) { | ||
| dispatch_by_head_size_default<false>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.block_size < 32) { | ||
| dispatch_by_head_size_p16( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else if (args.block_size < 64) { | ||
| dispatch_by_head_size_p32( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } else { | ||
| dispatch_by_head_size_default<true>( | ||
| queue, cuQKType, args, is_causal, is_local, is_sink); | ||
| } | ||
| } | ||
|
|
||
| void cutlass_chunk_prefill_impl( | ||
| sycl::queue& queue, | ||
| const at::Tensor& query, // [seq_q, heads, head_size] | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -586,6 +586,9 @@ class DecodeFwdEpilogue { | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| if constexpr (ReduceK{} == _1{}) { | ||||||||||||||||||||||
| ReduceFragARow rA_max; | ||||||||||||||||||||||
| // Initialize rA_max from tA_max so that max_logits is correct | ||||||||||||||||||||||
| // when num_kv_splits > 1 (used by ReduceSplitK). | ||||||||||||||||||||||
| rA_max(0) = tA_max(0); | ||||||||||||||||||||||
|
Comment on lines
588
to
+591
|
||||||||||||||||||||||
| ReduceFragARow rA_max; | |
| // Initialize rA_max from tA_max so that max_logits is correct | |
| // when num_kv_splits > 1 (used by ReduceSplitK). | |
| rA_max(0) = tA_max(0); | |
| ReduceFragARow rA_max{}; | |
| // Initialize rA_max from tA_max so that max_logits is correct | |
| // when num_kv_splits > 1 (used by ReduceSplitK). | |
| for (int i = 0; i < cute::size(rA_max); ++i) { | |
| rA_max(i) = tA_max(0); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says
Paged=false is passed as a bool arg, but this is a template parameter (template <bool IsPaged>), not a runtime bool argument. Update the comment to avoid confusing readers (e.g., clarify that paged-ness is a compile-time template parameter for selecting instantiations).