Skip to content

Commit 15a72ac

Browse files
liuyumoyeliuyumoye
andauthored
[V1] Exception Handling when Loading KV Cache from Remote Store (#21534)
Signed-off-by: liuyumoye <[email protected]> Co-authored-by: liuyumoye <[email protected]>
1 parent 04ff4be commit 15a72ac

File tree

10 files changed

+229
-5
lines changed

10 files changed

+229
-5
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import logging
4+
import random
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING
7+
8+
import torch
9+
10+
from vllm.config import VllmConfig
11+
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
12+
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
13+
from vllm.v1.core.sched.output import SchedulerOutput
14+
15+
if TYPE_CHECKING:
16+
from vllm.attention.backends.abstract import AttentionMetadata
17+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
18+
from vllm.v1.request import Request
19+
20+
logger = logging.getLogger()
21+
logging.basicConfig(level=logging.INFO)
22+
23+
24+
@dataclass
25+
class RandomDropConnectorMetadata(KVConnectorMetadata):
26+
req_meta: dict[str, list[int]]
27+
28+
29+
class RandomDropConnector(KVConnectorBase_V1):
30+
"""
31+
A connector designed for fault tolerance testing by randomly dropping
32+
kv data during the process of loading or receiving KV cache.
33+
34+
This class simulates real-world scenarios where requests or data
35+
might be lost or timeout, allowing developers to test and validate the
36+
system's ability to handle such failures.
37+
38+
Attributes:
39+
finished_recving_kv_req_ids (set[str]): A set of request IDs that
40+
have completed receiving KV cache data.
41+
finished_loading_dict (dict[str, int]): A dictionary that tracks
42+
the actual number of tokens loaded from the remote KV store
43+
for each completed request. The keys are request IDs, and
44+
the values are the corresponding token counts.
45+
"""
46+
47+
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
48+
super().__init__(vllm_config=vllm_config, role=role)
49+
50+
self.failure_request: list[str] = []
51+
self._reqs_need_recv: dict[str, list[int]] = {}
52+
self._finish_load: dict[str, int] = {}
53+
54+
self.chunk_size = 256
55+
56+
############################################################
57+
# Scheduler Side Methods
58+
############################################################
59+
60+
def get_num_new_matched_tokens(
61+
self, request: "Request",
62+
num_computed_tokens: int) -> tuple[int, bool]:
63+
if request.request_id in self.failure_request:
64+
self.failure_request.remove(request.request_id)
65+
return 0, False
66+
num_external_hit_tokens = request.num_prompt_tokens - 1
67+
logger.info(
68+
"request %s num_prompt_tokens %d num_external_hit_tokens %d",
69+
request.request_id, request.num_prompt_tokens,
70+
num_external_hit_tokens)
71+
return num_external_hit_tokens, True
72+
73+
def update_state_after_alloc(self, request: "Request",
74+
blocks: "KVCacheBlocks",
75+
num_external_tokens: int):
76+
if num_external_tokens > 0:
77+
self._reqs_need_recv[
78+
request.
79+
request_id] = request.prompt_token_ids[:num_external_tokens]
80+
81+
def build_connector_meta(
82+
self,
83+
scheduler_output: SchedulerOutput,
84+
) -> KVConnectorMetadata:
85+
req_meta = self._reqs_need_recv.copy()
86+
self._reqs_need_recv.clear()
87+
return RandomDropConnectorMetadata(req_meta)
88+
89+
def add_failure_request(self, request: "Request"):
90+
self.failure_request.append(request.request_id)
91+
92+
def start_load_kv(self, forward_context, **kwargs) -> None:
93+
for request_id, hit_tokens in self._get_connector_metadata(
94+
).req_meta.items():
95+
num_actual_load_tokens = self.load_kv(request_id, hit_tokens)
96+
logger.info("request %s hit_tokens %d num_actual_load_tokens %d",
97+
request_id, len(hit_tokens), num_actual_load_tokens)
98+
self._finish_load[request_id] = num_actual_load_tokens
99+
100+
def wait_for_layer_load(self, layer_name: str) -> None:
101+
pass
102+
103+
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
104+
attn_metadata: "AttentionMetadata", **kwargs) -> None:
105+
pass
106+
107+
def wait_for_save(self):
108+
pass
109+
110+
def load_kv(self, request_id, hit_tokens):
111+
num_actual_load_tokens = random.randint(0, len(hit_tokens))
112+
return num_actual_load_tokens
113+
114+
def get_finished_loading(self) -> dict[str, int]:
115+
if not self._finish_load:
116+
return {}
117+
finished_loading = self._finish_load.copy()
118+
self._finish_load.clear()
119+
120+
return finished_loading
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
4+
export PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR
5+
6+
vllm serve DeepSeek-V2-Lite-Chat \
7+
--trust-remote-code \
8+
--served-model-name vllm_cpu_offload \
9+
--max-model-len 32768 \
10+
--no-enable-prefix-caching \
11+
--max-seq-len-to-capture 10000 \
12+
--max-num-seqs 64 \
13+
--gpu-memory-utilization 0.9 \
14+
--host 0.0.0.0 \
15+
-tp 2 \
16+
--kv-transfer-config '{"kv_connector":"RandomDropConnector","kv_role":"kv_both","kv_connector_module_path":"random_drop_connector"}'

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,27 @@ def update_finished_set(req_ids: Optional[set[str]],
139139
finished_set.add(req_id)
140140
del remaining_count_dict[req_id]
141141

142+
def update_finished_load_dict(worker_finished_loading_dict: dict[str,
143+
int],
144+
finished_loading_dict: dict[str, int]):
145+
for req_id, num_actual_load_tokens in (worker_finished_loading_dict
146+
or {}).items():
147+
if req_id in finished_loading_dict:
148+
finished_loading_dict[req_id] = min(
149+
finished_loading_dict[req_id], num_actual_load_tokens)
150+
else:
151+
finished_loading_dict[req_id] = num_actual_load_tokens
152+
142153
finished_sending = set[str]()
143154
finished_recving = set[str]()
155+
finished_loading_dict: dict[str, int] = {}
144156
for output in outputs:
145157
update_finished_set(output.finished_sending,
146158
self._send_remaining_count, finished_sending)
147159
update_finished_set(output.finished_recving,
148160
self._recv_remaining_count, finished_recving)
161+
update_finished_load_dict(output.finished_loading_dict,
162+
finished_loading_dict)
149163

150164
# select output of the worker specified by output_rank
151165
output = outputs[output_rank]
@@ -157,7 +171,7 @@ def update_finished_set(req_ids: Optional[set[str]],
157171
# send/recv
158172
output.finished_sending = finished_sending if finished_sending else None
159173
output.finished_recving = finished_recving if finished_recving else None
160-
174+
output.finished_loading_dict = finished_loading_dict or None
161175
return output
162176

163177
def async_aggregate(self,

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
2929
get_finished() - called with ids of finished requests, returns
3030
ids of requests that have completed async sending/recving.
31+
get_finished_loading() - called with scheduler outputs, returns
32+
a dictionary that the keys are request IDs and the values are
33+
the actual number of tokens loaded from the remote KV cache
3134
"""
3235

3336
import enum
@@ -219,6 +222,23 @@ def get_finished(
219222
"""
220223
return None, None
221224

225+
def get_finished_loading(
226+
self, scheduler_output: "SchedulerOutput") -> dict[str, int]:
227+
"""
228+
Retrieves the actual number of tokens loaded for requests that have
229+
completed the asynchronous loading process from the remote KV cache.
230+
231+
This function is used by the scheduler process (via the Executors)
232+
to track the progress of requests and determine which requests have
233+
successfully finished loading their KV cache data.
234+
235+
Returns:
236+
A dictionary where the keys are request IDs and the values are the
237+
corresponding number of tokens that have been successfully loaded
238+
for each request.
239+
"""
240+
return {}
241+
222242
# ==============================
223243
# Scheduler-side methods
224244
# ==============================

vllm/sequence.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,8 @@ class IntermediateTensors:
11671167
# [req_ids]
11681168
finished_sending: Optional[set[str]] = None
11691169
finished_recving: Optional[set[str]] = None
1170+
#req_id -> num_actual_load_tokens
1171+
finished_loading_dict: Optional[dict[str, int]] = None
11701172

11711173
def __init__(self, tensors):
11721174
# manually define this function, so that

vllm/v1/core/sched/scheduler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def __init__(
118118

119119
# KV Connector: requests in process of async KV loading or recving
120120
self.finished_recving_kv_req_ids: set[str] = set()
121+
# The keys are request IDs, and the values are corresponding token
122+
# count that have been successfully loaded from the remote KV store
123+
self.finished_loading_dict: dict[str, int] = {}
121124

122125
# Encoder-related.
123126
# Calculate encoder cache size if applicable
@@ -1094,6 +1097,27 @@ def _connector_finished(
10941097
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
10951098
return self.connector.request_finished(request, block_ids)
10961099

1100+
def _update_actual_load_token_num_from_remote_kv(self,
1101+
request: Request) -> bool:
1102+
1103+
num_actual_load_tokens = self.finished_loading_dict.pop(
1104+
request.request_id)
1105+
num_computed_tokens = num_actual_load_tokens
1106+
assert self.connector is not None
1107+
if num_actual_load_tokens <= 0 and hasattr(self.connector,
1108+
"add_failure_request"):
1109+
self.connector.add_failure_request(request)
1110+
return True
1111+
1112+
if num_actual_load_tokens == request.num_tokens:
1113+
num_computed_tokens -= 1
1114+
1115+
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
1116+
1117+
# Update the request state for scheduling.
1118+
request.num_computed_tokens = num_computed_tokens
1119+
return True
1120+
10971121
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
10981122
"""
10991123
KV Connector: check if the request_id is finished_recving.
@@ -1107,6 +1131,9 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
11071131
WAITING_FOR_REMOTE_KV.
11081132
"""
11091133
assert self.connector is not None
1134+
if request.request_id in self.finished_loading_dict:
1135+
return self._update_actual_load_token_num_from_remote_kv(request)
1136+
11101137
if request.request_id not in self.finished_recving_kv_req_ids:
11111138
return False
11121139

@@ -1145,3 +1172,6 @@ def _update_from_kv_xfer_finished(self,
11451172
for req_id in (model_runner_output.finished_sending or ()):
11461173
logger.debug("Finished sending KV transfer for request %s", req_id)
11471174
self._free_blocks(self.requests[req_id])
1175+
if model_runner_output.finished_loading_dict:
1176+
self.finished_loading_dict.update(
1177+
model_runner_output.finished_loading_dict)

vllm/v1/outputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class ModelRunnerOutput:
107107
# [req_ids]
108108
finished_sending: Optional[set[str]] = None
109109
finished_recving: Optional[set[str]] = None
110+
# req_id -> actual_load_token from connector
111+
finished_loading_dict: Optional[dict[str, int]] = None
110112

111113
# req_id -> num_nans_in_logits
112114
num_nans_in_logits: Optional[dict[str, int]] = None
@@ -121,4 +123,5 @@ class ModelRunnerOutput:
121123
pooler_output=[],
122124
finished_sending=None,
123125
finished_recving=None,
126+
finished_loading_dict=None,
124127
num_nans_in_logits=None)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,7 @@ def _pool(
13501350
num_scheduled_tokens_np: np.ndarray,
13511351
finished_sending: Optional[set[str]],
13521352
finished_recving: Optional[set[str]],
1353+
finished_loading_dict: Optional[dict[str, int]],
13531354
) -> ModelRunnerOutput:
13541355
assert self.input_batch.num_reqs ==\
13551356
len(self.input_batch.pooling_params), \
@@ -1386,6 +1387,7 @@ def _pool(
13861387
pooler_output=pooler_output,
13871388
finished_sending=finished_sending,
13881389
finished_recving=finished_recving,
1390+
finished_loading_dict=finished_loading_dict,
13891391
)
13901392

13911393
@torch.inference_mode()
@@ -1505,6 +1507,7 @@ def execute_model(
15051507
self.maybe_wait_for_kv_save()
15061508
finished_sending, finished_recving = (
15071509
self.get_finished_kv_transfers(scheduler_output))
1510+
finished_loading_dict = self.get_finished_loading(scheduler_output)
15081511

15091512
if self.use_aux_hidden_state_outputs:
15101513
hidden_states, aux_hidden_states = model_output
@@ -1522,9 +1525,11 @@ def execute_model(
15221525
if not get_pp_group().is_last_rank:
15231526
# For mid-pipeline stages, return the hidden states.
15241527
if not broadcast_pp_output:
1525-
if finished_sending or finished_recving:
1528+
if (finished_sending or finished_recving
1529+
or finished_loading_dict):
15261530
hidden_states.finished_sending = finished_sending
15271531
hidden_states.finished_recving = finished_recving
1532+
hidden_states.finished_loading_dict = finished_loading_dict
15281533
return hidden_states
15291534
assert isinstance(hidden_states, IntermediateTensors)
15301535
get_pp_group().send_tensor_dict(hidden_states.tensors,
@@ -1534,7 +1539,7 @@ def execute_model(
15341539
if self.input_batch.pooling_params:
15351540
return self._pool(hidden_states, num_scheduled_tokens,
15361541
num_scheduled_tokens_np, finished_sending,
1537-
finished_recving)
1542+
finished_recving, finished_loading_dict)
15381543

15391544
sample_hidden_states = hidden_states[logits_indices]
15401545
logits = self.model.compute_logits(sample_hidden_states, None)
@@ -1686,6 +1691,7 @@ def execute_model(
16861691
pooler_output=[],
16871692
finished_sending=finished_sending,
16881693
finished_recving=finished_recving,
1694+
finished_loading_dict=finished_loading_dict,
16891695
num_nans_in_logits=num_nans_in_logits,
16901696
)
16911697

vllm/v1/worker/gpu_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,12 @@ def execute_model(
359359
# In case of PP with kv transfer, we need to pass through the
360360
# finished_sending and finished_recving buffers.
361361
new_output = EMPTY_MODEL_RUNNER_OUTPUT
362-
if output.finished_sending or output.finished_recving:
362+
if (output.finished_sending or output.finished_recving
363+
or output.finished_loading_dict):
363364
new_output = copy.copy(new_output)
364365
new_output.finished_sending = output.finished_sending
365366
new_output.finished_recving = output.finished_recving
367+
new_output.finished_loading_dict = output.finished_loading_dict
366368
output = new_output
367369

368370
assert isinstance(output, ModelRunnerOutput)

vllm/v1/worker/kv_connector_model_runner_mixin.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,29 @@ def get_finished_kv_transfers(
5353
scheduler_output.finished_req_ids)
5454
return None, None
5555

56+
@staticmethod
57+
def get_finished_loading(
58+
scheduler_output: "SchedulerOutput", ) -> dict[str, int]:
59+
if has_kv_transfer_group():
60+
return get_kv_transfer_group().get_finished_loading(
61+
scheduler_output)
62+
return {}
63+
5664
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
5765
vllm_config: VllmConfig) -> ModelRunnerOutput:
5866
# KV send/recv even if no work to do.
5967
with set_forward_context(None, vllm_config):
6068
self.maybe_setup_kv_connector(scheduler_output)
6169
finished_sending, finished_recving = (
6270
self.get_finished_kv_transfers(scheduler_output))
71+
finished_loading_dict = self.get_finished_loading(scheduler_output)
6372

64-
if not finished_sending and not finished_recving:
73+
if (not finished_sending and not finished_recving
74+
and not finished_loading_dict):
6575
return EMPTY_MODEL_RUNNER_OUTPUT
6676

6777
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
6878
output.finished_sending = finished_sending
6979
output.finished_recving = finished_recving
80+
output.finished_loading_dict = finished_loading_dict
7081
return output

0 commit comments

Comments
 (0)