-
Notifications
You must be signed in to change notification settings - Fork 544
[Bugfix] Fix duplicated KV cache allocation in qwen3-Next #3404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2850,49 +2850,33 @@ def initialize_kv_cache_tensors( | |
| # llmdatadist need the addr of cache tensor be aligned with 2M | ||
| alignment = 2 * 1024 * 1024 | ||
| for kv_cache_tensor in kv_cache_config.kv_cache_tensors: | ||
| # TODO: REFACTOR ME to sharing hybrid cache | ||
| if self.vllm_config.kv_transfer_config is None: | ||
| tensor = torch.zeros(kv_cache_tensor.size, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| k_tensor, v_tensor = tensor.chunk(2) | ||
| else: | ||
| cache_size = kv_cache_tensor.size // 2 | ||
| cache_size_aligned = cache_size + alignment | ||
| k_tensor_aligned = torch.zeros(cache_size_aligned, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| v_tensor_aligned = torch.zeros(cache_size_aligned, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| k_tensor = self._align_memory(k_tensor_aligned, | ||
| alignment)[:cache_size] | ||
| v_tensor = self._align_memory(v_tensor_aligned, | ||
| alignment)[:cache_size] | ||
| tensor = torch.cat([k_tensor, v_tensor]) | ||
|
|
||
| for idx in range(len(kv_cache_tensor.shared_by)): | ||
| layer_name = kv_cache_tensor.shared_by[idx] | ||
| if "linear_attn" in layer_name: | ||
| # for mamba linear attention | ||
| for layer_name_inner in kv_cache_tensor.shared_by: | ||
| if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \ | ||
| layer_name_inner in kv_cache_raw_tensors.keys(): | ||
| continue | ||
| if self.vllm_config.kv_transfer_config is None: | ||
| tensor = torch.zeros(kv_cache_tensor.size, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| else: | ||
| cache_size_aligned = kv_cache_tensor.size + alignment | ||
| tensor = torch.zeros(cache_size_aligned, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| tensor = self._align_memory( | ||
| tensor, alignment)[:kv_cache_tensor.size] | ||
| kv_cache_raw_tensors[layer_name_inner] = tensor | ||
| kv_cache_raw_tensors[layer_name] = tensor | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sharing kv cache memory between |
||
| elif "attn" in layer_name: | ||
| # for other attentions, e.g., self_attn, sliding window attn | ||
| if self.vllm_config.kv_transfer_config is None: | ||
| k_tensor = torch.zeros(kv_cache_tensor.size // 2, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| v_tensor = torch.zeros(kv_cache_tensor.size // 2, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| else: | ||
| cache_size = kv_cache_tensor.size // 2 | ||
| cache_size_aligned = kv_cache_tensor.size // 2 + alignment | ||
| k_tensor = torch.zeros(cache_size_aligned, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| v_tensor = torch.zeros(cache_size_aligned, | ||
| dtype=torch.int8, | ||
| device=self.device) | ||
| k_tensor = self._align_memory(k_tensor, | ||
| alignment)[:cache_size] | ||
| v_tensor = self._align_memory(v_tensor, | ||
| alignment)[:cache_size] | ||
| kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor) | ||
|
|
||
| layer_names = set() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactoring introduces a memory alignment issue for
linear_attnlayers whenkv_transfer_configis enabled.The code now allocates
k_tensorandv_tensorfrom separate 2M-aligned memory blocks, which is correct for standard attention layers. However, it then usestensor = torch.cat([k_tensor, v_tensor])to create the tensor forlinear_attnlayers.torch.caton tensors from different storage will create a new tensor by copying data, and the memory for this new tensor is not guaranteed to be 2M-aligned.This appears to break the requirement for
llmdatadist, which, according to the comment on line 2850, needs the cache tensor to be aligned. This can lead to runtime errors or incorrect behavior. The previous implementation correctly aligned the tensor forlinear_attn.Given the conflicting alignment requirements for a shared hybrid cache (one large aligned buffer for
linear_attnvs. two separate aligned buffers for standardattn), please revisit this implementation to ensure all tensors meet their alignment requirements.