@@ -1037,7 +1037,7 @@ def encode(
1037
1037
truncate_prompt_tokens : Optional [int ] = None ,
1038
1038
use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
1039
1039
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1040
- pooling_task : PoolingTask = "encode" ,
1040
+ pooling_task : Optional [ PoolingTask ] = None ,
1041
1041
tokenization_kwargs : Optional [dict [str , Any ]] = None ,
1042
1042
) -> list [PoolingRequestOutput ]:
1043
1043
"""Apply pooling to the hidden states corresponding to the input
@@ -1069,6 +1069,25 @@ def encode(
1069
1069
considered legacy and may be deprecated in the future. You should
1070
1070
instead pass them via the `inputs` parameter.
1071
1071
"""
1072
+ if pooling_task is None :
1073
+ if "embed" in self .supported_tasks :
1074
+ pooling_task = "embed"
1075
+ else :
1076
+ pooling_task = "encode"
1077
+
1078
+ logger .warning_once (
1079
+ "`LLM.encode` is currently using `pooling_task = %s`.\n "
1080
+ "Please use one of the more specific methods or set the "
1081
+ "task directly when using `LLM.encode`:\n "
1082
+ " - For embeddings, use `LLM.embed(...)` "
1083
+ "or `pooling_task=\" embed\" `.\n "
1084
+ " - For classification logits, use `LLM.classify(...)` "
1085
+ "or `pooling_task=\" classify\" `.\n "
1086
+ " - For rewards, use `LLM.reward(...)` "
1087
+ "or `pooling_task=\" reward\" `\n "
1088
+ " - For similarity scores, use `LLM.score(...)`." ,
1089
+ pooling_task )
1090
+
1072
1091
model_config = self .llm_engine .model_config
1073
1092
runner_type = model_config .runner_type
1074
1093
if runner_type != "pooling" :
@@ -1207,6 +1226,45 @@ def classify(
1207
1226
1208
1227
return [ClassificationRequestOutput .from_base (item ) for item in items ]
1209
1228
1229
+ def reward (
1230
+ self ,
1231
+ prompts : Union [PromptType , Sequence [PromptType ]],
1232
+ / ,
1233
+ * ,
1234
+ truncate_prompt_tokens : Optional [int ] = None ,
1235
+ use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
1236
+ pooling_params : Optional [Union [PoolingParams ,
1237
+ Sequence [PoolingParams ]]] = None ,
1238
+ lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1239
+ ) -> list [PoolingRequestOutput ]:
1240
+ """
1241
+ Generate rewards for each prompt.
1242
+
1243
+ Args:
1244
+ prompts: The prompts to the LLM. You may pass a sequence of prompts
1245
+ for batch inference. See [PromptType][vllm.inputs.PromptType]
1246
+ for more details about the format of each prompts.
1247
+ use_tqdm: If `True`, shows a tqdm progress bar.
1248
+ If a callable (e.g., `functools.partial(tqdm, leave=False)`),
1249
+ it is used to create the progress bar.
1250
+ If `False`, no progress bar is created.
1251
+ lora_request: LoRA request to use for generation, if any.
1252
+ pooling_params: The pooling parameters for pooling. If None, we
1253
+ use the default pooling parameters.
1254
+ Returns:
1255
+ A list of `PoolingRequestOutput` objects containing the
1256
+ pooled hidden states in the same order as the input prompts.
1257
+ """
1258
+
1259
+ return self .encode (
1260
+ prompts ,
1261
+ use_tqdm = use_tqdm ,
1262
+ lora_request = lora_request ,
1263
+ pooling_params = pooling_params ,
1264
+ truncate_prompt_tokens = truncate_prompt_tokens ,
1265
+ pooling_task = "encode" ,
1266
+ )
1267
+
1210
1268
def _embedding_score (
1211
1269
self ,
1212
1270
tokenizer : AnyTokenizer ,
0 commit comments