Skip to content

Conversation

NSDie
Copy link
Contributor

@NSDie NSDie commented Sep 4, 2025

What this PR does / why we need it?

fix prefill attention bug,not support sliding window. npu_fused_infer_attention_score head_dim only equal 128, not support other number.

Does this PR introduce any user-facing change?

remove prefill phase npu_fused_infer_attention_score

How was this patch tested?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the prefill phase for sliding window attention by removing the specialized code path that used npu_fused_infer_attention_score. While this correctly resolves the main issue, a related change to the _repeat_kv helper function introduces a subtle bug where it implements torch.repeat instead of the documented torch.repeat_interleave. I've provided a critical comment with a suggested fix for this function, and also noted that it appears to be unused after these changes and could potentially be removed.

Comment on lines 277 to 280
hidden_states = hidden_states[:, None, :, :].expand(
num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
slen, n_rep, num_key_value_heads, head_dim)
return hidden_states.reshape(slen, num_key_value_heads * n_rep,
head_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of _repeat_kv does not match its docstring, which states it should be equivalent to torch.repeat_interleave. The current implementation performs a torch.repeat operation, not torch.repeat_interleave. This can lead to incorrect attention calculations in Grouped-Query Attention (GQA) scenarios where key and value states are expanded.

For a tensor with shape (slen, num_kv_heads, head_dim), repeat_interleave on dim=1 should result in each head being repeated n_rep times consecutively. The current implementation repeats the whole sequence of heads n_rep times.

Additionally, after this pull request's changes, this function appears to be unused and could potentially be removed.

Suggested change
hidden_states = hidden_states[:, None, :, :].expand(
num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
slen, n_rep, num_key_value_heads, head_dim)
return hidden_states.reshape(slen, num_key_value_heads * n_rep,
head_dim)
hidden_states = hidden_states.unsqueeze(2).expand(
slen, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(slen, num_key_value_heads * n_rep,
head_dim)

Signed-off-by: nsdie <[email protected]>
Copy link

github-actions bot commented Sep 4, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: nsdie <[email protected]>
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)

output, _ = torch_npu.npu_fused_infer_attention_score(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it's confirmed that this op doesn't work for prefill? can you paste any link or explain more for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The npu_fused_infer_attention_score operator does not support the prefill phase

Signed-off-by: nsdie <[email protected]>
Copy link

codecov bot commented Sep 5, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.90%. Comparing base (f86596a) to head (685bfd9).
⚠️ Report is 9 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2758      +/-   ##
==========================================
- Coverage   72.99%   72.90%   -0.09%     
==========================================
  Files         153      153              
  Lines       21338    21368      +30     
==========================================
+ Hits        15575    15579       +4     
- Misses       5763     5789      +26     
Flag Coverage Δ
unittests 72.90% <100.00%> (-0.09%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@NSDie NSDie changed the title 【fix】prefill unsupport sliding window attention [fix]prefill unsupport sliding window attention Sep 6, 2025
@NSDie NSDie changed the title [fix]prefill unsupport sliding window attention [fix] prefill unsupport sliding window attention Sep 6, 2025
@zzzzwwjj
Copy link
Collaborator

zzzzwwjj commented Sep 6, 2025

This is a fallback PR for #2528
Please explain that why we need to revert it in pr msg.

@wangxiyuan wangxiyuan merged commit b2f77d3 into vllm-project:main Sep 7, 2025
38 of 42 checks passed
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Sep 10, 2025
### What this PR does / why we need it?
fix prefill attention bug,not support sliding window.
npu_fused_infer_attention_score head_dim only equal 128, not support
other number.
### Does this PR introduce _any_ user-facing change?
remove prefill phase npu_fused_infer_attention_score
### How was this patch tested?

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@e599e2c

---------

Signed-off-by: nsdie <[email protected]>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
### What this PR does / why we need it?
fix prefill attention bug,not support sliding window.
npu_fused_infer_attention_score head_dim only equal 128, not support
other number.
### Does this PR introduce _any_ user-facing change?
remove prefill phase npu_fused_infer_attention_score
### How was this patch tested?

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@e599e2c

---------

Signed-off-by: nsdie <[email protected]>
Signed-off-by: offline0806 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants