Skip to content

Commit ca9e2be

Browse files
authored
[Core] Move EngineCoreRequest to Request conversion out of EngineCore (#21627)
Signed-off-by: linzebing <[email protected]>
1 parent 601f856 commit ca9e2be

File tree

3 files changed

+73
-48
lines changed

3 files changed

+73
-48
lines changed

tests/v1/engine/test_engine_core.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
6565
"""Test basic request lifecycle."""
6666

6767
# First request.
68-
engine_core.add_request(make_request())
68+
engine_core.add_request(
69+
*engine_core.preprocess_add_request(make_request()))
6970
assert len(engine_core.scheduler.waiting) == 1
7071
assert len(engine_core.scheduler.running) == 0
7172

@@ -74,7 +75,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
7475
assert len(engine_core.scheduler.running) == 1
7576

7677
# Second request.
77-
engine_core.add_request(make_request())
78+
engine_core.add_request(
79+
*engine_core.preprocess_add_request(make_request()))
7880
assert len(engine_core.scheduler.waiting) == 1
7981
assert len(engine_core.scheduler.running) == 1
8082

@@ -83,8 +85,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
8385
assert len(engine_core.scheduler.running) == 2
8486

8587
# Add two requests in a row.
86-
engine_core.add_request(make_request())
87-
engine_core.add_request(make_request())
88+
engine_core.add_request(
89+
*engine_core.preprocess_add_request(make_request()))
90+
engine_core.add_request(
91+
*engine_core.preprocess_add_request(make_request()))
8892
assert len(engine_core.scheduler.waiting) == 2
8993
assert len(engine_core.scheduler.running) == 2
9094

@@ -104,7 +108,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
104108
req = make_request()
105109
request_id = req.request_id
106110

107-
engine_core.add_request(req)
111+
engine_core.add_request(*engine_core.preprocess_add_request(req))
108112
assert len(engine_core.scheduler.waiting) == 1
109113
assert len(engine_core.scheduler.running) == 0
110114
assert engine_core.scheduler.has_unfinished_requests()
@@ -131,16 +135,16 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
131135
req1 = make_request()
132136
req2 = make_request()
133137

134-
engine_core.add_request(req0)
135-
engine_core.add_request(req1)
138+
engine_core.add_request(*engine_core.preprocess_add_request(req0))
139+
engine_core.add_request(*engine_core.preprocess_add_request(req1))
136140
assert len(engine_core.scheduler.waiting) == 2
137141
assert len(engine_core.scheduler.running) == 0
138142

139143
_ = engine_core.step()
140144
assert len(engine_core.scheduler.waiting) == 0
141145
assert len(engine_core.scheduler.running) == 2
142146

143-
engine_core.add_request(req2)
147+
engine_core.add_request(*engine_core.preprocess_add_request(req2))
144148
assert len(engine_core.scheduler.waiting) == 1
145149
assert len(engine_core.scheduler.running) == 2
146150

@@ -166,12 +170,12 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
166170
req0 = make_request()
167171
req1 = make_request()
168172
req0.request_id = req1.request_id = "test"
169-
engine_core.add_request(req0)
173+
engine_core.add_request(*engine_core.preprocess_add_request(req0))
170174

171175
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
172176
pass
173177

174-
engine_core.add_request(req1)
178+
engine_core.add_request(*engine_core.preprocess_add_request(req1))
175179
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
176180
pass
177181

@@ -207,7 +211,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
207211
repetition_penalty=0.1,
208212
stop_token_ids=[1001, 1002],
209213
)
210-
engine_core.add_request(request)
214+
engine_core.add_request(*engine_core.preprocess_add_request(request))
211215

212216
def _check_engine_state():
213217
assert len(engine_core.scheduler.waiting) == 1
@@ -226,7 +230,7 @@ def _check_engine_state():
226230
top_p=0.99,
227231
top_k=50,
228232
)
229-
engine_core.add_request(request2)
233+
engine_core.add_request(*engine_core.preprocess_add_request(request2))
230234
_check_engine_state()
231235

