diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 9bd0b99798d7..f197cbb7bbba 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -198,7 +198,8 @@ def __init__(self, device: torch.device): def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, - vllm_config, device: torch.device, + layer_names: list[str], vllm_config, + device: torch.device, common_attn_metadata: CommonAttentionMetadata, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, if backend == _Backend.FLASHINFER_VLLM_V1: import unittest.mock - from vllm.v1.attention.backends.flashinfer import PerLayerParameters + from vllm.v1.attention.backends.utils import PerLayerParameters - def mock_get_per_layer_parameters(vllm_config, impl_cls): + def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): # Return mock parameters for a single layer head_size = vllm_config.model_config.get_head_size() return { - "mock_layer": + layer_name: PerLayerParameters( window_left=-1, # No sliding window logits_soft_cap=0.0, # No soft cap sm_scale=1.0 / (head_size**0.5) # Standard scale ) + for layer_name in layer_names } with unittest.mock.patch( 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', mock_get_per_layer_parameters): - builder = builder_cls(kv_cache_spec, vllm_config, device) + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, + device) attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) else: # Build metadata - builder = builder_cls(kv_cache_spec, vllm_config, device) + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): set_kv_cache_layout("HND") backend_output = run_attention_backend(backend_name, kv_cache_spec, - vllm_config, device, - common_attn_metadata, + ["placeholder"], vllm_config, + device, common_attn_metadata, query_vllm, key_vllm, value_vllm, kv_cache_for_backend) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 5c74a286c4a9..e77d276a4f45 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -310,6 +310,7 @@ def create_deterministic_logits(token_ids): _Backend.FLASH_ATTN_VLLM_V1) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, vllm_config=proposer.vllm_config, device=device, ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 7fec4782517c..15554233aaf7 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -749,7 +749,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): layer_4 = "model.layers.4.mixer" layer_5 = "model.layers.5.mixer" - with set_current_vllm_config(vllm_config): + with set_current_vllm_config(vllm_config), monkeypatch.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") hf_config = vllm_config.model_config.hf_config fwd_context = {} for key in [layer_0, layer_1]: diff --git a/vllm/config.py b/vllm/config.py index ea9f7dce894b..2ba5f7415be0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -624,8 +624,8 @@ def __post_init__(self) -> None: isinstance(sliding_window, list)) if not self.disable_sliding_window and has_interleaved_attention: - if (backend := - envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): + if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND + ) in ("XFORMERS", "FLASHINFER"): sliding_window_len_min = get_min_sliding_window( self.hf_text_config.sliding_window) @@ -4922,13 +4922,29 @@ def assert_hashable(text): T = TypeVar("T") -def get_layers_from_vllm_config(vllm_config: VllmConfig, - layer_type: type[T]) -> dict[str, T]: +def get_layers_from_vllm_config( + vllm_config: VllmConfig, + layer_type: type[T], + layer_names: Optional[list[str]] = None) -> dict[str, T]: + """ + Get layers from the vLLM config. + + Args: + vllm_config: The vLLM config. + layer_type: The type of the layer to get. + layer_names: The names of the layers to get. If None, return all layers. + """ + + if layer_names is None: + layer_names = list( + vllm_config.compilation_config.static_forward_context.keys()) + + forward_context = vllm_config.compilation_config.static_forward_context + return { - layer_name: layer - for layer_name, layer in - vllm_config.compilation_config.static_forward_context.items() - if isinstance(layer, layer_type) + layer_name: forward_context[layer_name] + for layer_name in layer_names + if isinstance(forward_context[layer_name], layer_type) } diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 3b6d753863d0..9ed46331863c 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -315,8 +315,8 @@ def get_seq_len_block_table_args( class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device) -> None: + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device) -> None: self.kv_cache_spec = kv_cache_spec self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 7c8a5e056fea..4c2a6c6b985b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -148,8 +148,8 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b72745ef156e..37ae87753a84 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -21,10 +21,9 @@ from vllm.utils import cdiv from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, - get_kv_cache_layout, get_per_layer_parameters, - infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, - split_decodes_and_prefills) + AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, + get_per_layer_parameters, infer_global_hyperparameters, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -220,8 +219,8 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): self.device = device self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append @@ -229,7 +228,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config @@ -284,10 +284,6 @@ def _get_cascade_wrapper(self): def _plan(self, num_prefills: int, num_decodes: int, attn_metadata: FlashInferMetadata): - if self.global_hyperparameters is None: - self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config, FlashInferImpl)) - if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index ad63f92cd88a..bb0d890c7754 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -258,8 +258,8 @@ def __post_init__(self): class FlexAttentionMetadataBuilder( AttentionMetadataBuilder[FlexAttentionMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index dca5de46c065..8b702e28d67c 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -87,8 +87,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index cf17d9330239..0095d7521785 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -406,6 +406,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): def __init__(self, kv_cache_spec: AttentionSpec, + layer_names: list[str], vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[type[M]] = None): @@ -471,7 +472,8 @@ def __init__(self, BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, MLACommonImpl)) + get_per_layer_parameters(vllm_config, layer_names, + MLACommonImpl)) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d3e5300dbbd6..39463b9c0616 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -56,9 +56,10 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # Decode-only - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): - super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata) + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + FlashMLAMetadata) self.compilation_config = vllm_config.compilation_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 834c23455835..5c5891f035ae 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -66,9 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # decode only - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): - super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata) + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + AiterMLAMetadata) assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 85a5dc8c91c1..dd10b7f02730 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -231,8 +231,8 @@ class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 83471ca51b73..195fbd3b1b9c 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -59,8 +59,8 @@ class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): self.device = device self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b13362f8a8d8..d1599ba10b61 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -70,8 +70,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): self.kv_cache_spec = kv_cache_spec @abstractmethod @@ -164,14 +164,14 @@ class PerLayerParameters: def get_per_layer_parameters( - vllm_config: VllmConfig, + vllm_config: VllmConfig, layer_names: list[str], cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: """ - Scan all attention layers and determine some hyperparameters + Scan layers in `layer_names` and determine some hyperparameters to use during `plan`. """ - layers = get_layers_from_vllm_config(vllm_config, Attention) + layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): @@ -208,6 +208,10 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] for params in param_sets: + if params.window_left != global_params.window_left: + raise ValueError( + "Window left is not the same for all layers. One potential fix " + "is to set disable_sliding_window=True") assert params == global_params, ( "FlashInfer backend currently only supports models in which all " "layers share the same values for the following hyperparameters: " diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1eef5e573ed..ddb6b8c3568e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2485,7 +2485,7 @@ def freeze_gc(): elapsed_time, cuda_graph_size / (1 << 30)) def _initialize_single_attn_backend( - self, kv_cache_spec: KVCacheSpec + self, kv_cache_spec: KVCacheSpec, layer_names: list[str] ) -> tuple[AttentionBackend, AttentionMetadataBuilder]: if isinstance(kv_cache_spec, AttentionSpec): attn_backend_i = get_attn_backend( @@ -2515,6 +2515,7 @@ def _initialize_single_attn_backend( attn_metadata_builder_i = attn_backend_i.get_builder_cls()( kv_cache_spec, + layer_names, self.vllm_config, self.device, ) @@ -2538,8 +2539,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec - attn_backend_i, attn_metadata_builder_i = \ - self._initialize_single_attn_backend(kv_cache_spec) + attn_backend_i, attn_metadata_builder_i = ( + self._initialize_single_attn_backend( + kv_cache_spec, kv_cache_group_spec.layer_names)) self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) @@ -2570,8 +2572,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: assert len(attn_specs) == len(attn_layers), \ "All or none of the layers are expected to be encoder-only" - attn_backend, attn_metadata_builder = \ - self._initialize_single_attn_backend(attn_specs[0]) + attn_backend, attn_metadata_builder = ( + self._initialize_single_attn_backend(attn_specs[0], + attn_layers.keys())) self.attn_backends.append(attn_backend) self.attn_metadata_builders.append(attn_metadata_builder) self.is_encoder_only_model = True