Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 21 additions & 37 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,49 +2850,33 @@ def initialize_kv_cache_tensors(
# llmdatadist need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
# TODO: REFACTOR ME to sharing hybrid cache
if self.vllm_config.kv_transfer_config is None:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
k_tensor, v_tensor = tensor.chunk(2)
else:
cache_size = kv_cache_tensor.size // 2
cache_size_aligned = cache_size + alignment
k_tensor_aligned = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
v_tensor_aligned = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
k_tensor = self._align_memory(k_tensor_aligned,
alignment)[:cache_size]
v_tensor = self._align_memory(v_tensor_aligned,
alignment)[:cache_size]
tensor = torch.cat([k_tensor, v_tensor])
Comment on lines +2858 to +2871
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This refactoring introduces a memory alignment issue for linear_attn layers when kv_transfer_config is enabled.

The code now allocates k_tensor and v_tensor from separate 2M-aligned memory blocks, which is correct for standard attention layers. However, it then uses tensor = torch.cat([k_tensor, v_tensor]) to create the tensor for linear_attn layers. torch.cat on tensors from different storage will create a new tensor by copying data, and the memory for this new tensor is not guaranteed to be 2M-aligned.

This appears to break the requirement for llmdatadist, which, according to the comment on line 2850, needs the cache tensor to be aligned. This can lead to runtime errors or incorrect behavior. The previous implementation correctly aligned the tensor for linear_attn.

Given the conflicting alignment requirements for a shared hybrid cache (one large aligned buffer for linear_attn vs. two separate aligned buffers for standard attn), please revisit this implementation to ensure all tensors meet their alignment requirements.


for idx in range(len(kv_cache_tensor.shared_by)):
layer_name = kv_cache_tensor.shared_by[idx]
if "linear_attn" in layer_name:
# for mamba linear attention
for layer_name_inner in kv_cache_tensor.shared_by:
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
layer_name_inner in kv_cache_raw_tensors.keys():
continue
if self.vllm_config.kv_transfer_config is None:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
else:
cache_size_aligned = kv_cache_tensor.size + alignment
tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
tensor = self._align_memory(
tensor, alignment)[:kv_cache_tensor.size]
kv_cache_raw_tensors[layer_name_inner] = tensor
kv_cache_raw_tensors[layer_name] = tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sharing kv cache memory between linear_attn and self_attn will cause accuracy issue in Qwen3-Next. Thus we should allocate memory for linear_attn and self_attn seperately.
Actually I also have a PR on fixing this issue and refactor the kvcache initialization logic, but it is not ready for review now. Feel free to change the bug fix logic in this pr according to #3106

elif "attn" in layer_name:
# for other attentions, e.g., self_attn, sliding window attn
if self.vllm_config.kv_transfer_config is None:
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
dtype=torch.int8,
device=self.device)
else:
cache_size = kv_cache_tensor.size // 2
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
k_tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
k_tensor = self._align_memory(k_tensor,
alignment)[:cache_size]
v_tensor = self._align_memory(v_tensor,
alignment)[:cache_size]
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)

layer_names = set()
Expand Down
Loading