Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions tests/kernels/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import torch
from packaging import version

from vllm import LLM, SamplingParams
from vllm import SamplingParams

from ..models.utils import check_embeddings_close

TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
Expand All @@ -28,15 +30,15 @@ def set_seed(seed):
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_flex_attention_vs_default_backend(monkeypatch):
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.

This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are identical when using the same seed.
"""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42
max_tokens = 32
max_tokens = 24
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reduced the max tokens a bit because as the sequence length growth the chance of divergence increases. On the A100 where I'm testing this, I get the following output on the main branch:

flex_text=   ' Paris. The capital of France is also the capital of which country?\nA) Germany\nB) Italy\nC) Spain\nD) United Kingdom\nE'
default_text=' Paris. The capital of France is also the capital of which country?\nA) Germany\nB) Italy\nC) Spain\nD) Belgium\nE)'

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand All @@ -54,39 +56,85 @@ def test_flex_attention_vs_default_backend(monkeypatch):
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")

set_seed(seed)

llm_flex = LLM(
model_name,
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
)
output_flex = llm_flex.generate(prompts, sampling_params)
with vllm_runner(model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True) as llm_flex:
output_flex = llm_flex.generate(prompts, sampling_params)

# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
set_seed(seed)
llm_default = LLM(
model_name,
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
)
output_default = llm_default.generate(prompts, sampling_params)
with vllm_runner(model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True) as llm_default:
output_default = llm_default.generate(prompts, sampling_params)

# Compare outputs from both backends
for i, (flex_result,
default_result) in enumerate(zip(output_flex, output_default)):
prompt = prompts[i]
flex_text = flex_result.outputs[0].text
default_text = default_result.outputs[0].text
flex_text = flex_result[1][0]
default_text = default_result[1][0]

assert flex_text == default_text, (
f"FlexAttention output doesn't match default for: {prompt!r}\n"
f"FlexAttention: {flex_text!r}\n"
f"Default: {default_text!r}")


@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.

This test compares the outputs from the FlexAttention backend with
the default backend for encoder models.
"""
model_name = "BAAI/bge-base-en-v1.5"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]

# Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_flex:
flex_outputs = llm_flex.embed(prompts)

# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model_name,
runner="pooling",
dtype=torch.bfloat16,
tensor_parallel_size=1,
max_model_len=100,
enforce_eager=True) as llm_default:
default_outputs = llm_default.embed(prompts)

check_embeddings_close(
embeddings_0_lst=flex_outputs,
embeddings_1_lst=default_outputs,
name_0="flex",
name_1="default",
tol=1e-2,
)


if __name__ == "__main__":
pytest.main([__file__])
97 changes: 70 additions & 27 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,

@dataclass
class FlexAttentionMetadata:
causal: bool
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
Expand Down Expand Up @@ -177,10 +178,9 @@ class FlexAttentionMetadata:
num_blocks = 0
block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None
mask_mod: Optional[_mask_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod

def get_mask_mod(self) -> _mask_mod_signature:
def get_causal_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.

This function creates the combined mask mod function that handles:
Expand Down Expand Up @@ -233,14 +233,39 @@ def final_mask_mod(

return final_mask_mod

def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
"""Creates the encoder mask_mod function for FlexAttention.

Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)

def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
return request_lookup[q_idx] == request_lookup[kv_idx]

return final_mask_mod

def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
if self.causal:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
else:
mask_mod = self.get_bidirectional_mask_mod()
kv_len = self.num_actual_tokens
return create_block_mask_compiled(
self.mask_mod,
mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
kv_len,
device=self.block_table.device,
)

Expand All @@ -251,7 +276,6 @@ def __post_init__(self):
assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet."
self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
self.block_mask = self.build_block_mask()


Expand Down Expand Up @@ -306,6 +330,7 @@ def build(self,
self.device, non_blocking=True)

out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
Expand Down Expand Up @@ -350,6 +375,12 @@ def __init__(
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.attn_type = attn_type

if attn_type not in (AttentionType.ENCODER_ONLY,
AttentionType.DECODER):
raise NotImplementedError(
f"FlexAttention does not support {attn_type} attention")

if alibi_slopes is not None:
raise NotImplementedError(
Expand Down Expand Up @@ -425,26 +456,38 @@ def forward(

num_actual_tokens = attn_metadata.num_actual_tokens

key_cache, value_cache = kv_cache.unbind(0)

torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY

query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key, value),
)

else:
assert self.attn_type == AttentionType.DECODER
key_cache, value_cache = kv_cache.unbind(0)

torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads,
self.head_size)
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)

# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
query, key_cache, value_cache = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
Expand All @@ -465,8 +508,8 @@ def forward(

out = flex_attention_compiled(
query,
key_cache,
value_cache,
key_tensor,
value_tensor,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've renamed the variables here because in one case they are key and value and in the other case key_cache and key_value.

attn_metadata.score_mod,
attn_metadata.block_mask,
self.scale,
Expand Down