Skip to content

Commit f825c6b

Browse files
authored
Support encoder_only attention for FlexAttention (#22273)
Signed-off-by: Max de Bayser <[email protected]>
1 parent 41b67f4 commit f825c6b

File tree

2 files changed

+138
-47
lines changed

2 files changed

+138
-47
lines changed

tests/kernels/test_flex_attention.py

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import torch
1010
from packaging import version
1111

12-
from vllm import LLM, SamplingParams
12+
from vllm import SamplingParams
13+
14+
from ..models.utils import check_embeddings_close
1315

1416
TORCH_VERSION = version.parse(torch.__version__)
1517
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
@@ -28,15 +30,15 @@ def set_seed(seed):
2830
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
2931
reason="CUDA not available or PyTorch version < 2.7",
3032
)
31-
def test_flex_attention_vs_default_backend(monkeypatch):
33+
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
3234
"""Test that FlexAttention produces the same outputs as the default backend.
3335
3436
This test compares the outputs from the FlexAttention backend with
3537
the default backend, ensuring they are identical when using the same seed.
3638
"""
3739
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
3840
seed = 42
39-
max_tokens = 32
41+
max_tokens = 24
4042
prompts = [
4143
"Hello, my name is",
4244
"The president of the United States is",
@@ -54,39 +56,85 @@ def test_flex_attention_vs_default_backend(monkeypatch):
5456
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
5557

5658
set_seed(seed)
57-
58-
llm_flex = LLM(
59-
model_name,
60-
tensor_parallel_size=1,
61-
num_gpu_blocks_override=128,
62-
enforce_eager=True,
63-
)
64-
output_flex = llm_flex.generate(prompts, sampling_params)
59+
with vllm_runner(model_name,
60+
runner="generate",
61+
tensor_parallel_size=1,
62+
num_gpu_blocks_override=128,
63+
enforce_eager=True) as llm_flex:
64+
output_flex = llm_flex.generate(prompts, sampling_params)
6565

6666
# Run with default backend
6767
with monkeypatch.context() as m:
6868
m.setenv("VLLM_USE_V1", "1")
6969
set_seed(seed)
70-
llm_default = LLM(
71-
model_name,
72-
tensor_parallel_size=1,
73-
num_gpu_blocks_override=128,
74-
enforce_eager=True,
75-
)
76-
output_default = llm_default.generate(prompts, sampling_params)
70+
with vllm_runner(model_name,
71+
runner="generate",
72+
tensor_parallel_size=1,
73+
num_gpu_blocks_override=128,
74+
enforce_eager=True) as llm_default:
75+
output_default = llm_default.generate(prompts, sampling_params)
7776

7877
# Compare outputs from both backends
7978
for i, (flex_result,
8079
default_result) in enumerate(zip(output_flex, output_default)):
8180
prompt = prompts[i]
82-
flex_text = flex_result.outputs[0].text
83-
default_text = default_result.outputs[0].text
81+
flex_text = flex_result[1][0]
82+
default_text = default_result[1][0]
8483

8584
assert flex_text == default_text, (
8685
f"FlexAttention output doesn't match default for: {prompt!r}\n"
8786
f"FlexAttention: {flex_text!r}\n"
8887
f"Default: {default_text!r}")
8988

9089

90+
@pytest.mark.skipif(
91+
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
92+
reason="CUDA not available or PyTorch version < 2.7",
93+
)
94+
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
95+
"""Test that FlexAttention produces the same outputs as the default backend.
96+
97+
This test compares the outputs from the FlexAttention backend with
98+
the default backend for encoder models.
99+
"""
100+
model_name = "BAAI/bge-base-en-v1.5"
101+
prompts = [
102+
"Hello, my name is",
103+
"The president of the United States is",
104+
"The capital of France is",
105+
]
106+
107+
# Run with flex attention
108+
with monkeypatch.context() as m:
109+
m.setenv("VLLM_USE_V1", "1")
110+
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
111+
with vllm_runner(model_name,
112+
runner="pooling",
113+
dtype=torch.bfloat16,
114+
tensor_parallel_size=1,
115+
max_model_len=100,
116+
enforce_eager=True) as llm_flex:
117+
flex_outputs = llm_flex.embed(prompts)
118+
119+
# Run with default backend
120+
with monkeypatch.context() as m:
121+
m.setenv("VLLM_USE_V1", "1")
122+
with vllm_runner(model_name,
123+
runner="pooling",
124+
dtype=torch.bfloat16,
125+
tensor_parallel_size=1,
126+
max_model_len=100,
127+
enforce_eager=True) as llm_default:
128+
default_outputs = llm_default.embed(prompts)
129+
130+
check_embeddings_close(
131+
embeddings_0_lst=flex_outputs,
132+
embeddings_1_lst=default_outputs,
133+
name_0="flex",
134+
name_1="default",
135+
tol=1e-2,
136+
)
137+
138+
91139
if __name__ == "__main__":
92140
pytest.main([__file__])

vllm/v1/attention/backends/flex_attention.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
148148

149149
@dataclass
150150
class FlexAttentionMetadata:
151+
causal: bool
151152
num_actual_tokens: int # Number of tokens excluding padding.
152153
max_query_len: int
153154
query_start_loc: torch.Tensor
@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
177178
num_blocks = 0
178179
block_mask: Optional[BlockMask] = None
179180
score_mod: Optional[_score_mod_signature] = None
180-
mask_mod: Optional[_mask_mod_signature] = None
181181
logical_mask_mod: _mask_mod_signature = causal_mask_mod
182182

183-
def get_mask_mod(self) -> _mask_mod_signature:
183+
def get_causal_mask_mod(self) -> _mask_mod_signature:
184184
"""Creates the mask_mod function for FlexAttention.
185185
186186
This function creates the combined mask mod function that handles:
@@ -233,14 +233,39 @@ def final_mask_mod(
233233

234234
return final_mask_mod
235235

236+
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
237+
"""Creates the encoder mask_mod function for FlexAttention.
238+
239+
Since the encoder bidirectional attention doesn't run with
240+
KV cache, this function creates a mask based on the
241+
packed query sequences.
242+
"""
243+
# Create a lookup mapping from query indices -> request number
244+
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
245+
246+
def final_mask_mod(
247+
b: torch.Tensor,
248+
h: torch.Tensor,
249+
q_idx: torch.Tensor,
250+
kv_idx: torch.Tensor,
251+
) -> torch.Tensor:
252+
return request_lookup[q_idx] == request_lookup[kv_idx]
253+
254+
return final_mask_mod
255+
236256
def build_block_mask(self) -> BlockMask:
237-
assert self.mask_mod is not None
257+
if self.causal:
258+
mask_mod = self.get_causal_mask_mod()
259+
kv_len = self.total_cache_tokens
260+
else:
261+
mask_mod = self.get_bidirectional_mask_mod()
262+
kv_len = self.num_actual_tokens
238263
return create_block_mask_compiled(
239-
self.mask_mod,
264+
mask_mod,
240265
None,
241266
None,
242267
self.num_actual_tokens,
243-
self.total_cache_tokens,
268+
kv_len,
244269
device=self.block_table.device,
245270
)
246271

@@ -251,7 +276,6 @@ def __post_init__(self):
251276
assert self.prefix_kv_lens is None, "Not implemented yet."
252277
assert self.suffix_kv_lens is None, "Not implemented yet."
253278
self.num_blocks = self.total_cache_tokens // self.block_size
254-
self.mask_mod = self.get_mask_mod()
255279
self.block_mask = self.build_block_mask()
256280

257281

@@ -306,6 +330,7 @@ def build(self,
306330
self.device, non_blocking=True)
307331

308332
out = FlexAttentionMetadata(
333+
causal=common_attn_metadata.causal,
309334
num_actual_tokens=num_actual_tokens,
310335
max_query_len=max_query_len,
311336
query_start_loc=query_start_loc,
@@ -350,6 +375,12 @@ def __init__(
350375
self.head_size = head_size
351376
self.scale = float(scale)
352377
self.num_kv_heads = num_kv_heads
378+
self.attn_type = attn_type
379+
380+
if attn_type not in (AttentionType.ENCODER_ONLY,
381+
AttentionType.DECODER):
382+
raise NotImplementedError(
383+
f"FlexAttention does not support {attn_type} attention")
353384

354385
if alibi_slopes is not None:
355386
raise NotImplementedError(
@@ -425,26 +456,38 @@ def forward(
425456

426457
num_actual_tokens = attn_metadata.num_actual_tokens
427458

428-
key_cache, value_cache = kv_cache.unbind(0)
429-
430-
torch.ops._C_cache_ops.reshape_and_cache_flash(
431-
key,
432-
value,
433-
key_cache,
434-
value_cache,
435-
attn_metadata.slot_mapping,
436-
self.kv_cache_dtype,
437-
layer._k_scale,
438-
layer._v_scale,
439-
)
459+
if not attn_metadata.causal:
460+
assert self.attn_type == AttentionType.ENCODER_ONLY
461+
462+
query, key_tensor, value_tensor = map(
463+
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
464+
(query, key, value),
465+
)
466+
467+
else:
468+
assert self.attn_type == AttentionType.DECODER
469+
key_cache, value_cache = kv_cache.unbind(0)
470+
471+
torch.ops._C_cache_ops.reshape_and_cache_flash(
472+
key,
473+
value,
474+
key_cache,
475+
value_cache,
476+
attn_metadata.slot_mapping,
477+
self.kv_cache_dtype,
478+
layer._k_scale,
479+
layer._v_scale,
480+
)
481+
482+
# View out the block_size dim
483+
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
484+
value_cache = value_cache.view(-1, self.num_kv_heads,
485+
self.head_size)
486+
query, key_tensor, value_tensor = map(
487+
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
488+
(query, key_cache, value_cache),
489+
)
440490

441-
# View out the block_size dim
442-
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
443-
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
444-
query, key_cache, value_cache = map(
445-
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
446-
(query, key_cache, value_cache),
447-
)
448491
query = query[:, :, :num_actual_tokens, :]
449492
# Doesn't work for now -> constraint violation
450493
# torch._dynamo.try_mark_dynamic(query, 2)
@@ -465,8 +508,8 @@ def forward(
465508

466509
out = flex_attention_compiled(
467510
query,
468-
key_cache,
469-
value_cache,
511+
key_tensor,
512+
value_tensor,
470513
attn_metadata.score_mod,
471514
attn_metadata.block_mask,
472515
self.scale,

0 commit comments

Comments
 (0)