232236

@@ -298,9 +302,9 @@ def shutdown(self):
298302

299303
# Add two requests in a row. Each request have 12 prompt tokens.
300304
req0 = make_request_with_max_tokens("0", 5)
301-
engine_core.add_request(req0)
305+
engine_core.add_request(*engine_core.preprocess_add_request(req0))
302306
req1 = make_request_with_max_tokens("1", 5)
303-
engine_core.add_request(req1)
307+
engine_core.add_request(*engine_core.preprocess_add_request(req1))
304308

305309
# Schedule Batch 1: (10, req0)
306310
assert engine_core.step_with_batch_queue()[0] is None
@@ -436,26 +440,30 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch):
436440

437441
with pytest.raises(TypeError,
438442
match="request_id must be a string, got.*UUID"):
439-
engine_core.add_request(uuid_request)
443+
engine_core.add_request(
444+
*engine_core.preprocess_add_request(uuid_request))
440445

441446
# Test with integer
442447
int_request = make_request()
443448
int_request.request_id = 12345
444449

445450
with pytest.raises(TypeError,
446451
match="request_id must be a string, got.*int"):
447-
engine_core.add_request(int_request)
452+
engine_core.add_request(
453+
*engine_core.preprocess_add_request(int_request))
448454

449455
# Test with None
450456
none_request = make_request()
451457
none_request.request_id = None
452458

453459
with pytest.raises(TypeError,
454460
match="request_id must be a string, got.*NoneType"):
455-
engine_core.add_request(none_request)
461+
engine_core.add_request(
462+
*engine_core.preprocess_add_request(none_request))
456463

457464
# Verify engine is still functional after errors
458465
valid_request = make_request()
459-
engine_core.add_request(valid_request)
466+
engine_core.add_request(
467+
*engine_core.preprocess_add_request(valid_request))
460468
assert len(engine_core.scheduler.waiting) == 1
461469
assert len(engine_core.scheduler.running) == 0

