6
6
from vllm .v1 .core .block_pool import BlockPool
7
7
from vllm .v1 .core .kv_cache_utils import BlockHash , KVCacheBlock
8
8
from vllm .v1 .core .single_type_kv_cache_manager import (
9
- FullAttentionManager , get_manager_for_kv_cache_spec )
9
+ CrossAttentionManager , FullAttentionManager , get_manager_for_kv_cache_spec )
10
10
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
11
11
KVCacheSpec )
12
12
from vllm .v1 .request import Request
@@ -42,9 +42,10 @@ def __init__(
42
42
) for i , kv_cache_group in enumerate (
43
43
self .kv_cache_config .kv_cache_groups ))
44
44
45
- def get_num_blocks_to_allocate (
46
- self , request_id : str , num_tokens : int ,
47
- new_computed_blocks : tuple [list [KVCacheBlock ], ...]) -> int :
45
+ def get_num_blocks_to_allocate (self , request_id : str , num_tokens : int ,
46
+ new_computed_blocks : tuple [
47
+ list [KVCacheBlock ], ...],
48
+ num_encoder_tokens : int ) -> int :
48
49
"""
49
50
Get the number of blocks needed to be allocated for the request.
50
51
@@ -54,14 +55,22 @@ def get_num_blocks_to_allocate(
54
55
tokens that are already allocated).
55
56
new_computed_blocks: The new computed blocks just hitting the
56
57
prefix caching.
58
+ num_encoder_tokens: The number of encoder tokens for allocating
59
+ blocks for cross-attention.
57
60
58
61
Returns:
59
62
The number of blocks.
60
63
"""
61
64
num_blocks_to_allocate = 0
62
65
for i , manager in enumerate (self .single_type_managers ):
63
- num_blocks_to_allocate += manager .get_num_blocks_to_allocate (
64
- request_id , num_tokens , new_computed_blocks [i ])
66
+ if isinstance (manager , CrossAttentionManager ):
67
+ # For cross-attention, we issue a single static allocation
68
+ # of blocks based on the number of encoder input tokens.
69
+ num_blocks_to_allocate += manager .get_num_blocks_to_allocate (
70
+ request_id , num_encoder_tokens , [])
71
+ else :
72
+ num_blocks_to_allocate += manager .get_num_blocks_to_allocate (
73
+ request_id , num_tokens , new_computed_blocks [i ])
65
74
return num_blocks_to_allocate
66
75
67
76
def save_new_computed_blocks (
@@ -79,8 +88,11 @@ def save_new_computed_blocks(
79
88
manager .save_new_computed_blocks (request_id ,
80
89
new_computed_blocks [i ])
81
90
82
- def allocate_new_blocks (self , request_id : str ,
83
- num_tokens : int ) -> tuple [list [KVCacheBlock ], ...]:
91
+ def allocate_new_blocks (
92
+ self ,
93
+ request_id : str ,
94
+ num_tokens : int ,
95
+ num_encoder_tokens : int = 0 ) -> tuple [list [KVCacheBlock ], ...]:
84
96
"""
85
97
Allocate new blocks for the request to give it at least `num_tokens`
86
98
token slots.
@@ -89,12 +101,16 @@ def allocate_new_blocks(self, request_id: str,
89
101
request_id: The request ID.
90
102
num_tokens: The total number of tokens that need a slot (including
91
103
tokens that are already allocated).
104
+ num_encoder_tokens: The number of encoder tokens for allocating
105
+ blocks for cross-attention.
92
106
93
107
Returns:
94
108
The new allocated blocks.
95
109
"""
96
110
return tuple (
97
- manager .allocate_new_blocks (request_id , num_tokens )
111
+ manager .allocate_new_blocks (
112
+ request_id , num_encoder_tokens if isinstance (
113
+ manager , CrossAttentionManager ) else num_tokens )
98
114
for manager in self .single_type_managers )
99
115
100
116
def cache_blocks (self , request : Request , num_computed_tokens : int ) -> None :
0 commit comments