@@ -3175,25 +3175,26 @@ def initialize_kv_cache_tensors(
31753175 # TODO: REFACTOR ME to sharing hybrid cache
31763176 for idx in range (len (kv_cache_tensor .shared_by )):
31773177 layer_name = kv_cache_tensor .shared_by [idx ]
3178- if "linear_attn" in layer_name :
3178+ if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors .keys (
3179+ ):
31793180 # for mamba linear attention
3181+ if self .vllm_config .kv_transfer_config is None :
3182+ tensor = torch .zeros (kv_cache_tensor .size ,
3183+ dtype = torch .int8 ,
3184+ device = self .device )
3185+ else :
3186+ cache_size_aligned = kv_cache_tensor .size + alignment
3187+ tensor = torch .zeros (cache_size_aligned ,
3188+ dtype = torch .int8 ,
3189+ device = self .device )
3190+ tensor = self ._align_memory (
3191+ tensor , alignment )[:kv_cache_tensor .size ]
31803192 for layer_name_inner in kv_cache_tensor .shared_by :
3181- if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner ) or \
3182- layer_name_inner in kv_cache_raw_tensors .keys ():
3183- continue
3184- if self .vllm_config .kv_transfer_config is None :
3185- tensor = torch .zeros (kv_cache_tensor .size ,
3186- dtype = torch .int8 ,
3187- device = self .device )
3188- else :
3189- cache_size_aligned = kv_cache_tensor .size + alignment
3190- tensor = torch .zeros (cache_size_aligned ,
3191- dtype = torch .int8 ,
3192- device = self .device )
3193- tensor = self ._align_memory (
3194- tensor , alignment )[:kv_cache_tensor .size ]
3195- kv_cache_raw_tensors [layer_name_inner ] = tensor
3196- elif "attn" in layer_name :
3193+ # shared the kvcache between the self_attn specs in the same group
3194+ if "linear_attn" in layer_name_inner :
3195+ kv_cache_raw_tensors [layer_name_inner ] = tensor
3196+ elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors .keys (
3197+ ):
31973198 # for other attentions, e.g., self_attn, sliding window attn
31983199 if self .vllm_config .kv_transfer_config is None :
31993200 k_tensor = torch .zeros (kv_cache_tensor .size // 2 ,
@@ -3215,7 +3216,12 @@ def initialize_kv_cache_tensors(
32153216 alignment )[:cache_size ]
32163217 v_tensor = self ._align_memory (v_tensor ,
32173218 alignment )[:cache_size ]
3218- kv_cache_raw_tensors [layer_name ] = (k_tensor , v_tensor )
3219+ for layer_name_inner in kv_cache_tensor .shared_by :
3220+ # shared the kvcache between the self_attn specs in the same group
3221+ if ("attn" in layer_name_inner
3222+ and "linear_attn" not in layer_name_inner ):
3223+ kv_cache_raw_tensors [layer_name_inner ] = (k_tensor ,
3224+ v_tensor )
32193225
32203226 layer_names = set ()
32213227 for group in kv_cache_config .kv_cache_groups :
0 commit comments