Skip to content

Commit 11a4c9d

Browse files
[Misc] Simplify get_max_tokens (vllm-project#34036)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 15a0b9e commit 11a4c9d

File tree

5 files changed

+8
-30
lines changed

5 files changed

+8
-30
lines changed

vllm/entrypoints/openai/chat_completion/serving.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ async def create_chat_completion(
388388

389389
max_tokens = get_max_tokens(
390390
self.max_model_len,
391-
request,
391+
request.max_completion_tokens
392+
if request.max_completion_tokens is not None
393+
else request.max_tokens,
392394
self._extract_prompt_len(engine_prompt),
393395
self.default_sampling_params,
394396
)

vllm/entrypoints/openai/completion/serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ async def create_completion(
164164

165165
max_tokens = get_max_tokens(
166166
self.max_model_len,
167-
request,
167+
request.max_tokens,
168168
self._extract_prompt_len(engine_prompt),
169169
self.default_sampling_params,
170170
)

vllm/entrypoints/openai/engine/serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ async def _generate_with_builtin_tools(
11761176

11771177
sampling_params.max_tokens = get_max_tokens(
11781178
self.max_model_len,
1179-
context.request,
1179+
context.request.max_output_tokens,
11801180
self._extract_prompt_len(engine_prompt),
11811181
self.default_sampling_params, # type: ignore
11821182
)

vllm/entrypoints/openai/responses/serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ async def create_responses(
441441

442442
default_max_tokens = get_max_tokens(
443443
self.max_model_len,
444-
request,
444+
request.max_output_tokens,
445445
self._extract_prompt_len(engine_prompt),
446446
self.default_sampling_params,
447447
)

vllm/entrypoints/utils.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,11 @@
2222
from vllm.utils.argparse_utils import FlexibleArgumentParser
2323

2424
if TYPE_CHECKING:
25-
from vllm.entrypoints.openai.chat_completion.protocol import (
26-
ChatCompletionRequest,
27-
)
28-
from vllm.entrypoints.openai.completion.protocol import (
29-
CompletionRequest,
30-
)
31-
from vllm.entrypoints.openai.engine.protocol import (
32-
StreamOptions,
33-
)
25+
from vllm.entrypoints.openai.engine.protocol import StreamOptions
3426
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
35-
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
3627
else:
37-
ChatCompletionRequest = object
38-
CompletionRequest = object
3928
StreamOptions = object
4029
LoRAModulePath = object
41-
ResponsesRequest = object
4230

4331

4432
logger = init_logger(__name__)
@@ -186,22 +174,10 @@ def cli_env_setup():
186174

187175
def get_max_tokens(
188176
max_model_len: int,
189-
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
177+
max_tokens: int | None,
190178
input_length: int,
191179
default_sampling_params: dict,
192180
) -> int:
193-
# NOTE: Avoid isinstance() for better efficiency
194-
max_tokens: int | None = None
195-
if max_tokens is None:
196-
# ChatCompletionRequest
197-
max_tokens = getattr(request, "max_completion_tokens", None)
198-
if max_tokens is None:
199-
# ResponsesRequest
200-
max_tokens = getattr(request, "max_output_tokens", None)
201-
if max_tokens is None:
202-
# CompletionRequest (also a fallback for ChatCompletionRequest)
203-
max_tokens = getattr(request, "max_tokens", None)
204-
205181
default_max_tokens = max_model_len - input_length
206182
max_output_tokens = current_platform.get_max_output_tokens(input_length)
207183

0 commit comments

Comments
 (0)