Skip to content

Commit 1697ba6

Browse files
committed
[HybridKV][Bugfix] Fix Hybrid kvcache sharing bug in same attention type
Signed-off-by: MengqingCao <[email protected]>
1 parent bb5f16d commit 1697ba6

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

tests/e2e/multicard/test_qwen3_next.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
def test_models_distributed_Qwen3_NEXT_TP4():
2828
example_prompts = [
2929
"Hello, my name is",
30-
]
30+
] * 4
3131
max_tokens = 5
3232
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
3333
tensor_parallel_size=4,
3434
max_model_len=4096,
35-
gpu_memory_utilization=0.7,
35+
gpu_memory_utilization=0.8,
3636
distributed_executor_backend="mp",
3737
enforce_eager=True) as vllm_model:
3838
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,25 +3154,26 @@ def initialize_kv_cache_tensors(
31543154
# TODO: REFACTOR ME to sharing hybrid cache
31553155
for idx in range(len(kv_cache_tensor.shared_by)):
31563156
layer_name = kv_cache_tensor.shared_by[idx]
3157-
if "linear_attn" in layer_name:
3157+
if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
3158+
):
31583159
# for mamba linear attention
3160+
if self.vllm_config.kv_transfer_config is None:
3161+
tensor = torch.zeros(kv_cache_tensor.size,
3162+
dtype=torch.int8,
3163+
device=self.device)
3164+
else:
3165+
cache_size_aligned = kv_cache_tensor.size + alignment
3166+
tensor = torch.zeros(cache_size_aligned,
3167+
dtype=torch.int8,
3168+
device=self.device)
3169+
tensor = self._align_memory(
3170+
tensor, alignment)[:kv_cache_tensor.size]
31593171
for layer_name_inner in kv_cache_tensor.shared_by:
3160-
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
3161-
layer_name_inner in kv_cache_raw_tensors.keys():
3162-
continue
3163-
if self.vllm_config.kv_transfer_config is None:
3164-
tensor = torch.zeros(kv_cache_tensor.size,
3165-
dtype=torch.int8,
3166-
device=self.device)
3167-
else:
3168-
cache_size_aligned = kv_cache_tensor.size + alignment
3169-
tensor = torch.zeros(cache_size_aligned,
3170-
dtype=torch.int8,
3171-
device=self.device)
3172-
tensor = self._align_memory(
3173-
tensor, alignment)[:kv_cache_tensor.size]
3174-
kv_cache_raw_tensors[layer_name_inner] = tensor
3175-
elif "attn" in layer_name:
3172+
# shared the kvcache between the self_attn specs in the same group
3173+
if "linear_attn" in layer_name_inner:
3174+
kv_cache_raw_tensors[layer_name_inner] = tensor
3175+
elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
3176+
):
31763177
# for other attentions, e.g., self_attn, sliding window attn
31773178
if self.vllm_config.kv_transfer_config is None:
31783179
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
@@ -3194,7 +3195,12 @@ def initialize_kv_cache_tensors(
31943195
alignment)[:cache_size]
31953196
v_tensor = self._align_memory(v_tensor,
31963197
alignment)[:cache_size]
3197-
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
3198+
for layer_name_inner in kv_cache_tensor.shared_by:
3199+
# shared the kvcache between the self_attn specs in the same group
3200+
if ("attn" in layer_name_inner
3201+
and "linear_attn" not in layer_name_inner):
3202+
kv_cache_raw_tensors[layer_name_inner] = (k_tensor,
3203+
v_tensor)
31983204

31993205
layer_names = set()
32003206
for group in kv_cache_config.kv_cache_groups:

0 commit comments

Comments
 (0)