@@ -965,6 +965,7 @@ def encode(
965965 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
966966 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
967967 pooling_task : PoolingTask = "encode" ,
968+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
968969 ) -> list [PoolingRequestOutput ]:
969970 ...
970971
@@ -981,6 +982,7 @@ def encode(
981982 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
982983 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
983984 pooling_task : PoolingTask = "encode" ,
985+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
984986 ) -> list [PoolingRequestOutput ]:
985987 ...
986988
@@ -997,6 +999,7 @@ def encode(
997999 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
9981000 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
9991001 pooling_task : PoolingTask = "encode" ,
1002+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
10001003 ) -> list [PoolingRequestOutput ]:
10011004 ...
10021005
@@ -1014,6 +1017,7 @@ def encode(
10141017 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10151018 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
10161019 pooling_task : PoolingTask = "encode" ,
1020+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
10171021 ) -> list [PoolingRequestOutput ]:
10181022 ...
10191023
@@ -1031,6 +1035,7 @@ def encode(
10311035 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10321036 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
10331037 pooling_task : PoolingTask = "encode" ,
1038+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
10341039 ) -> list [PoolingRequestOutput ]:
10351040 ...
10361041
@@ -1046,6 +1051,7 @@ def encode(
10461051 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10471052 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
10481053 pooling_task : PoolingTask = "encode" ,
1054+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
10491055 ) -> list [PoolingRequestOutput ]:
10501056 ...
10511057
@@ -1066,6 +1072,7 @@ def encode(
10661072 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
10671073 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
10681074 pooling_task : PoolingTask = "encode" ,
1075+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
10691076 ) -> list [PoolingRequestOutput ]:
10701077 """Apply pooling to the hidden states corresponding to the input
10711078 prompts.
@@ -1131,9 +1138,11 @@ def encode(
11311138 for pooling_param in pooling_params :
11321139 pooling_param .verify (pooling_task , model_config )
11331140
1134- tokenization_kwargs = dict [str , Any ]()
1135- _validate_truncation_size (model_config .max_model_len ,
1136- truncate_prompt_tokens , tokenization_kwargs )
1141+ if tokenization_kwargs is None :
1142+ tokenization_kwargs = dict [str , Any ]()
1143+ _validate_truncation_size (model_config .max_model_len ,
1144+ truncate_prompt_tokens ,
1145+ tokenization_kwargs )
11371146
11381147 self ._validate_and_add_requests (
11391148 prompts = parsed_prompts ,
0 commit comments