From be054e27597afd6603dc5319678ccae4abc0d631 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 13:51:53 +0800 Subject: [PATCH 01/23] + reward Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 49 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 98921a49fad6..44fcaa9f37e0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1060,7 +1060,7 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", + pooling_task: Optional[PoolingTask] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input @@ -1092,6 +1092,14 @@ def encode( considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ + if pooling_task is None: + raise ValueError( + "`pooling_task` must be specified. " + "Get embedding prefer `LLm.embed`. " + "Get classification logits prefer to `LLm.classify`. " + "Get reward scores prefer `LLm.reward`. " + "Get pairs similarity scores prefer `LLm.score`.") + model_config = self.llm_engine.model_config runner_type = model_config.runner_type if runner_type != "pooling": @@ -1230,6 +1238,45 @@ def classify( return [ClassificationRequestOutput.from_base(item) for item in items] + def reward( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + ) -> list[PoolingRequestOutput]: + """ + Generate reward scores for each prompt. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See [PromptType][vllm.inputs.PromptType] + for more details about the format of each prompts. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + Returns: + A list of `PoolingRequestOutput` objects containing the + pooled hidden states in the same order as the input prompts. + """ + + return self.encode( + prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + pooling_params=pooling_params, + truncate_prompt_tokens=truncate_prompt_tokens, + pooling_task="encode", + ) + def _embedding_score( self, tokenizer: AnyTokenizer, From f33d0847b06867bdb1c8e153ab87312f5c1033b4 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 14:02:24 +0800 Subject: [PATCH 02/23] + fix Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 44fcaa9f37e0..805316b7714b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1094,11 +1094,12 @@ def encode( """ if pooling_task is None: raise ValueError( - "`pooling_task` must be specified. " - "Get embedding prefer `LLm.embed`. " - "Get classification logits prefer to `LLm.classify`. " - "Get reward scores prefer `LLm.reward`. " - "Get pairs similarity scores prefer `LLm.score`.") + "`pooling_task` must be specified. Please use one of the more " + "specific methods instead of `encode`:\n" + " - For embeddings, use `LLM.embed(...)`.\n" + " - For classification logits, use `LLM.classify(...)`.\n" + " - For reward scores, use `LLM.reward(...)`.\n" + " - For similarity scores, use `LLM.score(...)`.") model_config = self.llm_engine.model_config runner_type = model_config.runner_type From 5025a2f7d51449def27da9f6b7eddc3dbdd11fbb Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 14:13:19 +0800 Subject: [PATCH 03/23] + warning Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 805316b7714b..ec04002337d9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1093,13 +1093,14 @@ def encode( instead pass them via the `inputs` parameter. """ if pooling_task is None: - raise ValueError( + logger.warning( "`pooling_task` must be specified. Please use one of the more " "specific methods instead of `encode`:\n" " - For embeddings, use `LLM.embed(...)`.\n" " - For classification logits, use `LLM.classify(...)`.\n" " - For reward scores, use `LLM.reward(...)`.\n" " - For similarity scores, use `LLM.score(...)`.") + pooling_task = "embed" model_config = self.llm_engine.model_config runner_type = model_config.runner_type From 3bf071eb26de2222cc1f4eb64a6c72a693e45b04 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 15:26:08 +0800 Subject: [PATCH 04/23] + reward Signed-off-by: wang.yuqi --- docs/models/pooling_models.md | 54 +++++++++++++------ tests/conftest.py | 4 +- tests/models/language/pooling/test_reward.py | 2 +- .../pooling/test_truncation_control.py | 6 +-- vllm/entrypoints/llm.py | 4 +- 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index a06d86523af1..556f849be5fc 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -83,21 +83,6 @@ which takes priority over both the model's and Sentence Transformers's defaults. The [LLM][vllm.LLM] class provides various methods for offline inference. See [configuration][configuration] for a list of options when initializing the model. -### `LLM.encode` - -The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. -It returns the extracted hidden states directly, which is useful for reward models. - -```python -from vllm import LLM - -llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", runner="pooling") -(output,) = llm.encode("Hello, my name is") - -data = output.outputs.data -print(f"Data: {data!r}") -``` - ### `LLM.embed` The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt. @@ -106,7 +91,7 @@ It is primarily designed for embedding models. ```python from vllm import LLM -llm = LLM(model="intfloat/e5-mistral-7b-instruct", runner="pooling") +llm = LLM(model="intfloat/e5-small", runner="pooling") (output,) = llm.embed("Hello, my name is") embeds = output.outputs.embedding @@ -154,6 +139,43 @@ print(f"Score: {score}") A code example can be found here: +### `LLM.reward` + +The [reward][vllm.LLM.reward] method is available to all pooling models in vLLM. +It returns the extracted hidden states directly, which is useful for reward models. + +```python +from vllm import LLM + +llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True) +(output,) = llm.reward("Hello, my name is") + +data = output.outputs.data +print(f"Data: {data!r}") +``` + +### `LLM.encode` + +The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. +It returns the extracted hidden states directly. + +!!! note + `LLM.encode` defaults to using `pooling_task = embed`. + - For embeddings, use `LLM.embed(...)`. + - For classification logits, use `LLM.classify(...)`. + - For reward scores, use `LLM.reward(...)`. + - For similarity scores, use `LLM.score(...)`. + +```python +from vllm import LLM + +llm = LLM(model="intfloat/e5-small", runner="pooling") +(output,) = llm.encode("Hello, my name is") + +data = output.outputs.data +print(f"Data: {data!r}") +``` + ## Online Serving Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: diff --git a/tests/conftest.py b/tests/conftest.py index e4df6ebf2c26..62c30d76c29f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1049,8 +1049,8 @@ def embed(self, req_outputs = self.llm.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] - def encode(self, prompts: list[str]) -> list[list[float]]: - req_outputs = self.llm.encode(prompts) + def reward(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.reward(prompts) return [req_output.outputs.data for req_output in req_outputs] def score( diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index 3b7fab3ba5c9..a5f7dca76d82 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -95,7 +95,7 @@ def test_prm_models( monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(math_step_prompts) + vllm_outputs = vllm_model.reward(math_step_prompts) with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: hf_model = step_reward_patch_hf_model(hf_model) diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index dc2bf21ef63b..c6ef899958a0 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner, with vllm_runner(model_name, runner="pooling", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.llm.encode( + vllm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner, with vllm_runner(model_name, runner="pooling", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.llm.encode( + vllm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner, model_name, runner="pooling", max_model_len=max_model_len) as vllm_model: - llm_output = vllm_model.llm.encode( + llm_output = vllm_model.llm.embed( input_str, truncate_prompt_tokens=truncate_prompt_tokens) assert llm_output == f"""truncate_prompt_tokens value diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ec04002337d9..79f985578379 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1094,8 +1094,8 @@ def encode( """ if pooling_task is None: logger.warning( - "`pooling_task` must be specified. Please use one of the more " - "specific methods instead of `encode`:\n" + "`LLM.encode` defaults to using `pooling_task = embed`.\n" + "Please use one of the more specific methods instead of `encode`:\n" " - For embeddings, use `LLM.embed(...)`.\n" " - For classification logits, use `LLM.classify(...)`.\n" " - For reward scores, use `LLM.reward(...)`.\n" From 332ffea95bbacfd88a84e95b35e7a0785c9754a7 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 16:07:41 +0800 Subject: [PATCH 05/23] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 79f985578379..254601117ced 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1095,7 +1095,8 @@ def encode( if pooling_task is None: logger.warning( "`LLM.encode` defaults to using `pooling_task = embed`.\n" - "Please use one of the more specific methods instead of `encode`:\n" + "Please use one of the more specific methods instead " + "of `encode`:\n" " - For embeddings, use `LLM.embed(...)`.\n" " - For classification logits, use `LLM.classify(...)`.\n" " - For reward scores, use `LLM.reward(...)`.\n" From 6dc06a082a894bd38c95061c0456e908a83292dc Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 16:13:28 +0800 Subject: [PATCH 06/23] fix Signed-off-by: wang.yuqi --- docs/models/pooling_models.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 556f849be5fc..4ff1810d7612 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -161,6 +161,7 @@ It returns the extracted hidden states directly. !!! note `LLM.encode` defaults to using `pooling_task = embed`. + Please use one of the more specific methods instead of `encode`: - For embeddings, use `LLM.embed(...)`. - For classification logits, use `LLM.classify(...)`. - For reward scores, use `LLM.reward(...)`. From 8ed7dc91c49548cdb918f598d180aa75d5046198 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 16:40:57 +0800 Subject: [PATCH 07/23] fix Signed-off-by: wang.yuqi --- docs/models/pooling_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 4ff1810d7612..88e165665e50 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -68,7 +68,7 @@ the pooler assigned to each task has the following attributes by default: | Task | Pooling Type | Normalization | Softmax | |------------|----------------|---------------|---------| -| `encode` | `ALL` | ❌ | ❌ | +| `reward` | `ALL` | ❌ | ❌ | | `embed` | `LAST` | ✅︎ | ❌ | | `classify` | `LAST` | ❌ | ✅︎ | From 26d6f8c16949a5f8f5d134508b22e110dda22218 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 17:09:54 +0800 Subject: [PATCH 08/23] + encoder_config check Signed-off-by: wang.yuqi --- docs/models/pooling_models.md | 1 - vllm/entrypoints/llm.py | 12 ++++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 88e165665e50..e10b09a8b98d 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -160,7 +160,6 @@ The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. It returns the extracted hidden states directly. !!! note - `LLM.encode` defaults to using `pooling_task = embed`. Please use one of the more specific methods instead of `encode`: - For embeddings, use `LLM.embed(...)`. - For classification logits, use `LLM.classify(...)`. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 254601117ced..8d33be12ecef 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1093,15 +1093,19 @@ def encode( instead pass them via the `inputs` parameter. """ if pooling_task is None: - logger.warning( - "`LLM.encode` defaults to using `pooling_task = embed`.\n" + if self.llm_engine.model_config.encoder_config is not None: + pooling_task = "embed" + else: + pooling_task = "encode" + + logger.warning_once( + "`LLM.encode` is currently using `pooling_task = %s`.\n" "Please use one of the more specific methods instead " "of `encode`:\n" " - For embeddings, use `LLM.embed(...)`.\n" " - For classification logits, use `LLM.classify(...)`.\n" " - For reward scores, use `LLM.reward(...)`.\n" - " - For similarity scores, use `LLM.score(...)`.") - pooling_task = "embed" + " - For similarity scores, use `LLM.score(...)`.", pooling_task) model_config = self.llm_engine.model_config runner_type = model_config.runner_type From 6e69ae827396f40d3324cb710b44af3c76d059a0 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 18:34:14 +0800 Subject: [PATCH 09/23] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8d33be12ecef..dbdd6f11b860 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1105,7 +1105,8 @@ def encode( " - For embeddings, use `LLM.embed(...)`.\n" " - For classification logits, use `LLM.classify(...)`.\n" " - For reward scores, use `LLM.reward(...)`.\n" - " - For similarity scores, use `LLM.score(...)`.", pooling_task) + " - For similarity scores, use `LLM.score(...)`.", + pooling_task) model_config = self.llm_engine.model_config runner_type = model_config.runner_type From 900ffcca5876bc5458bffaf7a9249193a0217458 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 28 Jul 2025 20:11:05 +0800 Subject: [PATCH 10/23] + check if "embed" in self.supported_tasks Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index dbdd6f11b860..3beae3b87edc 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1093,7 +1093,7 @@ def encode( instead pass them via the `inputs` parameter. """ if pooling_task is None: - if self.llm_engine.model_config.encoder_config is not None: + if "embed" in self.supported_tasks: pooling_task = "embed" else: pooling_task = "encode" From 5bbaa22110bbf02cf3c3a487569b876202d0abbe Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 29 Jul 2025 08:57:32 +0800 Subject: [PATCH 11/23] fix Signed-off-by: wang.yuqi --- docs/models/pooling_models.md | 4 +- examples/offline_inference/basic/embed.py | 3 +- examples/offline_inference/basic/reward.py | 53 ++++++++++++++++++++++ vllm/entrypoints/llm.py | 2 +- 4 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 examples/offline_inference/basic/reward.py diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index e10b09a8b98d..47f0bd3543aa 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -154,6 +154,8 @@ data = output.outputs.data print(f"Data: {data!r}") ``` +A code example can be found here: + ### `LLM.encode` The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM. @@ -163,7 +165,7 @@ It returns the extracted hidden states directly. Please use one of the more specific methods instead of `encode`: - For embeddings, use `LLM.embed(...)`. - For classification logits, use `LLM.classify(...)`. - - For reward scores, use `LLM.reward(...)`. + - For rewards, use `LLM.reward(...)`. - For similarity scores, use `LLM.score(...)`. ```python diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index 526753bcef22..158836728bee 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -12,10 +12,9 @@ def parse_args(): parser = EngineArgs.add_cli_args(parser) # Set example specific arguments parser.set_defaults( - model="intfloat/e5-mistral-7b-instruct", + model="intfloat/e5-small", runner="pooling", enforce_eager=True, - max_model_len=1024, ) return parser.parse_args() diff --git a/examples/offline_inference/basic/reward.py b/examples/offline_inference/basic/reward.py new file mode 100644 index 000000000000..55756777d0ca --- /dev/null +++ b/examples/offline_inference/basic/reward.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="internlm/internlm2-1_8b-reward", + runner="pooling", + enforce_eager=True, + max_model_len=1024, + trust_remote_code=True + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create an LLM. + # You should pass runner="pooling" for reward models + llm = LLM(**vars(args)) + + # Generate rewards. The output is a list of PoolingRequestOutput. + outputs = llm.reward(prompts) + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + rewards = output.outputs.data + rewards_trimmed = ( + (str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards + ) + print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})") + print("-" * 60) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3beae3b87edc..61bf05fcdf36 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1104,7 +1104,7 @@ def encode( "of `encode`:\n" " - For embeddings, use `LLM.embed(...)`.\n" " - For classification logits, use `LLM.classify(...)`.\n" - " - For reward scores, use `LLM.reward(...)`.\n" + " - For rewards, use `LLM.reward(...)`.\n" " - For similarity scores, use `LLM.score(...)`.", pooling_task) From 4faeee574e6fadb326aeda444424695a3b0e2918 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 29 Jul 2025 09:05:56 +0800 Subject: [PATCH 12/23] fix Signed-off-by: wang.yuqi --- examples/offline_inference/basic/reward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/basic/reward.py b/examples/offline_inference/basic/reward.py index 55756777d0ca..aa173cf96f5b 100644 --- a/examples/offline_inference/basic/reward.py +++ b/examples/offline_inference/basic/reward.py @@ -16,7 +16,7 @@ def parse_args(): runner="pooling", enforce_eager=True, max_model_len=1024, - trust_remote_code=True + trust_remote_code=True, ) return parser.parse_args() From 731cb1be267dc55b4bd9915c73ae160237c2ed3e Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 29 Jul 2025 09:16:08 +0800 Subject: [PATCH 13/23] fix Signed-off-by: wang.yuqi --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 61bf05fcdf36..0d66cb02122e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1258,7 +1258,7 @@ def reward( lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[PoolingRequestOutput]: """ - Generate reward scores for each prompt. + Generate rewards for each prompt. Args: prompts: The prompts to the LLM. You may pass a sequence of prompts From 8d10bd1d9bcd0c24e484b3643d0bca448fe59bd4 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 29 Jul 2025 09:51:35 +0800 Subject: [PATCH 14/23] fix Signed-off-by: wang.yuqi --- docs/models/pooling_models.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 47f0bd3543aa..975314fe7ced 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -66,11 +66,11 @@ you can override some of its attributes via the `--override-pooler-config` optio If the model has been converted via `--convert` (see above), the pooler assigned to each task has the following attributes by default: -| Task | Pooling Type | Normalization | Softmax | -|------------|----------------|---------------|---------| -| `reward` | `ALL` | ❌ | ❌ | -| `embed` | `LAST` | ✅︎ | ❌ | -| `classify` | `LAST` | ❌ | ✅︎ | +| Task | Pooling Type | Normalization | Softmax | +|------------|--------------|---------------|---------| +| `reward` | `ALL` | ❌ | ❌ | +| `embed` | `LAST` | ✅︎ | ❌ | +| `classify` | `LAST` | ❌ | ✅︎ | When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults. @@ -141,8 +141,8 @@ A code example can be found here: