@@ -3154,25 +3154,26 @@ def initialize_kv_cache_tensors(
31543154 # TODO: REFACTOR ME to sharing hybrid cache
31553155 for idx in range (len (kv_cache_tensor .shared_by )):
31563156 layer_name = kv_cache_tensor .shared_by [idx ]
3157- if "linear_attn" in layer_name :
3157+ if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors .keys (
3158+ ):
31583159 # for mamba linear attention
3160+ if self .vllm_config .kv_transfer_config is None :
3161+ tensor = torch .zeros (kv_cache_tensor .size ,
3162+ dtype = torch .int8 ,
3163+ device = self .device )
3164+ else :
3165+ cache_size_aligned = kv_cache_tensor .size + alignment
3166+ tensor = torch .zeros (cache_size_aligned ,
3167+ dtype = torch .int8 ,
3168+ device = self .device )
3169+ tensor = self ._align_memory (
3170+ tensor , alignment )[:kv_cache_tensor .size ]
31593171 for layer_name_inner in kv_cache_tensor .shared_by :
3160- if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner ) or \
3161- layer_name_inner in kv_cache_raw_tensors .keys ():
3162- continue
3163- if self .vllm_config .kv_transfer_config is None :
3164- tensor = torch .zeros (kv_cache_tensor .size ,
3165- dtype = torch .int8 ,
3166- device = self .device )
3167- else :
3168- cache_size_aligned = kv_cache_tensor .size + alignment
3169- tensor = torch .zeros (cache_size_aligned ,
3170- dtype = torch .int8 ,
3171- device = self .device )
3172- tensor = self ._align_memory (
3173- tensor , alignment )[:kv_cache_tensor .size ]
3174- kv_cache_raw_tensors [layer_name_inner ] = tensor
3175- elif "attn" in layer_name :
3172+ # shared the kvcache between the self_attn specs in the same group
3173+ if "linear_attn" in layer_name_inner :
3174+ kv_cache_raw_tensors [layer_name_inner ] = tensor
3175+ elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors .keys (
3176+ ):
31763177 # for other attentions, e.g., self_attn, sliding window attn
31773178 if self .vllm_config .kv_transfer_config is None :
31783179 k_tensor = torch .zeros (kv_cache_tensor .size // 2 ,
@@ -3194,7 +3195,12 @@ def initialize_kv_cache_tensors(
31943195 alignment )[:cache_size ]
31953196 v_tensor = self ._align_memory (v_tensor ,
31963197 alignment )[:cache_size ]
3197- kv_cache_raw_tensors [layer_name ] = (k_tensor , v_tensor )
3198+ for layer_name_inner in kv_cache_tensor .shared_by :
3199+ # shared the kvcache between the self_attn specs in the same group
3200+ if ("attn" in layer_name_inner
3201+ and "linear_attn" not in layer_name_inner ):
3202+ kv_cache_raw_tensors [layer_name_inner ] = (k_tensor ,
3203+ v_tensor )
31983204
31993205 layer_names = set ()
32003206 for group in kv_cache_config .kv_cache_groups :
0 commit comments