diff --git a/tests/models/multimodal/processing/test_transformers.py b/tests/models/multimodal/processing/test_transformers.py new file mode 100644 index 000000000000..c7d1b5271ff7 --- /dev/null +++ b/tests/models/multimodal/processing/test_transformers.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.assets.image import ImageAsset +from vllm.config import ModelConfig +from vllm.multimodal import MULTIMODAL_REGISTRY + + +# yapf: disable +@pytest.mark.parametrize("model_id", + ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +def test_multimodal_processor(model_id): + model_config = ModelConfig( + model=model_id, + model_impl="transformers", + ) + + mm_processor = MULTIMODAL_REGISTRY.create_processor(model_config, ) + + image_pil = ImageAsset('cherry_blossom').pil_image + mm_data = {"image": image_pil} + str_prompt = "<|im_start|>user \nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 + str_processed_inputs = mm_processor.apply( + prompt=str_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + ids_prompt = [ + 151644, 872, 220, 151646, 198, 3838, 374, 279, 2213, 315, 419, 2168, + 30, 151645, 151644, 77091, 198 + ] + ids_processed_inputs = mm_processor.apply( + prompt=ids_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert str_processed_inputs["prompt"] == ids_processed_inputs["prompt"] diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 47cff29caab0..8086f1bebac3 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -325,6 +325,11 @@ def apply( mm_items = self._to_mm_items(mm_data) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + if not isinstance(prompt, str): + # the prompt is the tokenized ids which is not supported + # by the hf_processor, which is why we would need to decode the ids + # into string + prompt = hf_processor.decode(prompt) (prompt_ids, processed_data, mm_token_type_ids) = self._apply_hf_processor_text_mm(