Skip to content

Commit 6666e52

Browse files
authored
Added support for KV connector v1 (#2039)
### What this PR does / why we need it? - This PR adds the support for the KV connector interface in the V1 architecture, in the same way as vllm. Vllm-ascend currently lacks of this support, required to support also layerwise management of KV caches. - The connector interface allows using external tools and integrate them with vllm ### Notes: We are aware of Issue #684 , however that issue does not modify the attention classes as necessary to perform layerwise management of KV caches required for connectors like LMCache. The implementation of this PR ported the necessary code from the vanilla vllm. The KV connector API is the same as vanilla vllm, supporting the standard KV connector API. EDIT: this PR was re-implementing part of the changes merged one hour before this PR was made on the file model_runner_v1.py. I solved the conflicts by removing any modification to the model_runner_v1 file, which now are largely already merged in main. Now this PR is left for the modifications to the attention_v1 file. ### Does this PR introduce _any_ user-facing change? The PR does not modify current APIs, but it extends the behavior of current worker runner and attention classes to save and load KV caches. In absence of connectors, the behavior should stay untouched. ### How was this patch tested? - No unit test implemented yet for the worker. - Tested together with LMCache using https://github.com/LMCache/LMCache/blob/dev/examples/kv_cache_reuse/local_backends/offload.py with the following models: 1 Deepseek-R1-Distill-Qwen-1.5B 2 Qwen3-30B-A3B 3 Deepseek-v2-lite 4 Llama-3.1-8B LMCache used in both layerwise and non-layerwise mode. - Performed LMEval on LMCache integrated with vllm-ascend. Results without LMCache on Qwen3-8B: |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8400|± |0.0101| | | |strict-match | 5|exact_match|↑ |0.8355|± |0.0102| Results with LMCache Layerwise: |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8385|± |0.0101| | | |strict-match | 5|exact_match|↑ |0.8332|± |0.0103| - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@50fede6 --------- Signed-off-by: marcobarlo <[email protected]> Signed-off-by: marcobarlo <[email protected]>
1 parent 2967e5e commit 6666e52

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
AttentionLayer, AttentionType)
2727
from vllm.attention.backends.utils import CommonAttentionState
2828
from vllm.config import VllmConfig
29+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
30+
has_kv_transfer_group,
31+
is_v1_kv_transfer_group)
2932
from vllm.forward_context import ForwardContext, get_forward_context
3033
from vllm.utils import cdiv, direct_register_custom_op
3134
from vllm.v1.core.sched.output import SchedulerOutput
@@ -37,6 +40,37 @@
3740
from vllm_ascend.worker.npu_input_batch import InputBatch
3841

3942

43+
def wait_for_kv_layer_from_connector(layer_name: str):
44+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
45+
return
46+
47+
connector = get_kv_transfer_group()
48+
49+
forward_context: ForwardContext = get_forward_context()
50+
attn_metadata = forward_context.attn_metadata
51+
if attn_metadata is None:
52+
return
53+
# TODO: assert ascendMetadata
54+
connector.wait_for_layer_load(layer_name)
55+
56+
57+
def maybe_save_kv_layer_to_connector(
58+
layer_name: str,
59+
kv_cache_layer: List[torch.Tensor],
60+
):
61+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
62+
return
63+
64+
connector = get_kv_transfer_group()
65+
66+
forward_context: ForwardContext = get_forward_context()
67+
attn_metadata = forward_context.attn_metadata
68+
if attn_metadata is None:
69+
return
70+
# TODO: assert ascendMetadata
71+
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
72+
73+
4074
class AscendAttentionBackend(AttentionBackend):
4175
accept_output_buffer: bool = True
4276

@@ -537,6 +571,7 @@ def unified_ascend_attention_with_output(
537571
output: torch.Tensor,
538572
layer_name: str,
539573
) -> None:
574+
wait_for_kv_layer_from_connector(layer_name)
540575
forward_context: ForwardContext = get_forward_context()
541576
attn_metadata = forward_context.attn_metadata
542577
self = forward_context.no_compile_layers[layer_name]
@@ -549,6 +584,7 @@ def unified_ascend_attention_with_output(
549584
attn_metadata,
550585
output,
551586
trace_flag=False)
587+
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
552588
return
553589

554590

0 commit comments

Comments
 (0)