From 6110d3c8d6216a735befe1db106b964a733726d3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 22 Jul 2025 14:34:02 -0700 Subject: [PATCH 01/10] hybrid allocator + flashinfer Signed-off-by: Chen Zhang --- tests/v1/attention/test_attention_backends.py | 21 ++++++------ vllm/config.py | 32 ++++++++++++++----- vllm/v1/attention/backends/cpu_attn.py | 8 ++--- vllm/v1/attention/backends/flash_attn.py | 8 ++--- vllm/v1/attention/backends/flashinfer.py | 24 +++++++------- vllm/v1/attention/backends/flex_attention.py | 10 +++--- vllm/v1/attention/backends/mamba_attn.py | 7 ++-- vllm/v1/attention/backends/mla/common.py | 11 ++++--- vllm/v1/attention/backends/mla/flashmla.py | 9 +++--- .../attention/backends/mla/rocm_aiter_mla.py | 9 +++--- vllm/v1/attention/backends/rocm_aiter_fa.py | 10 +++--- vllm/v1/attention/backends/triton_attn.py | 10 +++--- vllm/v1/attention/backends/utils.py | 14 ++++---- vllm/v1/worker/gpu_model_runner.py | 2 +- 14 files changed, 100 insertions(+), 75 deletions(-) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index b4e0101a0d4b..9733e136c28d 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -12,7 +12,7 @@ get_attention_backend) from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import FullAttentionSpec +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec BACKENDS_TO_TEST = [ _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, @@ -212,16 +212,17 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, from vllm.v1.attention.backends.flashinfer import PerLayerParameters - def mock_get_per_layer_parameters(vllm_config): + def mock_get_per_layer_parameters(vllm_config, layer_names, cls_): # Return mock parameters for a single layer head_size = vllm_config.model_config.get_head_size() return { - "mock_layer": + layer_name: PerLayerParameters( window_left=-1, # No sliding window logits_soft_cap=0.0, # No soft cap sm_scale=1.0 / (head_size**0.5) # Standard scale ) + for layer_name in layer_names } with unittest.mock.patch( @@ -301,6 +302,10 @@ def test_backend_correctness(batch_spec_name: str, model: str): device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + kv_cache_group_spec = KVCacheGroupSpec( + kv_cache_spec=kv_cache_spec, + layer_names=["placeholder"], + ) # 1. Setup batch_size = batch_spec.batch_size @@ -419,12 +424,10 @@ def test_backend_correctness(batch_spec_name: str, model: str): if backend_name == _Backend.FLASHINFER_VLLM_V1: kv_cache_for_backend = kv_cache.transpose(0, 1) - backend_output = run_attention_backend(backend_name, kv_cache_spec, - vllm_config, device, - common_attn_metadata, - query_vllm, key_vllm, - value_vllm, - kv_cache_for_backend) + backend_output = run_attention_backend( + backend_name, kv_cache_group_spec, vllm_config, device, + common_attn_metadata, query_vllm, key_vllm, value_vllm, + kv_cache_for_backend) # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( diff --git a/vllm/config.py b/vllm/config.py index d649eb75033f..00031ea48690 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -614,8 +614,8 @@ def __post_init__(self) -> None: isinstance(sliding_window, list)) if not self.disable_sliding_window and has_interleaved_attention: - if (backend := - envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): + if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND + ) in ("XFORMERS", "FLASHINFER"): sliding_window_len_min = get_min_sliding_window( self.hf_text_config.sliding_window) @@ -4982,13 +4982,29 @@ def assert_hashable(text): T = TypeVar("T") -def get_layers_from_vllm_config(vllm_config: VllmConfig, - layer_type: type[T]) -> dict[str, T]: +def get_layers_from_vllm_config( + vllm_config: VllmConfig, + layer_type: type[T], + layer_names: Optional[list[str]] = None) -> dict[str, T]: + """ + Get layers from the vLLM config. + + Args: + vllm_config: The vLLM config. + layer_type: The type of the layer to get. + layer_names: The names of the layers to get. If None, return all layers. + """ + + if layer_names is None: + layer_names = list( + vllm_config.compilation_config.static_forward_context.keys()) + + forward_context = vllm_config.compilation_config.static_forward_context + return { - layer_name: layer - for layer_name, layer in - vllm_config.compilation_config.static_forward_context.items() - if isinstance(layer, layer_type) + layer_name: forward_context[layer_name] + for layer_name in layer_names + if isinstance(forward_context[layer_name], layer_type) } diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 3b6d753863d0..0dc1df7d0a74 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -17,7 +17,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.worker.gpu_input_batch import InputBatch try: @@ -315,9 +315,9 @@ def get_seq_len_block_table_args( class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device) -> None: - self.kv_cache_spec = kv_cache_spec + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device) -> None: + self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5fe274f2c65b..ecff648b7f10 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -28,7 +28,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec logger = init_logger(__name__) @@ -146,8 +146,8 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -160,7 +160,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.headdim = self.model_config.get_head_size() - self.block_size = kv_cache_spec.block_size + self.block_size = kv_cache_group_spec.kv_cache_spec.block_size self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 953ef26c8143..8f34f8b50994 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -20,11 +20,10 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, - get_kv_cache_layout, get_per_layer_parameters, - infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, - split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec + AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, + get_per_layer_parameters, infer_global_hyperparameters, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheGroupSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -224,8 +223,8 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): self.device = device self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append @@ -233,11 +232,15 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers - self.global_hyperparameters: Optional[PerLayerParameters] = None + self.global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, + kv_cache_group_spec.layer_names, + FlashInferImpl)) self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config - self.kv_cache_spec = kv_cache_spec + assert isinstance(kv_cache_group_spec.kv_cache_spec, AttentionSpec) + self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -282,9 +285,6 @@ def _get_cascade_wrapper(self): def _plan(self, num_prefills: int, num_decodes: int, attn_metadata: FlashInferMetadata): - if self.global_hyperparameters is None: - self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config, FlashInferImpl)) if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index ad63f92cd88a..844a0be9fabf 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -19,7 +19,7 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec logger = init_logger(__name__) @@ -258,8 +258,8 @@ def __post_init__(self): class FlexAttentionMetadataBuilder( AttentionMetadataBuilder[FlexAttentionMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config @@ -269,8 +269,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.num_heads_kv = self.model_config.get_num_kv_heads( vllm_config.parallel_config) self.headdim = self.model_config.get_head_size() - self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec + self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + self.block_size = self.kv_cache_spec.block_size self.device = device def build(self, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index dca5de46c065..7757b36b8d53 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -11,7 +11,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec, MambaSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -87,8 +87,9 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index cf17d9330239..e579fa28ebae 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -213,7 +213,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheGroupSpec try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -405,13 +405,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ def __init__(self, - kv_cache_spec: AttentionSpec, + kv_cache_group_spec: KVCacheGroupSpec, vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata - self.kv_cache_spec = kv_cache_spec + assert isinstance(kv_cache_group_spec.kv_cache_spec, AttentionSpec) + self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec self.device = device scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config @@ -471,7 +472,9 @@ def __init__(self, BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, MLACommonImpl)) + get_per_layer_parameters(vllm_config, + kv_cache_group_spec.layer_names, + MLACommonImpl)) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d3e5300dbbd6..f8eddb1c83c7 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -18,7 +18,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec logger = init_logger(__name__) @@ -56,9 +56,10 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # Decode-only - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): - super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata) + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_group_spec, vllm_config, device, + FlashMLAMetadata) self.compilation_config = vllm_config.compilation_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 834c23455835..02c10e1117ba 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -17,7 +17,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec # yapf: enable @@ -66,9 +66,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # decode only - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): - super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata) + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_group_spec, vllm_config, device, + AiterMLAMetadata) assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 0739d2596676..58df7ca4b9c1 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec if current_platform.is_rocm(): import aiter @@ -165,8 +165,8 @@ def flash_attn_varlen_func_fake( class AiterFlashAttentionMetadataBuilder: - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -178,8 +178,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.headdim = self.model_config.get_head_size() - self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec + self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + self.block_size = self.kv_cache_spec.block_size # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 83471ca51b73..6f49cb22594a 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -20,7 +20,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec logger = init_logger(__name__) @@ -59,11 +59,11 @@ class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): self.device = device - self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec + self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + self.block_size = self.kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fc8649d587ee..fde6b73807c9 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -22,7 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec logger = init_logger(__name__) _KV_CACHE_LAYOUT_OVERRIDE = None @@ -68,9 +68,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, - device: torch.device): - self.kv_cache_spec = kv_cache_spec + def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + vllm_config: VllmConfig, device: torch.device): + pass @abstractmethod def build(self, @@ -162,14 +162,14 @@ class PerLayerParameters: def get_per_layer_parameters( - vllm_config: VllmConfig, + vllm_config: VllmConfig, layer_names: list[str], cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: """ - Scan all attention layers and determine some hyperparameters + Scan layers in `layer_names` and determine some hyperparameters to use during `plan`. """ - layers = get_layers_from_vllm_config(vllm_config, Attention) + layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c14ac3be3c0..cf39c1d56056 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2432,7 +2432,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: f"Unknown KV cache spec type: {type(kv_cache_spec)}") attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_spec, + kv_cache_group_spec, self.vllm_config, self.device, ) From e195f5ed1417c2ab88cff05db60da0ca921c5d80 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 22 Jul 2025 14:47:25 -0700 Subject: [PATCH 02/10] add comment on disable_sliding_window Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fde6b73807c9..9bb34fdbedf7 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -206,6 +206,10 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] for params in param_sets: + if params.window_left != global_params.window_left: + raise ValueError( + "Window left is not the same for all layers. One potential fix " + "is to set disable_sliding_window=True") assert params == global_params, ( "FlashInfer backend currently only supports models in which all " "layers share the same values for the following hyperparameters: " From e61b8e065d3c94b2a27d699fef4b2a279d0f17fe Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 22 Jul 2025 14:58:21 -0700 Subject: [PATCH 03/10] small cleanup Signed-off-by: Chen Zhang --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 00031ea48690..3eb0f43e35c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4996,7 +4996,7 @@ def get_layers_from_vllm_config( """ if layer_names is None: - layer_names = list( + layer_names = ( vllm_config.compilation_config.static_forward_context.keys()) forward_context = vllm_config.compilation_config.static_forward_context From 39feab4bf94ee48cd529f31aeeb1057882c2388a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 22 Jul 2025 23:00:26 -0700 Subject: [PATCH 04/10] try mypy Signed-off-by: Chen Zhang --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 3eb0f43e35c1..00031ea48690 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4996,7 +4996,7 @@ def get_layers_from_vllm_config( """ if layer_names is None: - layer_names = ( + layer_names = list( vllm_config.compilation_config.static_forward_context.keys()) forward_context = vllm_config.compilation_config.static_forward_context From 4165433720e00215967f2805af8c33403ff0ab5e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Jul 2025 17:32:33 -0700 Subject: [PATCH 05/10] change interface Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/cpu_attn.py | 6 +++--- vllm/v1/attention/backends/flash_attn.py | 6 +++--- vllm/v1/attention/backends/flashinfer.py | 12 +++++------- vllm/v1/attention/backends/flex_attention.py | 6 +++--- vllm/v1/attention/backends/mamba_attn.py | 5 ++--- vllm/v1/attention/backends/mla/common.py | 12 ++++++------ vllm/v1/attention/backends/mla/flashmla.py | 6 +++--- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 6 +++--- vllm/v1/attention/backends/rocm_aiter_fa.py | 6 +++--- vllm/v1/attention/backends/triton_attn.py | 6 +++--- vllm/v1/attention/backends/utils.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 3 ++- 12 files changed, 38 insertions(+), 40 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 0dc1df7d0a74..8f8ac2ec3c8c 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -17,7 +17,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.worker.gpu_input_batch import InputBatch try: @@ -315,9 +315,9 @@ def get_seq_len_block_table_args( class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device) -> None: - self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + self.kv_cache_spec = kv_cache_spec self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ecff648b7f10..0d2a6db492ea 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -28,7 +28,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout) -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -146,7 +146,7 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -160,7 +160,7 @@ def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.headdim = self.model_config.get_head_size() - self.block_size = kv_cache_group_spec.kv_cache_spec.block_size + self.block_size = kv_cache_spec.block_size self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8f34f8b50994..15de9778e3d4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -23,7 +23,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheGroupSpec +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -223,7 +223,7 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.device = device self._workspace_buffer = None @@ -233,14 +233,12 @@ def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, # Global hyperparameters shared by all attention layers self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, - kv_cache_group_spec.layer_names, - FlashInferImpl)) + get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config - assert isinstance(kv_cache_group_spec.kv_cache_spec, AttentionSpec) - self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + assert isinstance(kv_cache_spec, AttentionSpec) + self.kv_cache_spec = kv_cache_spec def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 844a0be9fabf..3db124cfe12f 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -19,7 +19,7 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -258,7 +258,7 @@ def __post_init__(self): class FlexAttentionMetadataBuilder( AttentionMetadataBuilder[FlexAttentionMetadata]): - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -269,7 +269,7 @@ def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, self.num_heads_kv = self.model_config.get_num_kv_heads( vllm_config.parallel_config) self.headdim = self.model_config.get_head_size() - self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + self.kv_cache_spec = kv_cache_spec self.block_size = self.kv_cache_spec.block_size self.device = device diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 7757b36b8d53..7f8100faedd9 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -11,7 +11,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import KVCacheGroupSpec, MambaSpec +from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -87,9 +87,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e579fa28ebae..013463c26feb 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -213,7 +213,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheGroupSpec +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -405,14 +405,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ def __init__(self, - kv_cache_group_spec: KVCacheGroupSpec, + kv_cache_spec: KVCacheSpec, + layer_names: list[str], vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata - assert isinstance(kv_cache_group_spec.kv_cache_spec, AttentionSpec) - self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + assert isinstance(kv_cache_spec, AttentionSpec) + self.kv_cache_spec = kv_cache_spec self.device = device scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config @@ -472,8 +473,7 @@ def __init__(self, BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, - kv_cache_group_spec.layer_names, + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)) if self._use_cudnn_prefill: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index f8eddb1c83c7..9f836ba3e093 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -18,7 +18,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -56,9 +56,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # Decode-only - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_group_spec, vllm_config, device, + super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata) self.compilation_config = vllm_config.compilation_config diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 02c10e1117ba..1cea6729724d 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -17,7 +17,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec # yapf: enable @@ -66,9 +66,9 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # decode only - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_group_spec, vllm_config, device, + super().__init__(kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata) assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 58df7ca4b9c1..ffd85a0325ef 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec if current_platform.is_rocm(): import aiter @@ -165,7 +165,7 @@ def flash_attn_varlen_func_fake( class AiterFlashAttentionMetadataBuilder: - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -178,7 +178,7 @@ def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.headdim = self.model_config.get_head_size() - self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + self.kv_cache_spec = kv_cache_spec self.block_size = self.kv_cache_spec.block_size # Sliding window size to be used with the AOT scheduler will be diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6f49cb22594a..468a75b75a5b 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -20,7 +20,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) @@ -59,10 +59,10 @@ class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.device = device - self.kv_cache_spec = kv_cache_group_spec.kv_cache_spec + self.kv_cache_spec = kv_cache_spec self.block_size = self.kv_cache_spec.block_size model_config = vllm_config.model_config diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 9bb34fdbedf7..fd0501af3ca6 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -22,7 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheSpec logger = init_logger(__name__) _KV_CACHE_LAYOUT_OVERRIDE = None @@ -68,7 +68,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def __init__(self, kv_cache_group_spec: KVCacheGroupSpec, + def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): pass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cf39c1d56056..d627906b5fef 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2432,7 +2432,8 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: f"Unknown KV cache spec type: {type(kv_cache_spec)}") attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_group_spec, + kv_cache_spec, + kv_cache_group_spec.layer_names, self.vllm_config, self.device, ) From 6ea5def24ef387c039bebb9e5d7da49bfe57970f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Jul 2025 17:48:39 -0700 Subject: [PATCH 06/10] remove unrelated changes Signed-off-by: Chen Zhang --- tests/v1/attention/test_attention_backends.py | 26 +++++++++---------- vllm/v1/attention/backends/cpu_attn.py | 4 +-- vllm/v1/attention/backends/flash_attn.py | 4 +-- vllm/v1/attention/backends/flashinfer.py | 5 ++-- vllm/v1/attention/backends/flex_attention.py | 6 ++--- vllm/v1/attention/backends/mamba_attn.py | 4 +-- vllm/v1/attention/backends/mla/common.py | 5 ++-- vllm/v1/attention/backends/mla/flashmla.py | 4 +-- .../attention/backends/mla/rocm_aiter_mla.py | 4 +-- vllm/v1/attention/backends/rocm_aiter_fa.py | 6 ++--- vllm/v1/attention/backends/triton_attn.py | 6 ++--- vllm/v1/attention/backends/utils.py | 6 ++--- 12 files changed, 39 insertions(+), 41 deletions(-) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 9733e136c28d..67446988b529 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -12,7 +12,7 @@ get_attention_backend) from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec +from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, @@ -197,7 +197,8 @@ def __init__(self, device: torch.device): def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, - vllm_config, device: torch.device, + layer_names: list[str], vllm_config, + device: torch.device, common_attn_metadata: CommonAttentionMetadata, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -210,7 +211,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, if backend == _Backend.FLASHINFER_VLLM_V1: import unittest.mock - from vllm.v1.attention.backends.flashinfer import PerLayerParameters + from vllm.v1.attention.backends.utils import PerLayerParameters def mock_get_per_layer_parameters(vllm_config, layer_names, cls_): # Return mock parameters for a single layer @@ -228,14 +229,15 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, cls_): with unittest.mock.patch( 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', mock_get_per_layer_parameters): - builder = builder_cls(kv_cache_spec, vllm_config, device) + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, + device) attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) else: # Build metadata - builder = builder_cls(kv_cache_spec, vllm_config, device) + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -302,10 +304,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) - kv_cache_group_spec = KVCacheGroupSpec( - kv_cache_spec=kv_cache_spec, - layer_names=["placeholder"], - ) # 1. Setup batch_size = batch_spec.batch_size @@ -424,10 +422,12 @@ def test_backend_correctness(batch_spec_name: str, model: str): if backend_name == _Backend.FLASHINFER_VLLM_V1: kv_cache_for_backend = kv_cache.transpose(0, 1) - backend_output = run_attention_backend( - backend_name, kv_cache_group_spec, vllm_config, device, - common_attn_metadata, query_vllm, key_vllm, value_vllm, - kv_cache_for_backend) + backend_output = run_attention_backend(backend_name, kv_cache_spec, + ["placeholder"], vllm_config, + device, common_attn_metadata, + query_vllm, key_vllm, + value_vllm, + kv_cache_for_backend) # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 8f8ac2ec3c8c..9ed46331863c 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -17,7 +17,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.gpu_input_batch import InputBatch try: @@ -315,7 +315,7 @@ def get_seq_len_block_table_args( class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device) -> None: self.kv_cache_spec = kv_cache_spec self.vllm_config = vllm_config diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 0d2a6db492ea..ea2662fcde9c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -28,7 +28,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout) -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -146,7 +146,7 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 15de9778e3d4..ed4c9ee21185 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -23,7 +23,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -223,7 +223,7 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.device = device self._workspace_buffer = None @@ -237,7 +237,6 @@ def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config - assert isinstance(kv_cache_spec, AttentionSpec) self.kv_cache_spec = kv_cache_spec def reorder_batch(self, input_batch: InputBatch, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 3db124cfe12f..bb0d890c7754 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -19,7 +19,7 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -258,7 +258,7 @@ def __post_init__(self): class FlexAttentionMetadataBuilder( AttentionMetadataBuilder[FlexAttentionMetadata]): - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -269,8 +269,8 @@ def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], self.num_heads_kv = self.model_config.get_num_kv_heads( vllm_config.parallel_config) self.headdim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_size = self.kv_cache_spec.block_size self.device = device def build(self, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 7f8100faedd9..8b702e28d67c 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -11,7 +11,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -87,7 +87,7 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 013463c26feb..0095d7521785 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -213,7 +213,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -405,14 +405,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ def __init__(self, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata - assert isinstance(kv_cache_spec, AttentionSpec) self.kv_cache_spec = kv_cache_spec self.device = device scheduler_config = vllm_config.scheduler_config diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 9f836ba3e093..39463b9c0616 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -18,7 +18,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -56,7 +56,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # Decode-only - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 1cea6729724d..5c5891f035ae 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -17,7 +17,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec # yapf: enable @@ -66,7 +66,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # decode only - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): super().__init__(kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index ffd85a0325ef..9393c5333cb5 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec if current_platform.is_rocm(): import aiter @@ -165,7 +165,7 @@ def flash_attn_varlen_func_fake( class AiterFlashAttentionMetadataBuilder: - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -179,7 +179,7 @@ def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], self.parallel_config) self.headdim = self.model_config.get_head_size() self.kv_cache_spec = kv_cache_spec - self.block_size = self.kv_cache_spec.block_size + self.block_size = kv_cache_spec.block_size # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 468a75b75a5b..40649e8a8a25 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -20,7 +20,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -59,11 +59,11 @@ class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.device = device self.kv_cache_spec = kv_cache_spec - self.block_size = self.kv_cache_spec.block_size + self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fd0501af3ca6..2df6393fd44d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -22,7 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger -from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) _KV_CACHE_LAYOUT_OVERRIDE = None @@ -68,9 +68,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def __init__(self, kv_cache_spec: KVCacheSpec, layer_names: list[str], + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - pass + self.kv_cache_spec = kv_cache_spec @abstractmethod def build(self, From 65d4bcb61011ad04fdb2d4b8f327f50a1ea4a21f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Jul 2025 17:53:02 -0700 Subject: [PATCH 07/10] remove unrelated changes Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/attention/backends/triton_attn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 9393c5333cb5..7949c285e86c 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -178,8 +178,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.num_heads_kv = self.model_config.get_num_kv_heads( self.parallel_config) self.headdim = self.model_config.get_head_size() - self.kv_cache_spec = kv_cache_spec self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 40649e8a8a25..195fbd3b1b9c 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -62,8 +62,8 @@ class TritonAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.device = device - self.kv_cache_spec = kv_cache_spec self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( From add4caab476875b6560dba2983c14a2577a57fa1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Jul 2025 00:42:38 -0700 Subject: [PATCH 08/10] fix test Signed-off-by: Chen Zhang --- tests/v1/worker/test_gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 7fec4782517c..15554233aaf7 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -749,7 +749,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): layer_4 = "model.layers.4.mixer" layer_5 = "model.layers.5.mixer" - with set_current_vllm_config(vllm_config): + with set_current_vllm_config(vllm_config), monkeypatch.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") hf_config = vllm_config.model_config.hf_config fwd_context = {} for key in [layer_0, layer_1]: From 167277a9a3fa5a614a46c2db335f37a773644e22 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Jul 2025 00:44:33 -0700 Subject: [PATCH 09/10] revert regex change Signed-off-by: Chen Zhang --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f18cc661f71a..d46e678e7aa4 100644 --- a/setup.py +++ b/setup.py @@ -6,12 +6,12 @@ import json import logging import os +import re import subprocess import sys from pathlib import Path from shutil import which -import regex as re import torch from packaging.version import Version, parse from setuptools import Extension, setup From 46f1416a16fd99d353778fe3c05cdbff83f038bb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Jul 2025 10:28:33 -0700 Subject: [PATCH 10/10] fix eagle test Signed-off-by: Chen Zhang --- tests/v1/spec_decode/test_eagle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 5c74a286c4a9..e77d276a4f45 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -310,6 +310,7 @@ def create_deterministic_logits(token_ids): _Backend.FLASH_ATTN_VLLM_V1) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, vllm_config=proposer.vllm_config, device=device, )