Skip to content

Commit e0fd906

Browse files
authored
Merge branch 'vllm-project:main' into Eagle-mulitmodal-support-Qwen2.5vl
2 parents 0853e02 + 540d54c commit e0fd906

File tree

9 files changed

+84
-50
lines changed

9 files changed

+84
-50
lines changed

docs/getting_started/installation/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ vLLM supports the following hardware platforms:
1818
## Hardware Plugins
1919

2020
The backends below live **outside** the main `vllm` repository and follow the
21-
[Hardware-Pluggable RFC](../design/plugin_system.md).
21+
[Hardware-Pluggable RFC](../../design/plugin_system.md).
2222

2323
| Accelerator | PyPI / package | Repository |
2424
|-------------|----------------|------------|

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ async def test_bad_requests(mary_had_lamb):
8080
async def test_long_audio_request(mary_had_lamb, model_name):
8181
server_args = ["--enforce-eager"]
8282

83-
if model_name.startswith("openai"):
84-
return
85-
8683
mary_had_lamb.seek(0)
8784
audio, sr = librosa.load(mary_had_lamb)
8885
# Add small silence after each audio for repeatability in the split process

tests/models/multimodal/test_tensor_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,4 @@ def validate_model_input(model):
153153
if hasattr(model, method_name):
154154
getattr(model, method_name)(**mm_kwargs)
155155

156-
vllm_model.apply_model(validate_model_input)
156+
vllm_model.apply_model(validate_model_input)

