-
Notifications
You must be signed in to change notification settings - Fork 9
[Enhancement] Enhance attention examples and fix bugs for several corner cases #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor
Are you sure you want to change the base?
Conversation
Rachmanino
commented
Nov 17, 2025
- Keep rowmax to avoid numerical issues
- Correctly fill acc_s for padded positions (rather than default value 0)
- Remove legacy compile flags
Summary of ChangesHello @Rachmanino, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the attention examples by addressing numerical stability issues, correcting padding errors, and removing legacy compile flags. These changes improve the robustness and correctness of the attention mechanism in various corner cases. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several important enhancements and bug fixes related to attention mechanisms. Key improvements include better numerical stability by correctly updating scores_max across iterations, and accurate handling of padded positions in attention scores by initializing them with negative infinity. Additionally, legacy compile flags have been removed, and internal function names have been standardized by removing leading underscores for improved clarity. One minor inconsistency was noted regarding the removal of TL_ENABLE_FAST_MATH in sparse_mla.py compared to other files.
| # We shall fill -inf for padded positions | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, | ||
| -T.infinity(acc_s.dtype), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for i, j in T.Parallel(block_M, block_N): | ||
| acc_s[i, j] = T.if_then_else(k * block_N + j >= seqlen_kv, | ||
| -T.infinity(acc_s.dtype), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for i, j in T.Parallel(block_H, block_N): | ||
| acc_s[i, j] = T.if_then_else(kv_start + j >= seqlen_kv, | ||
| -T.infinity(acc_s.dtype), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing T.clear(acc_s) with a conditional fill ensures that padded positions are correctly initialized with -T.infinity(acc_s.dtype). This is a critical fix for numerical correctness, as T.clear would default to zero, potentially leading to incorrect attention weights for padded elements. This directly addresses the PR description point "Correctly fill acc_s for padded positions".
| for i, j in T.Parallel(block_M, block_N): | ||
| acc_s[i, j] = T.if_then_else(k * block_N + j >= seqlen_kv, | ||
| -T.infinity(acc_s.dtype), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for i, j in T.Parallel(block_M, block_N): | ||
| acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, | ||
| -T.infinity(acc_s.dtype), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| order=[-1, 0, 3, 1, -1, 2], | ||
| stage=[-1, 0, 0, 1, -1, 1], | ||
| group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): | ||
| group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @tilelang.jit(out_idx=[6], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def mla_decode_func(block_H, block_N, num_split, num_stages, threads=128): | ||
|
|
||
| VALID_BLOCK_H = min(block_H, kv_group_num) | ||
|
|
||
| @T.macro |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @tilelang.jit(out_idx=[3, 4], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def mha_fwd_func(block_M, block_N, num_stages, threads): | ||
|
|
||
| @T.prim_func | ||
| def _mha_fwd_main( | ||
| def mha_fwd_main( | ||
| Q: T.Tensor(shape, dtype), # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @tilelang.jit(out_idx=[7, 8], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) | ||
| def mha_bwd_func(block_M, block_N, num_stages, threads): | ||
|
|
||
| @T.prim_func | ||
| def _mha_bwd_main( | ||
| def mha_bwd_main( | ||
| Q: T.Tensor(shape, dtype), # type: ignore | ||
| K: T.Tensor(shape, dtype), # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @tilelang.jit(out_idx=[-1]) | ||
| def gemm_func(block_M, block_N, block_K, threads, num_stages, enable_rasteration): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.