Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,22 @@ def get_encoder_dummy_data(
)

return dummy_data

def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.\
get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
# than whisper.
return 0
assert len(max_tokens) == 1, "Encoder-decoder models are expected \
to implement the multimodal interface with at most one modality."

first_modality = next(iter(max_tokens))
return max_tokens[first_modality]
34 changes: 25 additions & 9 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, get_manager_for_kv_cache_spec)
CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.request import Request
Expand Down Expand Up @@ -42,9 +42,10 @@ def __init__(
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))

def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
new_computed_blocks: tuple[
list[KVCacheBlock], ...],
num_encoder_tokens: int) -> int:
"""
Get the number of blocks needed to be allocated for the request.

Expand All @@ -54,14 +55,22 @@ def get_num_blocks_to_allocate(
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.

Returns:
The number of blocks.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
if isinstance(manager, CrossAttentionManager):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [])
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
return num_blocks_to_allocate

def save_new_computed_blocks(
Expand All @@ -79,8 +88,11 @@ def save_new_computed_blocks(
manager.save_new_computed_blocks(request_id,
new_computed_blocks[i])

def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
def allocate_new_blocks(
self,
request_id: str,
num_tokens: int,
num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Expand All @@ -89,12 +101,16 @@ def allocate_new_blocks(self, request_id: str,
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.

Returns:
The new allocated blocks.
"""
return tuple(
manager.allocate_new_blocks(request_id, num_tokens)
manager.allocate_new_blocks(
request_id, num_encoder_tokens if isinstance(
manager, CrossAttentionManager) else num_tokens)
for manager in self.single_type_managers)

def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def allocate_slots(
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
num_encoder_tokens: int = 0,
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.

Expand Down Expand Up @@ -253,6 +254,7 @@ def allocate_slots(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens,
)

if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
Expand All @@ -273,7 +275,7 @@ def allocate_slots(
new_computed_block_list)

new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
request.request_id, num_tokens_need_slot, num_encoder_tokens)

# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
Expand All @@ -292,7 +294,7 @@ def allocate_slots(

def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
We free the blocks in reverse order so that he tail blocks are evicted
We free the blocks in reverse order so that the tail blocks are evicted
first when caching is enabled.

Args:
Expand Down
37 changes: 36 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
self.parallel_config = vllm_config.parallel_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder

# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
Expand All @@ -83,6 +84,9 @@ def __init__(
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors")
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)

Expand Down Expand Up @@ -431,13 +435,30 @@ def schedule(self) -> SchedulerOutput:
== 0 else
self.num_lookahead_tokens)

# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
assert ("whisper"
in self.vllm_config.model_config.model.lower()), (
"Whisper is the only supported "
"encoder-decoder model.")
num_encoder_tokens = MULTIMODAL_REGISTRY.\
get_encdec_max_encoder_len(
self.vllm_config.model_config)
else:
num_encoder_tokens = 0

new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)

if new_blocks is None:
Expand Down Expand Up @@ -703,7 +724,21 @@ def _try_schedule_encoder_inputs(
# The encoder input is not needed in this step.
break

if start_pos + num_encoder_tokens <= num_computed_tokens:
if self.is_encoder_decoder and num_computed_tokens > 0:
assert start_pos == 0, (
"Encoder input should be processed at the beginning of "
"the sequence when encoder-decoder models are used.")
# Encoder input has already been computed
# The calculation here is a bit different. We don't turn encoder
# output into tokens that get processed by the decoder and
# reflected in num_computed_tokens. Instead, start_pos reflects
# the position where we need to ensure we calculate encoder
# inputs. This should always be 0 to ensure we calculate encoder
# inputs before running the decoder. Once we've calculated some
# decoder tokens (num_computed_tokens > 0), then we know we
# already calculated encoder inputs and can skip here.
continue
elif start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue
Expand Down
56 changes: 54 additions & 2 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
CrossAttentionSpec, FullAttentionSpec,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.request import Request


Expand Down Expand Up @@ -552,11 +553,62 @@ def allocate_new_blocks(self, request_id: str,
return new_blocks


class CrossAttentionManager(SingleTypeKVCacheManager):
"""Manager for cross-attention KV cache in encoder-decoder models."""

def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[KVCacheBlock]) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty.
assert len(new_computed_blocks) == 0

def cache_blocks(self, request: Request, num_tokens: int) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so this method is not relevant.
raise ValueError("Should not be called as prefix caching is disabled.")

def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
# Cross-attention blocks contain request-specific encoder states
# and are not shared between different requests
return 0

@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups"
)
# Cross-attention does not benefit from prefix caching since:
# 1. Encoder states are unique per request (different audio/image
# inputs)
# 2. Encoder states are computed once per request, not incrementally
# 3. No reusable prefix exists between different multimodal inputs
# Return empty blocks to indicate no cache hits
raise NotImplementedError(
"CrossAttentionManager does not support caching")

def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass


spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
}


Expand Down
15 changes: 15 additions & 0 deletions vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv, get_dtype_size

logger = init_logger(__name__)
Expand Down Expand Up @@ -211,6 +212,20 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
return 0


@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
"""
KV cache spec for cross-attention layers in encoder-decoder models.
"""

def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# For cross-attention, we need to cache encoder states
# Get encoder length (e.g., 1500 for Whisper).
max_encoder_len = MULTIMODAL_REGISTRY.\
get_encdec_max_encoder_len(vllm_config.model_config)
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes


@dataclass
class KVCacheTensor:
"""
Expand Down