@@ -56,10 +56,11 @@ class EngineClient:
5656 EngineClient is a class that handles the communication between the client and the server.
5757 """
5858
59- def __init__ (self , pid : int | str , port : int | str , fd_config : FDConfig , workers : int = 1 ):
59+ def __init__ (self , pid : int | str , port : int | str , fd_config : FDConfig , workers : int = 1 , max_logprobs : int = 20 ):
6060 self .fd_config = fd_config
6161 self .tensor_parallel_size = self .fd_config .parallel_config .tensor_parallel_size
6262 self .enable_mm = self .fd_config .model_config .enable_mm
63+ self .max_logprobs = max_logprobs
6364 input_processor = InputPreprocessor (
6465 self .fd_config .model_config ,
6566 self .fd_config .structured_outputs_config .reasoning_parser ,
@@ -70,6 +71,11 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers
7071 )
7172 self .enable_logprob = self .fd_config .model_config .enable_logprob
7273 self .data_processor = input_processor .create_processor ()
74+ self .ori_vocab_size = (
75+ len (self .data_processor .tokenizer .sp_model )
76+ if hasattr (self .data_processor .tokenizer , "sp_model" )
77+ else len (self .data_processor .tokenizer .vocab )
78+ )
7379 self .max_model_len = self .fd_config .model_config .max_model_len
7480 self .enable_prefix_caching = self .fd_config .cache_config .enable_prefix_caching
7581 self .enable_splitwise = self .fd_config .scheduler_config .splitwise_role != "mixed"
@@ -424,6 +430,53 @@ def valid_parameters(self, data):
424430 elif logprobs :
425431 raise ParameterError ("logprobs" , "Invalid type for 'logprobs'" )
426432
433+ max_logprobs = self .max_logprobs
434+ if max_logprobs == - 1 :
435+ max_logprobs = self .ori_vocab_size
436+ if max_logprobs < - 1 :
437+ err_msg = f"Invalid 'max_logprobs': must be >= -1, got { max_logprobs } ."
438+ api_server_logger .error (err_msg )
439+ raise ValueError ("max_logprobs" , err_msg )
440+ if max_logprobs > self .ori_vocab_size :
441+ err_msg = f"Invalid 'max_logprobs': must be <= vocab_size { self .ori_vocab_size } , got { max_logprobs } ."
442+ api_server_logger .error (err_msg )
443+ raise ValueError ("max_logprobs" , err_msg )
444+
445+ prompt_logprobs = data .get ("prompt_logprobs" , None )
446+
447+ if prompt_logprobs is not None :
448+ if not self .enable_logprob :
449+ err_msg = "`enable_logprob` is disabled, please enable it in startup config."
450+ api_server_logger .error (err_msg )
451+ raise ParameterError ("prompt_logprobs" , err_msg )
452+
453+ if not envs .FD_USE_GET_SAVE_OUTPUT_V1 :
454+ err_msg = "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled."
455+ api_server_logger .error (err_msg )
456+ raise ParameterError ("prompt_logprobs" , err_msg )
457+
458+ if self .enable_prefix_caching :
459+ err_msg = "prompt_logprobs is not support when prefix caching is enabled."
460+ api_server_logger .error (err_msg )
461+ raise ParameterError ("prompt_logprobs" , err_msg )
462+
463+ if prompt_logprobs == - 1 and self .ori_vocab_size > max_logprobs :
464+ err_msg = f"The requested value of ({ self .ori_vocab_size } ) for prompt_logprobs (-1) exceeds the maximum allowed value of ({ max_logprobs } )"
465+ api_server_logger .error (err_msg )
466+ raise ValueError ("prompt_logprobs" , err_msg )
467+
468+ if prompt_logprobs < - 1 :
469+ err_msg = (
470+ f"prompt_logprobs must be a non-negative value or -1; the current value is { prompt_logprobs } ."
471+ )
472+ api_server_logger .error (err_msg )
473+ raise ValueError ("prompt_logprobs" , err_msg )
474+
475+ if prompt_logprobs > max_logprobs :
476+ err_msg = f"Number of prompt_logprobs requested ({ prompt_logprobs } ) exceeds maximum allowed value ({ max_logprobs } )."
477+ api_server_logger .error (err_msg )
478+ raise ValueError ("prompt_logprobs" , err_msg )
479+
427480 # enable_logprob
428481 if top_logprobs :
429482 if not self .enable_logprob :
@@ -437,15 +490,26 @@ def valid_parameters(self, data):
437490 api_server_logger .error (err_msg )
438491 raise ParameterError ("top_logprobs" , err_msg )
439492
440- if top_logprobs < 0 :
441- err_msg = f"Invalid 'top_logprobs': must be >= 0, got { top_logprobs } ."
442- api_server_logger .error (err_msg )
443- raise ParameterError ("top_logprobs" , err_msg )
444-
445- if top_logprobs > 20 :
446- err_msg = "Invalid value for 'top_logprobs': must be <= 20."
447- api_server_logger .error (err_msg )
448- raise ParameterError ("top_logprobs" , err_msg )
493+ if not envs .FD_USE_GET_SAVE_OUTPUT_V1 :
494+ if top_logprobs < 0 or top_logprobs > 20 :
495+ err_msg = f"top_logprobs must be between 0 and 20; the current value is { top_logprobs } ."
496+ api_server_logger .error (err_msg )
497+ raise ValueError ("top_logprobs" , err_msg )
498+ else :
499+ if top_logprobs == - 1 and self .ori_vocab_size > max_logprobs :
500+ err_msg = f"The requested value of ({ self .ori_vocab_size } ) for top_logprobs (-1) exceeds the maximum allowed value of ({ max_logprobs } )"
501+ api_server_logger .error (err_msg )
502+ raise ValueError ("top_logprobs" , err_msg )
503+
504+ if top_logprobs < - 1 :
505+ err_msg = f"top_logprobs must be a non-negative value or -1; the current value is { top_logprobs } ."
506+ api_server_logger .error (err_msg )
507+ raise ValueError ("top_logprobs" , err_msg )
508+
509+ if top_logprobs > max_logprobs :
510+ err_msg = f"Number of logprobs requested ({ top_logprobs } ) exceeds maximum allowed value ({ max_logprobs } )."
511+ api_server_logger .error (err_msg )
512+ raise ValueError ("top_logprobs" , err_msg )
449513
450514 def check_health (self , time_interval_threashold = 30 ):
451515 """
0 commit comments