Skip to content

Commit d8f7dbb

Browse files
committed
+ reward
1 parent e626d28 commit d8f7dbb

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

vllm/entrypoints/llm.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,7 @@ def encode(
10601060
truncate_prompt_tokens: Optional[int] = None,
10611061
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
10621062
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1063-
pooling_task: PoolingTask = "encode",
1063+
pooling_task: Optional[PoolingTask] = None,
10641064
tokenization_kwargs: Optional[dict[str, Any]] = None,
10651065
) -> list[PoolingRequestOutput]:
10661066
"""Apply pooling to the hidden states corresponding to the input
@@ -1092,6 +1092,14 @@ def encode(
10921092
considered legacy and may be deprecated in the future. You should
10931093
instead pass them via the `inputs` parameter.
10941094
"""
1095+
if pooling_task is None:
1096+
raise ValueError(
1097+
"`pooling_task` must be specified. "
1098+
"Get embedding prefer `LLm.embed`. "
1099+
"Get classification logits prefer to `LLm.classify`. "
1100+
"Get reward scores prefer `LLm.reward`. "
1101+
"Get pairs similarity scores prefer `LLm.score`.")
1102+
10951103
model_config = self.llm_engine.model_config
10961104
runner_type = model_config.runner_type
10971105
if runner_type != "pooling":
@@ -1230,6 +1238,45 @@ def classify(
12301238

12311239
return [ClassificationRequestOutput.from_base(item) for item in items]
12321240

1241+
def reward(
1242+
self,
1243+
prompts: Union[PromptType, Sequence[PromptType]],
1244+
/,
1245+
*,
1246+
truncate_prompt_tokens: Optional[int] = None,
1247+
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1248+
pooling_params: Optional[Union[PoolingParams,
1249+
Sequence[PoolingParams]]] = None,
1250+
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1251+
) -> list[PoolingRequestOutput]:
1252+
"""
1253+
Generate reward scores for each prompt.
1254+
1255+
Args:
1256+
prompts: The prompts to the LLM. You may pass a sequence of prompts
1257+
for batch inference. See [PromptType][vllm.inputs.PromptType]
1258+
for more details about the format of each prompts.
1259+
use_tqdm: If `True`, shows a tqdm progress bar.
1260+
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
1261+
it is used to create the progress bar.
1262+
If `False`, no progress bar is created.
1263+
lora_request: LoRA request to use for generation, if any.
1264+
pooling_params: The pooling parameters for pooling. If None, we
1265+
use the default pooling parameters.
1266+
Returns:
1267+
A list of `PoolingRequestOutput` objects containing the
1268+
pooled hidden states in the same order as the input prompts.
1269+
"""
1270+
1271+
return self.encode(
1272+
prompts,
1273+
use_tqdm=use_tqdm,
1274+
lora_request=lora_request,
1275+
pooling_params=pooling_params,
1276+
truncate_prompt_tokens=truncate_prompt_tokens,
1277+
pooling_task="encode",
1278+
)
1279+
12331280
def _embedding_score(
12341281
self,
12351282
tokenizer: AnyTokenizer,

0 commit comments

Comments
 (0)