diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index d72e50e5196b..7bc7f44dd7ab 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Optional, Union import pytest import torch @@ -10,12 +9,6 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel from vllm.distributed import cleanup_dist_env_and_memory -from vllm.forward_context import get_forward_context -from vllm.model_executor.models.gemma3n_mm import ( - Gemma3nForConditionalGeneration) -from vllm.model_executor.models.registry import ModelRegistry -from vllm.model_executor.models.utils import extract_layer_index -from vllm.sequence import IntermediateTensors from ...utils import fork_new_process_for_each_test @@ -23,54 +16,6 @@ SEED = 42 -class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = super().forward(input_ids, positions, - intermediate_tensors, inputs_embeds, - **kwargs) - attn_metadata = get_forward_context().attn_metadata - # attn_metadata is None during dummy runs - if (attn_metadata is not None - and self.language_model.cache_config.kv_sharing_fast_prefill): - assert isinstance(attn_metadata, dict) # true in V1 - # Gemma3n-E2B has 30 layers, with last 20 layers being - # cross-decoder layers. Check attention metadata is correct - for layer_name, metadata in attn_metadata.items(): - layer_idx = extract_layer_index(layer_name) - if layer_idx >= 20: - assert hasattr(metadata, 'logits_indices_padded') - assert hasattr(metadata, 'num_logits_indices') - else: - assert not hasattr(metadata, 'logits_indices_padded') - assert not hasattr(metadata, 'num_logits_indices') - - # Last layer will be a KV sharing layer - layer_attn_metadata = attn_metadata[ - self.language_model.model.layers[-1].self_attn.attn.layer_name] - logits_indices_padded = (layer_attn_metadata.logits_indices_padded) - assert logits_indices_padded is not None - num_logits_indices = layer_attn_metadata.num_logits_indices - assert num_logits_indices > 0 - # Reset hidden states to random values and - # only set logits at logits_indices to valid values - # Because logits_indices are the only positions that are used - # for output token sampling, this still produces same outputs - logits_hs = hidden_states[logits_indices_padded] - hidden_states = torch.randn_like(hidden_states) - gen_indices = logits_indices_padded[:num_logits_indices] - hidden_states[gen_indices] = logits_hs[:num_logits_indices] - - return hidden_states - - @pytest.fixture def test_prompts(): """ @@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill( enforce_eager: bool, test_prompts: list[str], ): - ModelRegistry.register_model("Gemma3nForConditionalGeneration", - TestGemma3nForConditionalGeneration) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( # This allows vLLM compilation backend to handle allocating and diff --git a/vllm/config/cache.py b/vllm/config/cache.py index a9550d4390ad..a658aa522ebe 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -145,12 +145,19 @@ def __post_init__(self) -> None: self._verify_cache_dtype() self._verify_prefix_caching() + self._verify_kv_sharing_fast_prefill() def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info return {key: str(value) for key, value in self.__dict__.items()} + def _verify_kv_sharing_fast_prefill(self) -> None: + if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not supported " + "in V0 currently.") + @model_validator(mode='after') def _verify_args(self) -> Self: if self.cpu_offload_gb < 0: @@ -162,11 +169,6 @@ def _verify_args(self) -> Self: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") - if self.kv_sharing_fast_prefill: - logger.warning_once( - "--kv-sharing-fast-prefill is currently work in progress " - "and not functional yet (i.e. no prefill savings)") - return self def _verify_cache_dtype(self) -> None: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index ffec3408702c..0e0e191e75fc 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -23,9 +23,11 @@ from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig from vllm.attention import Attention +from vllm.compilation.backends import set_model_tag from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, GeluAndMul, @@ -45,6 +47,7 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import SupportsQuant from .utils import (AutoWeightsLoader, extract_layer_index, @@ -533,7 +536,178 @@ def forward( return corrected_predictions -@support_torch_compile +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class Gemma3nSelfDecoder(nn.Module): + """ + Includes altup embedding and self decoder layers + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + per_layer_model_projection: ColumnParallelLinear, + embed_scale_per_layer: torch.Tensor, + embed_tokens_per_layer: VocabParallelEmbedding, + per_layer_projection_norm: RMSNorm, + per_layer_input_scale: torch.Tensor, + altup_projections: nn.ModuleList, + eps: torch.Tensor, + embed_tokens: VocabParallelEmbedding, + embed_scale: torch.Tensor, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + self.per_layer_model_projection = per_layer_model_projection + self.config = vllm_config.model_config.hf_config + self.embed_scale_per_layer = embed_scale_per_layer + self.embed_tokens_per_layer = embed_tokens_per_layer + self.per_layer_projection_norm = per_layer_projection_norm + self.per_layer_input_scale = per_layer_input_scale + self.altup_projections = altup_projections + self.eps = eps + self.embed_tokens = embed_tokens + self.embed_scale = embed_scale + + def get_per_layer_input_embeddings( + self, input_ids: torch.Tensor) -> torch.Tensor: + # Deal with the fact that vocab_size_per_layer_input < vocab_size + # which causes us to have some out of vocab tokens by setting + # those token ids to 0. This matches the HF implementation. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, + torch.zeros_like(input_ids)) + return self.embed_tokens_per_layer( + per_layer_inputs_tokens) * self.embed_scale_per_layer + + def get_per_layer_inputs( + self, + hidden_states_0: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], + ) -> torch.Tensor: + per_layer_projection = self.per_layer_model_projection(hidden_states_0) + per_layer_projection = per_layer_projection.reshape( + *hidden_states_0.shape[:-1], + self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection) + if per_layer_inputs is not None: + # Profiling run does not compute per_layer_inputs + per_layer_inputs = per_layer_projection + per_layer_inputs + per_layer_inputs *= self.per_layer_input_scale + else: + per_layer_inputs = per_layer_projection + return per_layer_inputs + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.embed_scale + + def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: + # Altup embed. + hidden_states = [hidden_states_0] * self.config.altup_num_inputs + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + hidden_states = torch.stack(hidden_states, dim=-1) + return hidden_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds + else: + hidden_states_0 = self.get_input_embeddings(input_ids) + + adjusted_per_layer_inputs = self.get_per_layer_inputs( + hidden_states_0, per_layer_inputs) + hidden_states = self.altup_embed(hidden_states_0) + + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + + return hidden_states, adjusted_per_layer_inputs + + +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. + kv_sharing_fast_prefill) +class Gemma3nCrossDecoder(nn.Module): + """ + Cross-decoder layers + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + return hidden_states + + +# This disables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) class Gemma3nTextModel(nn.Module, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -543,7 +717,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -613,95 +786,211 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: Gemma3nDecoderLayer( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") + + self.eps = torch.tensor(torch.finfo().min) + + first_kv_shared_layer_idx = (config.num_hidden_layers - + config.num_kv_shared_layers) + # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) + with set_model_tag("self_decoder"): + self.self_decoder = Gemma3nSelfDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.self_decoder", + decoder_layers=self.layers[:first_kv_shared_layer_idx], + layer_idx_start=0, + per_layer_model_projection=self.per_layer_model_projection, + embed_scale_per_layer=self.embed_scale_per_layer, + embed_tokens_per_layer=self.embed_tokens_per_layer, + per_layer_projection_norm=self.per_layer_projection_norm, + per_layer_input_scale=self.per_layer_input_scale, + altup_projections=self.altup_projections, + eps=self.eps, + embed_tokens=self.embed_tokens, + embed_scale=self.embed_scale, + ) + # Layer idx 20-30 are cross-decoder layers in YOCO + with set_model_tag("cross_decoder"): + self.cross_decoder = Gemma3nCrossDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.cross_decoder", + decoder_layers=self.layers[first_kv_shared_layer_idx:], + layer_idx_start=first_kv_shared_layer_idx, + ) + self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, ) - self.eps = torch.tensor(torch.finfo().min) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) * self.embed_scale + self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill + + if self.fast_prefill_enabled: + # Allocate static buffers for CUDAGraph + # TODO(sarckk): Extract this functionality to interface + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + device = next(self.parameters()).device + self.positions = torch.zeros(max_num_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (max_num_tokens, config.hidden_size, + self.config.altup_num_inputs), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + self.per_layer_inputs = torch.zeros( + (max_num_tokens, self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) - def get_per_layer_input_embeddings( - self, input_ids: torch.Tensor) -> torch.Tensor: - # Deal with the fact that vocab_size_per_layer_input < vocab_size - # which causes us to have some out of vocab tokens by setting - # those token ids to 0. This matches the HF implementation. - per_layer_inputs_mask = torch.logical_and( - input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, - torch.zeros_like(input_ids)) - return self.embed_tokens_per_layer( - per_layer_inputs_tokens) * self.embed_scale_per_layer + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_input_embeddings(input_ids) - def forward( + def fast_prefill_forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, - per_layer_inputs: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if inputs_embeds is not None: - hidden_states_0 = inputs_embeds - else: - hidden_states_0 = self.get_input_embeddings(input_ids) + ) -> torch.Tensor: + logits_indices_padded, num_logits_indices = None, None + attn_metadata = get_forward_context().attn_metadata + + # attn_metadata is None during dummy runs + if (self.fast_prefill_enabled and attn_metadata is not None): + assert isinstance(attn_metadata, dict) + # Last layer is a KV sharing layer + layer_attn_metadata = attn_metadata[ + self.layers[-1].self_attn.attn.layer_name] + if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)): + logits_indices_padded = ( + layer_attn_metadata.logits_indices_padded) + num_logits_indices = layer_attn_metadata.num_logits_indices + + # Copy inputs for cudagraph + batch_size = positions.size(0) + self.positions[:batch_size].copy_(positions) + self_decoder_hidden_states, per_layer_inputs_adjusted = \ + self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) - per_layer_projection = self.per_layer_model_projection(hidden_states_0) - per_layer_projection = per_layer_projection.reshape( - *hidden_states_0.shape[:-1], - self.config.num_hidden_layers, - self.config.hidden_size_per_layer_input, + if logits_indices_padded is None: + logits_indices_padded = torch.arange( + positions.size(0), + dtype=positions.dtype, + device=positions.device, + ) + + # NOTE(sarckk): There is currently a bug caused by + # vLLM converting output of last piecewise CUDA graph + # to weakref, causing memory to be prematurely freed + # when there are multiple compilation units + # Keep .clone() until fix in + # https://github.com/vllm-project/vllm/pull/22282 + hidden_states = self_decoder_hidden_states.clone() + + # Copy inputs for cudagraph + num_padded_logits_indices = logits_indices_padded.size(0) + self.positions[:num_padded_logits_indices].copy_( + positions[logits_indices_padded]) + self.hidden_states[:num_padded_logits_indices].copy_( + self_decoder_hidden_states[logits_indices_padded]) + self.per_layer_inputs[:num_padded_logits_indices].copy_( + per_layer_inputs_adjusted[logits_indices_padded]) + cross_decoder_hidden_states = self.cross_decoder( + positions=self.positions[:num_padded_logits_indices], + hidden_states=self.hidden_states[:num_padded_logits_indices], + per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices], + **kwargs, ) - per_layer_projection = self.per_layer_projection_norm( - per_layer_projection) - if per_layer_inputs is not None: - # Profiling run does not compute per_layer_inputs - per_layer_inputs = per_layer_projection + per_layer_inputs - per_layer_inputs *= self.per_layer_input_scale + if num_logits_indices is not None: + assert num_logits_indices > 0 + # Merge cross-decoder and self-decoder hidden states + hidden_states[logits_indices_padded[:num_logits_indices]] = ( + cross_decoder_hidden_states[:num_logits_indices]) else: - per_layer_inputs = per_layer_projection + hidden_states = cross_decoder_hidden_states - # Altup embed. - hidden_states = [hidden_states_0] * self.config.altup_num_inputs - target_magnitude = torch.mean(hidden_states_0**2, dim=-1, - keepdim=True)**0.5 - for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, - dim=-1, - keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - hidden_states = torch.stack(hidden_states, dim=0) + return hidden_states - # Transformer blocks. - for layer_idx, layer in enumerate(self.layers): - # [altup_num_inputs, num_tokens, hidden_size] - hidden_states = layer( - positions=positions, - hidden_states=hidden_states, - per_layer_input=per_layer_inputs[:, layer_idx, :], - **kwargs, - ) + def normal_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + hidden_states = self.cross_decoder( + positions=positions, + hidden_states=hidden_states, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + return hidden_states + def altup_unembed( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: # Altup unembed. - target_magnitude = torch.mean(hidden_states[0]**2, + target_magnitude = torch.mean(hidden_states[..., 0]**2, dim=-1, keepdim=True)**0.5 for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_unembed_projections[i - 1]( - hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, + hidden_states[..., i] = self.altup_unembed_projections[i - 1]( + hidden_states[..., i]) + new_magnitude = torch.mean(hidden_states[..., i]**2, dim=-1, keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( + hidden_states[..., i] *= target_magnitude / torch.maximum( new_magnitude, self.eps) - # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] - hidden_states = torch.mean(hidden_states, dim=0) + # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=-1) + return hidden_states + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.fast_prefill_enabled: + hidden_states = self.fast_prefill_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + else: + hidden_states = self.normal_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + hidden_states = self.altup_unembed(hidden_states) return self.norm(hidden_states) def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index d59dde1560ae..aba4f98ea5f3 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -620,7 +620,7 @@ def get_input_embeddings( # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # them here, as the model forward has only access to the input_embeds. if input_ids is not None: - per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings( + per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings( input_ids) per_layer_inputs = per_layer_inputs.reshape( -1, self.config.text_config.num_hidden_layers, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 39bdbe125635..ad53b2e80bc7 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,11 +4,13 @@ import enum import functools from abc import abstractmethod -from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar +from dataclasses import dataclass, fields, make_dataclass +from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, + TypeVar) import numpy as np import torch +from typing_extensions import runtime_checkable from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv @@ -19,7 +21,8 @@ from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) @@ -65,6 +68,10 @@ class CommonAttentionMetadata: causal: bool = True + # Needed by FastPrefillAttentionBuilder + logits_indices_padded: Optional[torch.Tensor] = None + num_logits_indices: Optional[int] = None + @dataclass class UbatchSlice: @@ -542,6 +549,69 @@ def make_local_attention_virtual_batches( ) +def make_kv_sharing_fast_prefill_common_attn_metadata( + common_attn_metadata: CommonAttentionMetadata, +) -> CommonAttentionMetadata: + if common_attn_metadata.max_query_len == 1: + # All requests are decode (assume 1 token for now) + # Skip computing fast prefill path + return common_attn_metadata + + assert common_attn_metadata.logits_indices_padded is not None + assert common_attn_metadata.num_logits_indices is not None + + logits_indices_padded = common_attn_metadata.logits_indices_padded + num_logits_indices = common_attn_metadata.num_logits_indices + # Get rid of CUDAGraph padding, if any + logits_indices = logits_indices_padded[:num_logits_indices] + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + query_start_loc_cpu=decode_query_start_loc.to("cpu", + non_blocking=True), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + causal=True, + ) + return common_attn_metadata + + def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], builder_cls: type[AttentionMetadataBuilder[M]] @@ -679,13 +749,56 @@ def subclass_attention_metadata( return Wrapped -def make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls: Any, ) -> Any: - """ - Return a new subclass of `metadata_cls` for fast prefill - """ - return subclass_attention_metadata( - name_prefix="KVSharingFastPrefill", - metadata_cls=metadata_cls, - fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS, - ) +@runtime_checkable +class KVSharingFastPrefillMetadata(Protocol): + logits_indices_padded: torch.Tensor + num_logits_indices: int + + +def create_fast_prefill_custom_backend( + prefix: str, + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: + + underlying_builder = underlying_attn_backend.get_builder_cls() + + class FastPrefillAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_common_attn_metadata =\ + make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) + metadata = super().build(common_prefix_len, + new_common_attn_metadata, fast_build) + + class KVSharingFastPrefillAttentionMetadata( + metadata.__class__, # type: ignore + KVSharingFastPrefillMetadata): + + def __init__(self, metadata, common_attn_metadata): + # Shallow copy all fields in metadata cls + for field in fields(metadata.__class__): + setattr(self, field.name, + getattr(metadata, field.name)) + + # Set additional fields that will be used in model code + assert (common_attn_metadata.logits_indices_padded + is not None + and common_attn_metadata.num_logits_indices + is not None) + self.logits_indices_padded = \ + common_attn_metadata.logits_indices_padded + self.num_logits_indices = \ + common_attn_metadata.num_logits_indices + + return KVSharingFastPrefillAttentionMetadata( + metadata, common_attn_metadata) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=FastPrefillAttentionBuilder) + + return attn_backend diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index dbea0b610b31..7440fe1f07e9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -335,6 +335,13 @@ async def generate( returning the RequestOutput back to the caller. """ + if (self.vllm_config.cache_config.kv_sharing_fast_prefill + and sampling_params.prompt_logprobs): + raise ValueError( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, please disable it when the requests need " + "prompt logprobs") + try: # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 01c90b2ea38d..07e1b8415aed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import gc import itertools import time @@ -58,7 +57,7 @@ supports_dynamo) from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - make_kv_sharing_fast_prefill_attention_metadata, + create_fast_prefill_custom_backend, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, @@ -84,9 +83,10 @@ KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, - gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from .utils import (AttentionGroup, MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) if TYPE_CHECKING: import xgrammar as xgr @@ -860,6 +860,8 @@ def _prepare_inputs( max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), causal=True, ) @@ -884,28 +886,7 @@ def _prepare_inputs( common_attn_metadata=common_attn_metadata, )) - fast_prefill_metadata = attn_metadata_i - if (self.cache_config.kv_sharing_fast_prefill - and self.kv_sharing_fast_prefill_eligible_layers): - # Dynamically create a a dataclass type that inherits - # from attention metadata type but includes additional - # fields logits_indices_padded and num_logits_indices - # which are required for prefill truncation - fast_prefill_metadata_type = ( - make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls=type(attn_metadata_i), )) - fast_prefill_metadata = fast_prefill_metadata_type( - **dataclasses.asdict(attn_metadata_i), - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), - ) - for layer_name in attn_group.layer_names: - if (self.cache_config.kv_sharing_fast_prefill - and layer_name - in self.kv_sharing_fast_prefill_eligible_layers): - attn_metadata[layer_name] = fast_prefill_metadata - continue attn_metadata[layer_name] = attn_metadata_i # Hot-Swap lora model @@ -1484,6 +1465,12 @@ def execute_model( return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, tokens, please disable it when the requests " + "need prompt logprobs") + # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, @@ -2741,6 +2728,13 @@ def get_attn_backends_for_layers( # layer. for layer_name in layer_names: attn_backend = layers[layer_name].get_attn_backend() + + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, + ) + key = attn_backend.full_cls_name() attn_backends[key] = attn_backend attn_backend_layers[key].append(layer_name) @@ -3073,20 +3067,40 @@ def initialize_kv_cache_tensors( kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - self.attn_groups, - self.runner_only_attn_layers, - ) + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def maybe_add_kv_sharing_layers_to_kv_cache_groups( + self, kv_cache_config: KVCacheConfig) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + self.runner_only_attn_layers, + ) + + if self.cache_config.kv_sharing_fast_prefill: + # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other + # similar KV sharing setups, only the layers that generate KV caches + # are involved in the prefill phase, enabling prefill to early exit. attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - # Iterate in reversed order and add layers that re-use KV cache - # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: self.kv_sharing_fast_prefill_eligible_layers.add( @@ -3094,11 +3108,6 @@ def initialize_kv_cache_tensors( else: break - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches) - return kv_caches - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -3110,6 +3119,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 70ffde39ca33..230700612708 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -55,9 +55,8 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (MultiModalBudget, bind_kv_cache, - initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs) +from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, sanity_check_mm_encoder_outputs) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -1599,6 +1598,30 @@ def profile_run( self.encoder_cache.clear() gc.collect() + def maybe_setup_cross_layer_kv_sharing( + self, + kv_caches: dict[str, torch.Tensor], + kv_cache_config: KVCacheConfig, + ) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + ) + + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1664,14 +1687,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: else: raise NotImplementedError - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - ) + # Set up cross-layer KV cache sharing if needed + self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) bind_kv_cache( kv_caches, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index a519336e4161..6767804c71b9 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -203,12 +203,9 @@ def gather_mm_placeholders( return placeholders[is_embed] -def initialize_kv_cache_for_kv_sharing( +def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], - kv_caches: dict[str, torch.Tensor], - # Optional for now to avoid breaking TPU - attn_groups: Optional[list[list[AttentionGroup]]] = None, runner_only_attn_layers: Optional[set[str]] = None, ) -> None: """ @@ -223,38 +220,15 @@ def initialize_kv_cache_for_kv_sharing( means this layer will perform attention using the keys and values from the KV cache of `shared_kv_cache_layers[layer_name]`. kv_cache_groups: The KV cache groups of the model. - kv_caches: The allocated kv_caches with layer names as keys. - Note that layers in shared_kv_cache_layers.keys() are not - originally included as it only contains layers which have its own - KV cache allocation. - attn_groups: Optional list of attention groups. Layers in the same KV - cache group may be placed in different attention groups if they - have different attention backends. Currently only provided by - GPU model runner. """ - # mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx) - layer_to_attn_group_idx: dict[str, tuple[int, int]] = {} - if attn_groups: - for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups): - for attn_group_idx, attn_group in enumerate(kv_attn_groups): - for layer_name in attn_group.layer_names: - layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, - attn_group_idx) - else: - for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups): - for layer_name in kv_cache_group.layer_names: - # attn group idx default to 0 if not provided - layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0) + layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {} + for kv_cache_group in kv_cache_groups: + for layer_name in kv_cache_group.layer_names: + layer_to_kv_cache_group[layer_name] = kv_cache_group for layer_name, target_layer_name in shared_kv_cache_layers.items(): - kv_caches[layer_name] = kv_caches[target_layer_name] - kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0] - kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name) - - if attn_groups: - attn_group_idx = layer_to_attn_group_idx[target_layer_name][1] - attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append( - layer_name) + tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] + tgt_kv_cache_group.layer_names.append(layer_name) if runner_only_attn_layers is not None: runner_only_attn_layers.add(layer_name)