diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index f1200103171e..1fbbba7ace5e 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -45,14 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks], enabling the corresponding APIs: -| Task | APIs | -|------------|--------------------| -| `encode` | `encode` | -| `embed` | `embed`, `score`\* | -| `classify` | `classify` | -| `score` | `score` | +| Task | APIs | +|------------|--------------------------------------| +| `encode` | `LLM.reward(...)` | +| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* | +| `classify` | `LLM.classify(...)` | +| `score` | `LLM.score(...)` | -\* The `score` API falls back to `embed` task if the model does not support `score` task. +\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task. ### Pooler Configuration @@ -66,11 +66,11 @@ you can override some of its attributes via the `--override-pooler-config` optio If the model has been converted via `--convert` (see above), the pooler assigned to each task has the following attributes by default: -| Task | Pooling Type | Normalization | Softmax | -|------------|----------------|---------------|---------| -| `encode` | `ALL` | ❌ | ❌ | -| `embed` | `LAST` | ✅︎ | ❌ | -| `classify` | `LAST` | ❌ | ✅︎ | +| Task | Pooling Type | Normalization | Softmax | +|------------|--------------|---------------|---------| +| `reward` | `ALL` | ❌ | ❌ | +| `embed` | `LAST` | ✅︎ | ❌ | +| `classify` | `LAST` | ❌ | ✅︎ | When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults. @@ -83,21 +83,6 @@ which takes priority over both the model's and Sentence Transformers's defaults. The [LLM][vllm.LLM] class provides various methods for offline inference. See [configuration][configuration] for a list of options when initializing the model. -### `LLM.encode` - -The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. -It returns the extracted hidden states directly, which is useful for reward models. - -```python -from vllm import LLM - -llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", runner="pooling") -(output,) = llm.encode("Hello, my name is") - -data = output.outputs.data -print(f"Data: {data!r}") -``` - ### `LLM.embed` The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt. @@ -106,7 +91,7 @@ It is primarily designed for embedding models. ```python from vllm import LLM -llm = LLM(model="intfloat/e5-mistral-7b-instruct", runner="pooling") +llm = LLM(model="intfloat/e5-small", runner="pooling") (output,) = llm.embed("Hello, my name is") embeds = output.outputs.embedding @@ -154,6 +139,46 @@ print(f"Score: {score}") A code example can be found here: +### `LLM.reward` + +The [reward][vllm.LLM.reward] method is available to all reward models in vLLM. +It returns the extracted hidden states directly. + +```python +from vllm import LLM + +llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True) +(output,) = llm.reward("Hello, my name is") + +data = output.outputs.data +print(f"Data: {data!r}") +``` + +A code example can be found here: + +### `LLM.encode` + +The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. +It returns the extracted hidden states directly. + +!!! note + Please use one of the more specific methods or set the task directly when using `LLM.encode`: + + - For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`. + - For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`. + - For rewards, use `LLM.reward(...)` or `pooling_task="reward"`. + - For similarity scores, use `LLM.score(...)`. + +```python +from vllm import LLM + +llm = LLM(model="intfloat/e5-small", runner="pooling") +(output,) = llm.encode("Hello, my name is", pooling_task="embed") + +data = output.outputs.data +print(f"Data: {data!r}") +``` + ## Online Serving Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index 526753bcef22..158836728bee 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -12,10 +12,9 @@ def parse_args(): parser = EngineArgs.add_cli_args(parser) # Set example specific arguments parser.set_defaults( - model="intfloat/e5-mistral-7b-instruct", + model="intfloat/e5-small", runner="pooling", enforce_eager=True, - max_model_len=1024, ) return parser.parse_args() diff --git a/examples/offline_inference/basic/reward.py b/examples/offline_inference/basic/reward.py new file mode 100644 index 000000000000..aa173cf96f5b --- /dev/null +++ b/examples/offline_inference/basic/reward.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="internlm/internlm2-1_8b-reward", + runner="pooling", + enforce_eager=True, + max_model_len=1024, + trust_remote_code=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create an LLM. + # You should pass runner="pooling" for reward models + llm = LLM(**vars(args)) + + # Generate rewards. The output is a list of PoolingRequestOutput. + outputs = llm.reward(prompts) + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + rewards = output.outputs.data + rewards_trimmed = ( + (str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards + ) + print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})") + print("-" * 60) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/conftest.py b/tests/conftest.py index e4df6ebf2c26..67f0e7424038 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1053,6 +1053,10 @@ def encode(self, prompts: list[str]) -> list[list[float]]: req_outputs = self.llm.encode(prompts) return [req_output.outputs.data for req_output in req_outputs] + def reward(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.reward(prompts) + return [req_output.outputs.data for req_output in req_outputs] + def score( self, text_1: Union[str, list[str]], diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index 3b7fab3ba5c9..a5f7dca76d82 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -95,7 +95,7 @@ def test_prm_models( monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(math_step_prompts) + vllm_outputs = vllm_model.reward(math_step_prompts) with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: hf_model = step_reward_patch_hf_model(hf_model) diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index dc2bf21ef63b..c6ef899958a0 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner, with vllm_runner(model_name, runner="pooling", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.llm.encode( + vllm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner, with vllm_runner(model_name, runner="pooling", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.llm.encode( + vllm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner, model_name, runner="pooling", max_model_len=max_model_len) as vllm_model: - llm_output = vllm_model.llm.encode( + llm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) assert llm_output == f"""truncate_prompt_tokens value diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index adef350931f3..842a22ccebaa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1037,7 +1037,7 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", + pooling_task: Optional[PoolingTask] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input @@ -1069,6 +1069,25 @@ def encode( considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ + if pooling_task is None: + if "embed" in self.supported_tasks: + pooling_task = "embed" + else: + pooling_task = "encode" + + logger.warning_once( + "`LLM.encode` is currently using `pooling_task = %s`.\n" + "Please use one of the more specific methods or set the " + "task directly when using `LLM.encode`:\n" + " - For embeddings, use `LLM.embed(...)` " + "or `pooling_task=\"embed\"`.\n" + " - For classification logits, use `LLM.classify(...)` " + "or `pooling_task=\"classify\"`.\n" + " - For rewards, use `LLM.reward(...)` " + "or `pooling_task=\"reward\"`\n" + " - For similarity scores, use `LLM.score(...)`.", + pooling_task) + model_config = self.llm_engine.model_config runner_type = model_config.runner_type if runner_type != "pooling": @@ -1207,6 +1226,45 @@ def classify( return [ClassificationRequestOutput.from_base(item) for item in items] + def reward( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + ) -> list[PoolingRequestOutput]: + """ + Generate rewards for each prompt. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See [PromptType][vllm.inputs.PromptType] + for more details about the format of each prompts. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + Returns: + A list of `PoolingRequestOutput` objects containing the + pooled hidden states in the same order as the input prompts. + """ + + return self.encode( + prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + pooling_params=pooling_params, + truncate_prompt_tokens=truncate_prompt_tokens, + pooling_task="encode", + ) + def _embedding_score( self, tokenizer: AnyTokenizer,