Skip to content
19 changes: 11 additions & 8 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def __init__(self, device: torch.device):


def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
vllm_config, device: torch.device,
layer_names: list[str], vllm_config,
device: torch.device,
common_attn_metadata: CommonAttentionMetadata,
query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
if backend == _Backend.FLASHINFER_VLLM_V1:
import unittest.mock

from vllm.v1.attention.backends.flashinfer import PerLayerParameters
from vllm.v1.attention.backends.utils import PerLayerParameters

def mock_get_per_layer_parameters(vllm_config, impl_cls):
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
# Return mock parameters for a single layer
head_size = vllm_config.model_config.get_head_size()
return {
"mock_layer":
layer_name:
PerLayerParameters(
window_left=-1, # No sliding window
logits_soft_cap=0.0, # No soft cap
sm_scale=1.0 / (head_size**0.5) # Standard scale
)
for layer_name in layer_names
}

with unittest.mock.patch(
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
mock_get_per_layer_parameters):
builder = builder_cls(kv_cache_spec, vllm_config, device)
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
# Build metadata
builder = builder_cls(kv_cache_spec, vllm_config, device)
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
Expand Down Expand Up @@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
set_kv_cache_layout("HND")

backend_output = run_attention_backend(backend_name, kv_cache_spec,
vllm_config, device,
common_attn_metadata,
["placeholder"], vllm_config,
device, common_attn_metadata,
query_vllm, key_vllm,
value_vllm,
kv_cache_for_backend)
Expand Down
1 change: 1 addition & 0 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def create_deterministic_logits(token_ids):
_Backend.FLASH_ATTN_VLLM_V1)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"

with set_current_vllm_config(vllm_config):
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
Expand Down
32 changes: 24 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ def __post_init__(self) -> None:
isinstance(sliding_window, list))

if not self.disable_sliding_window and has_interleaved_attention:
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

Expand Down Expand Up @@ -4922,13 +4922,29 @@ def assert_hashable(text):
T = TypeVar("T")


def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
def get_layers_from_vllm_config(
vllm_config: VllmConfig,
layer_type: type[T],
layer_names: Optional[list[str]] = None) -> dict[str, T]:
"""
Get layers from the vLLM config.

Args:
vllm_config: The vLLM config.
layer_type: The type of the layer to get.
layer_names: The names of the layers to get. If None, return all layers.
"""

if layer_names is None:
layer_names = list(
vllm_config.compilation_config.static_forward_context.keys())

forward_context = vllm_config.compilation_config.static_forward_context

return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
layer_name: forward_context[layer_name]
for layer_name in layer_names
if isinstance(forward_context[layer_name], layer_type)
}


Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def get_seq_len_block_table_args(

class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device) -> None:
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
Expand Down
18 changes: 7 additions & 11 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
get_per_layer_parameters, infer_global_hyperparameters,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec

if TYPE_CHECKING:
Expand Down Expand Up @@ -220,16 +219,17 @@ def __post_init__(self):

class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might play nicer with: #21588

if we do:

    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                  vllm_config: VllmConfig, device: torch.device):

instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I've fixed it.

self.device = device
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._cascade_wrapper = None # Wrapper for cascade attention

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))

self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
Expand Down Expand Up @@ -284,10 +284,6 @@ def _get_cascade_wrapper(self):

def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config, FlashInferImpl))

if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def __post_init__(self):
class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):

def __init__(self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[type[M]] = None):
Expand Down Expand Up @@ -471,7 +472,8 @@ def __init__(self,
BatchPrefillWithRaggedKVCacheWrapper] = []

self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, MLACommonImpl))
get_per_layer_parameters(vllm_config, layer_names,
MLACommonImpl))

if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty(
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)

self.compilation_config = vllm_config.compilation_config
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
Expand Down
7 changes: 4 additions & 3 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."

Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
Expand Down
14 changes: 9 additions & 5 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
full_cudagraph_supported: ClassVar[bool] = False

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec

@abstractmethod
Expand Down Expand Up @@ -164,14 +164,14 @@ class PerLayerParameters:


def get_per_layer_parameters(
vllm_config: VllmConfig,
vllm_config: VllmConfig, layer_names: list[str],
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
Scan layers in `layer_names` and determine some hyperparameters
to use during `plan`.
"""

layers = get_layers_from_vllm_config(vllm_config, Attention)
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
Expand Down Expand Up @@ -208,6 +208,10 @@ def infer_global_hyperparameters(
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. One potential fix "
"is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
Expand Down
13 changes: 8 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2485,7 +2485,7 @@ def freeze_gc():
elapsed_time, cuda_graph_size / (1 << 30))

def _initialize_single_attn_backend(
self, kv_cache_spec: KVCacheSpec
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend(
Expand Down Expand Up @@ -2515,6 +2515,7 @@ def _initialize_single_attn_backend(

attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
Expand All @@ -2538,8 +2539,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec

attn_backend_i, attn_metadata_builder_i = \
self._initialize_single_attn_backend(kv_cache_spec)
attn_backend_i, attn_metadata_builder_i = (
self._initialize_single_attn_backend(
kv_cache_spec, kv_cache_group_spec.layer_names))
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)

Expand Down Expand Up @@ -2570,8 +2572,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
assert len(attn_specs) == len(attn_layers), \
"All or none of the layers are expected to be encoder-only"

attn_backend, attn_metadata_builder = \
self._initialize_single_attn_backend(attn_specs[0])
attn_backend, attn_metadata_builder = (
self._initialize_single_attn_backend(attn_specs[0],
attn_layers.keys()))
self.attn_backends.append(attn_backend)
self.attn_metadata_builders.append(attn_metadata_builder)
self.is_encoder_only_model = True
Expand Down