Skip to content

[V1] port xformers backend to v1 #21342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
_Backend.TREE_ATTN:
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS_VLLM_V1:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
}

if backend_name not in backend_map:
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION",
"TREE_ATTN",
"XFORMERS_VLLM_V1",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501

if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
Expand All @@ -291,6 +292,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend on V1 engine.")
return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS_VLLM_V1:
logger.info_once("Using XFormers backend on V1 engine.")
return XFORMERS_V1

from vllm.attention.selector import is_attn_backend_supported

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class _Backend(enum.Enum):
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
XFORMERS_VLLM_V1 = enum.auto()


class PlatformEnum(enum.Enum):
Expand Down
7 changes: 1 addition & 6 deletions vllm/v1/attention/backends/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional

import torch

Expand Down Expand Up @@ -313,15 +313,10 @@ def __init__(
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"TreeAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
Loading