diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 6b8b1519d021..caf458c24497 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -359,13 +359,19 @@ Full example: [examples/offline_inference/audio_language.py](../../examples/offl To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. +You must enable this feature via `enable_mm_embeds=True`. + +!!! warning + The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users! + ??? code ```python from vllm import LLM # Inference with image embeddings as input - llm = LLM(model="llava-hf/llava-1.5-7b-hf") + llm = LLM(model="llava-hf/llava-1.5-7b-hf", enable_mm_embeds=True) # Refer to the HuggingFace repo for the correct format to use prompt = "USER: \nWhat is the content of this image?\nASSISTANT:" @@ -397,7 +403,11 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd image_embeds = torch.load(...) # Qwen2-VL - llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4}) + llm = LLM( + "Qwen/Qwen2-VL-2B-Instruct", + limit_mm_per_prompt={"image": 4}, + enable_mm_embeds=True, + ) mm_data = { "image": { "image_embeds": image_embeds, @@ -407,7 +417,12 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd } # MiniCPM-V - llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={"image": 4}) + llm = LLM( + "openbmb/MiniCPM-V-2_6", + trust_remote_code=True, + limit_mm_per_prompt={"image": 4}, + enable_mm_embeds=True, + ) mm_data = { "image": { "image_embeds": image_embeds, @@ -732,7 +747,13 @@ Full example: [examples/online_serving/openai_chat_completion_client_for_multimo ### Embedding Inputs To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, -pass a tensor of shape to the corresponding field of the multi-modal dictionary. +pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. + +You must enable this feature via the `--enable-mm-embeds` flag in `vllm serve`. + +!!! warning + The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users! #### Image Embedding Inputs diff --git a/docs/features/prompt_embeds.md b/docs/features/prompt_embeds.md index 041025887612..b81d2f28e3b9 100644 --- a/docs/features/prompt_embeds.md +++ b/docs/features/prompt_embeds.md @@ -20,12 +20,16 @@ You can pass prompt embeddings from Hugging Face Transformers models to the `'p ## Online Serving -Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package and are enabled by the `--enable-prompt-embeds` flag in `vllm serve`. When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. Prompt embeddings are passed in as base64 encoded torch tensors. +!!! warning + The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users! + ### Transformers Inputs via OpenAI Client First, launch the OpenAI-compatible server: diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 2c73ed6aa608..8aa3fc1d3c85 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -49,6 +49,7 @@ def __init__(self, model): dtype="float16", enforce_eager=True, model_impl="terratorch", + enable_mm_embeds=True, ) def run(self, input_data, location_coords): diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 6c47b5715438..afe8f056cc5f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -38,6 +38,7 @@ def main(): max_num_seqs=32, io_processor_plugin="prithvi_to_tiff", model_impl="terratorch", + enable_mm_embeds=True, ) pooling_params = PoolingParams(task="token_classify", activation=False) diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index 611a7cbc89fa..fba52fe77139 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -19,6 +19,7 @@ # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager # --io-processor-plugin prithvi_to_tiff +# --enable-mm-embeds def main(): diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 81126a4f16f9..c17486d962f3 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm import LLM @@ -12,8 +13,22 @@ def test_empty_prompt(): llm.generate([""]) -@pytest.mark.skip_v1 def test_out_of_vocab_token(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) with pytest.raises(ValueError, match="out of vocabulary"): llm.generate({"prompt_token_ids": [999999]}) + + +def test_require_mm_embeds(): + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + enforce_eager=True, + enable_mm_embeds=False, + ) + with pytest.raises(ValueError, match="--enable-mm-embeds"): + llm.generate( + { + "prompt": "", + "multi_modal_data": {"image": torch.empty(1, 1, 1)}, + } + ) diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 3ed98ffe0e39..0a057b1848ad 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -292,3 +292,16 @@ async def test_prompt_logprobs_raises_error( temperature=0.0, extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True}, ) + + +@pytest.mark.asyncio +async def test_empty_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, +) -> None: + await client_with_prompt_embeds.completions.create( + model=MODEL_NAME, + prompt="Hello", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": []}, + ) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 3d0885414b24..cd5661e5739f 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io +from unittest.mock import Mock # imports for structured outputs tests import openai @@ -10,7 +11,8 @@ import regex as re import torch -from vllm.entrypoints.renderer import BaseRenderer +from vllm.config import ModelConfig +from vllm.entrypoints.renderer import CompletionRenderer from ...utils import RemoteOpenAIServer @@ -59,6 +61,10 @@ async def test_out_of_vocab_token_ids(): def test_load_prompt_embeds( dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int ): + model_config = Mock(spec=ModelConfig) + model_config.enable_prompt_embeds = True + renderer = CompletionRenderer(model_config, tokenizer=None) + # construct arbitrary tensors of various dtypes, layouts, and sizes. # We need to check against different layouts to make sure that if a user # uses sparse tensors to reduce the transmission size of prompt embeddings, @@ -83,7 +89,7 @@ def test_load_prompt_embeds( buffer.seek(0) encoded_tensor = pybase64.b64encode(buffer.getvalue()) - loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor) + loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor) assert len(loaded_prompt_embeds) == 1 loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] assert loaded_tensor.device.type == "cpu" @@ -91,3 +97,22 @@ def test_load_prompt_embeds( torch.testing.assert_close( loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True ) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("seq_len", [2]) +@pytest.mark.parametrize("hidden_size", [2]) +def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int): + model_config = Mock(spec=ModelConfig) + model_config.enable_prompt_embeds = False + renderer = CompletionRenderer(model_config, tokenizer=None) + + tensor = torch.randn((seq_len, hidden_size), dtype=dtype) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + encoded_tensor = pybase64.b64encode(buffer.getvalue()) + + with pytest.raises(ValueError, match="--enable-prompt-embeds"): + renderer.load_prompt_embeds(encoded_tensor) diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_vision_embeds.py similarity index 76% rename from tests/entrypoints/openai/test_skip_tokenizer.py rename to tests/entrypoints/openai/test_vision_embeds.py index 6998566c03d0..a6593c5b05e2 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_vision_embeds.py @@ -15,30 +15,7 @@ DTYPE = "float16" -@pytest.fixture(scope="module") -def server(): - args = [ - "--runner", - "pooling", - # use half precision for speed and memory savings in CI environment - "--dtype", - DTYPE, - "--enforce-eager", - "--trust-remote-code", - "--skip-tokenizer-init", - "--max-num-seqs", - "32", - "--model-impl", - "terratorch", - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_request(server: RemoteOpenAIServer, model_name: str): +def _terratorch_dummy_inputs(model_name: str): pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) @@ -54,7 +31,7 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): binary_data = buffer_coord.read() base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") - prompt = { + return { "model": model_name, "additional_data": {"prompt_token_ids": [1]}, "encoding_format": "base64", @@ -74,12 +51,33 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): ], } - # test single pooling - response = requests.post(server.url_for("pooling"), json=prompt) - response.raise_for_status() - output = response.json()["data"][0]["data"] +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_request(model_name: str): + args = [ + "--runner", + "pooling", + # use half precision for speed and memory savings in CI environment + "--dtype", + DTYPE, + "--enforce-eager", + "--trust-remote-code", + "--max-num-seqs", + "32", + "--model-impl", + "terratorch", + "--skip-tokenizer-init", + "--enable-mm-embeds", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as server: + prompt = _terratorch_dummy_inputs(model_name) + + # test single pooling + response = requests.post(server.url_for("pooling"), json=prompt) + response.raise_for_status() - np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) + output = response.json()["data"][0]["data"] - assert len(np_response) == 524288 + np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) + assert len(np_response) == 524288 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 8c7e626830ce..f5faa7a3507c 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -73,6 +73,19 @@ def phi3v_model_config_mm_interleaved(): ) +@pytest.fixture(scope="function") +def phi3v_model_config_image_embeds(): + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "image": 2, + }, + enable_mm_embeds=True, + ) + + @pytest.fixture(scope="module") def phi3v_tokenizer(): return get_tokenizer(PHI3V_MODEL_ID) @@ -799,7 +812,7 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( def test_parse_chat_messages_empty_image_embeds_with_uuid( - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, ): uuid = "abcd" @@ -813,7 +826,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( ], } ], - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, content_format="string", ) @@ -832,7 +845,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( @pytest.mark.asyncio async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, ): uuid = "abcd" @@ -846,7 +859,7 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( ], } ], - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, content_format="string", ) diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py index c811a6ba63cb..b0ef3dd045bd 100644 --- a/tests/entrypoints/test_renderer.py +++ b/tests/entrypoints/test_renderer.py @@ -17,6 +17,7 @@ class MockModelConfig: max_model_len: int = 100 encoder_config: dict | None = None + enable_prompt_embeds: bool = True class MockTokenizerResult: diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 44bbc4479ca4..1aace66480f8 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -109,8 +109,7 @@ limit_mm_per_prompt={"image": 4}, ) ], - # TODO: Revert to "auto" when CPU backend can use torch > 2.6 - dtype="bfloat16" if current_platform.is_cpu() else "auto", + vllm_runner_kwargs={"enable_mm_embeds": True}, marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "qwen2_5_vl": VLMTestInfo( diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index a4abf6e405f7..e10b8e1e77af 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -292,6 +292,7 @@ def run_embedding_input_test( tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, default_torch_num_threads=1, + enable_mm_embeds=True, ) as vllm_model: outputs_per_case_for_original_input = [ vllm_model.generate_greedy_logprobs( diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 62154b083487..676076c45847 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -34,6 +34,7 @@ def _run_test( dtype="half", enforce_eager=True, skip_tokenizer_init=True, + enable_mm_embeds=True, # Limit the maximum number of sequences to avoid the # test going OOM during the warmup run max_num_seqs=32, diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 6074cdef1bd1..9edab1ad67fc 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -104,6 +104,11 @@ def _initialize_kv_caches_v1(self, vllm_config): m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") if model_arch == "WhisperForConditionalGeneration": m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + + extra_args = {} + if model_arch in ("PrithviGeoSpatialMAE", "Terratorch"): + extra_args["enable_mm_embeds"] = True + LLM( model_info.default, tokenizer=model_info.tokenizer, @@ -128,6 +133,7 @@ def _initialize_kv_caches_v1(self, vllm_config): else "vllm", hf_overrides=hf_overrides_fn, max_num_seqs=model_info.max_num_seqs, + **extra_args, ) diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py index cadce5d2b2bb..15764145bc1a 100644 --- a/tests/models/test_terratorch.py +++ b/tests/models/test_terratorch.py @@ -32,6 +32,7 @@ def test_inference( dtype="half", enforce_eager=True, skip_tokenizer_init=True, + enable_mm_embeds=True, # Limit the maximum number of sequences to avoid the # test going OOM during the warmup run max_num_seqs=32, diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 936f27fb69bc..e4a60f95eb87 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -38,6 +38,7 @@ def server(): "prithvi_to_tiff", "--model-impl", "terratorch", + "--enable-mm-embeds", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index c66a66b84b62..736ccbefbc4d 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -6,7 +6,6 @@ import pytest import pytest_asyncio import regex as re -import requests from openai import BadRequestError from tests.utils import RemoteOpenAIServer @@ -686,17 +685,3 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): "structured_outputs": {"grammar": invalid_simplified_sql_grammar} }, ) - - -@pytest.mark.asyncio -async def test_completion_with_empty_prompt_embeds(client: openai.AsyncOpenAI) -> None: - """Test completion with empty prompt embeds.""" - payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []} - headers: dict[str, str] = {"Content-Type": "application/json"} - # base_url = http://localhost:8000/v1/completions - response = requests.post( - f"{client.base_url}completions", headers=headers, json=payload - ) - assert response.status_code == 200, ( - f"Expected status code 200, got {response.status_code}. " - ) diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py index 3c2b3de33958..276de2ff8e2c 100644 --- a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py +++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py @@ -32,6 +32,7 @@ def default_image_embeds_server_args() -> list[str]: "--enforce-eager", "--limit-mm-per-prompt", json.dumps({"image": MAXIMUM_IMAGES}), + "--enable-mm-embeds", ] diff --git a/vllm/config/model.py b/vllm/config/model.py index 7bf8b4bfc15a..27bcbf90c2bc 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -232,8 +232,10 @@ class ModelConfig: output will contain token ids.""" enable_prompt_embeds: bool = False """If `True`, enables passing text embeddings as inputs via the - `prompt_embeds` key. Note that enabling this will double the time required - for graph compilation.""" + `prompt_embeds` key. + + WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users!""" served_model_name: str | list[str] | None = None """The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the @@ -303,6 +305,7 @@ class ModelConfig: """Configuration for multimodal model. If `None`, this will be inferred from the architecture of `self.model`.""" limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None + enable_mm_embeds: InitVar[bool | None] = None media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None mm_processor_kwargs: InitVar[dict[str, Any] | None] = None mm_processor_cache_gb: InitVar[float | None] = None @@ -421,6 +424,7 @@ def __post_init__( self, # Multimodal config init vars limit_mm_per_prompt: dict[str, int] | None, + enable_mm_embeds: bool | None, media_io_kwargs: dict[str, dict[str, Any]] | None, mm_processor_kwargs: dict[str, Any] | None, mm_processor_cache_gb: float | None, @@ -731,6 +735,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType: mm_config_kwargs = dict( limit_per_prompt=limit_mm_per_prompt, + enable_mm_embeds=enable_mm_embeds, media_io_kwargs=media_io_kwargs, mm_processor_kwargs=mm_processor_kwargs, mm_processor_cache_gb=mm_processor_cache_gb, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index e80d072dab45..ef73720efe09 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -75,6 +75,14 @@ class MultiModalConfig: {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512, "height": 512}} """ + enable_mm_embeds: bool = False + """If `True`, enables passing multimodal embeddings: + for `LLM` class, this refers to tensor inputs under `multi_modal_data`; + for the OpenAI-compatible server, this refers to chat messages with content + `"type": "*_embeds"`. + + WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users!""" media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict) """Additional args passed to process media inputs, keyed by modalities. For example, to set num_frames for video, set diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a06ec92b51c8..ce41c377e457 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -438,6 +438,7 @@ class EngineArgs: limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field( MultiModalConfig, "limit_per_prompt" ) + enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings media_io_kwargs: dict[str, dict[str, Any]] = get_field( MultiModalConfig, "media_io_kwargs" @@ -896,6 +897,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] ) + multimodal_group.add_argument( + "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"] + ) multimodal_group.add_argument( "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"] ) @@ -1159,6 +1163,7 @@ def create_model_config(self) -> ModelConfig: enable_prompt_embeds=self.enable_prompt_embeds, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, + enable_mm_embeds=self.enable_mm_embeds, interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, skip_mm_profiling=self.skip_mm_profiling, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 9bed327723b9..4c73e94fb72b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -811,6 +811,10 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: allowed_media_domains=tracker.allowed_media_domains, ) + @property + def model_config(self) -> ModelConfig: + return self._tracker.model_config + def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: image = self._connector.fetch_image(image_url) if image_url else None @@ -822,6 +826,12 @@ def parse_image_embeds( image_embeds: str | dict[str, str] | None, uuid: str | None = None, ) -> None: + mm_config = self.model_config.get_multimodal_config() + if not mm_config.enable_mm_embeds: + raise ValueError( + "You must set `--enable-mm-embeds` to input `image_embeds`" + ) + if isinstance(image_embeds, dict): embeds = { k: self._connector.fetch_image_embedding(v) @@ -886,6 +896,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: allowed_media_domains=tracker.allowed_media_domains, ) + @property + def model_config(self) -> ModelConfig: + return self._tracker.model_config + def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: image_coro = self._connector.fetch_image_async(image_url) if image_url else None @@ -897,6 +911,12 @@ def parse_image_embeds( image_embeds: str | dict[str, str] | None, uuid: str | None = None, ) -> None: + mm_config = self.model_config.get_multimodal_config() + if not mm_config.enable_mm_embeds: + raise ValueError( + "You must set `--enable-mm-embeds` to input `image_embeds`" + ) + future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future() if isinstance(image_embeds, dict): diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index a845528200d5..3c5a396a99f9 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -156,14 +156,17 @@ async def render_prompt_and_embeds( """ raise NotImplementedError - @classmethod def load_prompt_embeds( - cls, + self, prompt_embeds: bytes | list[bytes], truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, cache_salt: str | None = None, ) -> list[EngineEmbedsPrompt]: """Load and validate base64-encoded embeddings into prompt objects.""" + if not self.model_config.enable_prompt_embeds: + raise ValueError( + "You must set `--enable-prompt-embeds` to input `prompt_embeds`." + ) def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: tensor = torch.load( diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 94122c1d4cc9..55132a6036ef 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1308,6 +1308,16 @@ def _to_mm_items( [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. """ mm_items = self.data_parser.parse_mm_data(mm_data) + + mm_config = self.info.ctx.model_config.get_multimodal_config() + if not mm_config.enable_mm_embeds: + for modality, items in mm_items.items(): + if isinstance(items, (EmbeddingItems, DictEmbeddingItems)): + raise ValueError( + f"You must set `--enable-mm-embeds` to input " + f"`{modality}_embeds`" + ) + for modality, items in mm_items.items(): self.validate_num_items(modality, len(items))