@@ -205,8 +205,12 @@ def _initialize_kv_caches(
205
205
def get_supported_tasks (self ) -> tuple [SupportedTask , ...]:
206
206
return self .model_executor .supported_tasks
207
207
208
- def add_request (self , request : EngineCoreRequest ):
209
- """Add request to the scheduler."""
208
+ def add_request (self , request : Request , request_wave : int = 0 ):
209
+ """Add request to the scheduler.
210
+
211
+ `request_wave`: indicate which wave of requests this is expected to
212
+ belong to in DP case
213
+ """
210
214
# Validate the request_id type.
211
215
if not isinstance (request .request_id , str ):
212
216
raise TypeError (
@@ -222,27 +226,12 @@ def add_request(self, request: EngineCoreRequest):
222
226
raise ValueError (f"Unsupported task: { pooling_params .task !r} "
223
227
f"Supported tasks: { supported_pooling_tasks } " )
224
228
225
- if request .mm_hashes is not None :
226
- # Here, if hash exists for a multimodal input, then it will be
227
- # fetched from the cache, else it will be added to the cache.
228
- # Note that the cache here is mirrored with the client cache, so
229
- # anything that has a hash must have a HIT cache entry here
230
- # as well.
231
- assert request .mm_inputs is not None
232
- request .mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
233
- request .mm_inputs , request .mm_hashes )
234
-
235
- req = Request .from_engine_core_request (request )
236
- if req .use_structured_output :
237
- # Start grammar compilation asynchronously
238
- self .structured_output_manager .grammar_init (req )
239
-
240
- if req .kv_transfer_params is not None and (
229
+ if request .kv_transfer_params is not None and (
241
230
not self .scheduler .get_kv_connector ()):
242
231
logger .warning ("Got kv_transfer_params, but no KVConnector found. "
243
232
"Disabling KVTransfer for this request." )
244
233
245
- self .scheduler .add_request (req )
234
+ self .scheduler .add_request (request )
246
235
247
236
def abort_requests (self , request_ids : list [str ]):
248
237
"""Abort requests from the scheduler."""
@@ -414,6 +403,31 @@ def save_tensorized_model(
414
403
self .model_executor .save_tensorized_model (
415
404
tensorizer_config = tensorizer_config , )
416
405
406
+ def preprocess_add_request (
407
+ self , request : EngineCoreRequest ) -> tuple [Request , int ]:
408
+ """Preprocess the request.
409
+
410
+ This function could be directly used in input processing thread to allow
411
+ request initialization running in parallel with Model forward
412
+ """
413
+ if request .mm_hashes is not None :
414
+ assert request .mm_inputs is not None
415
+ # Note on thread safety: no race condition.
416
+ # `mm_input_cache_server` is reset at the end of LLMEngine init,
417
+ # and will only accessed in the input processing thread afterwards.
418
+ request .mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
419
+ request .mm_inputs , request .mm_hashes )
420
+
421
+ req = Request .from_engine_core_request (request )
422
+ if req .use_structured_output :
423
+ # Note on thread safety: no race condition.
424
+ # `grammar_init` is only invoked in input processing thread. For
425
+ # `structured_output_manager`, each request is independent and
426
+ # grammar compilation is async. Scheduler always checks grammar
427
+ # compilation status before scheduling request.
428
+ self .structured_output_manager .grammar_init (req )
429
+ return req , request .current_wave
430
+
417
431
418
432
class EngineCoreProc (EngineCore ):
419
433
"""ZMQ-wrapper for running EngineCore in background process."""
@@ -707,7 +721,8 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
707
721
"""Dispatch request from client."""
708
722
709
723
if request_type == EngineCoreRequestType .ADD :
710
- self .add_request (request )
724
+ req , request_wave = request
725
+ self .add_request (req , request_wave )
711
726
elif request_type == EngineCoreRequestType .ABORT :
712
727
self .abort_requests (request )
713
728
elif request_type == EngineCoreRequestType .UTILITY :
@@ -806,10 +821,11 @@ def process_input_sockets(self, input_addresses: list[str],
806
821
bytes (type_frame .buffer ))
807
822
808
823
# Deserialize the request data.
809
- decoder = add_request_decoder if (
810
- request_type
811
- == EngineCoreRequestType .ADD ) else generic_decoder
812
- request = decoder .decode (data_frames )
824
+ if request_type == EngineCoreRequestType .ADD :
825
+ request = add_request_decoder .decode (data_frames )
826
+ request = self .preprocess_add_request (request )
827
+ else :
828
+ request = generic_decoder .decode (data_frames )
813
829
814
830
# Push to input queue for core busy loop.
815
831
self .input_queue .put_nowait ((request_type , request ))
@@ -939,17 +955,17 @@ def shutdown(self):
939
955
if dp_group := getattr (self , "dp_group" , None ):
940
956
stateless_destroy_torch_distributed_process_group (dp_group )
941
957
942
- def add_request (self , request : EngineCoreRequest ):
943
- if self .has_coordinator and request . current_wave != self .current_wave :
944
- if request . current_wave > self .current_wave :
945
- self .current_wave = request . current_wave
958
+ def add_request (self , request : Request , request_wave : int = 0 ):
959
+ if self .has_coordinator and request_wave != self .current_wave :
960
+ if request_wave > self .current_wave :
961
+ self .current_wave = request_wave
946
962
elif not self .engines_running :
947
963
# Request received for an already-completed wave, notify
948
964
# front-end that we need to start the next one.
949
965
self .output_queue .put_nowait (
950
966
(- 1 , EngineCoreOutputs (start_wave = self .current_wave )))
951
967
952
- super ().add_request (request )
968
+ super ().add_request (request , request_wave )
953
969
954
970
def _handle_client_request (self , request_type : EngineCoreRequestType ,
955
971
request : Any ) -> None :
0 commit comments