Skip to content

Commit 1e84fbf

Browse files
russellbamd-xiaoyu12
authored andcommitted
[v1] Add cross-attention KV cache support for encoder-decoder models (vllm-project#23664)
Signed-off-by: Russell Bryant <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 7dd8a25 commit 1e84fbf

File tree

6 files changed

+153
-14
lines changed

6 files changed

+153
-14
lines changed

vllm/multimodal/registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,22 @@ def get_encoder_dummy_data(
372372
)
373373

374374
return dummy_data
375+
376+
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
377+
"""
378+
Get the maximum length of the encoder input for encoder-decoder models.
379+
"""
380+
if not model_config.is_encoder_decoder:
381+
return 0
382+
max_tokens = self.\
383+
get_max_tokens_per_item_by_nonzero_modality(model_config)
384+
if not max_tokens:
385+
# TODO - this function assumes encoder-decoder models are
386+
# multimodal. This will need to change when adding support for more
387+
# than whisper.
388+
return 0
389+
assert len(max_tokens) == 1, "Encoder-decoder models are expected \
390+
to implement the multimodal interface with at most one modality."
391+
392+
first_modality = next(iter(max_tokens))
393+
return max_tokens[first_modality]

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from vllm.v1.core.block_pool import BlockPool
77
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
88
from vllm.v1.core.single_type_kv_cache_manager import (
9-
FullAttentionManager, get_manager_for_kv_cache_spec)
9+
CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec)
1010
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1111
KVCacheSpec)
1212
from vllm.v1.request import Request
@@ -42,9 +42,10 @@ def __init__(
4242
) for i, kv_cache_group in enumerate(
4343
self.kv_cache_config.kv_cache_groups))
4444

45-
def get_num_blocks_to_allocate(
46-
self, request_id: str, num_tokens: int,
47-
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
45+
def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
46+
new_computed_blocks: tuple[
47+
list[KVCacheBlock], ...],
48+
num_encoder_tokens: int) -> int:
4849
"""
4950
Get the number of blocks needed to be allocated for the request.
5051
@@ -54,14 +55,22 @@ def get_num_blocks_to_allocate(
5455
tokens that are already allocated).
5556
new_computed_blocks: The new computed blocks just hitting the
5657
prefix caching.
58+
num_encoder_tokens: The number of encoder tokens for allocating
59+
blocks for cross-attention.
5760
5861
Returns:
5962
The number of blocks.
6063
"""
6164
num_blocks_to_allocate = 0
6265
for i, manager in enumerate(self.single_type_managers):
63-
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
64-
request_id, num_tokens, new_computed_blocks[i])
66+
if isinstance(manager, CrossAttentionManager):
67+
# For cross-attention, we issue a single static allocation
68+
# of blocks based on the number of encoder input tokens.
69+
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
70+
request_id, num_encoder_tokens, [])
71+
else:
72+
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
73+
request_id, num_tokens, new_computed_blocks[i])
6574
return num_blocks_to_allocate
6675

