Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import copy
import json
import math
import os
Expand Down Expand Up @@ -183,6 +184,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
self.port = dp_rank_local * tp_size + envs.VLLM_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs.VLLM_LLMDD_RPC_PORT

self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}

def get_num_new_matched_tokens(
self, request: "Request",
Expand Down Expand Up @@ -247,7 +249,12 @@ def build_connector_meta(
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params)

meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)

# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_send.clear()

return meta

Expand All @@ -271,9 +278,14 @@ def request_finished(
# note: there might be some issue on this, check it if there is any unexpected result
computed_block_ids = block_ids
delay_free_blocks = len(computed_block_ids) > 0

if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = time.perf_counter(
) + envs.VLLM_LLMDD_ABORT_REQUEST_TIMEOUT

return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
Expand Down Expand Up @@ -340,6 +352,7 @@ def __init__(self, vllm_config: VllmConfig):
os.environ["HCCL_DETERMINISTIC"] = "true"
self.done_receiving_counts: defaultdict[str,
set[int]] = defaultdict(set)
self.reqs_to_send: dict[str, float] = {}

def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
Expand Down Expand Up @@ -383,7 +396,9 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
self.finished_reqs.add(finished_req_id)
if finished_req_id in self.reqs_to_send:
self.finished_reqs.add(finished_req_id)
del self.reqs_to_send[finished_req_id]
sock.send_multipart(
(identity, b"", b"receiving decode finished"))
else:
Expand Down Expand Up @@ -606,6 +621,7 @@ def handle_exception(future):

for future in futures:
future.add_done_callback(handle_exception)
self.reqs_to_send.update(metadata.reqs_to_send)

def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
assert self.local_agent_metadata is not None
Expand Down Expand Up @@ -860,8 +876,20 @@ def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requuests."""
import copy
now = time.perf_counter()

with self.thread_lock:
while self.reqs_to_send:
req_id, expires = next(iter(self.reqs_to_send.items()))
if now < expires:
break
logger.warning(
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
"If a greater mean TTFT is acceptable, you can 'export VLLM_LLMDD_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
)
if req_id in self.reqs_to_send:
self.finished_reqs.add(req_id)
del self.reqs_to_send[req_id]
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
if self.llm_datadist_role == LLMRole.PROMPT:
Expand Down
6 changes: 6 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@
# remote worker.
"VLLM_LLMDD_RPC_PORT":
lambda: int(os.getenv("VLLM_LLMDD_RPC_PORT", 5557)),
# `LLMDataDistCMgrConnector` required variable. Time (in seconds) after which the KV cache on the producer side is
# automatically cleared if no READ notification is received from the consumer.
# `VLLM_LLMDD_ABORT_REQUEST_TIMEOUT` is only applicable when using LLMDataDistCMgrConnector in a
# disaggregated decode-prefill setup.
"VLLM_LLMDD_ABORT_REQUEST_TIMEOUT":
lambda: int(os.getenv("VLLM_LLMDD_ABORT_REQUEST_TIMEOUT", 300)),
# Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible
# and the mla_pa will be the default path of deepseek decode path.
"VLLM_ASCEND_MLA_PA":
Expand Down