Skip to content

Conversation

mingfa-w
Copy link

@mingfa-w mingfa-w commented Aug 29, 2025

Summary by CodeRabbit

  • New Features
    • Added variable-length sequence support via offsets across attention paths.
    • Enhanced Top-K selection over blocks with offset-aware indexing.
    • Adaptive block sizing and improved autotuning for broader device coverage.
    • Optional online softmax/LSE in backward when precomputed values are absent.
  • Tests
    • Introduced NPU-focused parity tests validating outputs and gradients vs. reference.
    • Migrated existing tests to NPU; parameterized scenarios (shapes, dtypes, scale).
    • Temporarily disabled a variable-length test pending updates.
  • Documentation
    • Added clarifying comments and parameter notes.

Copy link

coderabbitai bot commented Aug 29, 2025

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @mingfa-w, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on expanding hardware compatibility by integrating support for NPU devices, ensuring the sparse attention operations can run on these platforms. It also significantly enhances the maintainability and understanding of the core parallel attention kernels through the addition of comprehensive inline comments. Furthermore, the PR improves performance adaptability by implementing more intelligent block size selection based on the underlying GPU architecture.

Highlights

  • NPU Device Support: Introduced support for NPU (Neural Processing Unit) devices, with corresponding changes in test configurations to target NPU for tensor operations.
  • Enhanced Code Readability: Added extensive and detailed Chinese comments to the core Triton kernels (parallel_nsa_compression_bwd_kernel_dq, parallel_nsa_kernel_topk, parallel_nsa_fwd_kernel) to clarify their functionality, parameters, and internal logic.
  • Adaptive Block Size Determination: Refined the logic for determining optimal block sizes (BK and BV) for Key and Value dimensions in forward pass functions, now dynamically adjusting based on CUDA device availability and compute capability (e.g., 256 for compute capability >= 9, 128 otherwise).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mingfa-w
Copy link
Author

mingfa-w commented Aug 29, 2025 via email

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request appears to be focused on adding support for NPU devices, which involves creating new tests, modifying existing code to be device-agnostic, and adding extensive comments to Triton kernels. While the effort to improve code clarity with comments is appreciated, there are a few key areas for improvement. The new comments are in Chinese, which should be translated to English for consistency. A significant test case for variable-length sequences has been commented out, reducing test coverage. Most critically, the new NPU tests seem to compare inequivalent implementations, which could lead to unreliable test results. There are also some minor style guide violations regarding variable naming.

Comment on lines +75 to +107
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale
)
ref.backward(do)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
if window_size > 0:
ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None

