Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

  1. Keep rowmax to avoid numerical issues
  2. Correctly fill acc_s for padded positions (rather than default value 0)
  3. Remove legacy compile flags

@Rachmanino Rachmanino marked this pull request as draft November 17, 2025 12:10
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Rachmanino, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the attention examples by addressing numerical stability issues, correcting padding errors, and removing legacy compile flags. These changes improve the robustness and correctness of the attention mechanism in various corner cases.

Highlights

  • Numerical Stability: Ensures numerical stability by keeping rowmax to avoid potential issues in attention calculations.
  • Padding Correction: Corrects the filling of acc_s for padded positions, using -T.infinity(acc_s.dtype) instead of the default value of 0, which fixes bugs in corner cases.
  • Legacy Flag Removal: Removes legacy compile flags, streamlining the compilation process.
  • Code Modernization: Replaces the naming of internal functions to remove the underscore prefix, and modernizes the T.copy calls to include src=.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several important enhancements and bug fixes related to attention mechanisms. Key improvements include better numerical stability by correctly updating scores_max across iterations, and accurate handling of padded positions in attention scores by initializing them with negative infinity. Additionally, legacy compile flags have been removed, and internal function names have been standardized by removing leading underscores for improved clarity. One minor inconsistency was noted regarding the removal of TL_ENABLE_FAST_MATH in sparse_mla.py compared to other files.

Comment on lines +62 to +65
# We shall fill -inf for padded positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly initializes acc_s with -T.infinity(acc_s.dtype) for padded positions, preventing numerical issues and ensuring correct attention weight calculation. This directly addresses the PR description point "Correctly fill acc_s for padded positions".

Comment on lines +114 to +116
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seqlen_kv,
-T.infinity(acc_s.dtype), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the previous change, this correctly initializes acc_s for padded positions with negative infinity in the split MHA decode path, ensuring numerical stability.

Comment on lines +136 to +138
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(kv_start + j >= seqlen_kv,
-T.infinity(acc_s.dtype), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Replacing T.clear(acc_s) with a conditional fill ensures that padded positions are correctly initialized with -T.infinity(acc_s.dtype). This is a critical fix for numerical correctness, as T.clear would default to zero, potentially leading to incorrect attention weights for padded elements. This directly addresses the PR description point "Correctly fill acc_s for padded positions".

Comment on lines +62 to +64
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seqlen_kv,
-T.infinity(acc_s.dtype), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly initializes acc_s with -T.infinity(acc_s.dtype) for padded positions, preventing numerical issues and ensuring correct attention weight calculation. This directly addresses the PR description point "Correctly fill acc_s for padded positions".

Comment on lines +607 to +609
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to other forward passes, this correctly initializes acc_s for padded positions with negative infinity in pipelined GQA, ensuring numerical stability.

order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The group list has been updated, likely reflecting changes in pipelining stages or memory access patterns due to other optimizations. This is a functional change that should be verified for correctness and performance impact.

Comment on lines +17 to 22
@tilelang.jit(out_idx=[6], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def mla_decode_func(block_H, block_N, num_split, num_stages, threads=128):

VALID_BLOCK_H = min(block_H, kv_group_num)

@T.macro
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The removal of compile_flags is a good cleanup, aligning with the PR description to remove legacy flags. This simplifies the JIT decorator configuration.

Comment on lines +21 to 26
@tilelang.jit(out_idx=[3, 4], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def mha_fwd_func(block_M, block_N, num_stages, threads):

@T.prim_func
def _mha_fwd_main(
def mha_fwd_main(
Q: T.Tensor(shape, dtype), # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removal of compile_flags and renaming of _mha_fwd_func to mha_fwd_func and _mha_fwd_main to mha_fwd_main are consistent with the PR's objectives.

Comment on lines +117 to 123
@tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def mha_bwd_func(block_M, block_N, num_stages, threads):

@T.prim_func
def _mha_bwd_main(
def mha_bwd_main(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removal of compile_flags and renaming of _mha_bwd_func to mha_bwd_func and _mha_bwd_main to mha_bwd_main are consistent with the PR's objectives.

Comment on lines +13 to +14
@tilelang.jit(out_idx=[-1])
def gemm_func(block_M, block_N, block_K, threads, num_stages, enable_rasteration):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removal of compile_flags and renaming of _gemm_func to gemm_func and _gemm_main to gemm_main are consistent with the PR's objectives.

@Rachmanino Rachmanino marked this pull request as ready for review November 18, 2025 02:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants