Skip to content

Commit 555e722

Browse files
authored
[v1][attention] Support Hybrid Allocator + FlashInfer (#21412)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 0e36abf commit 555e722

File tree

16 files changed

+85
-57
lines changed

16 files changed

+85
-57
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def __init__(self, device: torch.device):
198198

199199

200200
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
201-
vllm_config, device: torch.device,
201+
layer_names: list[str], vllm_config,
202+
device: torch.device,
202203
common_attn_metadata: CommonAttentionMetadata,
203204
query: torch.Tensor, key: torch.Tensor,
204205
value: torch.Tensor,
@@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
211212
if backend == _Backend.FLASHINFER_VLLM_V1:
212213
import unittest.mock
213214

214-
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
215+
from vllm.v1.attention.backends.utils import PerLayerParameters
215216

216-
def mock_get_per_layer_parameters(vllm_config, impl_cls):
217+
def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
217218
# Return mock parameters for a single layer
218219
head_size = vllm_config.model_config.get_head_size()
219220
return {
220-
"mock_layer":
221+
layer_name:
221222
PerLayerParameters(
222223
window_left=-1, # No sliding window
223224
logits_soft_cap=0.0, # No soft cap
224225
sm_scale=1.0 / (head_size**0.5) # Standard scale
225226
)
227+
for layer_name in layer_names
226228
}
227229

228230
with unittest.mock.patch(
229231
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
230232
mock_get_per_layer_parameters):
231-
builder = builder_cls(kv_cache_spec, vllm_config, device)
233+
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
234+
device)
232235
attn_metadata = builder.build(
233236
common_prefix_len=0,
234237
common_attn_metadata=common_attn_metadata,
235238
)
236239
else:
237240
# Build metadata
238-
builder = builder_cls(kv_cache_spec, vllm_config, device)
241+
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
239242
attn_metadata = builder.build(
240243
common_prefix_len=0,
241244
common_attn_metadata=common_attn_metadata,
@@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
427430
set_kv_cache_layout("HND")
428431

429432
backend_output = run_attention_backend(backend_name, kv_cache_spec,
430-
vllm_config, device,
431-
common_attn_metadata,
433+
["placeholder"], vllm_config,
434+
device, common_attn_metadata,
432435
query_vllm, key_vllm,
433436
value_vllm,
434437
kv_cache_for_backend)

tests/v1/spec_decode/test_eagle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def create_deterministic_logits(token_ids):
305305
_Backend.FLASH_ATTN_VLLM_V1)
306306
attn_metadata_builder = attn_metadata_builder_cls(
307307
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
308+
layer_names=proposer.attn_layer_names,
308309
vllm_config=proposer.vllm_config,
309310
device=device,
310311
)

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
745745
layer_4 = "model.layers.4.mixer"
746746
layer_5 = "model.layers.5.mixer"
747747

748-
with set_current_vllm_config(vllm_config):
748+
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
749+
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
749750
hf_config = vllm_config.model_config.hf_config
750751
fwd_context = {}
751752
for key in [layer_0, layer_1]:

vllm/config.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -740,8 +740,8 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
740740
isinstance(sliding_window, list))
741741

742742
if not self.disable_sliding_window and has_interleaved_attention:
743-
if (backend :=
744-
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
743+
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
744+
) in ("XFORMERS", "FLASHINFER"):
745745
sliding_window_len_min = get_min_sliding_window(
746746
self.hf_text_config.sliding_window)
747747

@@ -5065,13 +5065,29 @@ def assert_hashable(text):
50655065
T = TypeVar("T")
50665066

50675067

5068-
def get_layers_from_vllm_config(vllm_config: VllmConfig,
5069-
layer_type: type[T]) -> dict[str, T]:
5068+
def get_layers_from_vllm_config(
5069+
vllm_config: VllmConfig,
5070+
layer_type: type[T],
5071+
layer_names: Optional[list[str]] = None) -> dict[str, T]:
5072+
"""
5073+
Get layers from the vLLM config.
5074+
5075+
Args:
5076+
vllm_config: The vLLM config.
5077+
layer_type: The type of the layer to get.
5078+
layer_names: The names of the layers to get. If None, return all layers.
5079+
"""
5080+
5081+
if layer_names is None:
5082+
layer_names = list(
5083+
vllm_config.compilation_config.static_forward_context.keys())
5084+
5085+
forward_context = vllm_config.compilation_config.static_forward_context
5086+
50705087
return {
5071-
layer_name: layer
5072-
for layer_name, layer in
5073-
vllm_config.compilation_config.static_forward_context.items()
5074-
if isinstance(layer, layer_type)
5088+
layer_name: forward_context[layer_name]
5089+
for layer_name in layer_names
5090+
if isinstance(forward_context[layer_name], layer_type)
50755091
}
50765092

50775093

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ def get_seq_len_block_table_args(
315315

316316
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
317317

318-
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
319-
device: torch.device) -> None:
318+
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
319+
vllm_config: VllmConfig, device: torch.device) -> None:
320320
self.kv_cache_spec = kv_cache_spec
321321
self.vllm_config = vllm_config
322322
self.scheduler_config = vllm_config.scheduler_config

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ class FlashAttentionMetadataBuilder(
148148
AttentionMetadataBuilder[FlashAttentionMetadata]):
149149
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
150150

151-
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
152-
device: torch.device):
151+
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
152+
vllm_config: VllmConfig, device: torch.device):
153153
self.vllm_config = vllm_config
154154
self.model_config = vllm_config.model_config
155155
self.parallel_config = vllm_config.parallel_config

vllm/v1/attention/backends/flashinfer.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
from vllm.utils import cdiv
2222
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2323
from vllm.v1.attention.backends.utils import (
24-
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
25-
get_kv_cache_layout, get_per_layer_parameters,
26-
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
27-
split_decodes_and_prefills)
24+
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
25+
get_per_layer_parameters, infer_global_hyperparameters,
26+
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
2827
from vllm.v1.kv_cache_interface import AttentionSpec
2928

3029
if TYPE_CHECKING:
@@ -219,16 +218,17 @@ def __post_init__(self):
219218

220219
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
221220

222-
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
223-
device: torch.device):
221+
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
222+
vllm_config: VllmConfig, device: torch.device):
224223
self.device = device
225224
self._workspace_buffer = None
226225
self._prefill_wrapper = None # Wrapper for prefill/append
227226
self._decode_wrapper = None # Wrapper for decode
228227
self._cascade_wrapper = None # Wrapper for cascade attention
229228

230229
# Global hyperparameters shared by all attention layers
231-
self.global_hyperparameters: Optional[PerLayerParameters] = None
230+
self.global_hyperparameters = infer_global_hyperparameters(
231+
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
232232

233233
self.vllm_config = vllm_config
234234
self.cache_config = vllm_config.cache_config
@@ -283,10 +283,6 @@ def _get_cascade_wrapper(self):
283283

284284
def _plan(self, num_prefills: int, num_decodes: int,
285285
attn_metadata: FlashInferMetadata):
286-
if self.global_hyperparameters is None:
287-
self.global_hyperparameters = infer_global_hyperparameters(
288-
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
289-
290286
if attn_metadata.use_cascade:
291287
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
292288
attn_metadata.cascade_wrapper.plan(

vllm/v1/attention/backends/flex_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ def __post_init__(self):
258258
class FlexAttentionMetadataBuilder(
259259
AttentionMetadataBuilder[FlexAttentionMetadata]):
260260

261-
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
262-
device: torch.device):
261+
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
262+
vllm_config: VllmConfig, device: torch.device):
263263
self.model_config = vllm_config.model_config
264264
self.parallel_config = vllm_config.parallel_config
265265
self.cache_config = vllm_config.cache_config

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ class Mamba2AttentionMetadata:
8787
class Mamba2AttentionMetadataBuilder(
8888
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
8989

90-
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
91-
device: torch.device):
90+
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
91+
vllm_config: VllmConfig, device: torch.device):
9292
assert isinstance(kv_cache_spec, MambaSpec)
9393
self.kv_cache_spec = kv_cache_spec
9494
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()

vllm/v1/attention/backends/mla/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
406406

407407
def __init__(self,
408408
kv_cache_spec: AttentionSpec,
409+
layer_names: list[str],
409410
vllm_config: VllmConfig,
410411
device: torch.device,
411412
metadata_cls: Optional[type[M]] = None):
@@ -471,7 +472,8 @@ def __init__(self,
471472
BatchPrefillWithRaggedKVCacheWrapper] = []
472473

473474
self._global_hyperparameters = infer_global_hyperparameters(
474-
get_per_layer_parameters(vllm_config, MLACommonImpl))
475+
get_per_layer_parameters(vllm_config, layer_names,
476+
MLACommonImpl))
475477

476478
if self._use_cudnn_prefill:
477479
self.cudnn_workspace = torch.empty(

0 commit comments

Comments
 (0)