tri = parallel_nsa(
q=q,
k=k,
v=v,
g_cmp=g_cmp,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical issue in this test. The reference implementation naive_nsa is being compared against parallel_nsa. However, parallel_nsa is called with the g_cmp parameter, which enables compression-related logic, while naive_nsa is called without it and does not seem to support it. This means the test is comparing two different computations, which will lead to incorrect results and make the test unreliable. Please ensure that the reference and the implementation under test are computing the same function.

Comment on lines +125 to +221
# @pytest.mark.parametrize("N", [4])
# @pytest.mark.parametrize("T", [64, 128, 200, 250, 256, 300, 400, 512, 1000, 2048])
# @pytest.mark.parametrize("H", [4])
# @pytest.mark.parametrize("HQ", [64])
# @pytest.mark.parametrize("D", [100, 64])
# @pytest.mark.parametrize("S", [16])
# @pytest.mark.parametrize("block_size", [32])
# @pytest.mark.parametrize("window_size", [0, 32])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# def test_parallel_varlen(
# N: int,
# T: int,
# H: int,
# HQ: int,
# D: int,
# S: int,
# block_size: int,
# window_size: int,
# dtype: torch.dtype,
# ):
# torch.manual_seed(42)
# os.environ['TRITON_F32_DEFAULT'] = 'ieee'

# # randomly split the sequence into N segments
# offsets = torch.cat([
# torch.tensor([0], dtype=torch.long),
# torch.arange(16, T)[torch.randperm(T - 1)[:N-1]],
# torch.tensor([T], dtype=torch.long)
# ], 0).npu().sort()[0]
# # seq-first required for inputs with variable lengths
# perm_q = torch.randperm(T, device='npu')
# perm_k = torch.randperm(T, device='npu')
# perm_v = torch.randperm(T, device='npu')
# q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
# k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
# v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
# g_slc = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
# g_swa = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
# do = torch.randn((1, T, HQ, D), dtype=dtype, device='npu')

# token_indices = prepare_token_indices(offsets).tolist()
# block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='npu')
# for i in range(T):
# _, t = token_indices[i]
# for h in range(H):
# i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
# block_indices[0, i, h, :len(i_i)] = i_i
# block_indices = block_indices.sort(-1)[0]
# block_counts = torch.randint(1, S + 1, (1, T, H), device='npu')

# ref = naive_nsa(
# q=q,
# k=k,
# v=v,
# g_slc=g_slc,
# g_swa=g_swa,
# block_indices=block_indices,
# block_counts=block_counts,
# block_size=block_size,
# window_size=window_size,
# cu_seqlens=offsets
# )
# ref.backward(do)
# ref_dq, q.grad = q.grad.clone(), None
# ref_dk, k.grad = k.grad.clone(), None
# ref_dv, v.grad = v.grad.clone(), None
# ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
# if window_size > 0:
# ref_dg_swa, g_swa.grad = g_swa.grad.clone(), None

# tri = parallel_nsa(
# q=q,
# k=k,
# v=v,
# g_slc=g_slc,
# g_swa=g_swa,
# block_indices=block_indices,
# block_counts=block_counts,
# block_size=block_size,
# window_size=window_size,
# cu_seqlens=offsets
# )
# tri.backward(do)
# tri_dq, q.grad = q.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dv, v.grad = v.grad.clone(), None
# tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None
# if window_size > 0:
# tri_dg_swa, g_swa.grad = g_swa.grad.clone(), None

# assert_close(" o", ref, tri, 0.004)
# assert_close("dq", ref_dq, tri_dq, 0.005)
# assert_close("dk", ref_dk, tri_dk, 0.005)
# assert_close("dv", ref_dv, tri_dv, 0.005)
# assert_close("dg_slc", ref_dg_slc, tri_dg_slc, 0.005)
# if window_size > 0:
# assert_close("dg_swa", ref_dg_swa, tri_dg_swa, 0.005)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This entire test for variable-length sequences (test_parallel_varlen) has been commented out. This leads to a significant loss in test coverage. If the test is failing, it should be fixed. If it's not applicable anymore or needs to be temporarily disabled, please use pytest.mark.skip with a descriptive reason instead of commenting it out.

Comment on lines +141 to 166
# 输入参数
q, # Query矩阵
k, # Key矩阵
v, # Value矩阵
lse, # Log Sum Exp结果
delta, # softmax导数修正项
do, # 输出o的梯度
dq, # Query梯度(输出)
scale, # 缩放因子
offsets, # 序列偏移量
token_indices, # token索引
chunk_offsets,
T,
B: tl.constexpr,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BC: tl.constexpr,
BS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr
T, # 序列长度
# 常量参数
B: tl.constexpr, # 块大小
H: tl.constexpr, # 注意力头数
HQ: tl.constexpr, # 每个头的维度
G: tl.constexpr, # 组数
K: tl.constexpr, # Key维度
V: tl.constexpr, # Value维度
BC: tl.constexpr, # 压缩块大小
BS: tl.constexpr, # 序列块大小
BK: tl.constexpr, # Key块大小
BV: tl.constexpr, # Value块大小
USE_OFFSETS: tl.constexpr # 是否使用偏移量的标志
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments added to this function signature and throughout the file are in Chinese. For consistency with the rest of the codebase and to make it accessible to a wider range of contributors, please translate these comments to English. This feedback applies to all new Chinese comments in this file.

Comment on lines +884 to +889
BKV_VALUE=256
if torch.cuda.is_available():
if torch.cuda.get_device_capability()[0] < 9:
BKV_VALUE=128
BK = min(BKV_VALUE, triton.next_power_of_2(K))
BV = min(BKV_VALUE, triton.next_power_of_2(V))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable BKV_VALUE does not follow Python's PEP 8 naming convention for constants, which should be UPPER_CASE_WITH_UNDERSCORES. Please consider renaming it to something more descriptive and in the correct style, for example, BKV_BLOCK_SIZE.

Suggested change
BKV_VALUE=256
if torch.cuda.is_available():
if torch.cuda.get_device_capability()[0] < 9:
BKV_VALUE=128
BK = min(BKV_VALUE, triton.next_power_of_2(K))
BV = min(BKV_VALUE, triton.next_power_of_2(V))
BKV_BLOCK_SIZE=256
if torch.cuda.is_available():
if torch.cuda.get_device_capability()[0] < 9:
BKV_BLOCK_SIZE=128
BK = min(BKV_BLOCK_SIZE, triton.next_power_of_2(K))
BV = min(BKV_BLOCK_SIZE, triton.next_power_of_2(V))

Comment on lines +1133 to +1138
BKV_VALUE=256
if torch.cuda.is_available():
if torch.cuda.get_device_capability()[0] < 9:
BKV_VALUE=128
BK = min(BKV_VALUE, triton.next_power_of_2(K))
BV = min(BKV_VALUE, triton.next_power_of_2(V))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to a previous comment, the variable BKV_VALUE does not follow PEP 8 naming conventions for constants. Please rename it to BKV_BLOCK_SIZE or similar for clarity and consistency.

Suggested change
BKV_VALUE=256
if torch.cuda.is_available():
if torch.cuda.get_device_capability()[0] < 9:
BKV_VALUE=128
BK = min(BKV_VALUE, triton.next_power_of_2(K))
BV = min(BKV_VALUE, triton.next_power_of_2(V))
BKV_BLOCK_SIZE=256
if torch.cuda.is_available():
if torch.cuda.get_device_capability()[0] < 9:
BKV_BLOCK_SIZE=128
BK = min(BKV_BLOCK_SIZE, triton.next_power_of_2(K))
BV = min(BKV_BLOCK_SIZE, triton.next_power_of_2(V))

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
native_sparse_attention/ops/parallel.py (3)

1310-1311: Invalid use of @torch.compile on an autograd.Function class.

torch.compile decorates callables (functions/modules), not autograd.Function classes. This will error or be a no-op.

Remove the decorator:

-@torch.compile
 class ParallelNSAFunction(torch.autograd.Function):

1479-1497: Call to flash_attn_func without availability guard.

If flash-attn isn’t installed, flash_attn_func is None and this path will crash when window_size > 0. Add a clear guard/fallback.

-        else:
-            o_swa = flash_attn_func(
+        else:
+            if flash_attn_func is None:
+                raise ImportError("flash-attn is required for windowed attention (window_size>0)")
+            o_swa = flash_attn_func(
                 q, k, v,
                 causal=True,
                 window_size=(window_size-1, 0)
             )

1348-1365: Fix backward return arity to match forward signature
In native_sparse_attention/ops/parallel.py, ParallelNSAFunction.backward must return one gradient per forward input (8 total), but currently returns 11. Update the return to:

-        return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
+        return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None
🧹 Nitpick comments (6)
tests/test_nsa.py (2)

52-54: Guard NPU-only test and windowed attention.

  • Skip when NPU isn’t available.
  • Windowed attention (window_size > 0) requires flash-attn; skip those cases if flash-attn isn’t installed to avoid calling None.
     torch.manual_seed(42)
     os.environ['TRITON_F32_DEFAULT'] = 'ieee'
+    if not hasattr(torch, "npu") or not torch.npu.is_available():
+        pytest.skip("NPU not available on this runner")
+    try:
+        import importlib.util
+        _flash_attn_ok = importlib.util.find_spec("flash_attn") is not None
+    except Exception:
+        _flash_attn_ok = False
+    if window_size > 0 and not _flash_attn_ok:
+        pytest.skip("flash-attn is required for windowed attention tests")

Also applies to: 37-39


125-222: Either remove or unskip the varlen test coherently.

It’s fully commented out. If kept disabled, remove to reduce noise; if needed, port it to the NPU path and gate like the main test. I can help port it.

tests/test_nsa_npu.py (3)

8-8: Import kept only for side effects; annotate to silence linters.

torch_npu is unused symbol-wise but required to register the "npu" device. Add noqa and a clarifying comment.

-import torch_npu
+import torch_npu  # noqa: F401  # required to register 'npu' device

148-165: Skip when NPU/flash-attn aren’t available.

Make the test robust across CI environments.

     torch.manual_seed(42)
     os.environ['TRITON_F32_DEFAULT'] = 'ieee'
+    if not hasattr(torch, "npu") or not torch.npu.is_available():
+        pytest.skip("NPU not available on this runner")
+    try:
+        import importlib.util
+        _flash_attn_ok = importlib.util.find_spec("flash_attn") is not None
+    except Exception:
+        _flash_attn_ok = False
+    if window_size > 0 and not _flash_attn_ok:
+        pytest.skip("flash-attn is required for windowed attention tests")

167-175: Consistency: dtype for block_counts.

Explicitly set dtype to torch.long for clarity (it is default, but being explicit helps on different backends).

-    block_counts = torch.randint(1, S + 1, (1, T, H), device='npu')
+    block_counts = torch.randint(1, S + 1, (1, T, H), dtype=torch.long, device='npu')
native_sparse_attention/ops/parallel.py (1)

885-889: Minor: collapse nested CUDA capability checks.

No behavior change; simplifies logic.

-    BKV_VALUE=256
-    if torch.cuda.is_available():
-        if torch.cuda.get_device_capability()[0] < 9:
-            BKV_VALUE=128
+    BKV_VALUE = 128 if (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9) else 256

Also applies to: 1134-1138

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between bd67af5 and dceb419.

📒 Files selected for processing (3)
  • native_sparse_attention/ops/parallel.py (5 hunks)
  • tests/test_nsa.py (4 hunks)
  • tests/test_nsa_npu.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/test_nsa_npu.py (2)
native_sparse_attention/ops/naive.py (1)
  • naive_nsa (29-170)
native_sparse_attention/ops/parallel.py (3)
  • parallel_nsa (1385-1500)
  • backward (1057-1071)
  • backward (1348-1364)
native_sparse_attention/ops/parallel.py (1)
native_sparse_attention/ops/utils.py (1)
  • _bitonic_merge (49-71)
🪛 Ruff (0.12.2)
tests/test_nsa_npu.py

8-8: torch_npu imported but unused

Remove unused import: torch_npu

(F401)

native_sparse_attention/ops/parallel.py

566-569: Use ternary operator NS = tl.load(block_counts + (bos + i_t) * H + i_h) if USE_BLOCK_COUNTS else S instead of if-else-block

Replace if-else-block with NS = tl.load(block_counts + (bos + i_t) * H + i_h) if USE_BLOCK_COUNTS else S

(SIM108)


885-886: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


1134-1135: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

🪛 Flake8 (7.2.0)
tests/test_nsa_npu.py

[error] 8-8: 'torch_npu' imported but unused

(F401)

🔇 Additional comments (1)
native_sparse_attention/ops/parallel.py (1)

129-138: Heuristic/autotune key comments are fine; no action.

Config keys and heuristics look consistent across kernels.

Also applies to: 347-356, 508-518

q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True)
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Disable compression here too to match the naive reference.

Same mismatch as in tests/test_nsa.py—g_cmp triggers a compressed path and recomputed Top‑K, while naive_nsa doesn’t. Set g_cmp=None.

-    g_cmp = torch.rand((1, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
+    g_cmp = None  # disable compression path to match naive_nsa
@@
-        g_cmp=g_cmp,
+        g_cmp=g_cmp,

Also applies to: 196-208

🤖 Prompt for AI Agents
In tests/test_nsa_npu.py around lines 61 (and similarly for lines 196-208), the
test creates g_cmp as a random tensor which enables the compressed path and
recomputed Top-K, causing a mismatch with the naive reference; change g_cmp to
None to disable compression and match the naive_nsa behavior—update the
assignment(s) that currently set g_cmp = torch.rand(...) to g_cmp = None so the
compressed branch is not exercised.

q = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_q].view(1, T, 1, 1).expand(B, T, HQ, D).clone().requires_grad_(True)
k = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_k].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
v = torch.linspace(0, 1, steps=T, dtype=dtype, device='npu')[perm_v].view(1, T, 1, 1).expand(B, T, H, D).clone().requires_grad_(True)
g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Disable compression in this test to match the naive baseline (set g_cmp=None).

parallel_nsa adds a compressed path when g_cmp is not None and also ignores the passed block_indices (it recomputes Top‑K). naive_nsa doesn’t implement the compressed path, so the outputs/gradients won’t match.

Apply this diff to remove the mismatch:

-    g_cmp = torch.rand((B, T, HQ), dtype=dtype, device='npu').requires_grad_(True)
+    g_cmp = None  # disable compression path to match naive_nsa

@@
-    tri = parallel_nsa(
+    tri = parallel_nsa(
         q=q,
         k=k,
         v=v,
-        g_cmp=g_cmp,
+        g_cmp=g_cmp,
         g_slc=g_slc,
         g_swa=g_swa,
         block_indices=block_indices,
         block_counts=block_counts,
         block_size=block_size,
         window_size=window_size,
         scale=scale
     )

Also applies to: 95-107

🤖 Prompt for AI Agents
In tests/test_nsa.py around line 61 (and similarly lines 95-107) the test
enables the compressed path by creating g_cmp with random values, causing
parallel_nsa to use compression and recompute Top-K while naive_nsa lacks that
path; set g_cmp=None in these locations to disable compression and ensure both
implementations follow the same (naive) path so outputs and gradients match.
Replace the g_cmp tensor assignments with g_cmp = None at line 61 and the
corresponding assignments in the block covering lines 95-107.

Comment on lines +66 to +73
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='cuda')
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Use a sentinel like -1 for unfilled block_indices instead of T.

Filling with T will clamp to T-1 in naive_nsa.gather, unintentionally selecting the last block. Use -1 (properly masked later) to avoid bias.

-    block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')
+    block_indices = torch.full((B, T, H, S), -1, dtype=torch.long, device='npu')

🏁 Script executed:

#!/bin/bash
set -e
# Locate the naive_nsa module to inspect its gather implementation
rg -l 'naive_nsa' -g '*.py'

Length of output: 188


🏁 Script executed:

#!/bin/bash
set -e
# Search for gather implementation in naive.py
rg -nP 'def gather' -C5 native_sparse_attention/ops/naive.py

Length of output: 71


🏁 Script executed:

#!/bin/bash
set -e
# Inspect block_indices usage in naive implementation
rg -n -C3 'block_indices' native_sparse_attention/ops/naive.py

# Locate gather implementation and index clamping behavior
rg -n -C3 'def gather' native_sparse_attention/ops/naive.py
rg -n -C3 'gather\(' native_sparse_attention/ops/naive.py

Length of output: 5641


🏁 Script executed:

#!/bin/bash
set -e
# Find any C++/CUDA implementations in ops
find native_sparse_attention/ops -type f \( -name '*.cpp' -o -name '*.cu' \)
# Search for a gather kernel or function in ops directory
rg -n 'gather' -C3 native_sparse_attention/ops

Length of output: 757


🏁 Script executed:

#!/bin/bash
set -e
# Verify index clamping in gather: sentinel T will be clamped to T-1
rg -n "clamp(0, T-1)" native_sparse_attention/ops/naive.py

# Verify masking of out-of-range indices: sentinel -1 vs T
rg -n "masked_fill" native_sparse_attention/ops/naive.py

Length of output: 69


Use –1 as the sentinel for unfilled block_indices
T as a sentinel is clamped to T−1 in the gather step, biasing toward the last block—use –1 (masked downstream) to avoid this.

tests/test_nsa.py (66–73)
-    block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')
+    block_indices = torch.full((B, T, H, S), -1, dtype=torch.long, device='npu')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='npu')
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='cuda')
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu')
block_indices = torch.full((B, T, H, S), -1, dtype=torch.long, device='npu')
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), dtype=torch.long, device='npu')
🤖 Prompt for AI Agents
In tests/test_nsa.py around lines 66 to 73, block_indices is initialized using T
as the sentinel which is clamped to T-1 downstream and biases gathers;
initialize block_indices with -1 as the sentinel (same dtype long, same device)
so unfilled slots are marked -1 and can be masked downstream; keep the rest of
the loop that fills valid entries unchanged and preserve the final sort call so
the -1 sentinels end up at the front/end as expected for downstream masking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant