Skip to content

[Bug]: Qwen3-next: Out of Memory due to repeated allocation of shared hybrid cache #3368

@QilaiZhang

Description

@QilaiZhang

Your current environment

Runtime environment matches the setup at: https://vllm-ascend.readthedocs.io/zh-cn/latest/tutorials/multi_npu_qwen3_next.html

🐛 Describe the bug

When launching vLLM server for the Qwen3-Next-80B model with a --gpu-memory-utilization of 0.7, the actual memory usage reaches 0.85, exceeding the set limit. Increasing the limit --gpu-memory-utilization to 0.9 results in an "NPU out of memory" error.

The issue has been primarily traced to the initialize_kv_cache_tensors function in vllm_ascend/worker/model_runner_v1.py. This function repeatedly allocates KV cache tensors within the same group, which deviates from the behavior in standard vLLM.

def initialize_kv_cache_tensors(
            self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
        """
        Initialize the memory buffer for KV cache.

        Args:
            kv_cache_config: The KV cache config
        Returns:
            Dict[str, torch.Tensor]: A map between layer names to their
            corresponding memory buffer for KV cache.
        """
        # init kv cache tensors
        kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
                                              Optional[torch.Tensor]]] = {}
        # llmdatadist need the addr of cache tensor be aligned with 2M
        alignment = 2 * 1024 * 1024
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            # TODO: REFACTOR ME to sharing hybrid cache
            for idx in range(len(kv_cache_tensor.shared_by)):
                layer_name = kv_cache_tensor.shared_by[idx]
                if "linear_attn" in layer_name:
                    # for mamba linear attention
                    for layer_name_inner in kv_cache_tensor.shared_by:
                        if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
                            layer_name_inner in kv_cache_raw_tensors.keys():
                            continue
                        if self.vllm_config.kv_transfer_config is None:
                            tensor = torch.zeros(kv_cache_tensor.size,
                                                 dtype=torch.int8,
                                                 device=self.device)
                        else:
                            cache_size_aligned = kv_cache_tensor.size + alignment
                            tensor = torch.zeros(cache_size_aligned,
                                                 dtype=torch.int8,
                                                 device=self.device)
                            tensor = self._align_memory(
                                tensor, alignment)[:kv_cache_tensor.size]
                        kv_cache_raw_tensors[layer_name_inner] = tensor
                elif "attn" in layer_name:
                    # for other attentions, e.g., self_attn, sliding window attn
                    if self.vllm_config.kv_transfer_config is None:
                        k_tensor = torch.zeros(kv_cache_tensor.size // 2,
                                               dtype=torch.int8,
                                               device=self.device)
                        v_tensor = torch.zeros(kv_cache_tensor.size // 2,
                                               dtype=torch.int8,
                                               device=self.device)
                    else:
                        cache_size = kv_cache_tensor.size // 2
                        cache_size_aligned = kv_cache_tensor.size // 2 + alignment
                        k_tensor = torch.zeros(cache_size_aligned,
                                               dtype=torch.int8,
                                               device=self.device)
                        v_tensor = torch.zeros(cache_size_aligned,
                                               dtype=torch.int8,
                                               device=self.device)
                        k_tensor = self._align_memory(k_tensor,
                                                      alignment)[:cache_size]
                        v_tensor = self._align_memory(v_tensor,
                                                      alignment)[:cache_size]
                    kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions