diff --git a/top/kernels/deepseek_mla/mla_decode.py b/top/kernels/deepseek_mla/mla_decode.py index 9a49737..77b61e3 100644 --- a/top/kernels/deepseek_mla/mla_decode.py +++ b/top/kernels/deepseek_mla/mla_decode.py @@ -14,18 +14,13 @@ def _mla_decode_kernel(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, dtype= kv_group_num = heads // kv_head_num assert kv_head_num == 1, "kv_head_num must be 1" - @tilelang.jit( - out_idx=[6], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _mla_decode_func(block_H, block_N, num_split, num_stages, threads=128): + @tilelang.jit(out_idx=[6], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def mla_decode_func(block_H, block_N, num_split, num_stages, threads=128): VALID_BLOCK_H = min(block_H, kv_group_num) @T.macro - def _mla_no_split( + def mla_no_split( Q: T.Tensor([batch, heads, dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), @@ -75,6 +70,8 @@ def _mla_no_split( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -92,7 +89,7 @@ def _mla_no_split( T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) @T.macro - def _mla_split( + def mla_split( Q: T.Tensor([batch, heads, dim], dtype), Q_pe: T.Tensor([batch, heads, pe_dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), @@ -136,7 +133,9 @@ def _mla_split( kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared) T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared) - T.clear(acc_s) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(kv_start + j >= seqlen_kv, + -T.infinity(acc_s.dtype), 0) T.gemm( Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) T.gemm( @@ -148,6 +147,8 @@ def _mla_split( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -215,7 +216,7 @@ def main_split( Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype), ): - _mla_split(Q, Q_pe, KV, K_pe, glse, Output_partial) + mla_split(Q, Q_pe, KV, K_pe, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func @@ -228,14 +229,14 @@ def main_no_split( Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype), ): - _mla_no_split(Q, Q_pe, KV, K_pe, Output) + mla_no_split(Q, Q_pe, KV, K_pe, Output) if num_split > 1: return main_split else: return main_no_split - return _mla_decode_func + return mla_decode_func @torch.library.custom_op("top::mla_decode_wrapped_kernel", mutates_args=()) @@ -378,7 +379,7 @@ def _mla_decode_ws_kernel(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, dty "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" ], ) - def _mla_decode_ws_func(block_H, block_N, num_split, num_stages, threads=128): + def mla_decode_ws_func(block_H, block_N, num_split, num_stages, threads=128): VALID_BLOCK_H = min(block_H, kv_group_num) @T.macro @@ -458,6 +459,8 @@ def flash_attn( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -490,6 +493,8 @@ def flash_attn( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -677,6 +682,8 @@ def flash_attn_split( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -709,6 +716,8 @@ def flash_attn_split( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(block_H): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(block_H): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(block_H, block_N): @@ -888,7 +897,7 @@ def main_no_split( else: return main_no_split - return _mla_decode_ws_func + return mla_decode_ws_func @torch.library.custom_op("top::mla_decode_ws_wrapped_kernel", mutates_args=()) diff --git a/top/kernels/deepseek_mla/sparse_mla.py b/top/kernels/deepseek_mla/sparse_mla.py index 69da597..7bc9285 100644 --- a/top/kernels/deepseek_mla/sparse_mla.py +++ b/top/kernels/deepseek_mla/sparse_mla.py @@ -61,7 +61,7 @@ def _sparse_mla_kernel( "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" ], ) - def _sparse_mla_fwd_func(block_I, threads): + def sparse_mla_fwd_func(block_I, threads): G = kv_group heads = head_kv @@ -89,7 +89,7 @@ def _sparse_mla_fwd_func(block_I, threads): H_per_block = padded_H if REPLICATE_H == 1 else 64 @T.prim_func - def _sparse_mla_fwd_main( + def sparse_mla_fwd_main( Q: T.Tensor(q_shape, dtype), # type: ignore KV: T.Tensor(kv_shape, dtype), # type: ignore Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore @@ -176,8 +176,10 @@ def _sparse_mla_fwd_main( T.barrier_arrive(bar_sScale_and_sS_free) T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) - T.copy(m_i, m_i_prev) + T.copy(src=m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -212,6 +214,8 @@ def _sparse_mla_fwd_main( T.copy(m_i, m_i_prev) T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) for h_i in T.Parallel(H_per_block): alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): @@ -322,9 +326,9 @@ def _sparse_mla_fwd_main( D + (tx - 256) % 8 * 8 + v] T.cp_async_barrier_noinc(bar_k_1_ready[0]) - return _sparse_mla_fwd_main + return sparse_mla_fwd_main - return _sparse_mla_fwd_func + return sparse_mla_fwd_func @torch.library.custom_op("top::sparse_mla_fwd_wrapped_kernel", mutates_args=()) diff --git a/top/kernels/flash_attn/bwd.py b/top/kernels/flash_attn/bwd.py index ea35517..3a32061 100644 --- a/top/kernels/flash_attn/bwd.py +++ b/top/kernels/flash_attn/bwd.py @@ -114,16 +114,11 @@ def _mha_bwd_kernel(batch, heads, seq_len, dim, is_causal, dtype="float16"): shape = [batch, seq_len, heads, dim] accum_dtype = "float" - @tilelang.jit( - out_idx=[7, 8], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _mha_bwd_func(block_M, block_N, num_stages, threads): + @tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def mha_bwd_func(block_M, block_N, num_stages, threads): @T.prim_func - def _mha_bwd_main( + def mha_bwd_main( Q: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore @@ -205,9 +200,9 @@ def _mha_bwd_main( T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) - return _mha_bwd_main + return mha_bwd_main - return _mha_bwd_func + return mha_bwd_func class mha_bwd_kernel(Kernel): @@ -270,16 +265,11 @@ def _mha_bwd_wgmma_pipelined_kernel(batch, heads, seq_len, dim, is_causal, dtype shape = [batch, seq_len, heads, dim] accum_dtype = "float" - @tilelang.jit( - out_idx=[7, 8], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _mha_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): + @tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def mha_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): @T.prim_func - def _mha_bwd_wgmma_pipelined_main( + def mha_bwd_wgmma_pipelined_main( Q: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore @@ -378,9 +368,9 @@ def _mha_bwd_wgmma_pipelined_main( T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) - return _mha_bwd_wgmma_pipelined_main + return mha_bwd_wgmma_pipelined_main - return _mha_bwd_wgmma_pipelined_func + return mha_bwd_wgmma_pipelined_func class mha_bwd_wgmma_pipelined_kernel(Kernel): @@ -449,15 +439,11 @@ def _gqa_bwd_kernel(batch, heads, heads_kv, seq_len, dim, is_causal, dtype="floa groups = heads // heads_kv accum_dtype = "float" - @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _gqa_bwd_func(block_M, block_N, num_stages, threads): + @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def gqa_bwd_func(block_M, block_N, num_stages, threads): @T.prim_func - def _gqa_bwd_main( + def gqa_bwd_main( Q: T.Tensor(q_shape, dtype), # type: ignore K: T.Tensor(kv_shape, dtype), # type: ignore V: T.Tensor(kv_shape, dtype), # type: ignore @@ -534,9 +520,9 @@ def _gqa_bwd_main( T.copy(dk, dk_shared) T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) - return _gqa_bwd_main + return gqa_bwd_main - return _gqa_bwd_func + return gqa_bwd_func class gqa_bwd_kernel(Kernel): @@ -610,15 +596,11 @@ def _gqa_bwd_wgmma_pipelined_kernel(batch, groups = heads // heads_kv accum_dtype = "float" - @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _gqa_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): + @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def gqa_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): @T.prim_func - def _gqa_bwd_wgmma_pipelined_main( + def gqa_bwd_wgmma_pipelined_main( Q: T.Tensor(q_shape, dtype), # type: ignore K: T.Tensor(kv_shape, dtype), # type: ignore V: T.Tensor(kv_shape, dtype), # type: ignore @@ -711,9 +693,9 @@ def _gqa_bwd_wgmma_pipelined_main( T.copy(dk, dk_shared) T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared) - return _gqa_bwd_wgmma_pipelined_main + return gqa_bwd_wgmma_pipelined_main - return _gqa_bwd_wgmma_pipelined_func + return gqa_bwd_wgmma_pipelined_func class gqa_bwd_wgmma_pipelined_kernel(Kernel): diff --git a/top/kernels/flash_attn/fwd.py b/top/kernels/flash_attn/fwd.py index e816261..3fc99b1 100644 --- a/top/kernels/flash_attn/fwd.py +++ b/top/kernels/flash_attn/fwd.py @@ -18,16 +18,11 @@ def _mha_fwd_kernel(batch, heads, seq_len, dim, is_causal, dtype='float16'): shape = [batch, seq_len, heads, dim] accum_dtype = "float" - @tilelang.jit( - out_idx=[3, 4], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _mha_fwd_func(block_M, block_N, num_stages, threads): + @tilelang.jit(out_idx=[3, 4], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def mha_fwd_func(block_M, block_N, num_stages, threads): @T.prim_func - def _mha_fwd_main( + def mha_fwd_main( Q: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore @@ -64,7 +59,10 @@ def _mha_fwd_main( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for padded positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm( Q_shared, K_shared, @@ -74,6 +72,8 @@ def _mha_fwd_main( T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -92,9 +92,9 @@ def _mha_fwd_main( logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) - return _mha_fwd_main + return mha_fwd_main - return _mha_fwd_func + return mha_fwd_func @torch.library.custom_op("top::mha_fwd_wrapped_kernel", mutates_args=()) @@ -186,13 +186,8 @@ def _mha_fwd_wgmma_pipelined_kernel(batch, heads, seq_len, dim, is_causal, dtype shape = [batch, seq_len, heads, dim] accum_dtype = "float" - @tilelang.jit( - out_idx=[3, 4], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _mha_fwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): + @tilelang.jit(out_idx=[3, 4], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def mha_fwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): @T.macro def MMA0( @@ -211,7 +206,10 @@ def MMA0( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for padded positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -240,6 +238,8 @@ def Softmax( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -266,7 +266,7 @@ def Rescale( acc_o[i, j] *= scores_scale[i] @T.prim_func - def _mha_fwd_wgmma_pipelined_main( + def mha_fwd_wgmma_pipelined_main( Q: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore @@ -303,7 +303,7 @@ def _mha_fwd_wgmma_pipelined_main( num_stages=num_stages, order=[-1, 0, 3, 1, -1, 2], stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) @@ -317,9 +317,9 @@ def _mha_fwd_wgmma_pipelined_main( logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) - return _mha_fwd_wgmma_pipelined_main + return mha_fwd_wgmma_pipelined_main - return _mha_fwd_wgmma_pipelined_func + return mha_fwd_wgmma_pipelined_func @torch.library.custom_op("top::mha_fwd_wgmma_pipelined_wrapped_kernel", mutates_args=()) @@ -402,16 +402,11 @@ def _gqa_fwd_kernel(batch, heads, heads_kv, seq_len, dim, is_causal, dtype='floa kv_shape = [batch, seq_len, heads_kv, dim] accum_dtype = "float" - @tilelang.jit( - out_idx=[3, 4], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _gqa_fwd_func(block_M, block_N, num_stages, threads): + @tilelang.jit(out_idx=[3, 4], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def gqa_fwd_func(block_M, block_N, num_stages, threads): @T.prim_func - def _gqa_fwd_main( + def gqa_fwd_main( Q: T.Tensor(q_shape, dtype), # type: ignore K: T.Tensor(kv_shape, dtype), # type: ignore V: T.Tensor(kv_shape, dtype), # type: ignore @@ -448,7 +443,10 @@ def _gqa_fwd_main( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for padded positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm( Q_shared, K_shared, @@ -458,6 +456,8 @@ def _gqa_fwd_main( T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -476,9 +476,9 @@ def _gqa_fwd_main( logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) - return _gqa_fwd_main + return gqa_fwd_main - return _gqa_fwd_func + return gqa_fwd_func @torch.library.custom_op("top::gqa_fwd_wrapped_kernel", mutates_args=()) @@ -583,13 +583,8 @@ def _gqa_fwd_wgmma_pipelined_kernel(batch, kv_shape = [batch, seq_len, heads_kv, dim] accum_dtype = "float" - @tilelang.jit( - out_idx=[3, 4], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _gqa_fwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): + @tilelang.jit(out_idx=[3, 4], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def gqa_fwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads): @T.macro def MMA0( @@ -608,7 +603,10 @@ def MMA0( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for padded positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -637,6 +635,8 @@ def Softmax( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -663,7 +663,7 @@ def Rescale( acc_o[i, j] *= scores_scale[i] @T.prim_func - def _gqa_fwd_wgmma_pipelined_main( + def gqa_fwd_wgmma_pipelined_main( Q: T.Tensor(q_shape, dtype), # type: ignore K: T.Tensor(kv_shape, dtype), # type: ignore V: T.Tensor(kv_shape, dtype), # type: ignore @@ -700,7 +700,7 @@ def _gqa_fwd_wgmma_pipelined_main( num_stages=num_stages, order=[-1, 0, 3, 1, -1, 2], stage=[-1, 0, 0, 1, -1, 1], - group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum) @@ -714,9 +714,9 @@ def _gqa_fwd_wgmma_pipelined_main( logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) - return _gqa_fwd_wgmma_pipelined_main + return gqa_fwd_wgmma_pipelined_main - return _gqa_fwd_wgmma_pipelined_func + return gqa_fwd_wgmma_pipelined_func @torch.library.custom_op("top::gqa_fwd_wgmma_pipelined_wrapped_kernel", mutates_args=()) diff --git a/top/kernels/flash_decode/gqa_decode.py b/top/kernels/flash_decode/gqa_decode.py index 126b29a..c47ad96 100644 --- a/top/kernels/flash_decode/gqa_decode.py +++ b/top/kernels/flash_decode/gqa_decode.py @@ -15,13 +15,8 @@ def _gqa_decode_kernel(batch, heads, groups, seqlen_kv, dim): dtype = "float16" accum_dtype = "float" - @tilelang.jit( - out_idx=[6], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _gqa_decode_func(block_H, block_N, num_split, num_stages, threads): + @tilelang.jit(out_idx=[6], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def gqa_decode_func(block_H, block_N, num_split, num_stages, threads): shape_q = [batch, heads, dim] shape_k = [batch, seqlen_kv, groups, dim] @@ -34,7 +29,7 @@ def _gqa_decode_func(block_H, block_N, num_split, num_stages, threads): valid_block_N = min(block_N, seqlen_kv // num_split) @T.macro - def _gqa_decode_no_split( + def gqa_decode_no_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), @@ -77,6 +72,8 @@ def _gqa_decode_no_split( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): @@ -97,7 +94,7 @@ def _gqa_decode_no_split( T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) @T.macro - def _gqa_decode_split( + def gqa_decode_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), @@ -150,6 +147,8 @@ def _gqa_decode_split( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, valid_block_N): @@ -184,7 +183,6 @@ def combine( Output: T.Tensor(shape_o, dtype), ): with T.Kernel(heads, batch, threads=128) as (by, bz): - # 1) 读 glse 到一个 1D fragment,再做 reduce_max glse_vec = T.alloc_fragment([num_split], dtype) for k in T.Parallel(num_split): glse_vec[k] = glse[bz, by, k] @@ -192,14 +190,12 @@ def combine( T.fill(lse_max, -T.infinity(accum_dtype)) T.reduce_max(glse_vec, lse_max, dim=0, clear=False) - # 2) 计算 logsum(串行或小批并行累加) lse_logsum = T.alloc_local([1], accum_dtype) lse_logsum[0] = 0 for k in T.serial(num_split): lse_logsum[0] += T.exp2(glse[bz, by, k] - lse_max[0]) lse_logsum[0] = T.log2(lse_logsum[0]) + lse_max[0] - # 3) 按权重合并 partial 输出 o_accum = T.alloc_fragment([dim], accum_dtype) T.clear(o_accum) for k in T.serial(num_split): @@ -219,7 +215,7 @@ def gqa_decode_split( Output_partial: T.Tensor(part_shape, dtype), Output: T.Tensor(shape_o, dtype), ): - _gqa_decode_split(Q, K, V, mask, glse, Output_partial) + gqa_decode_split(Q, K, V, mask, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func @@ -232,14 +228,14 @@ def gqa_decode_no_split( Output_partial: T.Tensor(part_shape, dtype), Output: T.Tensor(shape_o, dtype), ): - _gqa_decode_no_split(Q, K, V, mask, Output) + gqa_decode_no_split(Q, K, V, mask, Output) if num_split > 1: return gqa_decode_split else: return gqa_decode_no_split - return _gqa_decode_func + return gqa_decode_func @torch.library.custom_op("top::gqa_decode_wrapped_kernel", mutates_args=()) diff --git a/top/kernels/flash_decode/mha_decode.py b/top/kernels/flash_decode/mha_decode.py index 3dc59a0..5ea8cf4 100644 --- a/top/kernels/flash_decode/mha_decode.py +++ b/top/kernels/flash_decode/mha_decode.py @@ -15,20 +15,15 @@ def _mha_decode_kernel(batch, heads, seqlen_q, seqlen_kv, dim, is_causal): dtype = "float16" accum_dtype = "float" - @tilelang.jit( - out_idx=[5], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - compile_flags=["-O3", "-DENABLE_BF16"]) - def _mha_decode_func(block_M, block_N, num_split, num_stages, threads): + @tilelang.jit(out_idx=[5], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) + def mha_decode_func(block_M, block_N, num_split, num_stages, threads): shape_q = [batch, seqlen_q, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim] part_shape = [batch, seqlen_q, heads, num_split, dim] @T.macro - def _mha_decode_no_split( + def mha_decode_no_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype), @@ -64,7 +59,9 @@ def _mha_decode_no_split( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seqlen_kv, + -T.infinity(acc_s.dtype), 0) T.gemm( Q_shared, K_shared, @@ -74,6 +71,8 @@ def _mha_decode_no_split( T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) for i in T.Parallel(block_M): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_M, dim): @@ -112,7 +111,9 @@ def MMA0( acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seqlen_kv, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -144,6 +145,8 @@ def Softmax( T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) # To do causal softmax, we need to set the scores_max to 0 if it is -inf # This process is called Check_inf in FlashAttention3 code, and it only need to be done # in the first ceil_div(kBlockM, kBlockN) steps. @@ -170,7 +173,7 @@ def Rescale( acc_o[i, j] *= scores_scale[i] @T.macro - def _mha_decode_split( + def mha_decode_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype), @@ -290,7 +293,7 @@ def mha_decode_split( Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] Output: T.Tensor(shape_q, dtype), ): - _mha_decode_split(Q, K, V, glse, Output_partial) + mha_decode_split(Q, K, V, glse, Output_partial) combine(glse, Output_partial, Output) @T.prim_func @@ -302,14 +305,14 @@ def mha_decode_no_split( Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] Output: T.Tensor(shape_q, dtype), ): - _mha_decode_no_split(Q, K, V, Output) + mha_decode_no_split(Q, K, V, Output) if num_split > 1: return mha_decode_split else: return mha_decode_no_split - return _mha_decode_func + return mha_decode_func @torch.library.custom_op("top::mha_decode_wrapped_kernel", mutates_args=()) diff --git a/top/kernels/gemm/gemm.py b/top/kernels/gemm/gemm.py index 50f9732..0b62d1a 100644 --- a/top/kernels/gemm/gemm.py +++ b/top/kernels/gemm/gemm.py @@ -10,11 +10,11 @@ def _gemm_kernel(M, N, K, dtype='float16'): accum_dtype = "float" - @tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"]) - def _gemm_func(block_M, block_N, block_K, threads, num_stages, enable_rasteration): + @tilelang.jit(out_idx=[-1]) + def gemm_func(block_M, block_N, block_K, threads, num_stages, enable_rasteration): @T.prim_func - def _gemm_main( + def gemm_main( A: T.Tensor((M, K), dtype), # type: ignore B: T.Tensor((K, N), dtype), # type: ignore C: T.Tensor((M, N), dtype), # type: ignore @@ -41,9 +41,9 @@ def _gemm_main( T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) - return _gemm_main + return gemm_main - return _gemm_func + return gemm_func class gemm_kernel(Kernel):