Skip to content

Commit ffe61a0

Browse files
noooopjinzhen-lin
authored andcommitted
[Frontend] Add LLM.reward specific to reward models (vllm-project#21720)
Signed-off-by: wang.yuqi <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
1 parent c5c747c commit ffe61a0

File tree

7 files changed

+174
-35
lines changed

7 files changed

+174
-35
lines changed

docs/models/pooling_models.md

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
4545
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
4646
enabling the corresponding APIs:
4747

48-
| Task | APIs |
49-
|------------|--------------------|
50-
| `encode` | `encode` |
51-
| `embed` | `embed`, `score`\* |
52-
| `classify` | `classify` |
53-
| `score` | `score` |
48+
| Task | APIs |
49+
|------------|--------------------------------------|
50+
| `encode` | `LLM.reward(...)` |
51+
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* |
52+
| `classify` | `LLM.classify(...)` |
53+
| `score` | `LLM.score(...)` |
5454

55-
\* The `score` API falls back to `embed` task if the model does not support `score` task.
55+
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
5656

5757
### Pooler Configuration
5858

@@ -66,11 +66,11 @@ you can override some of its attributes via the `--override-pooler-config` optio
6666
If the model has been converted via `--convert` (see above),
6767
the pooler assigned to each task has the following attributes by default:
6868

69-
| Task | Pooling Type | Normalization | Softmax |
70-
|------------|----------------|---------------|---------|
71-
| `encode` | `ALL` || |
72-
| `embed` | `LAST` | ✅︎ ||
73-
| `classify` | `LAST` || ✅︎ |
69+
| Task | Pooling Type | Normalization | Softmax |
70+
|------------|--------------|---------------|---------|
71+
| `reward` | `ALL` |||
72+
| `embed` | `LAST` | ✅︎ ||
73+
| `classify` | `LAST` || ✅︎ |
7474

7575
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
7676
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
@@ -83,21 +83,6 @@ which takes priority over both the model's and Sentence Transformers's defaults.
8383
The [LLM][vllm.LLM] class provides various methods for offline inference.
8484
See [configuration][configuration] for a list of options when initializing the model.
8585

86-
### `LLM.encode`
87-
88-
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
89-
It returns the extracted hidden states directly, which is useful for reward models.
90-
91-
```python
92-
from vllm import LLM
93-
94-
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", runner="pooling")
95-
(output,) = llm.encode("Hello, my name is")
96-
97-
data = output.outputs.data
98-
print(f"Data: {data!r}")
99-
```
100-
10186
### `LLM.embed`
10287

10388
The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt.
@@ -106,7 +91,7 @@ It is primarily designed for embedding models.
10691
```python
10792
from vllm import LLM
10893

109-
llm = LLM(model="intfloat/e5-mistral-7b-instruct", runner="pooling")
94+
llm = LLM(model="intfloat/e5-small", runner="pooling")
11095
(output,) = llm.embed("Hello, my name is")
11196

11297
embeds = output.outputs.embedding
@@ -154,6 +139,46 @@ print(f"Score: {score}")
154139

155140
A code example can be found here: <gh-file:examples/offline_inference/basic/score.py>
156141

142+
### `LLM.reward`
143+
144+
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
145+
It returns the extracted hidden states directly.
146+
147+
```python
148+
from vllm import LLM
149+
150+
llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True)
151+
(output,) = llm.reward("Hello, my name is")
152+
153+
data = output.outputs.data
154+
print(f"Data: {data!r}")
155+
```
156+
157+
A code example can be found here: <gh-file:examples/offline_inference/basic/reward.py>
158+
159+
### `LLM.encode`
160+
161+
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
162+
It returns the extracted hidden states directly.
163+
164+
!!! note
165+
Please use one of the more specific methods or set the task directly when using `LLM.encode`:
166+
167+
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
168+
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
169+
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
170+
- For similarity scores, use `LLM.score(...)`.
171+
172+
```python
173+
from vllm import LLM
174+
175+
llm = LLM(model="intfloat/e5-small", runner="pooling")
176+
(output,) = llm.encode("Hello, my name is", pooling_task="embed")
177+
178+
data = output.outputs.data
179+
print(f"Data: {data!r}")
180+
```
181+
157182
## Online Serving
158183

159184
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:

examples/offline_inference/basic/embed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ def parse_args():
1212
parser = EngineArgs.add_cli_args(parser)
1313
# Set example specific arguments
1414
parser.set_defaults(
15-
model="intfloat/e5-mistral-7b-instruct",
15+
model="intfloat/e5-small",
1616
runner="pooling",
1717
enforce_eager=True,
18-
max_model_len=1024,
1918
)
2019
return parser.parse_args()
2120

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from argparse import Namespace
5+
6+
from vllm import LLM, EngineArgs
7+
from vllm.utils import FlexibleArgumentParser
8+
9+
10+
def parse_args():
11+
parser = FlexibleArgumentParser()
12+
parser = EngineArgs.add_cli_args(parser)
13+
# Set example specific arguments
14+
parser.set_defaults(
15+
model="internlm/internlm2-1_8b-reward",
16+
runner="pooling",
17+
enforce_eager=True,
18+
max_model_len=1024,
19+
trust_remote_code=True,
20+
)
21+
return parser.parse_args()
22+
23+
24+
def main(args: Namespace):
25+
# Sample prompts.
26+
prompts = [
27+
"Hello, my name is",
28+
"The president of the United States is",
29+
"The capital of France is",
30+
"The future of AI is",
31+
]
32+
33+
# Create an LLM.
34+
# You should pass runner="pooling" for reward models
35+
llm = LLM(**vars(args))
36+
37+
# Generate rewards. The output is a list of PoolingRequestOutput.
38+
outputs = llm.reward(prompts)
39+
40+
# Print the outputs.
41+
print("\nGenerated Outputs:\n" + "-" * 60)
42+
for prompt, output in zip(prompts, outputs):
43+
rewards = output.outputs.data
44+
rewards_trimmed = (
45+
(str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards
46+
)
47+
print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})")
48+
print("-" * 60)
49+
50+
51+
if __name__ == "__main__":
52+
args = parse_args()
53+
main(args)

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,10 @@ def encode(self, prompts: list[str]) -> list[list[float]]:
10531053
req_outputs = self.llm.encode(prompts)
10541054
return [req_output.outputs.data for req_output in req_outputs]
10551055

1056+
def reward(self, prompts: list[str]) -> list[list[float]]:
1057+
req_outputs = self.llm.reward(prompts)
1058+
return [req_output.outputs.data for req_output in req_outputs]
1059+
10561060
def score(
10571061
self,
10581062
text_1: Union[str, list[str]],

tests/models/language/pooling/test_reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_prm_models(
9595
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
9696

9797
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
98-
vllm_outputs = vllm_model.encode(math_step_prompts)
98+
vllm_outputs = vllm_model.reward(math_step_prompts)
9999

100100
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
101101
hf_model = step_reward_patch_hf_model(hf_model)

tests/models/language/pooling/test_truncation_control.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner,
2828

2929
with vllm_runner(model_name, runner="pooling",
3030
max_model_len=max_model_len) as vllm_model:
31-
vllm_output = vllm_model.llm.encode(
31+
vllm_output = vllm_model.llm.embed(
3232
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
3333

3434
prompt_tokens = vllm_output[0].prompt_token_ids
@@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner,
4343

4444
with vllm_runner(model_name, runner="pooling",
4545
max_model_len=max_model_len) as vllm_model:
46-
vllm_output = vllm_model.llm.encode(
46+
vllm_output = vllm_model.llm.embed(
4747
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
4848

4949
prompt_tokens = vllm_output[0].prompt_token_ids
@@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner,
6161
model_name, runner="pooling",
6262
max_model_len=max_model_len) as vllm_model:
6363

64-
llm_output = vllm_model.llm.encode(
64+
llm_output = vllm_model.llm.embed(
6565
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
6666

6767
assert llm_output == f"""truncate_prompt_tokens value

vllm/entrypoints/llm.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ def encode(
10371037
truncate_prompt_tokens: Optional[int] = None,
10381038
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
10391039
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1040-
pooling_task: PoolingTask = "encode",
1040+
pooling_task: Optional[PoolingTask] = None,
10411041
tokenization_kwargs: Optional[dict[str, Any]] = None,
10421042
) -> list[PoolingRequestOutput]:
10431043
"""Apply pooling to the hidden states corresponding to the input
@@ -1069,6 +1069,25 @@ def encode(
10691069
considered legacy and may be deprecated in the future. You should
10701070
instead pass them via the `inputs` parameter.
10711071
"""
1072+
if pooling_task is None:
1073+
if "embed" in self.supported_tasks:
1074+
pooling_task = "embed"
1075+
else:
1076+
pooling_task = "encode"
1077+
1078+
logger.warning_once(
1079+
"`LLM.encode` is currently using `pooling_task = %s`.\n"
1080+
"Please use one of the more specific methods or set the "
1081+
"task directly when using `LLM.encode`:\n"
1082+
" - For embeddings, use `LLM.embed(...)` "
1083+
"or `pooling_task=\"embed\"`.\n"
1084+
" - For classification logits, use `LLM.classify(...)` "
1085+
"or `pooling_task=\"classify\"`.\n"
1086+
" - For rewards, use `LLM.reward(...)` "
1087+
"or `pooling_task=\"reward\"`\n"
1088+
" - For similarity scores, use `LLM.score(...)`.",
1089+
pooling_task)
1090+
10721091
model_config = self.llm_engine.model_config
10731092
runner_type = model_config.runner_type
10741093
if runner_type != "pooling":
@@ -1207,6 +1226,45 @@ def classify(
12071226

12081227
return [ClassificationRequestOutput.from_base(item) for item in items]
12091228

1229+
def reward(
1230+
self,
1231+
prompts: Union[PromptType, Sequence[PromptType]],
1232+
/,
1233+
*,
1234+
truncate_prompt_tokens: Optional[int] = None,
1235+
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1236+
pooling_params: Optional[Union[PoolingParams,
1237+
Sequence[PoolingParams]]] = None,
1238+
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1239+
) -> list[PoolingRequestOutput]:
1240+
"""
1241+
Generate rewards for each prompt.
1242+
1243+
Args:
1244+
prompts: The prompts to the LLM. You may pass a sequence of prompts
1245+
for batch inference. See [PromptType][vllm.inputs.PromptType]
1246+
for more details about the format of each prompts.
1247+
use_tqdm: If `True`, shows a tqdm progress bar.
1248+
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
1249+
it is used to create the progress bar.
1250+
If `False`, no progress bar is created.
1251+
lora_request: LoRA request to use for generation, if any.
1252+
pooling_params: The pooling parameters for pooling. If None, we
1253+
use the default pooling parameters.
1254+
Returns:
1255+
A list of `PoolingRequestOutput` objects containing the
1256+
pooled hidden states in the same order as the input prompts.
1257+
"""
1258+
1259+
return self.encode(
1260+
prompts,
1261+
use_tqdm=use_tqdm,
1262+
lora_request=lora_request,
1263+
pooling_params=pooling_params,
1264+
truncate_prompt_tokens=truncate_prompt_tokens,
1265+
pooling_task="encode",
1266+
)
1267+
12101268
def _embedding_score(
12111269
self,
12121270
tokenizer: AnyTokenizer,

0 commit comments

Comments
 (0)