Skip to content

Commit 5fbbfe9

Browse files
LucasWilkinsonsimon-mo
authored andcommitted
[BugFix] FA2 MLA Accuracy Issue (#18807)
Signed-off-by: LucasWilkinson <[email protected]>
1 parent 5873877 commit 5fbbfe9

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

csrc/attention/merge_attn_states.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
143143
const uint pack_size = 16 / sizeof(scalar_t);
144144
TORCH_CHECK(head_size % pack_size == 0,
145145
"headsize must be multiple of pack_size:", pack_size);
146+
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
147+
"output heads must be contiguous in memory");
148+
TORCH_CHECK(
149+
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
150+
"prefix_output heads must be contiguous in memory");
151+
TORCH_CHECK(
152+
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
153+
"suffix_output heads must be contiguous in memory");
146154
float* output_lse_ptr = nullptr;
147155
if (output_lse.has_value()) {
148156
output_lse_ptr = output_lse.value().data_ptr<float>();

vllm/attention/backends/mla/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,10 +1093,6 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
10931093
if isinstance(attn_out, tuple):
10941094
attn_out, *rest = attn_out
10951095

1096-
# unpad if necessary
1097-
if self._pad_v:
1098-
attn_out = attn_out[..., :v.shape[-1]]
1099-
11001096
# Remain consistent with old `flash_attn_varlen_func` where there
11011097
# is only one output tensor if `return_softmax_lse` is False.
11021098
if return_softmax_lse:
@@ -1294,6 +1290,10 @@ def _forward_prefill(
12941290
suffix_lse=suffix_lse,
12951291
)
12961292

1293+
# unpad if necessary
1294+
if self._pad_v:
1295+
output = output[..., :v.shape[-1]]
1296+
12971297
return output.flatten(start_dim=-2)
12981298

12991299
@abstractmethod

vllm/v1/attention/backends/mla/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,6 @@ def _flash_attn_varlen_diff_headdims(self,
653653
if isinstance(attn_out, tuple):
654654
attn_out, lse = attn_out[0], attn_out[1]
655655

656-
# unpad if necessary
657-
if self._pad_v:
658-
attn_out = attn_out[..., :v.shape[-1]]
659-
660656
# Remain consistent with old `flash_attn_varlen_func` where there
661657
# is only one output tensor if `return_softmax_lse` is False.
662658
if return_softmax_lse:
@@ -839,6 +835,10 @@ def _forward_prefill(
839835
suffix_lse=suffix_lse,
840836
)
841837

838+
# unpad if necessary
839+
if self._pad_v:
840+
output = output[..., :v.shape[-1]]
841+
842842
return output.flatten(start_dim=-2)
843843

844844
@abstractmethod

0 commit comments

Comments
 (0)