Skip to content

Commit 117980d

Browse files
authored
[LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)
* add prompt logprobs * Merge prompt_logprobs_tensors and prompt_logprobs * fix param check * trigger ci * fix unitest * fix logprobs bug
1 parent af39819 commit 117980d

27 files changed

+4950
-236
lines changed

fastdeploy/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,15 @@ def __init__(
229229
self.think_end_id = args.get("think_end_id", -1)
230230
self.im_patch_id = args.get("image_patch_id", -1)
231231
self.line_break_id = args.get("line_break_id", -1)
232-
if self.max_logprobs < -1:
232+
233+
num_max_logprobs = args.get("max_logprobs", None)
234+
if num_max_logprobs is not None and num_max_logprobs < -1:
233235
raise ValueError(" The possible values for max_logprobs can't be less than -1 ")
236+
if self.ori_vocab_size is not None and num_max_logprobs is not None:
237+
if num_max_logprobs > self.ori_vocab_size:
238+
raise ValueError(
239+
f" The possible values for max_logprobs can't be greater than the vocabulary size {self.ori_vocab_size}"
240+
)
234241

235242
self._post_init()
236243

fastdeploy/engine/request.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@
3131
from fastdeploy.engine.sampling_params import SamplingParams
3232
from fastdeploy.entrypoints.openai.protocol import ToolCall
3333
from fastdeploy.utils import data_processor_logger
34-
from fastdeploy.worker.output import (
35-
LogprobsLists,
36-
LogprobsTensors,
37-
PromptLogprobs,
38-
SampleLogprobs,
39-
)
34+
from fastdeploy.worker.output import LogprobsLists, PromptLogprobs, SampleLogprobs
4035

4136

4237
class RequestStatus(Enum):
@@ -519,7 +514,6 @@ def __init__(
519514
prompt: Optional[str] = None,
520515
prompt_token_ids: Optional[list[int]] = None,
521516
prompt_logprobs: Optional[PromptLogprobs] = None,
522-
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
523517
output_type: Optional[int] = 3,
524518
outputs: CompletionOutput = None,
525519
finished: bool = False,
@@ -537,7 +531,6 @@ def __init__(
537531
self.prompt = prompt
538532
self.prompt_token_ids = prompt_token_ids
539533
self.prompt_logprobs = prompt_logprobs
540-
self.prompt_logprobs_tensors = prompt_logprobs_tensors
541534
self.output_type = output_type
542535
self.outputs = outputs
543536
self.finished = finished

fastdeploy/engine/sampling_params.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616

1717
from __future__ import annotations
1818

19-
import os
2019
import random
2120
from dataclasses import dataclass, fields
2221
from enum import Enum
2322
from typing import Any, List, Optional, Union
2423

24+
from fastdeploy import envs
25+
2526

2627
@dataclass
2728
class SamplingParams:
@@ -207,12 +208,17 @@ def _verify_args(self) -> None:
207208
raise ValueError(
208209
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
209210
)
210-
if self.logprobs is not None and self.logprobs < -1:
211-
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
212-
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
213-
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
214-
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
215-
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
211+
212+
if not envs.FD_USE_GET_SAVE_OUTPUT_V1: # False (0)
213+
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20):
214+
raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.")
215+
if self.prompt_logprobs is not None:
216+
raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.")
217+
else: # True (1)
218+
if self.logprobs is not None and self.logprobs < -1:
219+
raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.")
220+
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
221+
raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.")
216222

217223
if not 0 <= self.seed <= 922337203685477580:
218224
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")

fastdeploy/entrypoints/engine_client.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,11 @@ class EngineClient:
5656
EngineClient is a class that handles the communication between the client and the server.
5757
"""
5858

59-
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1):
59+
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20):
6060
self.fd_config = fd_config
6161
self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size
6262
self.enable_mm = self.fd_config.model_config.enable_mm
63+
self.max_logprobs = max_logprobs
6364
input_processor = InputPreprocessor(
6465
self.fd_config.model_config,
6566
self.fd_config.structured_outputs_config.reasoning_parser,
@@ -70,6 +71,11 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers
7071
)
7172
self.enable_logprob = self.fd_config.model_config.enable_logprob
7273
self.data_processor = input_processor.create_processor()
74+
self.ori_vocab_size = (
75+
len(self.data_processor.tokenizer.sp_model)
76+
if hasattr(self.data_processor.tokenizer, "sp_model")
77+
else len(self.data_processor.tokenizer.vocab)
78+
)
7379
self.max_model_len = self.fd_config.model_config.max_model_len
7480
self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching
7581
self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed"
@@ -424,6 +430,53 @@ def valid_parameters(self, data):
424430
elif logprobs:
425431
raise ParameterError("logprobs", "Invalid type for 'logprobs'")
426432

433+
max_logprobs = self.max_logprobs
434+
if max_logprobs == -1:
435+
max_logprobs = self.ori_vocab_size
436+
if max_logprobs < -1:
437+
err_msg = f"Invalid 'max_logprobs': must be >= -1, got {max_logprobs}."
438+
api_server_logger.error(err_msg)
439+
raise ValueError("max_logprobs", err_msg)
440+
if max_logprobs > self.ori_vocab_size:
441+
err_msg = f"Invalid 'max_logprobs': must be <= vocab_size {self.ori_vocab_size}, got {max_logprobs}."
442+
api_server_logger.error(err_msg)
443+
raise ValueError("max_logprobs", err_msg)
444+
445+
prompt_logprobs = data.get("prompt_logprobs", None)
446+
447+
if prompt_logprobs is not None:
448+
if not self.enable_logprob:
449+
err_msg = "`enable_logprob` is disabled, please enable it in startup config."
450+
api_server_logger.error(err_msg)
451+
raise ParameterError("prompt_logprobs", err_msg)
452+
453+
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
454+
err_msg = "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled."
455+
api_server_logger.error(err_msg)
456+
raise ParameterError("prompt_logprobs", err_msg)
457+
458+
if self.enable_prefix_caching:
459+
err_msg = "prompt_logprobs is not support when prefix caching is enabled."
460+
api_server_logger.error(err_msg)
461+
raise ParameterError("prompt_logprobs", err_msg)
462+
463+
if prompt_logprobs == -1 and self.ori_vocab_size > max_logprobs:
464+
err_msg = f"The requested value of ({self.ori_vocab_size}) for prompt_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
465+
api_server_logger.error(err_msg)
466+
raise ValueError("prompt_logprobs", err_msg)
467+
468+
if prompt_logprobs < -1:
469+
err_msg = (
470+
f"prompt_logprobs must be a non-negative value or -1; the current value is {prompt_logprobs}."
471+
)
472+
api_server_logger.error(err_msg)
473+
raise ValueError("prompt_logprobs", err_msg)
474+
475+
if prompt_logprobs > max_logprobs:
476+
err_msg = f"Number of prompt_logprobs requested ({prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
477+
api_server_logger.error(err_msg)
478+
raise ValueError("prompt_logprobs", err_msg)
479+
427480
# enable_logprob
428481
if top_logprobs:
429482
if not self.enable_logprob:
@@ -437,15 +490,26 @@ def valid_parameters(self, data):
437490
api_server_logger.error(err_msg)
438491
raise ParameterError("top_logprobs", err_msg)
439492

440-
if top_logprobs < 0:
441-
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
442-
api_server_logger.error(err_msg)
443-
raise ParameterError("top_logprobs", err_msg)
444-
445-
if top_logprobs > 20:
446-
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
447-
api_server_logger.error(err_msg)
448-
raise ParameterError("top_logprobs", err_msg)
493+
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
494+
if top_logprobs < 0 or top_logprobs > 20:
495+
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
496+
api_server_logger.error(err_msg)
497+
raise ValueError("top_logprobs", err_msg)
498+
else:
499+
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
500+
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
501+
api_server_logger.error(err_msg)
502+
raise ValueError("top_logprobs", err_msg)
503+
504+
if top_logprobs < -1:
505+
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
506+
api_server_logger.error(err_msg)
507+
raise ValueError("top_logprobs", err_msg)
508+
509+
if top_logprobs > max_logprobs:
510+
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
511+
api_server_logger.error(err_msg)
512+
raise ValueError("top_logprobs", err_msg)
449513

450514
def check_health(self, time_interval_threashold=30):
451515
"""

fastdeploy/entrypoints/llm.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -335,23 +335,40 @@ def _add_request(
335335
current_sampling_params = sampling_params[i]
336336
else:
337337
current_sampling_params = sampling_params
338-
if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None:
339-
raise ValueError("prompt_logprobs is not supported with streaming.")
338+
339+
ori_vocab_size = (
340+
len(self.llm_engine.data_processor.tokenizer.sp_model)
341+
if hasattr(self.llm_engine.data_processor.tokenizer, "sp_model")
342+
else len(self.llm_engine.data_processor.tokenizer.vocab)
343+
)
340344
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
341345
if max_logprobs == -1:
342-
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
346+
max_logprobs = ori_vocab_size
347+
if max_logprobs < -1:
348+
raise ValueError(f"max_logprobs ({max_logprobs}) can't be less than -1.")
349+
if max_logprobs > ori_vocab_size:
350+
raise ValueError(f"max_logprobs ({max_logprobs}) exceeds vocabulary size ({ori_vocab_size}).")
351+
343352
if current_sampling_params.logprobs is not None:
344353
num_logprobs = current_sampling_params.logprobs
345-
if num_logprobs == -1:
346-
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
354+
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
355+
raise ValueError(
356+
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
357+
)
347358
if num_logprobs > max_logprobs:
348359
raise ValueError(
349360
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
350361
)
351362
if current_sampling_params.prompt_logprobs is not None:
363+
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
364+
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
365+
if kwargs.get("stream"):
366+
raise ValueError("prompt_logprobs is not supported with streaming.")
352367
num_prompt_logprobs = current_sampling_params.prompt_logprobs
353-
if num_prompt_logprobs == -1:
354-
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
368+
if num_prompt_logprobs == -1 and ori_vocab_size > max_logprobs:
369+
raise ValueError(
370+
f"Number of prompt_logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
371+
)
355372
if num_prompt_logprobs > max_logprobs:
356373
raise ValueError(
357374
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
@@ -436,7 +453,7 @@ def _build_prompt_logprobs(
436453
prompt_token_ranks = ranks.tolist()
437454
prompt_logprobs = logprobs.tolist()
438455
token_ids = token_ids.tolist()
439-
result: Optional[PromptLogprobs] = []
456+
result: Optional[PromptLogprobs] = [None]
440457
# Make Logprob for each position.
441458
for pos in range(num_prompt_tokens):
442459
# Handle flattening.
@@ -548,11 +565,11 @@ def _run_engine(
548565
result.outputs.logprobs = self._build_sample_logprobs(
549566
result.outputs.top_logprobs, topk_logprobs
550567
)
551-
if result.prompt_logprobs_tensors and num_prompt_logprobs:
568+
if result.prompt_logprobs is not None and num_prompt_logprobs is not None:
552569
if num_prompt_logprobs == -1:
553570
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
554571
result.prompt_logprobs = self._build_prompt_logprobs(
555-
result.prompt_logprobs_tensors, num_prompt_logprobs
572+
result.prompt_logprobs, num_prompt_logprobs
556573
)
557574

558575
output[pos] = result

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ async def lifespan(app: FastAPI):
181181
port=int(os.environ.get("INFERENCE_MSG_QUEUE_ID", "0")),
182182
fd_config=fd_config,
183183
workers=args.workers,
184+
max_logprobs=args.max_logprobs,
184185
)
185186
await engine_client.connection_manager.initialize()
186187
app.state.dynamic_load_weight = args.dynamic_load_weight

0 commit comments

Comments
 (0)