Skip to content

Commit d6ef3df

Browse files
[Bugfix]fix_mulit_connector_bug (#3332)
### What this PR does / why we need it? When using multi connector, the multi connector does not define get_finished_count, which will cause the kv cache to be released ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@83f478b --------- Signed-off-by: baxingpiaochong <[email protected]>
1 parent 07873d9 commit d6ef3df

File tree

2 files changed

+9
-41
lines changed

2 files changed

+9
-41
lines changed

tests/ut/kv_connector/test_mooncake_connector.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -673,10 +673,6 @@ def test_build_connector_meta(self):
673673
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
674674
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
675675

676-
def test_get_finished_count(self):
677-
count = self.scheduler.get_finished_count()
678-
self.assertEqual(count, 2)
679-
680676

681677
class TestHelperFunctions(unittest.TestCase):
682678

vllm_ascend/distributed/mooncake_connector.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def __init__(self):
8080
self.record_finished_requests: set[str] = set()
8181
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
8282

83+
def add_not_transfer_request(self, request_id: str):
84+
with self.done_task_lock:
85+
self.finished_requests.add(request_id)
86+
8387
def update_done_task_count(self, request_id: str):
8488
with self.done_task_lock:
8589
self.finished_requests.add(request_id)
@@ -157,6 +161,9 @@ def get_and_clear_finished_requests(self) -> set[str]:
157161
"""
158162
return self.task_tracker.get_and_clear_finished_requests()
159163

164+
def add_not_transfer_request(self, request_id: str):
165+
self.task_tracker.add_not_transfer_request(request_id)
166+
160167
def add_delayed_request(self, request_id: str, delay_start_time: float):
161168
return self.task_tracker.add_delayed_request(request_id,
162169
delay_start_time)
@@ -658,10 +665,6 @@ def request_finished(
658665
assert self.connector_scheduler is not None
659666
return self.connector_scheduler.request_finished(request, block_ids)
660667

661-
def get_finished_count(self) -> Optional[int]:
662-
assert self.connector_scheduler is not None
663-
return self.connector_scheduler.get_finished_count()
664-
665668
############################################################
666669
# Worker Side Methods
667670
############################################################
@@ -846,39 +849,6 @@ def request_finished(
846849
last_token_id=request.output_token_ids[-1],
847850
)
848851

849-
def get_finished_count(self) -> Optional[int]:
850-
prefill_parallel_config: dict[
851-
str,
852-
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
853-
"prefill", {})
854-
855-
assert "tp_size" in prefill_parallel_config.keys()
856-
self._prefill_tp_size = prefill_parallel_config["tp_size"]
857-
decode_parallel_config: dict[
858-
str,
859-
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
860-
"decode", {})
861-
assert "tp_size" in decode_parallel_config.keys()
862-
self._decode_tp_size = decode_parallel_config["tp_size"]
863-
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
864-
if self.vllm_config.model_config.use_mla or hasattr(
865-
self.vllm_config.model_config.hf_config, "index_topk"):
866-
num_need_pulls = 1
867-
else:
868-
num_p_block_heads = max(
869-
1, num_key_value_heads // self._prefill_tp_size)
870-
num_d_block_heads = max(
871-
1, num_key_value_heads // self._decode_tp_size)
872-
num_need_pulls = num_d_block_heads // num_p_block_heads
873-
kv_role = self.vllm_config.kv_transfer_config.kv_role
874-
logger.debug(
875-
"get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d",
876-
kv_role, num_need_pulls, self._decode_tp_size)
877-
if kv_role == 'kv_producer':
878-
return num_need_pulls * self._decode_tp_size
879-
else:
880-
return self._decode_tp_size
881-
882852

883853
class MooncakeConnectorWorker:
884854
"""Implementation of Worker side methods"""
@@ -1150,6 +1120,8 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
11501120
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
11511121
self.kv_send_thread.add_delayed_request(
11521122
req_id, delay_start_time)
1123+
else:
1124+
self.kv_send_thread.add_not_transfer_request(req_id)
11531125

11541126
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
11551127
return sum(self._get_remote_tp_ranks_for_req(req_id), [])

0 commit comments

Comments
 (0)