vllm/engine/async_llm_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,7 @@ async def reset_prefix_cache(self,
10921092
self.engine.reset_prefix_cache(device)
10931093

10941094
async def sleep(self, level: int = 1) -> None:
1095+
await self.reset_prefix_cache()
10951096
self.engine.sleep(level)
10961097

10971098
async def wake_up(self, tags: Optional[list[str]] = None) -> None:

vllm/model_executor/models/llava.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ class PixtralHFImagePixelInputs(TensorSchema):
7272
in which case the data is passed as a list instead of a batched tensor.
7373
"""
7474
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
75-
pixel_values: Annotated[Union[torch.Tensor, list[torch.Tensor]],
76-
TensorShape("bn", "c", "h", "w")]
75+
pixel_values: Annotated[
76+
Union[torch.Tensor, list[torch.Tensor]],
77+
TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"})]
7778

7879

7980
class LlavaImageEmbeddingInputs(TensorSchema):

vllm/transformers_utils/config.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -449,23 +449,6 @@ def get_config(
449449
raise e
450450
config = _maybe_remap_hf_config_attrs(config)
451451

452-
# Phi4Flash misuses this config as list[int]. Convert it to int and add
453-
# the layer_types list[str] to make it HF compatible
454-
if (config.model_type == "phi4flash"):
455-
# TODO: Remove after the following PR is merged:
456-
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
457-
if not hasattr(config, "layer_types"):
458-
config.layer_types = [
459-
"sliding_attention" if i < config.num_hidden_layers // 2
460-
and i % 2 == 1 else "full_attention"
461-
for i in range(config.num_hidden_layers)
462-
]
463-
# TODO: Remove after the following PR is merged:
464-
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
465-
if isinstance(config.sliding_window, list):
466-
config.sliding_window = next(
467-
filter(None, config.sliding_window), None)
468-
469452
elif config_format == ConfigFormat.MISTRAL:
470453
# This function loads a params.json config which
471454
# should be used when loading models in mistral format

vllm/utils/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,28 @@ def cancel_tasks():
709709

710710

711711
def cancel_task_threadsafe(task: Task):
712-
if task and not task.done() and not (loop := task.get_loop()).is_closed():
713-
loop.call_soon_threadsafe(task.cancel)
712+
if task and not task.done():
713+
run_in_loop(task.get_loop(), task.cancel)
714+
715+
716+
def close_sockets(sockets: Sequence[Union[zmq.Socket, zmq.asyncio.Socket]]):
717+
for sock in sockets:
718+
if sock is not None:
719+
sock.close(linger=0)
720+
721+
722+
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
723+
if in_loop(loop):
724+
function(*args)
725+
elif not loop.is_closed():
726+
loop.call_soon_threadsafe(function, *args)
727+
728+
729+
def in_loop(event_loop: AbstractEventLoop) -> bool:
730+
try:
731+
return asyncio.get_running_loop() == event_loop
732+
except RuntimeError:
733+
return False
714734

715735

716736
def make_async(

vllm/v1/engine/async_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ async def reset_prefix_cache(self,
576576
await self.engine_core.reset_prefix_cache_async()
577577

578578
async def sleep(self, level: int = 1) -> None:
579+
await self.reset_prefix_cache()
579580
await self.engine_core.sleep_async(level)
580581

581582
async def wake_up(self, tags: Optional[list[str]] = None) -> None:

vllm/v1/engine/core_client.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from vllm.logger import init_logger
2424
from vllm.lora.request import LoRARequest
2525
from vllm.tasks import SupportedTask
26-
from vllm.utils import (cancel_task_threadsafe, get_open_port,
27-
get_open_zmq_inproc_path, make_zmq_socket)
26+
from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path,
27+
in_loop, make_zmq_socket)
2828
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
2929
EngineCoreRequestType,
3030
ReconfigureDistributedRequest, ReconfigureRankType,
@@ -317,7 +317,7 @@ class BackgroundResources:
317317
"""Used as a finalizer for clean shutdown, avoiding
318318
circular reference back to the client object."""
319319

320-
ctx: Union[zmq.Context]
320+
ctx: zmq.Context
321321
# If CoreEngineProcManager, it manages local engines;
322322
# if CoreEngineActorManager, it manages all engines.
323323
engine_manager: Optional[Union[CoreEngineProcManager,
@@ -326,6 +326,8 @@ class BackgroundResources:
326326
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
327327
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
328328
first_req_send_socket: Optional[zmq.asyncio.Socket] = None
329+
first_req_rcv_socket: Optional[zmq.asyncio.Socket] = None
330+
stats_update_socket: Optional[zmq.asyncio.Socket] = None
329331
output_queue_task: Optional[asyncio.Task] = None
330332
stats_update_task: Optional[asyncio.Task] = None
331333
shutdown_path: Optional[str] = None
@@ -343,23 +345,47 @@ def __call__(self):
343345
if self.coordinator is not None:
344346
self.coordinator.close()
345347

346-
cancel_task_threadsafe(self.output_queue_task)
347-
cancel_task_threadsafe(self.stats_update_task)
348+
if isinstance(self.output_socket, zmq.asyncio.Socket):
349+
# Async case.
350+
loop = self.output_socket._get_loop()
351+
asyncio.get_running_loop()
352+
sockets = (self.output_socket, self.input_socket,
353+
self.first_req_send_socket, self.first_req_rcv_socket,
354+
self.stats_update_socket)
355+
356+
tasks = (self.output_queue_task, self.stats_update_task)
357+
358+
def close_sockets_and_tasks():
359+
close_sockets(sockets)
360+
for task in tasks:
361+
if task is not None and not task.done():
362+
task.cancel()
363+
364+
if in_loop(loop):
365+
close_sockets_and_tasks()
366+
elif not loop.is_closed():
367+
loop.call_soon_threadsafe(close_sockets_and_tasks)
368+
else:
369+
# Loop has been closed, try to clean up directly.
370+
del tasks
371+
del close_sockets_and_tasks
372+
close_sockets(sockets)
373+
del self.output_queue_task
374+
del self.stats_update_task
375+
else:
376+
# Sync case.
348377

349-
# ZMQ context termination can hang if the sockets
350-
# aren't explicitly closed first.
351-
for socket in (self.output_socket, self.input_socket,
352-
self.first_req_send_socket):
353-
if socket is not None:
354-
socket.close(linger=0)
378+
# ZMQ context termination can hang if the sockets
379+
# aren't explicitly closed first.
380+
close_sockets((self.output_socket, self.input_socket))
355381

356-
if self.shutdown_path is not None:
357-
# We must ensure that the sync output socket is
358-
# closed cleanly in its own thread.
359-
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
360-
shutdown_sender.connect(self.shutdown_path)
361-
# Send shutdown signal.
362-
shutdown_sender.send(b'')
382+
if self.shutdown_path is not None:
383+
# We must ensure that the sync output socket is
384+
# closed cleanly in its own thread.
385+
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
386+
shutdown_sender.connect(self.shutdown_path)
387+
# Send shutdown signal.
388+
shutdown_sender.send(b'')
363389

364390
def validate_alive(self, frames: Sequence[zmq.Frame]):
365391
if len(frames) == 1 and (frames[0].buffer
@@ -969,14 +995,19 @@ def _ensure_stats_update_task(self):
969995
self.engine_ranks_managed[-1] + 1)
970996

971997
async def run_engine_stats_update_task():
972-
with make_zmq_socket(self.ctx, self.stats_update_address,
973-
zmq.XSUB) as socket, make_zmq_socket(
974-
self.ctx,
975-
self.first_req_sock_addr,
976-
zmq.PAIR,
977-
bind=False) as first_req_rcv_socket:
998+
with (make_zmq_socket(self.ctx,
999+
self.stats_update_address,
1000+
zmq.XSUB,
1001+
linger=0) as socket,
1002+
make_zmq_socket(self.ctx,
1003+
self.first_req_sock_addr,
1004+
zmq.PAIR,
1005+
bind=False,
1006+
linger=0) as first_req_rcv_socket):
9781007
assert isinstance(socket, zmq.asyncio.Socket)
9791008
assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket)
1009+
self.resources.stats_update_socket = socket
1010+
self.resources.first_req_rcv_socket = first_req_rcv_socket
9801011
# Send subscription message.
9811012
await socket.send(b'\x01')
9821013

0 commit comments

Comments
 (0)