vllm/v1/engine/core.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,12 @@ def _initialize_kv_caches(
205205
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
206206
return self.model_executor.supported_tasks
207207

208-
def add_request(self, request: EngineCoreRequest):
209-
"""Add request to the scheduler."""
208+
def add_request(self, request: Request, request_wave: int = 0):
209+
"""Add request to the scheduler.
210+
211+
`request_wave`: indicate which wave of requests this is expected to
212+
belong to in DP case
213+
"""
210214
# Validate the request_id type.
211215
if not isinstance(request.request_id, str):
212216
raise TypeError(
@@ -222,27 +226,12 @@ def add_request(self, request: EngineCoreRequest):
222226
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
223227
f"Supported tasks: {supported_pooling_tasks}")
224228

225-
if request.mm_hashes is not None:
226-
# Here, if hash exists for a multimodal input, then it will be
227-
# fetched from the cache, else it will be added to the cache.
228-
# Note that the cache here is mirrored with the client cache, so
229-
# anything that has a hash must have a HIT cache entry here
230-
# as well.
231-
assert request.mm_inputs is not None
232-
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
233-
request.mm_inputs, request.mm_hashes)
234-
235-
req = Request.from_engine_core_request(request)
236-
if req.use_structured_output:
237-
# Start grammar compilation asynchronously
238-
self.structured_output_manager.grammar_init(req)
239-
240-
if req.kv_transfer_params is not None and (
229+
if request.kv_transfer_params is not None and (
241230
not self.scheduler.get_kv_connector()):
242231
logger.warning("Got kv_transfer_params, but no KVConnector found. "
243232
"Disabling KVTransfer for this request.")
244233

245-
self.scheduler.add_request(req)
234+
self.scheduler.add_request(request)
246235

247236
def abort_requests(self, request_ids: list[str]):
248237
"""Abort requests from the scheduler."""
@@ -414,6 +403,31 @@ def save_tensorized_model(
414403
self.model_executor.save_tensorized_model(
415404
tensorizer_config=tensorizer_config, )
416405

406+
def preprocess_add_request(
407+
self, request: EngineCoreRequest) -> tuple[Request, int]:
408+
"""Preprocess the request.
409+
410+
This function could be directly used in input processing thread to allow
411+
request initialization running in parallel with Model forward
412+
"""
413+
if request.mm_hashes is not None:
414+
assert request.mm_inputs is not None
415+
# Note on thread safety: no race condition.
416+
# `mm_input_cache_server` is reset at the end of LLMEngine init,
417+
# and will only accessed in the input processing thread afterwards.
418+
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
419+
request.mm_inputs, request.mm_hashes)
420+
421+
req = Request.from_engine_core_request(request)
422+
if req.use_structured_output:
423+
# Note on thread safety: no race condition.
424+
# `grammar_init` is only invoked in input processing thread. For
425+
# `structured_output_manager`, each request is independent and
426+
# grammar compilation is async. Scheduler always checks grammar
427+
# compilation status before scheduling request.
428+
self.structured_output_manager.grammar_init(req)
429+
return req, request.current_wave
430+
417431

418432
class EngineCoreProc(EngineCore):
419433
"""ZMQ-wrapper for running EngineCore in background process."""
@@ -707,7 +721,8 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
707721
"""Dispatch request from client."""
708722

709723
if request_type == EngineCoreRequestType.ADD:
710-
self.add_request(request)
724+
req, request_wave = request
725+
self.add_request(req, request_wave)
711726
elif request_type == EngineCoreRequestType.ABORT:
712727
self.abort_requests(request)
713728
elif request_type == EngineCoreRequestType.UTILITY:
@@ -806,10 +821,11 @@ def process_input_sockets(self, input_addresses: list[str],
806821
bytes(type_frame.buffer))
807822

808823
# Deserialize the request data.
809-
decoder = add_request_decoder if (
810-
request_type
811-
== EngineCoreRequestType.ADD) else generic_decoder
812-
request = decoder.decode(data_frames)
824+
if request_type == EngineCoreRequestType.ADD:
825+
request = add_request_decoder.decode(data_frames)
826+
request = self.preprocess_add_request(request)
827+
else:
828+
request = generic_decoder.decode(data_frames)
813829

814830
# Push to input queue for core busy loop.
815831
self.input_queue.put_nowait((request_type, request))
@@ -939,17 +955,17 @@ def shutdown(self):
939955
if dp_group := getattr(self, "dp_group", None):
940956
stateless_destroy_torch_distributed_process_group(dp_group)
941957

942-
def add_request(self, request: EngineCoreRequest):
943-
if self.has_coordinator and request.current_wave != self.current_wave:
944-
if request.current_wave > self.current_wave:
945-
self.current_wave = request.current_wave
958+
def add_request(self, request: Request, request_wave: int = 0):
959+
if self.has_coordinator and request_wave != self.current_wave:
960+
if request_wave > self.current_wave:
961+
self.current_wave = request_wave
946962
elif not self.engines_running:
947963
# Request received for an already-completed wave, notify
948964
# front-end that we need to start the next one.
949965
self.output_queue.put_nowait(
950966
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
951967

952-
super().add_request(request)
968+
super().add_request(request, request_wave)
953969

954970
def _handle_client_request(self, request_type: EngineCoreRequestType,
955971
request: Any) -> None:

vllm/v1/engine/core_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
250250
return self.engine_core.get_supported_tasks()
251251

252252
def add_request(self, request: EngineCoreRequest) -> None:
253-
self.engine_core.add_request(request)
253+
req, request_wave = self.engine_core.preprocess_add_request(request)
254+
self.engine_core.add_request(req, request_wave)
254255

255256
def abort_requests(self, request_ids: list[str]) -> None:
256257
if len(request_ids) > 0:

0 commit comments

Comments
 (0)