@@ -1060,7 +1060,7 @@ def encode(
1060
1060
truncate_prompt_tokens : Optional [int ] = None ,
1061
1061
use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
1062
1062
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1063
- pooling_task : PoolingTask = "encode" ,
1063
+ pooling_task : Optional [ PoolingTask ] = None ,
1064
1064
tokenization_kwargs : Optional [dict [str , Any ]] = None ,
1065
1065
) -> list [PoolingRequestOutput ]:
1066
1066
"""Apply pooling to the hidden states corresponding to the input
@@ -1092,6 +1092,14 @@ def encode(
1092
1092
considered legacy and may be deprecated in the future. You should
1093
1093
instead pass them via the `inputs` parameter.
1094
1094
"""
1095
+ if pooling_task is None :
1096
+ raise ValueError (
1097
+ "`pooling_task` must be specified. "
1098
+ "Get embedding prefer `LLm.embed`. "
1099
+ "Get classification logits prefer to `LLm.classify`. "
1100
+ "Get reward scores prefer `LLm.reward`. "
1101
+ "Get pairs similarity scores prefer `LLm.score`." )
1102
+
1095
1103
model_config = self .llm_engine .model_config
1096
1104
runner_type = model_config .runner_type
1097
1105
if runner_type != "pooling" :
@@ -1230,6 +1238,45 @@ def classify(
1230
1238
1231
1239
return [ClassificationRequestOutput .from_base (item ) for item in items ]
1232
1240
1241
+ def reward (
1242
+ self ,
1243
+ prompts : Union [PromptType , Sequence [PromptType ]],
1244
+ / ,
1245
+ * ,
1246
+ truncate_prompt_tokens : Optional [int ] = None ,
1247
+ use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
1248
+ pooling_params : Optional [Union [PoolingParams ,
1249
+ Sequence [PoolingParams ]]] = None ,
1250
+ lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1251
+ ) -> list [PoolingRequestOutput ]:
1252
+ """
1253
+ Generate reward scores for each prompt.
1254
+
1255
+ Args:
1256
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
1257
+ for batch inference. See [PromptType][vllm.inputs.PromptType]
1258
+ for more details about the format of each prompts.
1259
+ use_tqdm: If `True`, shows a tqdm progress bar.
1260
+ If a callable (e.g., `functools.partial(tqdm, leave=False)`),
1261
+ it is used to create the progress bar.
1262
+ If `False`, no progress bar is created.
1263
+ lora_request: LoRA request to use for generation, if any.
1264
+ pooling_params: The pooling parameters for pooling. If None, we
1265
+ use the default pooling parameters.
1266
+ Returns:
1267
+ A list of `PoolingRequestOutput` objects containing the
1268
+ pooled hidden states in the same order as the input prompts.
1269
+ """
1270
+
1271
+ return self .encode (
1272
+ prompts ,
1273
+ use_tqdm = use_tqdm ,
1274
+ lora_request = lora_request ,
1275
+ pooling_params = pooling_params ,
1276
+ truncate_prompt_tokens = truncate_prompt_tokens ,
1277
+ pooling_task = "encode" ,
1278
+ )
1279
+
1233
1280
def _embedding_score (
1234
1281
self ,
1235
1282
tokenizer : AnyTokenizer ,
0 commit comments