From ead4f2fcb1563e016c46f830214b53dce8a6eba2 Mon Sep 17 00:00:00 2001 From: QilaiZhang <245706640@qq.com> Date: Mon, 13 Oct 2025 11:41:01 +0800 Subject: [PATCH 1/2] [Bugfix] Fix duplicated KV cache allocation in qwen3-Next Signed-off-by: QilaiZhang <245706640@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 58 ++++++++++----------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7bddae09da..1f5b2ba09e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2850,49 +2850,33 @@ def initialize_kv_cache_tensors( # llmdatadist need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - # TODO: REFACTOR ME to sharing hybrid cache + if self.vllm_config.kv_transfer_config is None: + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=self.device) + k_tensor, v_tensor = tensor.chunk(2) + else: + cache_size = kv_cache_tensor.size // 2 + cache_size_aligned = cache_size + alignment + k_tensor_aligned = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + v_tensor_aligned = torch.zeros(cache_size_aligned, + dtype=torch.int8, + device=self.device) + k_tensor = self._align_memory(k_tensor_aligned, + alignment)[:cache_size] + v_tensor = self._align_memory(v_tensor_aligned, + alignment)[:cache_size] + tensor = torch.cat([k_tensor, v_tensor]) + for idx in range(len(kv_cache_tensor.shared_by)): layer_name = kv_cache_tensor.shared_by[idx] if "linear_attn" in layer_name: # for mamba linear attention - for layer_name_inner in kv_cache_tensor.shared_by: - if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \ - layer_name_inner in kv_cache_raw_tensors.keys(): - continue - if self.vllm_config.kv_transfer_config is None: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) - else: - cache_size_aligned = kv_cache_tensor.size + alignment - tensor = torch.zeros(cache_size_aligned, - dtype=torch.int8, - device=self.device) - tensor = self._align_memory( - tensor, alignment)[:kv_cache_tensor.size] - kv_cache_raw_tensors[layer_name_inner] = tensor + kv_cache_raw_tensors[layer_name] = tensor elif "attn" in layer_name: # for other attentions, e.g., self_attn, sliding window attn - if self.vllm_config.kv_transfer_config is None: - k_tensor = torch.zeros(kv_cache_tensor.size // 2, - dtype=torch.int8, - device=self.device) - v_tensor = torch.zeros(kv_cache_tensor.size // 2, - dtype=torch.int8, - device=self.device) - else: - cache_size = kv_cache_tensor.size // 2 - cache_size_aligned = kv_cache_tensor.size // 2 + alignment - k_tensor = torch.zeros(cache_size_aligned, - dtype=torch.int8, - device=self.device) - v_tensor = torch.zeros(cache_size_aligned, - dtype=torch.int8, - device=self.device) - k_tensor = self._align_memory(k_tensor, - alignment)[:cache_size] - v_tensor = self._align_memory(v_tensor, - alignment)[:cache_size] kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor) layer_names = set() From 1fe26cb57c60a956d422c04ba9f8c8fd91e2abee Mon Sep 17 00:00:00 2001 From: QilaiZhang <245706640@qq.com> Date: Mon, 13 Oct 2025 12:37:37 +0800 Subject: [PATCH 2/2] [Bugfix] Fix duplicated KV cache allocation in qwen3-Next Signed-off-by: QilaiZhang <245706640@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1f5b2ba09e..6b90025834 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2874,7 +2874,7 @@ def initialize_kv_cache_tensors( layer_name = kv_cache_tensor.shared_by[idx] if "linear_attn" in layer_name: # for mamba linear attention - kv_cache_raw_tensors[layer_name] = tensor + kv_cache_raw_tensors[layer_name] = tensor elif "attn" in layer_name: # for other attentions, e.g., self_attn, sliding window attn kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)