6776
def save_new_computed_blocks(
@@ -79,8 +88,11 @@ def save_new_computed_blocks(
7988
manager.save_new_computed_blocks(request_id,
8089
new_computed_blocks[i])
8190

82-
def allocate_new_blocks(self, request_id: str,
83-
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
91+
def allocate_new_blocks(
92+
self,
93+
request_id: str,
94+
num_tokens: int,
95+
num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]:
8496
"""
8597
Allocate new blocks for the request to give it at least `num_tokens`
8698
token slots.
@@ -89,12 +101,16 @@ def allocate_new_blocks(self, request_id: str,
89101
request_id: The request ID.
90102
num_tokens: The total number of tokens that need a slot (including
91103
tokens that are already allocated).
104+
num_encoder_tokens: The number of encoder tokens for allocating
105+
blocks for cross-attention.
92106
93107
Returns:
94108
The new allocated blocks.
95109
"""
96110
return tuple(
97-
manager.allocate_new_blocks(request_id, num_tokens)
111+
manager.allocate_new_blocks(
112+
request_id, num_encoder_tokens if isinstance(
113+
manager, CrossAttentionManager) else num_tokens)
98114
for manager in self.single_type_managers)
99115

100116
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:

vllm/v1/core/kv_cache_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def allocate_slots(
187187
new_computed_blocks: Optional[KVCacheBlocks] = None,
188188
num_lookahead_tokens: int = 0,
189189
delay_cache_blocks: bool = False,
190+
num_encoder_tokens: int = 0,
190191
) -> Optional[KVCacheBlocks]:
191192
"""Add slots for a request with new tokens to append.
192193
@@ -253,6 +254,7 @@ def allocate_slots(
253254
request_id=request.request_id,
254255
num_tokens=num_tokens_need_slot,
255256
new_computed_blocks=new_computed_block_list,
257+
num_encoder_tokens=num_encoder_tokens,
256258
)
257259

258260
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
@@ -273,7 +275,7 @@ def allocate_slots(
273275
new_computed_block_list)
274276

275277
new_blocks = self.coordinator.allocate_new_blocks(
276-
request.request_id, num_tokens_need_slot)
278+
request.request_id, num_tokens_need_slot, num_encoder_tokens)
277279

278280
# P/D: delay caching blocks if we have to recv from
279281
# remote. Update state for locally cached blocks.
@@ -292,7 +294,7 @@ def allocate_slots(
292294

293295
def free(self, request: Request) -> None:
294296
"""Free the blocks allocated for the request.
295-
We free the blocks in reverse order so that he tail blocks are evicted
297+
We free the blocks in reverse order so that the tail blocks are evicted
296298
first when caching is enabled.
297299
298300
Args:

vllm/v1/core/sched/scheduler.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
self.parallel_config = vllm_config.parallel_config
5959
self.log_stats = log_stats
6060
self.structured_output_manager = structured_output_manager
61+
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
6162

6263
# include_finished_set controls whether a separate set of finished
6364
# request ids should be included in the EngineCoreOutputs returned
@@ -83,6 +84,9 @@ def __init__(
8384
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
8485
"Multiple KV cache groups are not currently supported "
8586
"with KV connectors")
87+
assert not self.is_encoder_decoder, (
88+
"Encoder-decoder models are not currently supported "
89+
"with KV connectors")
8690
self.connector = KVConnectorFactory.create_connector(
8791
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
8892

@@ -431,13 +435,30 @@ def schedule(self) -> SchedulerOutput:
431435
== 0 else
432436
self.num_lookahead_tokens)
433437

438+
# Determine if we need to allocate cross-attention blocks.
439+
if self.is_encoder_decoder and request.has_encoder_inputs:
440+
# TODO(russellb): For Whisper, we know that the input is
441+
# always padded to the maximum length. If we support other
442+
# encoder-decoder models, this will need to be updated if we
443+
# want to only allocate what is needed.
444+
assert ("whisper"
445+
in self.vllm_config.model_config.model.lower()), (
446+
"Whisper is the only supported "
447+
"encoder-decoder model.")
448+
num_encoder_tokens = MULTIMODAL_REGISTRY.\
449+
get_encdec_max_encoder_len(
450+
self.vllm_config.model_config)
451+
else:
452+
num_encoder_tokens = 0
453+
434454
new_blocks = self.kv_cache_manager.allocate_slots(
435455
request,
436456
num_new_tokens + num_external_computed_tokens,
437457
num_new_local_computed_tokens,
438458
new_computed_blocks,
439459
num_lookahead_tokens=effective_lookahead_tokens,
440460
delay_cache_blocks=load_kv_async,
461+
num_encoder_tokens=num_encoder_tokens,
441462
)
442463

443464
if new_blocks is None:
@@ -703,7 +724,21 @@ def _try_schedule_encoder_inputs(
703724
# The encoder input is not needed in this step.
704725
break
705726

706-
if start_pos + num_encoder_tokens <= num_computed_tokens:
727+
if self.is_encoder_decoder and num_computed_tokens > 0:
728+
assert start_pos == 0, (
729+
"Encoder input should be processed at the beginning of "
730+
"the sequence when encoder-decoder models are used.")
731+
# Encoder input has already been computed
732+
# The calculation here is a bit different. We don't turn encoder
733+
# output into tokens that get processed by the decoder and
734+
# reflected in num_computed_tokens. Instead, start_pos reflects
735+
# the position where we need to ensure we calculate encoder
736+
# inputs. This should always be 0 to ensure we calculate encoder
737+
# inputs before running the decoder. Once we've calculated some
738+
# decoder tokens (num_computed_tokens > 0), then we know we
739+
# already calculated encoder inputs and can skip here.
740+
continue
741+
elif start_pos + num_encoder_tokens <= num_computed_tokens:
707742
# The encoder input is already computed and stored
708743
# in the decoder's KV cache.
709744
continue

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from vllm.v1.core.block_pool import BlockPool
99
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
1010
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
11-
FullAttentionSpec, KVCacheSpec,
12-
MambaSpec, SlidingWindowSpec)
11+
CrossAttentionSpec, FullAttentionSpec,
12+
KVCacheSpec, MambaSpec,
13+
SlidingWindowSpec)
1314
from vllm.v1.request import Request
1415

1516

@@ -552,11 +553,62 @@ def allocate_new_blocks(self, request_id: str,
552553
return new_blocks
553554

554555

556+
class CrossAttentionManager(SingleTypeKVCacheManager):
557+
"""Manager for cross-attention KV cache in encoder-decoder models."""
558+
559+
def save_new_computed_blocks(
560+
self, request_id: str,
561+
new_computed_blocks: list[KVCacheBlock]) -> None:
562+
# We do not cache blocks for cross-attention to be shared between
563+
# requests, so `new_computed_blocks` should always be empty.
564+
assert len(new_computed_blocks) == 0
565+
566+
def cache_blocks(self, request: Request, num_tokens: int) -> None:
567+
# We do not cache blocks for cross-attention to be shared between
568+
# requests, so this method is not relevant.
569+
raise ValueError("Should not be called as prefix caching is disabled.")
570+
571+
def get_num_common_prefix_blocks(self, request_id: str,
572+
num_running_requests: int) -> int:
573+
# Cross-attention blocks contain request-specific encoder states
574+
# and are not shared between different requests
575+
return 0
576+
577+
@classmethod
578+
def find_longest_cache_hit(
579+
cls,
580+
block_hashes: list[BlockHash],
581+
max_length: int,
582+
kv_cache_group_ids: list[int],
583+
block_pool: BlockPool,
584+
kv_cache_spec: KVCacheSpec,
585+
use_eagle: bool,
586+
) -> tuple[list[KVCacheBlock], ...]:
587+
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
588+
"CrossAttentionManager can only be used for cross-attention groups"
589+
)
590+
# Cross-attention does not benefit from prefix caching since:
591+
# 1. Encoder states are unique per request (different audio/image
592+
# inputs)
593+
# 2. Encoder states are computed once per request, not incrementally
594+
# 3. No reusable prefix exists between different multimodal inputs
595+
# Return empty blocks to indicate no cache hits
596+
raise NotImplementedError(
597+
"CrossAttentionManager does not support caching")
598+
599+
def remove_skipped_blocks(self, request_id: str,
600+
num_computed_tokens: int) -> None:
601+
# Cross-attention blocks represent encoder states which are needed
602+
# for the entire decoding process, so no blocks should be skipped
603+
pass
604+
605+
555606
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
556607
FullAttentionSpec: FullAttentionManager,
557608
SlidingWindowSpec: SlidingWindowManager,
558609
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
559610
MambaSpec: MambaManager,
611+
CrossAttentionSpec: CrossAttentionManager,
560612
}
561613

562614

vllm/v1/kv_cache_interface.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from vllm.config import VllmConfig
1313
from vllm.logger import init_logger
14+
from vllm.multimodal import MULTIMODAL_REGISTRY
1415
from vllm.utils import cdiv, get_dtype_size
1516

1617
logger = init_logger(__name__)
@@ -211,6 +212,20 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
211212
return 0
212213

213214

215+
@dataclass(frozen=True)
216+
class CrossAttentionSpec(AttentionSpec):
217+
"""
218+
KV cache spec for cross-attention layers in encoder-decoder models.
219+
"""
220+
221+
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
222+
# For cross-attention, we need to cache encoder states
223+
# Get encoder length (e.g., 1500 for Whisper).
224+
max_encoder_len = MULTIMODAL_REGISTRY.\
225+
get_encdec_max_encoder_len(vllm_config.model_config)
226+
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
227+
228+
214229
@dataclass
215230
class KVCacheTensor:
216231
"""

0 commit comments

Comments
 (0)