Skip to content

Commit f6c3690

Browse files
NSDieoffline0806
authored andcommitted
[fix] prefill unsupport sliding window attention (vllm-project#2758)
### What this PR does / why we need it? fix prefill attention bug,not support sliding window. npu_fused_infer_attention_score head_dim only equal 128, not support other number. ### Does this PR introduce _any_ user-facing change? remove prefill phase npu_fused_infer_attention_score ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@e599e2c --------- Signed-off-by: nsdie <[email protected]> Signed-off-by: offline0806 <[email protected]>
1 parent 65793c6 commit f6c3690

File tree

2 files changed

+9
-72
lines changed

2 files changed

+9
-72
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -341,36 +341,6 @@ def test_forward_prefill_no_cache(self, mock_flash_attention,
341341
mock_flash_attention.assert_called_once()
342342
assert output.shape == (10, 8 * 64)
343343

344-
@patch('torch_npu._npu_reshape_and_cache')
345-
@patch('torch_npu._npu_flash_attention')
346-
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
347-
mock_reshape_cache):
348-
"""Test forward pass in PrefillNoCache state"""
349-
query = torch.randn(10, 8 * 64)
350-
key = torch.randn(10, 8 * 64)
351-
value = torch.randn(10, 8 * 64)
352-
kv_cache = torch.empty(2, 5, 128, 8, 64)
353-
metadata = self.attn_metadata
354-
metadata.attn_state = AscendAttentionState.PrefillNoCache
355-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
356-
metadata.seq_lens = torch.tensor([10])
357-
metadata.num_actual_tokens = 10
358-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
359-
layer = self.layer_no_quant
360-
# layer.quant_method.apply.return_value = metadata
361-
print(self.layer_no_quant._v_scale_float)
362-
output = self.impl_swa.forward(layer,
363-
query,
364-
key,
365-
value,
366-
kv_cache,
367-
metadata,
368-
trace_flag=False)
369-
370-
mock_reshape_cache.assert_called_once()
371-
mock_flash_attention.assert_called_once()
372-
assert output.shape == (10, 8 * 64)
373-
374344
@patch('torch_npu._npu_reshape_and_cache')
375345
@patch('torch_npu._npu_flash_attention_qlens')
376346
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,

vllm_ascend/attention/attention_v1.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -265,20 +265,6 @@ def __init__(
265265
self.key_cache = None
266266
self.value_cache = None
267267

268-
def _repeat_kv(self, hidden_states: torch.Tensor,
269-
n_rep: int) -> torch.Tensor:
270-
"""
271-
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
272-
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
273-
"""
274-
num_key_value_heads, slen, head_dim = hidden_states.shape
275-
if n_rep == 1:
276-
return hidden_states
277-
hidden_states = hidden_states[:, None, :, :].expand(
278-
num_key_value_heads, n_rep, slen, head_dim)
279-
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
280-
head_dim)
281-
282268
def _forward_prefill_no_cache(
283269
self,
284270
query: torch.Tensor,
@@ -304,34 +290,15 @@ def _forward_prefill_no_cache(
304290
mask = torch_npu.npu_format_cast(mask.contiguous(),
305291
ACL_FORMAT_FRACTAL_NZ)
306292

307-
if self.sliding_window is not None and \
308-
attn_metadata.attn_mask.shape[0] > self.sliding_window:
309-
310-
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
311-
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
312-
313-
output, _ = torch_npu.npu_fused_infer_attention_score(
314-
query,
315-
key,
316-
value,
317-
num_heads=self.num_heads,
318-
num_key_value_heads=self.num_kv_heads,
319-
input_layout="TND",
320-
pre_tokens=self.sliding_window,
321-
scale=self.scale,
322-
actual_seq_lengths=attn_metadata.seq_lens,
323-
actual_seq_lengths_kv=attn_metadata.seq_lens)
324-
output = output.view(num_tokens, self.num_heads, self.head_size)
325-
else:
326-
torch_npu._npu_flash_attention(query=query,
327-
key=key,
328-
value=value,
329-
mask=mask,
330-
seq_len=attn_metadata.seq_lens,
331-
scale_value=self.scale,
332-
num_heads=self.num_heads,
333-
num_kv_heads=self.num_kv_heads,
334-
out=output)
293+
torch_npu._npu_flash_attention(query=query,
294+
key=key,
295+
value=value,
296+
mask=mask,
297+
seq_len=attn_metadata.seq_lens,
298+
scale_value=self.scale,
299+
num_heads=self.num_heads,
300+
num_kv_heads=self.num_kv_heads,
301+
out=output)
335302
assert output is not None
336303
return output[:num_tokens, :, :]
337304

0 commit comments

Comments
 (0)