@@ -80,6 +80,10 @@ def __init__(self):
8080 self .record_finished_requests : set [str ] = set ()
8181 self .delayed_free_requests : OrderedDict [str , float ] = OrderedDict ()
8282
83+ def add_not_transfer_request (self , request_id : str ):
84+ with self .done_task_lock :
85+ self .finished_requests .add (request_id )
86+
8387 def update_done_task_count (self , request_id : str ):
8488 with self .done_task_lock :
8589 self .finished_requests .add (request_id )
@@ -157,6 +161,9 @@ def get_and_clear_finished_requests(self) -> set[str]:
157161 """
158162 return self .task_tracker .get_and_clear_finished_requests ()
159163
164+ def add_not_transfer_request (self , request_id : str ):
165+ self .task_tracker .add_not_transfer_request (request_id )
166+
160167 def add_delayed_request (self , request_id : str , delay_start_time : float ):
161168 return self .task_tracker .add_delayed_request (request_id ,
162169 delay_start_time )
@@ -658,10 +665,6 @@ def request_finished(
658665 assert self .connector_scheduler is not None
659666 return self .connector_scheduler .request_finished (request , block_ids )
660667
661- def get_finished_count (self ) -> Optional [int ]:
662- assert self .connector_scheduler is not None
663- return self .connector_scheduler .get_finished_count ()
664-
665668 ############################################################
666669 # Worker Side Methods
667670 ############################################################
@@ -846,39 +849,6 @@ def request_finished(
846849 last_token_id = request .output_token_ids [- 1 ],
847850 )
848851
849- def get_finished_count (self ) -> Optional [int ]:
850- prefill_parallel_config : dict [
851- str ,
852- Any ] = self .vllm_config .kv_transfer_config .get_from_extra_config (
853- "prefill" , {})
854-
855- assert "tp_size" in prefill_parallel_config .keys ()
856- self ._prefill_tp_size = prefill_parallel_config ["tp_size" ]
857- decode_parallel_config : dict [
858- str ,
859- Any ] = self .vllm_config .kv_transfer_config .get_from_extra_config (
860- "decode" , {})
861- assert "tp_size" in decode_parallel_config .keys ()
862- self ._decode_tp_size = decode_parallel_config ["tp_size" ]
863- num_key_value_heads = self .vllm_config .model_config .hf_config .num_key_value_heads
864- if self .vllm_config .model_config .use_mla or hasattr (
865- self .vllm_config .model_config .hf_config , "index_topk" ):
866- num_need_pulls = 1
867- else :
868- num_p_block_heads = max (
869- 1 , num_key_value_heads // self ._prefill_tp_size )
870- num_d_block_heads = max (
871- 1 , num_key_value_heads // self ._decode_tp_size )
872- num_need_pulls = num_d_block_heads // num_p_block_heads
873- kv_role = self .vllm_config .kv_transfer_config .kv_role
874- logger .debug (
875- "get_finished_count, kv_role=%s, num_need_pulls=%d, decode_tp_size=%d" ,
876- kv_role , num_need_pulls , self ._decode_tp_size )
877- if kv_role == 'kv_producer' :
878- return num_need_pulls * self ._decode_tp_size
879- else :
880- return self ._decode_tp_size
881-
882852
883853class MooncakeConnectorWorker :
884854 """Implementation of Worker side methods"""
@@ -1150,6 +1120,8 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
11501120 if self .tp_rank in self ._prefill_get_remote_tp_rank (req_id ):
11511121 self .kv_send_thread .add_delayed_request (
11521122 req_id , delay_start_time )
1123+ else :
1124+ self .kv_send_thread .add_not_transfer_request (req_id )
11531125
11541126 def _prefill_get_remote_tp_rank (self , req_id : str ) -> List [int ]:
11551127 return sum (self ._get_remote_tp_ranks_for_req (req_id ), [])
0 commit comments