-
Notifications
You must be signed in to change notification settings - Fork 9
[Enhancement] Enhance attention examples and fix bugs for several corner cases #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,18 +14,13 @@ def _mla_decode_kernel(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, dtype= | |
| kv_group_num = heads // kv_head_num | ||
| assert kv_head_num == 1, "kv_head_num must be 1" | ||
|
|
||
| @tilelang.jit( | ||
| out_idx=[6], | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | ||
| }, | ||
| compile_flags=["-O3", "-DENABLE_BF16"]) | ||
| def _mla_decode_func(block_H, block_N, num_split, num_stages, threads=128): | ||
| @tilelang.jit(out_idx=[6], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def mla_decode_func(block_H, block_N, num_split, num_stages, threads=128): | ||
|
|
||
| VALID_BLOCK_H = min(block_H, kv_group_num) | ||
|
|
||
| @T.macro | ||
| def _mla_no_split( | ||
| def mla_no_split( | ||
| Q: T.Tensor([batch, heads, dim], dtype), | ||
| Q_pe: T.Tensor([batch, heads, pe_dim], dtype), | ||
| KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), | ||
|
|
@@ -75,6 +70,8 @@ def _mla_no_split( | |
| T.copy(scores_max, scores_max_prev) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
| T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
| for i in T.Parallel(block_H): | ||
| scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) | ||
|
Comment on lines
+73
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change correctly updates |
||
| for i in T.Parallel(block_H): | ||
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
| for i, j in T.Parallel(block_H, block_N): | ||
|
|
@@ -92,7 +89,7 @@ def _mla_no_split( | |
| T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) | ||
|
|
||
| @T.macro | ||
| def _mla_split( | ||
| def mla_split( | ||
| Q: T.Tensor([batch, heads, dim], dtype), | ||
| Q_pe: T.Tensor([batch, heads, pe_dim], dtype), | ||
| KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), | ||
|
|
@@ -136,7 +133,9 @@ def _mla_split( | |
| kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N | ||
| T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) | ||
| T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) | ||
| T.clear(acc_s) | ||
| for i, j in T.Parallel(block_H, block_N): | ||
| acc_s[i, j] = T.if_then_else(kv_start + j >= seqlen_kv, | ||
| -T.infinity(acc_s.dtype), 0) | ||
|
Comment on lines
+136
to
+138
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replacing |
||
| T.gemm( | ||
| Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) | ||
| T.gemm( | ||
|
|
@@ -148,6 +147,8 @@ def _mla_split( | |
| T.copy(scores_max, scores_max_prev) | ||
| T.fill(scores_max, -T.infinity(accum_dtype)) | ||
| T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
| for i in T.Parallel(block_H): | ||
| scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) | ||
|
Comment on lines
+150
to
+151
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for i in T.Parallel(block_H): | ||
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
| for i, j in T.Parallel(block_H, block_N): | ||
|
|
@@ -215,7 +216,7 @@ def main_split( | |
| Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), | ||
| Output: T.Tensor([batch, heads, dim], dtype), | ||
| ): | ||
| _mla_split(Q, Q_pe, KV, K_pe, glse, Output_partial) | ||
| mla_split(Q, Q_pe, KV, K_pe, glse, Output_partial) | ||
| combine(glse, Output_partial, Output) | ||
|
|
||
| @T.prim_func | ||
|
|
@@ -228,14 +229,14 @@ def main_no_split( | |
| Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), | ||
| Output: T.Tensor([batch, heads, dim], dtype), | ||
| ): | ||
| _mla_no_split(Q, Q_pe, KV, K_pe, Output) | ||
| mla_no_split(Q, Q_pe, KV, K_pe, Output) | ||
|
|
||
| if num_split > 1: | ||
| return main_split | ||
| else: | ||
| return main_no_split | ||
|
|
||
| return _mla_decode_func | ||
| return mla_decode_func | ||
|
|
||
|
|
||
| @torch.library.custom_op("top::mla_decode_wrapped_kernel", mutates_args=()) | ||
|
|
@@ -378,7 +379,7 @@ def _mla_decode_ws_kernel(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, dty | |
| "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" | ||
| ], | ||
| ) | ||
| def _mla_decode_ws_func(block_H, block_N, num_split, num_stages, threads=128): | ||
| def mla_decode_ws_func(block_H, block_N, num_split, num_stages, threads=128): | ||
|
|
||
| VALID_BLOCK_H = min(block_H, kv_group_num) | ||
| @T.macro | ||
|
|
@@ -458,6 +459,8 @@ def flash_attn( | |
|
|
||
| T.copy(m_i, m_i_prev) | ||
| T.reduce_max(acc_s, m_i, dim=1, clear=False) | ||
| for h_i in T.Parallel(block_H): | ||
| m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) | ||
|
Comment on lines
+462
to
+463
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for h_i in T.Parallel(block_H): | ||
| alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) | ||
| for h_i, bi_i in T.Parallel(block_H, block_N): | ||
|
|
@@ -490,6 +493,8 @@ def flash_attn( | |
|
|
||
| T.copy(m_i, m_i_prev) | ||
| T.reduce_max(acc_s, m_i, dim=1, clear=False) | ||
| for h_i in T.Parallel(block_H): | ||
| m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) | ||
|
Comment on lines
+496
to
+497
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for h_i in T.Parallel(block_H): | ||
| alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) | ||
| for h_i, bi_i in T.Parallel(block_H, block_N): | ||
|
|
@@ -677,6 +682,8 @@ def flash_attn_split( | |
|
|
||
| T.copy(m_i, m_i_prev) | ||
| T.reduce_max(acc_s, m_i, dim=1, clear=False) | ||
| for h_i in T.Parallel(block_H): | ||
| m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) | ||
|
Comment on lines
+685
to
+686
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for h_i in T.Parallel(block_H): | ||
| alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) | ||
| for h_i, bi_i in T.Parallel(block_H, block_N): | ||
|
|
@@ -709,6 +716,8 @@ def flash_attn_split( | |
|
|
||
| T.copy(m_i, m_i_prev) | ||
| T.reduce_max(acc_s, m_i, dim=1, clear=False) | ||
| for h_i in T.Parallel(block_H): | ||
| m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) | ||
|
Comment on lines
+719
to
+720
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for h_i in T.Parallel(block_H): | ||
| alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) | ||
| for h_i, bi_i in T.Parallel(block_H, block_N): | ||
|
|
@@ -888,7 +897,7 @@ def main_no_split( | |
| else: | ||
| return main_no_split | ||
|
|
||
| return _mla_decode_ws_func | ||
| return mla_decode_ws_func | ||
|
|
||
|
|
||
| @torch.library.custom_op("top::mla_decode_ws_wrapped_kernel", mutates_args=()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,7 +61,7 @@ def _sparse_mla_kernel( | |
| "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" | ||
| ], | ||
| ) | ||
| def _sparse_mla_fwd_func(block_I, threads): | ||
| def sparse_mla_fwd_func(block_I, threads): | ||
|
|
||
| G = kv_group | ||
| heads = head_kv | ||
|
|
@@ -89,7 +89,7 @@ def _sparse_mla_fwd_func(block_I, threads): | |
|
|
||
| H_per_block = padded_H if REPLICATE_H == 1 else 64 | ||
| @T.prim_func | ||
| def _sparse_mla_fwd_main( | ||
| def sparse_mla_fwd_main( | ||
| Q: T.Tensor(q_shape, dtype), # type: ignore | ||
| KV: T.Tensor(kv_shape, dtype), # type: ignore | ||
| Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore | ||
|
|
@@ -176,8 +176,10 @@ def _sparse_mla_fwd_main( | |
| T.barrier_arrive(bar_sScale_and_sS_free) | ||
| T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) | ||
|
|
||
| T.copy(m_i, m_i_prev) | ||
| T.copy(src=m_i, m_i_prev) | ||
| T.reduce_max(acc_s, m_i, dim=1, clear=False) | ||
| for h_i in T.Parallel(H_per_block): | ||
| m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) | ||
|
Comment on lines
+181
to
+182
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for h_i in T.Parallel(H_per_block): | ||
| alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) | ||
| for h_i, bi_i in T.Parallel(H_per_block, BI): | ||
|
|
@@ -212,6 +214,8 @@ def _sparse_mla_fwd_main( | |
|
|
||
| T.copy(m_i, m_i_prev) | ||
| T.reduce_max(acc_s, m_i, dim=1, clear=False) | ||
| for h_i in T.Parallel(H_per_block): | ||
| m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) | ||
|
Comment on lines
+217
to
+218
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for h_i in T.Parallel(H_per_block): | ||
| alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) | ||
| for h_i, bi_i in T.Parallel(H_per_block, BI): | ||
|
|
@@ -322,9 +326,9 @@ def _sparse_mla_fwd_main( | |
| D + (tx - 256) % 8 * 8 + v] | ||
| T.cp_async_barrier_noinc(bar_k_1_ready[0]) | ||
|
|
||
| return _sparse_mla_fwd_main | ||
| return sparse_mla_fwd_main | ||
|
|
||
| return _sparse_mla_fwd_func | ||
| return sparse_mla_fwd_func | ||
|
|
||
|
|
||
| @torch.library.custom_op("top::sparse_mla_fwd_wrapped_kernel", mutates_args=()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,16 +114,11 @@ def _mha_bwd_kernel(batch, heads, seq_len, dim, is_causal, dtype="float16"): | |
| shape = [batch, seq_len, heads, dim] | ||
| accum_dtype = "float" | ||
|
|
||
| @tilelang.jit( | ||
| out_idx=[7, 8], | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | ||
| }, | ||
| compile_flags=["-O3", "-DENABLE_BF16"]) | ||
| def _mha_bwd_func(block_M, block_N, num_stages, threads): | ||
| @tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def mha_bwd_func(block_M, block_N, num_stages, threads): | ||
|
|
||
| @T.prim_func | ||
| def _mha_bwd_main( | ||
| def mha_bwd_main( | ||
| Q: T.Tensor(shape, dtype), # type: ignore | ||
| K: T.Tensor(shape, dtype), # type: ignore | ||
|
Comment on lines
+117
to
123
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| V: T.Tensor(shape, dtype), # type: ignore | ||
|
|
@@ -205,9 +200,9 @@ def _mha_bwd_main( | |
| T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) | ||
| T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) | ||
|
|
||
| return _mha_bwd_main | ||
| return mha_bwd_main | ||
|
|
||
| return _mha_bwd_func | ||
| return mha_bwd_func | ||
|
|
||
|
|
||
| class mha_bwd_kernel(Kernel): | ||
|
|
@@ -270,16 +265,11 @@ def _mha_bwd_wgmma_pipelined_kernel(batch, heads, seq_len, dim, is_causal, dtype | |
| shape = [batch, seq_len, heads, dim] | ||
| accum_dtype = "float" | ||
|
|
||
| @tilelang.jit( | ||
| out_idx=[7, 8], | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | ||
| }, | ||
| compile_flags=["-O3", "-DENABLE_BF16"]) | ||
| def _mha_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): | ||
| @tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def mha_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): | ||
|
|
||
| @T.prim_func | ||
| def _mha_bwd_wgmma_pipelined_main( | ||
| def mha_bwd_wgmma_pipelined_main( | ||
| Q: T.Tensor(shape, dtype), # type: ignore | ||
| K: T.Tensor(shape, dtype), # type: ignore | ||
| V: T.Tensor(shape, dtype), # type: ignore | ||
|
|
@@ -378,9 +368,9 @@ def _mha_bwd_wgmma_pipelined_main( | |
| T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) | ||
| T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) | ||
|
|
||
| return _mha_bwd_wgmma_pipelined_main | ||
| return mha_bwd_wgmma_pipelined_main | ||
|
|
||
| return _mha_bwd_wgmma_pipelined_func | ||
| return mha_bwd_wgmma_pipelined_func | ||
|
|
||
|
|
||
| class mha_bwd_wgmma_pipelined_kernel(Kernel): | ||
|
|
@@ -449,15 +439,11 @@ def _gqa_bwd_kernel(batch, heads, heads_kv, seq_len, dim, is_causal, dtype="floa | |
| groups = heads // heads_kv | ||
| accum_dtype = "float" | ||
|
|
||
| @tilelang.jit( | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | ||
| }, | ||
| compile_flags=["-O3", "-DENABLE_BF16"]) | ||
| def _gqa_bwd_func(block_M, block_N, num_stages, threads): | ||
| @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def gqa_bwd_func(block_M, block_N, num_stages, threads): | ||
|
|
||
| @T.prim_func | ||
| def _gqa_bwd_main( | ||
| def gqa_bwd_main( | ||
| Q: T.Tensor(q_shape, dtype), # type: ignore | ||
| K: T.Tensor(kv_shape, dtype), # type: ignore | ||
| V: T.Tensor(kv_shape, dtype), # type: ignore | ||
|
|
@@ -534,9 +520,9 @@ def _gqa_bwd_main( | |
| T.copy(dk, dk_shared) | ||
| T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) | ||
|
|
||
| return _gqa_bwd_main | ||
| return gqa_bwd_main | ||
|
|
||
| return _gqa_bwd_func | ||
| return gqa_bwd_func | ||
|
|
||
|
|
||
| class gqa_bwd_kernel(Kernel): | ||
|
|
@@ -610,15 +596,11 @@ def _gqa_bwd_wgmma_pipelined_kernel(batch, | |
| groups = heads // heads_kv | ||
| accum_dtype = "float" | ||
|
|
||
| @tilelang.jit( | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | ||
| }, | ||
| compile_flags=["-O3", "-DENABLE_BF16"]) | ||
| def _gqa_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): | ||
| @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def gqa_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): | ||
|
|
||
| @T.prim_func | ||
| def _gqa_bwd_wgmma_pipelined_main( | ||
| def gqa_bwd_wgmma_pipelined_main( | ||
| Q: T.Tensor(q_shape, dtype), # type: ignore | ||
| K: T.Tensor(kv_shape, dtype), # type: ignore | ||
| V: T.Tensor(kv_shape, dtype), # type: ignore | ||
|
|
@@ -711,9 +693,9 @@ def _gqa_bwd_wgmma_pipelined_main( | |
| T.copy(dk, dk_shared) | ||
| T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) | ||
|
|
||
| return _gqa_bwd_wgmma_pipelined_main | ||
| return gqa_bwd_wgmma_pipelined_main | ||
|
|
||
| return _gqa_bwd_wgmma_pipelined_func | ||
| return gqa_bwd_wgmma_pipelined_func | ||
|
|
||
|
|
||
| class gqa_bwd_wgmma_pipelined_kernel(Kernel): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of
compile_flagsis a good cleanup, aligning with the PR description to remove legacy flags. This simplifies the JIT decorator configuration.