Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions top/kernels/deepseek_mla/mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@ def _mla_decode_kernel(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, dtype=
kv_group_num = heads // kv_head_num
assert kv_head_num == 1, "kv_head_num must be 1"

@tilelang.jit(
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
def _mla_decode_func(block_H, block_N, num_split, num_stages, threads=128):
@tilelang.jit(out_idx=[6], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def mla_decode_func(block_H, block_N, num_split, num_stages, threads=128):

VALID_BLOCK_H = min(block_H, kv_group_num)

@T.macro
Comment on lines +17 to 22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of compile_flags is a good cleanup, aligning with the PR description to remove legacy flags. This simplifies the JIT decorator configuration.

def _mla_no_split(
def mla_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
Expand Down Expand Up @@ -75,6 +70,8 @@ def _mla_no_split(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
Comment on lines +73 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change correctly updates scores_max by taking the maximum of the current and previous values. This is crucial for maintaining numerical stability in the softmax calculation, especially in iterative or pipelined attention computations. This directly addresses the PR description point "Keep rowmax to avoid numerical issues".

for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand All @@ -92,7 +89,7 @@ def _mla_no_split(
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])

@T.macro
def _mla_split(
def mla_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
Expand Down Expand Up @@ -136,7 +133,9 @@ def _mla_split(
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(kv_start + j >= seqlen_kv,
-T.infinity(acc_s.dtype), 0)
Comment on lines +136 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Replacing T.clear(acc_s) with a conditional fill ensures that padded positions are correctly initialized with -T.infinity(acc_s.dtype). This is a critical fix for numerical correctness, as T.clear would default to zero, potentially leading to incorrect attention weights for padded elements. This directly addresses the PR description point "Correctly fill acc_s for padded positions".

T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Expand All @@ -148,6 +147,8 @@ def _mla_split(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
Comment on lines +150 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the mla_no_split macro, this update ensures scores_max is correctly accumulated across iterations for numerical stability in the split attention mechanism.

for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -215,7 +216,7 @@ def main_split(
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
_mla_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
mla_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)

@T.prim_func
Expand All @@ -228,14 +229,14 @@ def main_no_split(
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
_mla_no_split(Q, Q_pe, KV, K_pe, Output)
mla_no_split(Q, Q_pe, KV, K_pe, Output)

if num_split > 1:
return main_split
else:
return main_no_split

return _mla_decode_func
return mla_decode_func


@torch.library.custom_op("top::mla_decode_wrapped_kernel", mutates_args=())
Expand Down Expand Up @@ -378,7 +379,7 @@ def _mla_decode_ws_kernel(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, dty
"--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG"
],
)
def _mla_decode_ws_func(block_H, block_N, num_split, num_stages, threads=128):
def mla_decode_ws_func(block_H, block_N, num_split, num_stages, threads=128):

VALID_BLOCK_H = min(block_H, kv_group_num)
@T.macro
Expand Down Expand Up @@ -458,6 +459,8 @@ def flash_attn(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
Comment on lines +462 to +463
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This ensures m_i (maximum score) is correctly updated across iterations, which is vital for numerical stability in the FlashAttention mechanism.

for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -490,6 +493,8 @@ def flash_attn(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
Comment on lines +496 to +497
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar numerical stability fix for m_i in the second buffer processing loop.

for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -677,6 +682,8 @@ def flash_attn_split(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
Comment on lines +685 to +686
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Ensures m_i is correctly updated for numerical stability in the split FlashAttention mechanism.

for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -709,6 +716,8 @@ def flash_attn_split(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
Comment on lines +719 to +720
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Consistent numerical stability fix for m_i in the second buffer processing loop of split FlashAttention.

for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -888,7 +897,7 @@ def main_no_split(
else:
return main_no_split

return _mla_decode_ws_func
return mla_decode_ws_func


@torch.library.custom_op("top::mla_decode_ws_wrapped_kernel", mutates_args=())
Expand Down
14 changes: 9 additions & 5 deletions top/kernels/deepseek_mla/sparse_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _sparse_mla_kernel(
"--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG"
],
)
def _sparse_mla_fwd_func(block_I, threads):
def sparse_mla_fwd_func(block_I, threads):

G = kv_group
heads = head_kv
Expand Down Expand Up @@ -89,7 +89,7 @@ def _sparse_mla_fwd_func(block_I, threads):

H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def _sparse_mla_fwd_main(
def sparse_mla_fwd_main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Expand Down Expand Up @@ -176,8 +176,10 @@ def _sparse_mla_fwd_main(
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)

T.copy(m_i, m_i_prev)
T.copy(src=m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
Comment on lines +181 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line ensures m_i is correctly updated for numerical stability, similar to other attention kernels.

for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
Expand Down Expand Up @@ -212,6 +214,8 @@ def _sparse_mla_fwd_main(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
Comment on lines +217 to +218
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Consistent numerical stability fix for m_i in the second buffer processing loop.

for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
Expand Down Expand Up @@ -322,9 +326,9 @@ def _sparse_mla_fwd_main(
D + (tx - 256) % 8 * 8 + v]
T.cp_async_barrier_noinc(bar_k_1_ready[0])

return _sparse_mla_fwd_main
return sparse_mla_fwd_main

return _sparse_mla_fwd_func
return sparse_mla_fwd_func


@torch.library.custom_op("top::sparse_mla_fwd_wrapped_kernel", mutates_args=())
Expand Down
58 changes: 20 additions & 38 deletions top/kernels/flash_attn/bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,11 @@ def _mha_bwd_kernel(batch, heads, seq_len, dim, is_causal, dtype="float16"):
shape = [batch, seq_len, heads, dim]
accum_dtype = "float"

@tilelang.jit(
out_idx=[7, 8],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
def _mha_bwd_func(block_M, block_N, num_stages, threads):
@tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def mha_bwd_func(block_M, block_N, num_stages, threads):

@T.prim_func
def _mha_bwd_main(
def mha_bwd_main(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
Comment on lines +117 to 123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removal of compile_flags and renaming of _mha_bwd_func to mha_bwd_func and _mha_bwd_main to mha_bwd_main are consistent with the PR's objectives.

V: T.Tensor(shape, dtype), # type: ignore
Expand Down Expand Up @@ -205,9 +200,9 @@ def _mha_bwd_main(
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :])

return _mha_bwd_main
return mha_bwd_main

return _mha_bwd_func
return mha_bwd_func


class mha_bwd_kernel(Kernel):
Expand Down Expand Up @@ -270,16 +265,11 @@ def _mha_bwd_wgmma_pipelined_kernel(batch, heads, seq_len, dim, is_causal, dtype
shape = [batch, seq_len, heads, dim]
accum_dtype = "float"

@tilelang.jit(
out_idx=[7, 8],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
def _mha_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads):
@tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def mha_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads):

@T.prim_func
def _mha_bwd_wgmma_pipelined_main(
def mha_bwd_wgmma_pipelined_main(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Expand Down Expand Up @@ -378,9 +368,9 @@ def _mha_bwd_wgmma_pipelined_main(
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :])

return _mha_bwd_wgmma_pipelined_main
return mha_bwd_wgmma_pipelined_main

return _mha_bwd_wgmma_pipelined_func
return mha_bwd_wgmma_pipelined_func


class mha_bwd_wgmma_pipelined_kernel(Kernel):
Expand Down Expand Up @@ -449,15 +439,11 @@ def _gqa_bwd_kernel(batch, heads, heads_kv, seq_len, dim, is_causal, dtype="floa
groups = heads // heads_kv
accum_dtype = "float"

@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
def _gqa_bwd_func(block_M, block_N, num_stages, threads):
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def gqa_bwd_func(block_M, block_N, num_stages, threads):

@T.prim_func
def _gqa_bwd_main(
def gqa_bwd_main(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore
Expand Down Expand Up @@ -534,9 +520,9 @@ def _gqa_bwd_main(
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)

return _gqa_bwd_main
return gqa_bwd_main

return _gqa_bwd_func
return gqa_bwd_func


class gqa_bwd_kernel(Kernel):
Expand Down Expand Up @@ -610,15 +596,11 @@ def _gqa_bwd_wgmma_pipelined_kernel(batch,
groups = heads // heads_kv
accum_dtype = "float"

@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
def _gqa_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads):
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def gqa_bwd_wgmma_pipelined_func(block_M, block_N, num_stages, threads):

@T.prim_func
def _gqa_bwd_wgmma_pipelined_main(
def gqa_bwd_wgmma_pipelined_main(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore
Expand Down Expand Up @@ -711,9 +693,9 @@ def _gqa_bwd_wgmma_pipelined_main(
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx // groups, :], dk_shared)

return _gqa_bwd_wgmma_pipelined_main
return gqa_bwd_wgmma_pipelined_main

return _gqa_bwd_wgmma_pipelined_func
return gqa_bwd_wgmma_pipelined_func


class gqa_bwd_wgmma_pipelined_kernel(Kernel):
Expand Down
Loading
Loading