-
Notifications
You must be signed in to change notification settings - Fork 46
Mingfa/main #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mingfa/main #30
Conversation
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @mingfa-w, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily focuses on expanding hardware compatibility by integrating support for NPU devices, ensuring the sparse attention operations can run on these platforms. It also significantly enhances the maintainability and understanding of the core parallel attention kernels through the addition of comprehensive inline comments. Furthermore, the PR improves performance adaptability by implementing more intelligent block size selection based on the underlying GPU architecture.
Highlights
- NPU Device Support: Introduced support for NPU (Neural Processing Unit) devices, with corresponding changes in test configurations to target NPU for tensor operations.
- Enhanced Code Readability: Added extensive and detailed Chinese comments to the core Triton kernels (
parallel_nsa_compression_bwd_kernel_dq
,parallel_nsa_kernel_topk
,parallel_nsa_fwd_kernel
) to clarify their functionality, parameters, and internal logic. - Adaptive Block Size Determination: Refined the logic for determining optimal block sizes (BK and BV) for Key and Value dimensions in forward pass functions, now dynamically adjusting based on CUDA device availability and compute capability (e.g., 256 for compute capability >= 9, 128 otherwise).
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
这是来自QQ邮箱的假期自动回复邮件。
您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request appears to be focused on adding support for NPU devices, which involves creating new tests, modifying existing code to be device-agnostic, and adding extensive comments to Triton kernels. While the effort to improve code clarity with comments is appreciated, there are a few key areas for improvement. The new comments are in Chinese, which should be translated to English for consistency. A significant test case for variable-length sequences has been commented out, reducing test coverage. Most critically, the new NPU tests seem to compare inequivalent implementations, which could lead to unreliable test results. There are also some minor style guide violations regarding variable naming.
ref = naive_nsa( | ||
q=q, | ||
k=k, | ||
v=v, | ||
g_slc=g_slc, | ||
g_swa=g_swa, | ||
block_indices=block_indices, | ||
block_counts=block_counts, | ||
block_size=block_size, | ||
window_size=window_size, | ||
scale=scale | ||
) | ||
ref.backward(do) | ||
ref_dq, q.grad = q.grad.clone(), None | ||
ref_dk, k.grad = k.grad.clone(), None | ||
ref_dv, v.grad = v.grad.clone(), None | ||
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None | ||
if window_size > 0: | ||
ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None | ||
|
||
tri = parallel_nsa( | ||
q=q, | ||
k=k, | ||
v=v, | ||
g_cmp=g_cmp, | ||
g_slc=g_slc, | ||
g_swa=g_swa, | ||
block_indices=block_indices, | ||
block_counts=block_counts, | ||
block_size=block_size, | ||
window_size=window_size, | ||
scale=scale | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a critical issue in this test. The reference implementation naive_nsa
is being compared against parallel_nsa
. However, parallel_nsa
is called with the g_cmp
parameter, which enables compression-related logic, while naive_nsa
is called without it and does not seem to support it. This means the test is comparing two different computations, which will lead to incorrect results and make the test unreliable. Please ensure that the reference and the implementation under test are computing the same function.
# @pytest.mark.parametrize("N", [4]) | ||
# @pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048]) | ||
# @pytest.mark.parametrize("H", [4]) | ||
# @pytest.mark.parametrize("HQ", [64]) | ||
# @pytest.mark.parametrize("D", [100, 64]) | ||
# @pytest.mark.parametrize("S", [16]) | ||
# @pytest.mark.parametrize("block_size", [32]) | ||
# @pytest.mark.parametrize("window_size", [0, 32]) | ||
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
# def test_parallel_varlen( | ||
# N: int, | ||
# T: int, | ||
# H: int, | ||
# HQ: int, | ||
# D: int, | ||
# S: int, | ||
# block_size: int, | ||
# window_size: int, | ||
# dtype: torch.dtype, | ||
# ): | ||
# torch.manual_seed(42) | ||
# os.environ['TRITON_F32_DEFAULT'] = 'ieee' | ||
|
||
# # randomly split the sequence into N segments | ||
# offsets = torch.cat([ | ||
# torch.tensor([0], dtype=torch.long), | ||
# torch.arange(16, T)[torch.randperm(T - 1)[:N-1]], | ||
# torch.tensor([T], dtype=torch.long) | ||
# ], 0).npu().sort()[0] | ||
# # seq-first required for inputs with variable lengths | ||
# perm_q = torch.randperm(T, device='npu') | ||
# perm_k = torch.randperm(T, device='npu') | ||
# perm_v = torch.randperm(T, device='npu') | ||
# q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True) | ||
# k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) | ||
# v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True) | ||
# g_slc = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True) | ||
# g_swa = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True) | ||
# do = torch.randn((1, T, HQ, D), dtype=dtype, device='npu') | ||
|
||
# token_indices = prepare_token_indices(offsets).tolist() | ||
# block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='npu') | ||
# for i in range(T): | ||
# _, t = token_indices[i] | ||
# for h in range(H): | ||
# i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] | ||
# block_indices[0, i, h, :len(i_i)] = i_i | ||
# block_indices = block_indices.sort(-1)[0] | ||
# block_counts = torch.randint(1, S + 1, (1, T, H), device='npu') | ||
|
||
# ref = naive_nsa( | ||
# q=q, | ||
# k=k, | ||
# v=v, | ||
# g_slc=g_slc, | ||
# g_swa=g_swa, | ||
# block_indices=block_indices, | ||
# block_counts=block_counts, | ||
# block_size=block_size, | ||
# window_size=window_size, | ||
# cu_seqlens=offsets | ||
# ) | ||
# ref.backward(do) | ||
# ref_dq, q.grad = q.grad.clone(), None | ||
# ref_dk, k.grad = k.grad.clone(), None | ||
# ref_dv, v.grad = v.grad.clone(), None | ||
# ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None | ||
# if window_size > 0: | ||
# ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None | ||
|
||
# tri = parallel_nsa( | ||
# q=q, | ||
# k=k, | ||
# v=v, | ||
# g_slc=g_slc, | ||
# g_swa=g_swa, | ||
# block_indices=block_indices, | ||
# block_counts=block_counts, | ||
# block_size=block_size, | ||
# window_size=window_size, | ||
# cu_seqlens=offsets | ||
# ) | ||
# tri.backward(do) | ||
# tri_dq, q.grad = q.grad.clone(), None | ||
# tri_dk, k.grad = k.grad.clone(), None | ||
# tri_dv, v.grad = v.grad.clone(), None | ||
# tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None | ||
# if window_size > 0: | ||
# tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None | ||
|
||
# assert_close(" o", ref, tri, 0.004) | ||
# assert_close("dq", ref_dq, tri_dq, 0.005) | ||
# assert_close("dk", ref_dk, tri_dk, 0.005) | ||
# assert_close("dv", ref_dv, tri_dv, 0.005) | ||
# assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005) | ||
# if window_size > 0: | ||
# assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire test for variable-length sequences (test_parallel_varlen
) has been commented out. This leads to a significant loss in test coverage. If the test is failing, it should be fixed. If it's not applicable anymore or needs to be temporarily disabled, please use pytest.mark.skip
with a descriptive reason instead of commenting it out.
# 输入参数 | ||
q, # Query矩阵 | ||
k, # Key矩阵 | ||
v, # Value矩阵 | ||
lse, # Log Sum Exp结果 | ||
delta, # softmax导数修正项 | ||
do, # 输出o的梯度 | ||
dq, # Query梯度(输出) | ||
scale, # 缩放因子 | ||
offsets, # 序列偏移量 | ||
token_indices, # token索引 | ||
chunk_offsets, | ||
T, | ||
B: tl.constexpr, | ||
H: tl.constexpr, | ||
HQ: tl.constexpr, | ||
G: tl.constexpr, | ||
K: tl.constexpr, | ||
V: tl.constexpr, | ||
BC: tl.constexpr, | ||
BS: tl.constexpr, | ||
BK: tl.constexpr, | ||
BV: tl.constexpr, | ||
USE_OFFSETS: tl.constexpr | ||
T, # 序列长度 | ||
# 常量参数 | ||
B: tl.constexpr, # 块大小 | ||
H: tl.constexpr, # 注意力头数 | ||
HQ: tl.constexpr, # 每个头的维度 | ||
G: tl.constexpr, # 组数 | ||
K: tl.constexpr, # Key维度 | ||
V: tl.constexpr, # Value维度 | ||
BC: tl.constexpr, # 压缩块大小 | ||
BS: tl.constexpr, # 序列块大小 | ||
BK: tl.constexpr, # Key块大小 | ||
BV: tl.constexpr, # Value块大小 | ||
USE_OFFSETS: tl.constexpr # 是否使用偏移量的标志 | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BKV_VALUE=256 | ||
if torch.cuda.is_available(): | ||
if torch.cuda.get_device_capability()[0] < 9: | ||
BKV_VALUE=128 | ||
BK = min(BKV_VALUE, triton.next_power_of_2(K)) | ||
BV = min(BKV_VALUE, triton.next_power_of_2(V)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable BKV_VALUE
does not follow Python's PEP 8 naming convention for constants, which should be UPPER_CASE_WITH_UNDERSCORES
. Please consider renaming it to something more descriptive and in the correct style, for example, BKV_BLOCK_SIZE
.
BKV_VALUE=256 | |
if torch.cuda.is_available(): | |
if torch.cuda.get_device_capability()[0] < 9: | |
BKV_VALUE=128 | |
BK = min(BKV_VALUE, triton.next_power_of_2(K)) | |
BV = min(BKV_VALUE, triton.next_power_of_2(V)) | |
BKV_BLOCK_SIZE=256 | |
if torch.cuda.is_available(): | |
if torch.cuda.get_device_capability()[0] < 9: | |
BKV_BLOCK_SIZE=128 | |
BK = min(BKV_BLOCK_SIZE, triton.next_power_of_2(K)) | |
BV = min(BKV_BLOCK_SIZE, triton.next_power_of_2(V)) |
BKV_VALUE=256 | ||
if torch.cuda.is_available(): | ||
if torch.cuda.get_device_capability()[0] < 9: | ||
BKV_VALUE=128 | ||
BK = min(BKV_VALUE, triton.next_power_of_2(K)) | ||
BV = min(BKV_VALUE, triton.next_power_of_2(V)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to a previous comment, the variable BKV_VALUE
does not follow PEP 8 naming conventions for constants. Please rename it to BKV_BLOCK_SIZE
or similar for clarity and consistency.
BKV_VALUE=256 | |
if torch.cuda.is_available(): | |
if torch.cuda.get_device_capability()[0] < 9: | |
BKV_VALUE=128 | |
BK = min(BKV_VALUE, triton.next_power_of_2(K)) | |
BV = min(BKV_VALUE, triton.next_power_of_2(V)) | |
BKV_BLOCK_SIZE=256 | |
if torch.cuda.is_available(): | |
if torch.cuda.get_device_capability()[0] < 9: | |
BKV_BLOCK_SIZE=128 | |
BK = min(BKV_BLOCK_SIZE, triton.next_power_of_2(K)) | |
BV = min(BKV_BLOCK_SIZE, triton.next_power_of_2(V)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
native_sparse_attention/ops/parallel.py (3)
1310-1311
: Invalid use of @torch.compile on an autograd.Function class.torch.compile decorates callables (functions/modules), not autograd.Function classes. This will error or be a no-op.
Remove the decorator:
-@torch.compile class ParallelNSAFunction(torch.autograd.Function):
1479-1497
: Call to flash_attn_func without availability guard.If flash-attn isn’t installed, flash_attn_func is None and this path will crash when window_size > 0. Add a clear guard/fallback.
- else: - o_swa = flash_attn_func( + else: + if flash_attn_func is None: + raise ImportError("flash-attn is required for windowed attention (window_size>0)") + o_swa = flash_attn_func( q, k, v, causal=True, window_size=(window_size-1, 0) )
1348-1365
: Fix backward return arity to match forward signature
In native_sparse_attention/ops/parallel.py,ParallelNSAFunction.backward
must return one gradient per forward input (8 total), but currently returns 11. Update the return to:- return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None
🧹 Nitpick comments (6)
tests/test_nsa.py (2)
52-54
: Guard NPU-only test and windowed attention.
- Skip when NPU isn’t available.
- Windowed attention (window_size > 0) requires flash-attn; skip those cases if flash-attn isn’t installed to avoid calling None.
torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' + if not hasattr(torch, "npu") or not torch.npu.is_available(): + pytest.skip("NPU not available on this runner") + try: + import importlib.util + _flash_attn_ok = importlib.util.find_spec("flash_attn") is not None + except Exception: + _flash_attn_ok = False + if window_size > 0 and not _flash_attn_ok: + pytest.skip("flash-attn is required for windowed attention tests")Also applies to: 37-39
125-222
: Either remove or unskip the varlen test coherently.It’s fully commented out. If kept disabled, remove to reduce noise; if needed, port it to the NPU path and gate like the main test. I can help port it.
tests/test_nsa_npu.py (3)
8-8
: Import kept only for side effects; annotate to silence linters.torch_npu is unused symbol-wise but required to register the "npu" device. Add noqa and a clarifying comment.
-import torch_npu +import torch_npu # noqa: F401 # required to register 'npu' device
148-165
: Skip when NPU/flash-attn aren’t available.Make the test robust across CI environments.
torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' + if not hasattr(torch, "npu") or not torch.npu.is_available(): + pytest.skip("NPU not available on this runner") + try: + import importlib.util + _flash_attn_ok = importlib.util.find_spec("flash_attn") is not None + except Exception: + _flash_attn_ok = False + if window_size > 0 and not _flash_attn_ok: + pytest.skip("flash-attn is required for windowed attention tests")
167-175
: Consistency: dtype for block_counts.Explicitly set dtype to torch.long for clarity (it is default, but being explicit helps on different backends).
- block_counts = torch.randint(1, S + 1, (1, T, H), device='npu') + block_counts = torch.randint(1, S + 1, (1, T, H), dtype=torch.long, device='npu')native_sparse_attention/ops/parallel.py (1)
885-889
: Minor: collapse nested CUDA capability checks.No behavior change; simplifies logic.
- BKV_VALUE=256 - if torch.cuda.is_available(): - if torch.cuda.get_device_capability()[0] < 9: - BKV_VALUE=128 + BKV_VALUE = 128 if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9) else 256Also applies to: 1134-1138
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
native_sparse_attention/ops/parallel.py
(5 hunks)tests/test_nsa.py
(4 hunks)tests/test_nsa_npu.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/test_nsa_npu.py (2)
native_sparse_attention/ops/naive.py (1)
naive_nsa
(29-170)native_sparse_attention/ops/parallel.py (3)
parallel_nsa
(1385-1500)backward
(1057-1071)backward
(1348-1364)
native_sparse_attention/ops/parallel.py (1)
native_sparse_attention/ops/utils.py (1)
_bitonic_merge
(49-71)
🪛 Ruff (0.12.2)
tests/test_nsa_npu.py
8-8: torch_npu
imported but unused
Remove unused import: torch_npu
(F401)
native_sparse_attention/ops/parallel.py
566-569: Use ternary operator NS = tl.load(block_counts + (bos + i_t) * H + i_h) if USE_BLOCK_COUNTS else S
instead of if
-else
-block
Replace if
-else
-block with NS = tl.load(block_counts + (bos + i_t) * H + i_h) if USE_BLOCK_COUNTS else S
(SIM108)
885-886: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
1134-1135: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
🪛 Flake8 (7.2.0)
tests/test_nsa_npu.py
[error] 8-8: 'torch_npu' imported but unused
(F401)
🔇 Additional comments (1)
native_sparse_attention/ops/parallel.py (1)
129-138
: Heuristic/autotune key comments are fine; no action.Config keys and heuristics look consistent across kernels.
Also applies to: 347-356, 508-518
q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True) | ||
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) | ||
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) | ||
g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Disable compression here too to match the naive reference.
Same mismatch as in tests/test_nsa.py—g_cmp triggers a compressed path and recomputed Top‑K, while naive_nsa doesn’t. Set g_cmp=None.
- g_cmp = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
+ g_cmp = None # disable compression path to match naive_nsa
@@
- g_cmp=g_cmp,
+ g_cmp=g_cmp,
Also applies to: 196-208
🤖 Prompt for AI Agents
In tests/test_nsa_npu.py around lines 61 (and similarly for lines 196-208), the
test creates g_cmp as a random tensor which enables the compressed path and
recomputed Top-K, causing a mismatch with the naive reference; change g_cmp to
None to disable compression and match the naive_nsa behavior—update the
assignment(s) that currently set g_cmp = torch.rand(...) to g_cmp = None so the
compressed branch is not exercised.
q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True) | ||
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) | ||
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True) | ||
g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Disable compression in this test to match the naive baseline (set g_cmp=None).
parallel_nsa adds a compressed path when g_cmp is not None and also ignores the passed block_indices (it recomputes Top‑K). naive_nsa doesn’t implement the compressed path, so the outputs/gradients won’t match.
Apply this diff to remove the mismatch:
- g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
+ g_cmp = None # disable compression path to match naive_nsa
@@
- tri = parallel_nsa(
+ tri = parallel_nsa(
q=q,
k=k,
v=v,
- g_cmp=g_cmp,
+ g_cmp=g_cmp,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale
)
Also applies to: 95-107
🤖 Prompt for AI Agents
In tests/test_nsa.py around line 61 (and similarly lines 95-107) the test
enables the compressed path by creating g_cmp with random values, causing
parallel_nsa to use compression and recompute Top-K while naive_nsa lacks that
path; set g_cmp=None in these locations to disable compression and ensure both
implementations follow the same (naive) path so outputs and gradients match.
Replace the g_cmp tensor assignments with g_cmp = None at line 61 and the
corresponding assignments in the block covering lines 95-107.
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu') | ||
for b in range(B): | ||
for t in range(T): | ||
for h in range(H): | ||
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] | ||
block_indices[b, t, h, :len(i_i)] = i_i | ||
block_indices = block_indices.sort(-1)[0] | ||
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='cuda') | ||
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Use a sentinel like -1 for unfilled block_indices instead of T.
Filling with T will clamp to T-1 in naive_nsa.gather, unintentionally selecting the last block. Use -1 (properly masked later) to avoid bias.
- block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')
+ block_indices = torch.full((B, T, H, S), -1, dtype=torch.long, device='npu')
🏁 Script executed:
#!/bin/bash
set -e
# Locate the naive_nsa module to inspect its gather implementation
rg -l 'naive_nsa' -g '*.py'
Length of output: 188
🏁 Script executed:
#!/bin/bash
set -e
# Search for gather implementation in naive.py
rg -nP 'def gather' -C5 native_sparse_attention/ops/naive.py
Length of output: 71
🏁 Script executed:
#!/bin/bash
set -e
# Inspect block_indices usage in naive implementation
rg -n -C3 'block_indices' native_sparse_attention/ops/naive.py
# Locate gather implementation and index clamping behavior
rg -n -C3 'def gather' native_sparse_attention/ops/naive.py
rg -n -C3 'gather\(' native_sparse_attention/ops/naive.py
Length of output: 5641
🏁 Script executed:
#!/bin/bash
set -e
# Find any C++/CUDA implementations in ops
find native_sparse_attention/ops -type f \( -name '*.cpp' -o -name '*.cu' \)
# Search for a gather kernel or function in ops directory
rg -n 'gather' -C3 native_sparse_attention/ops
Length of output: 757
🏁 Script executed:
#!/bin/bash
set -e
# Verify index clamping in gather: sentinel T will be clamped to T-1
rg -n "clamp(0, T-1)" native_sparse_attention/ops/naive.py
# Verify masking of out-of-range indices: sentinel -1 vs T
rg -n "masked_fill" native_sparse_attention/ops/naive.py
Length of output: 69
Use –1 as the sentinel for unfilled block_indices
T as a sentinel is clamped to T−1 in the gather step, biasing toward the last block—use –1 (masked downstream) to avoid this.
tests/test_nsa.py (66–73)
- block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')
+ block_indices = torch.full((B, T, H, S), -1, dtype=torch.long, device='npu')
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu') | |
for b in range(B): | |
for t in range(T): | |
for h in range(H): | |
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] | |
block_indices[b, t, h, :len(i_i)] = i_i | |
block_indices = block_indices.sort(-1)[0] | |
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='cuda') | |
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu') | |
block_indices = torch.full((B, T, H, S), -1, dtype=torch.long, device='npu') | |
for b in range(B): | |
for t in range(T): | |
for h in range(H): | |
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] | |
block_indices[b, t, h, :len(i_i)] = i_i | |
block_indices = block_indices.sort(-1)[0] | |
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu') |
🤖 Prompt for AI Agents
In tests/test_nsa.py around lines 66 to 73, block_indices is initialized using T
as the sentinel which is clamped to T-1 downstream and biases gathers;
initialize block_indices with -1 as the sentinel (same dtype long, same device)
so unfilled slots are marked -1 and can be masked downstream; keep the rest of
the loop that fills valid entries unchanged and preserve the final sort call so
the -1 sentinels end up at the front/end as expected for downstream masking.
Summary by CodeRabbit