Skip to content

Commit 2b97c69

Browse files
authored
[v0.9.1][Bugfix][PD] Auto-clear producer KV cache if no pull notification (#2085)
### What this PR does / why we need it? This PR addresses a critical issue where Node D (Device) failures cause Node P (Processor) to hang due to inability to release KV cache. **Trigger Scenarios:** 1. Node D fails mid-inference (e.g., network disconnection) 2. Node D rejects requests at a certain stage (e.g., via API server) 3. Load-test script termination causes Node P or D to abort queued requests **Root Cause Analysis:** 1. Currently, Node D sends a "KV cache pull complete, release approved" message to Node P 2. This message is transmitted via the worker connector. If PD connection breaks or requests are rejected upstream, Node D cannot send the message 3. Node P will never release KV cache without receiving this message **Solution:** Following VLLM community's approach (NIXL connector timeout mechanism), we're implementing: - A timeout mechanism with comprehensive warnings - Updated README documentation - Reference: VLLM's optimization PR [#20139](vllm-project/vllm#20139) **Note:** The full disaster recovery solution is still in design. This PR will be merged into v091-dev branch simply but will evolve in main ([PR #2174](#2174)). ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? --------- Signed-off-by: underfituu <[email protected]>
1 parent 741a8cf commit 2b97c69

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import copy
23
import json
34
import math
45
import os
@@ -183,6 +184,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
183184
self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT
184185

185186
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
187+
self._reqs_need_send: dict[str, float] = {}
186188

187189
def get_num_new_matched_tokens(
188190
self, request: "Request",
@@ -247,7 +249,12 @@ def build_connector_meta(
247249
meta.add_new_req(request_id=req_id,
248250
local_block_ids=block_ids,
249251
kv_transfer_params=req.kv_transfer_params)
252+
253+
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
254+
255+
# Clear the list once workers start the transfers
250256
self._reqs_need_recv.clear()
257+
self._reqs_need_send.clear()
251258

252259
return meta
253260

@@ -271,9 +278,14 @@ def request_finished(
271278
# note: there might be some issue on this, check it if there is any unexpected result
272279
computed_block_ids = block_ids
273280
delay_free_blocks = len(computed_block_ids) > 0
281+
274282
if delay_free_blocks:
275283
logger.info("Delaying free of %d blocks for request %s",
276284
len(computed_block_ids), request.request_id)
285+
# Prefill request on remote. It will be read from D upon completion
286+
self._reqs_need_send[request.request_id] = time.perf_counter(
287+
) + envs.VLLM_LLMDD_ABORT_REQUEST_TIMEOUT
288+
277289
return delay_free_blocks, dict(
278290
do_remote_prefill=True,
279291
do_remote_decode=False,
@@ -340,6 +352,7 @@ def __init__(self, vllm_config: VllmConfig):
340352
os.environ["HCCL_DETERMINISTIC"] = "true"
341353
self.done_receiving_counts: defaultdict[str,
342354
set[int]] = defaultdict(set)
355+
self.reqs_to_send: dict[str, float] = {}
343356

344357
def listen_for_agent_metadata_req(self, event: threading.Event):
345358
assert self.local_agent_metadata is not None
@@ -383,7 +396,9 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
383396
logger.debug(
384397
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
385398
)
386-
self.finished_reqs.add(finished_req_id)
399+
if finished_req_id in self.reqs_to_send:
400+
self.finished_reqs.add(finished_req_id)
401+
del self.reqs_to_send[finished_req_id]
387402
sock.send_multipart(
388403
(identity, b"", b"receiving decode finished"))
389404
else:
@@ -606,6 +621,7 @@ def handle_exception(future):
606621

607622
for future in futures:
608623
future.add_done_callback(handle_exception)
624+
self.reqs_to_send.update(metadata.reqs_to_send)
609625

610626
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
611627
assert self.local_agent_metadata is not None
@@ -860,8 +876,20 @@ def get_finished(
860876
self, finished_req_ids: set[str]
861877
) -> tuple[Optional[set[str]], Optional[set[str]]]:
862878
"""Get the finished recving and sending requuests."""
863-
import copy
879+
now = time.perf_counter()
880+
864881
with self.thread_lock:
882+
while self.reqs_to_send:
883+
req_id, expires = next(iter(self.reqs_to_send.items()))
884+
if now < expires:
885+
break
886+
logger.warning(
887+
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
888+
"If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
889+
)
890+
if req_id in self.reqs_to_send:
891+
self.finished_reqs.add(req_id)
892+
del self.reqs_to_send[req_id]
865893
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
866894
self.finished_reqs.clear()
867895
if self.llm_datadist_role == LLMRole.PROMPT:

vllm_ascend/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@
133133
# remote worker.
134134
"VLLM_LLMDD_RPC_PORT":
135135
lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)),
136+
# `LLMDataDistCMgrConnector` required variable. Time (in seconds) after which the KV cache on the producer side is
137+
# automatically cleared if no READ notification is received from the consumer.
138+
# `VLLM_LLMDD_ABORT_REQUEST_TIMEOUT` is only applicable when using LLMDataDistCMgrConnector in a
139+
# disaggregated decode-prefill setup.
140+
"VLLM_LLMDD_ABORT_REQUEST_TIMEOUT":
141+
lambda: int(os.getenv("VLLM_LLMDD_ABORT_REQUEST_TIMEOUT", 300)),
136142
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
137143
# and the mla_pa will be the default path of deepseek decode path.
138144
"VLLM_ASCEND_MLA_PA":

0 commit comments

Comments
 (0)