Skip to content

[V1] port xformers backend to v1 #21342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

TheEpicDolphin
Copy link

@TheEpicDolphin TheEpicDolphin commented Jul 22, 2025

Purpose

Port over the xformers backend to the v1 engine. There are several benefits to using XFormers, including:

  1. Built-in heursitic which determines which attention implementation is best suited for the given inputs.
  2. AMD kernel support
  3. Well suited for certain Meta models.

Test Plan

Added test case to test_attention_backends which verifies correctness of the xformers v1 backend attention output.

(py312conda) bash-5.1$ pytest tests/v1/attention/test_attention_backends.py -k test_backend_correctness
================================================================================ test session starts ================================================================================
platform linux -- Python 3.12.9, pytest-8.4.1, pluggy-1.6.0
rootdir: /data/users/gdelfin/gitrepos/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 6 items                                                                                                                                                                   

tests/v1/attention/test_attention_backends.py ......                                                                                                                          [100%]

================================================================================= warnings summary ==================================================================================
tests/v1/attention/test_attention_backends.py::test_backend_correctness[meta-llama/Meta-Llama-3-8B-small_decode]
tests/v1/attention/test_attention_backends.py::test_backend_correctness[meta-llama/Meta-Llama-3-8B-small_decode]
tests/v1/attention/test_attention_backends.py::test_backend_correctness[meta-llama/Meta-Llama-3-8B-small_decode]
tests/v1/attention/test_attention_backends.py::test_backend_correctness[meta-llama/Meta-Llama-3-8B-small_decode]
  /home/gdelfin/.conda/envs/py312conda/lib/python3.12/site-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
    warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================== 6 passed, 4 warnings in 18.72s ===========================================================================

Benchmark

In addition, I used the following command to run the LLM service and benchmark TreeAttentionBackend vs FlashAttentionBackend:
Server

export VLLM_TORCH_PROFILER_DIR=~/traces/vllm
export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT_MODEL=yuhuili/EAGLE-LLaMA3.1-Instruct-8B
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=<backend>
python -m vllm.entrypoints.openai.api_server --model $LLAMA_MODEL --disable-log-requests --tensor-parallel-size=1 --max-num-seqs=64 --max-model-len=32768 --block-size=128 --no-enable-prefix-caching --speculative-config="$SPEC_DEC_CONFIG" 2>&1 | tee ~/server_logs/vllm_server.log

Client

export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
python benchmarks/benchmark_serving.py --model $LLAMA_MODEL --tokenizer $LLAMA_MODEL --host 0.0.0.0 --dataset-name random --ignore-eos --request-rate inf --random-input-len 1000 --random-output-len 300 --max-concurrency 64 --num-prompts 128

Results

Serving Benchmark Result Flash Attention 3 Triton Attention XFormers Attention
Successful requests 128 128 128
Benchmark duration (s) 15.94 13.23 13.22
Total input tokens 127731 127731 127731
Total generated tokens 38400 38400 38400
Request throughput (req/s) 8.03 9.68 9.68
Output token throughput (tok/s) 2408.88 2903.54 2905.24
Total Token throughput (tok/s) 10421.59 12561.66 12569.01
Time to First Token
Mean TTFT (ms) 894.77 920.44 929.93
Median TTFT (ms) 856.32 769.44 776.85
P99 TTFT (ms) 1817.60 2063.83 2080.56
Time per Output Token (excl. 1st token)
Mean TPOT (ms) 23.49 18.94 18.87
Median TPOT (ms) 22.67 19.34 19.24
P99 TPOT (ms) 30.50 21.57 21.51
Inter-token Latency
Mean ITL (ms) 23.49 18.94 18.87
Median ITL (ms) 15.14 15.48 15.35
P99 ITL (ms) 222.39 252.56 256.01

The v1 XFormers backend performs better than FA3, and is on par with Triton split-k.

@mergify mergify bot added the v1 label Jul 22, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports the xformers attention backend to the v1 engine, including the implementation, tests, and wiring it into the system. The implementation correctly splits logic for prefill and decode phases for optimization. My review identified two high-severity issues: one is a naming inconsistency for the new backend that would prevent it from being selected correctly, and the other is an inadequate test case in the new test file that only covers the decode path, leaving the prefill path untested. I've provided suggestions to fix both issues.

@TheEpicDolphin TheEpicDolphin force-pushed the xformers_attention_v1 branch 2 times, most recently from 82f774a to 137c8c1 Compare July 22, 2025 01:22
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Giancarlo Delfin <[email protected]>
Comment on lines +399 to +417
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=prefill_meta.block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
Copy link
Collaborator

@WoosukKwon WoosukKwon Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why does it fall back to the Triton kernel? IIRC, the Triton kernel here is not very well optimized.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the info, would you recommend using FA3 instead?

self._num_decode_tokens = 0

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we can't use reorder_batch_to_split_decodes_and_prefills in vllm/v1/attention/backends/utils.py here? like in FlashInfer:

return reorder_batch_to_split_decodes_and_prefills(input_batch,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This must have been added after i started working on this PR, thanks, i will use this


@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does xFormers support more head sizes then this? might be a nice option as alternative head size 80 (which falls back to FlexAttention currently)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for catching, turns out xformers supports a lot of head sizes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants