From 3e6e7111d3b5422dc679a1a82a12912bcbda9dc1 Mon Sep 17 00:00:00 2001 From: underfituu Date: Tue, 29 Jul 2025 17:40:20 +0800 Subject: [PATCH 1/8] fix_pd_expiry Signed-off-by: underfituu --- .../llmdatadist_c_mgr_connector.py | 25 +++++++++++++++++++ vllm_ascend/envs.py | 6 +++++ 2 files changed, 31 insertions(+) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 45b7dc5e4f..47dfaf9065 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -183,6 +183,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]): self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, float] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -247,7 +248,12 @@ def build_connector_meta( meta.add_new_req(request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params) + + meta.reqs_to_send = self._reqs_need_send + + # Clear the list once workers start the transfers self._reqs_need_recv.clear() + self._reqs_need_send.clear() return meta @@ -271,9 +277,14 @@ def request_finished( # note: there might be some issue on this, check it if there is any unexpected result computed_block_ids = block_ids delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks: logger.info("Delaying free of %d blocks for request %s", len(computed_block_ids), request.request_id) + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_LLMDD_ABORT_REQUEST_TIMEOUT + return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -340,6 +351,7 @@ def __init__(self, vllm_config: VllmConfig): os.environ["HCCL_DETERMINISTIC"] = "true" self.done_receiving_counts: defaultdict[str, set[int]] = defaultdict(set) + self._reqs_to_send: dict[str, float] = {} def listen_for_agent_metadata_req(self, event: threading.Event): assert self.local_agent_metadata is not None @@ -606,6 +618,7 @@ def handle_exception(future): for future in futures: future.add_done_callback(handle_exception) + self._reqs_to_send.update(metadata._reqs_need_send) def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: assert self.local_agent_metadata is not None @@ -861,7 +874,19 @@ def get_finished( ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requuests.""" import copy + now = time.perf_counter() with self.thread_lock: + while self._reqs_to_send: + req_id, expires = next(iter(self._reqs_to_send.items())) + if now < expires: + break + logger.warning( + "Some requests in prefill node fail to receive KV Cache transfer done signal. " + "If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " + ) + if req_id not in self.finished_reqs: + self.finished_reqs.add(req_id) + del self._reqs_to_send[req_id] req_ids_to_ret = copy.deepcopy(self.finished_reqs) self.finished_reqs.clear() if self.llm_datadist_role == LLMRole.PROMPT: diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 1f0b6ff4b1..caca9904e7 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -133,6 +133,12 @@ # remote worker. "VLLM_LLMDD_RPC_PORT": lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)), + # `LLMDataDistCMgrConnector` required variable. Time (in seconds) after which the KV cache on the producer side is + # automatically cleared if no READ notification is received from the consumer. + # `VLLM_LLMDD_ABORT_REQUEST_TIMEOUT` is only applicable when using LLMDataDistCMgrConnector in a + # disaggregated decode-prefill setup. + "VLLM_LLMDD_ABORT_REQUEST_TIMEOUT": + lambda: int(os.getenv("VLLM_LLMDD_ABORT_REQUEST_TIMEOUT", 300)), # Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible # and the mla_pa will be the default path of deepseek decode path. "VLLM_ASCEND_MLA_PA": From 4e280ddb9943f28c7c1e745e2b6d6fd62ecd5b34 Mon Sep 17 00:00:00 2001 From: underfituu Date: Fri, 1 Aug 2025 14:02:30 +0800 Subject: [PATCH 2/8] fix_pd_expiry Signed-off-by: underfituu --- .../llmdatadist_c_mgr_connector.py | 57 +++++++++++++++---- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 47dfaf9065..7d647d9c56 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -4,6 +4,7 @@ import os import threading import time +import copy from collections import defaultdict from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor @@ -44,7 +45,8 @@ class LLMDataDistCMgrEvent(Enum): ReqForMetadata = 0 - ReqForFinished = 1 + ReqForChecking = 1 + ReqForFinished = 2 class LLMDataDistCMgrAgentMetadata(msgspec.Struct): @@ -249,7 +251,7 @@ def build_connector_meta( local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params) - meta.reqs_to_send = self._reqs_need_send + meta.reqs_to_send = copy.deepcopy(self._reqs_need_send) # Clear the list once workers start the transfers self._reqs_need_recv.clear() @@ -351,7 +353,7 @@ def __init__(self, vllm_config: VllmConfig): os.environ["HCCL_DETERMINISTIC"] = "true" self.done_receiving_counts: defaultdict[str, set[int]] = defaultdict(set) - self._reqs_to_send: dict[str, float] = {} + self.reqs_to_send: dict[str, float] = {} def listen_for_agent_metadata_req(self, event: threading.Event): assert self.local_agent_metadata is not None @@ -384,6 +386,13 @@ def listen_for_agent_metadata_req(self, event: threading.Event): logger.warning( f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" ) + elif event_msg == LLMDataDistCMgrEvent.ReqForChecking: + finished_req_id = decode_msg[0] + checking = 1 + with self.thread_lock: + checking = 1 if finished_req_id in self.reqs_to_send else 0 + checking_to_send = msg_encoder.encode(checking) + sock.send_multipart((identity, b"", checking_to_send)) elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: finished_req_id = decode_msg[0] decode_tp_rank = decode_msg[1] @@ -618,7 +627,7 @@ def handle_exception(future): for future in futures: future.add_done_callback(handle_exception) - self._reqs_to_send.update(metadata._reqs_need_send) + self.reqs_to_send.update(metadata._reqs_need_send) def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: assert self.local_agent_metadata is not None @@ -800,6 +809,28 @@ def send_finish_to_remote(self, host: str, port: int, request_id): logger.error( f"Failed to send reqest_id {request_id} to prefill: {e}") + def send_checking_to_prefill_node(self, host: str, port: int, request_id): + url = f"tcp://{host}:{port}" + logger.info(f"Sending checking to remote: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode([ + LLMDataDistCMgrEvent.ReqForChecking, + [request_id] + ]) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + try: + sock.send(msg_send) + logger.info( + f"Request id {request_id} checking message send to remote {url}" + ) + checking_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder() + checking_flag = decoder.decode(checking_bytes) + return checking_flag + except Exception as e: + logger.error( + f"Failed to send reqest_id {request_id} to prefill: {e}") + def _read_blocks( self, local_block_ids: list[int], @@ -823,6 +854,10 @@ def _read_blocks( remote_block_ids = remote_block_ids[-num_local_blocks:] logger.info(f"remote cluster id is: {remote_cluster_id}") + if not self.send_checking_to_prefill_node(remote_ip, remote_port, request_id): + raise RuntimeError( + "Remote prefill node has already free blocks, skipping pull blocks" + ) if self.use_mla: remote_cache_key_k_normed = BlocksCacheKey( cluster_id=remote_cluster_id, model_id=0) @@ -873,20 +908,22 @@ def get_finished( self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requuests.""" - import copy now = time.perf_counter() + with self.thread_lock: - while self._reqs_to_send: - req_id, expires = next(iter(self._reqs_to_send.items())) + while self.reqs_to_send: + req_id, expires = next(iter(self.reqs_to_send.items())) + if req_id in self.finished_reqs: + del self.reqs_to_send[req_id] + continue if now < expires: break logger.warning( "Some requests in prefill node fail to receive KV Cache transfer done signal. " "If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " ) - if req_id not in self.finished_reqs: - self.finished_reqs.add(req_id) - del self._reqs_to_send[req_id] + self.finished_reqs.add(req_id) + del self.reqs_to_send[req_id] req_ids_to_ret = copy.deepcopy(self.finished_reqs) self.finished_reqs.clear() if self.llm_datadist_role == LLMRole.PROMPT: From 91a86ff0c8384bdbcc84eca3158ed8b2d1efb10b Mon Sep 17 00:00:00 2001 From: underfituu Date: Fri, 1 Aug 2025 18:09:07 +0800 Subject: [PATCH 3/8] del reqs_to_send in listening thread Signed-off-by: underfituu --- vllm_ascend/distributed/llmdatadist_c_mgr_connector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 7d647d9c56..5c4cec4c58 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -405,6 +405,7 @@ def listen_for_agent_metadata_req(self, event: threading.Event): f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" ) self.finished_reqs.add(finished_req_id) + del self.reqs_to_send[finished_req_id] sock.send_multipart( (identity, b"", b"receiving decode finished")) else: @@ -913,9 +914,6 @@ def get_finished( with self.thread_lock: while self.reqs_to_send: req_id, expires = next(iter(self.reqs_to_send.items())) - if req_id in self.finished_reqs: - del self.reqs_to_send[req_id] - continue if now < expires: break logger.warning( From 691c8f6a6a0b1e74b135d361f3aaecf5b39c6d95 Mon Sep 17 00:00:00 2001 From: underfituu Date: Sat, 2 Aug 2025 10:07:05 +0800 Subject: [PATCH 4/8] fix lint Signed-off-by: underfituu --- vllm_ascend/distributed/llmdatadist_c_mgr_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 5c4cec4c58..0e64ee46be 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -1,10 +1,10 @@ +import copy import contextlib import json import math import os import threading import time -import copy from collections import defaultdict from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor From 90464d879d0fc57c349176608a1d81dac347495a Mon Sep 17 00:00:00 2001 From: underfituu Date: Mon, 4 Aug 2025 14:12:13 +0800 Subject: [PATCH 5/8] fix lint Signed-off-by: underfituu --- vllm_ascend/distributed/llmdatadist_c_mgr_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 0e64ee46be..bffdf5d7c7 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -1,5 +1,5 @@ -import copy import contextlib +import copy import json import math import os From 5620f2da358b40f0ba9f6455acb354402fd47c04 Mon Sep 17 00:00:00 2001 From: underfituu Date: Mon, 4 Aug 2025 14:17:16 +0800 Subject: [PATCH 6/8] fix lint Signed-off-by: underfituu --- .../llmdatadist_c_mgr_connector.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index bffdf5d7c7..77c09ebddc 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -814,10 +814,8 @@ def send_checking_to_prefill_node(self, host: str, port: int, request_id): url = f"tcp://{host}:{port}" logger.info(f"Sending checking to remote: {url}") msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode([ - LLMDataDistCMgrEvent.ReqForChecking, - [request_id] - ]) + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForChecking, [request_id]]) with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] try: sock.send(msg_send) @@ -855,10 +853,11 @@ def _read_blocks( remote_block_ids = remote_block_ids[-num_local_blocks:] logger.info(f"remote cluster id is: {remote_cluster_id}") - if not self.send_checking_to_prefill_node(remote_ip, remote_port, request_id): + if not self.send_checking_to_prefill_node(remote_ip, remote_port, + request_id): raise RuntimeError( - "Remote prefill node has already free blocks, skipping pull blocks" - ) + "Remote prefill node has already free blocks, skipping pull blocks" + ) if self.use_mla: remote_cache_key_k_normed = BlocksCacheKey( cluster_id=remote_cluster_id, model_id=0) @@ -910,16 +909,16 @@ def get_finished( ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requuests.""" now = time.perf_counter() - + with self.thread_lock: while self.reqs_to_send: req_id, expires = next(iter(self.reqs_to_send.items())) if now < expires: break logger.warning( - "Some requests in prefill node fail to receive KV Cache transfer done signal. " - "If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " - ) + "Some requests in prefill node fail to receive KV Cache transfer done signal. " + "If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " + ) self.finished_reqs.add(req_id) del self.reqs_to_send[req_id] req_ids_to_ret = copy.deepcopy(self.finished_reqs) From 680bdd1f20a48eca846a6d40bd745ce7f442714c Mon Sep 17 00:00:00 2001 From: underfituu Date: Mon, 4 Aug 2025 14:27:29 +0800 Subject: [PATCH 7/8] only_enable_p_release_kv_cache Signed-off-by: underfituu --- .../llmdatadist_c_mgr_connector.py | 35 +------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 77c09ebddc..c367d1af54 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -45,8 +45,7 @@ class LLMDataDistCMgrEvent(Enum): ReqForMetadata = 0 - ReqForChecking = 1 - ReqForFinished = 2 + ReqForFinished = 1 class LLMDataDistCMgrAgentMetadata(msgspec.Struct): @@ -386,13 +385,6 @@ def listen_for_agent_metadata_req(self, event: threading.Event): logger.warning( f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" ) - elif event_msg == LLMDataDistCMgrEvent.ReqForChecking: - finished_req_id = decode_msg[0] - checking = 1 - with self.thread_lock: - checking = 1 if finished_req_id in self.reqs_to_send else 0 - checking_to_send = msg_encoder.encode(checking) - sock.send_multipart((identity, b"", checking_to_send)) elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: finished_req_id = decode_msg[0] decode_tp_rank = decode_msg[1] @@ -810,26 +802,6 @@ def send_finish_to_remote(self, host: str, port: int, request_id): logger.error( f"Failed to send reqest_id {request_id} to prefill: {e}") - def send_checking_to_prefill_node(self, host: str, port: int, request_id): - url = f"tcp://{host}:{port}" - logger.info(f"Sending checking to remote: {url}") - msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode( - [LLMDataDistCMgrEvent.ReqForChecking, [request_id]]) - with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] - try: - sock.send(msg_send) - logger.info( - f"Request id {request_id} checking message send to remote {url}" - ) - checking_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder() - checking_flag = decoder.decode(checking_bytes) - return checking_flag - except Exception as e: - logger.error( - f"Failed to send reqest_id {request_id} to prefill: {e}") - def _read_blocks( self, local_block_ids: list[int], @@ -853,11 +825,6 @@ def _read_blocks( remote_block_ids = remote_block_ids[-num_local_blocks:] logger.info(f"remote cluster id is: {remote_cluster_id}") - if not self.send_checking_to_prefill_node(remote_ip, remote_port, - request_id): - raise RuntimeError( - "Remote prefill node has already free blocks, skipping pull blocks" - ) if self.use_mla: remote_cache_key_k_normed = BlocksCacheKey( cluster_id=remote_cluster_id, model_id=0) From 5c9ffcafcf82a6e145c24422d0eb5ffa7506b1bd Mon Sep 17 00:00:00 2001 From: underfituu Date: Mon, 4 Aug 2025 15:51:37 +0800 Subject: [PATCH 8/8] check finished_req_id in self.reqs_to_send Signed-off-by: underfituu --- .../distributed/llmdatadist_c_mgr_connector.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index c367d1af54..48dc7698e0 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -396,8 +396,9 @@ def listen_for_agent_metadata_req(self, event: threading.Event): logger.debug( f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" ) - self.finished_reqs.add(finished_req_id) - del self.reqs_to_send[finished_req_id] + if finished_req_id in self.reqs_to_send: + self.finished_reqs.add(finished_req_id) + del self.reqs_to_send[finished_req_id] sock.send_multipart( (identity, b"", b"receiving decode finished")) else: @@ -620,7 +621,7 @@ def handle_exception(future): for future in futures: future.add_done_callback(handle_exception) - self.reqs_to_send.update(metadata._reqs_need_send) + self.reqs_to_send.update(metadata.reqs_to_send) def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: assert self.local_agent_metadata is not None @@ -886,8 +887,9 @@ def get_finished( "Some requests in prefill node fail to receive KV Cache transfer done signal. " "If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " ) - self.finished_reqs.add(req_id) - del self.reqs_to_send[req_id] + if req_id in self.reqs_to_send: + self.finished_reqs.add(req_id) + del self.reqs_to_send[req_id] req_ids_to_ret = copy.deepcopy(self.finished_reqs) self.finished_reqs.clear() if self.llm_datadist_role == LLMRole.PROMPT: