Skip to content

Commit 625b876

Browse files
author
xusenlin
committed
Fix model max length
1 parent d8313af commit 625b876

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

api/generation/core.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
)
1414

1515
from api.apapter import get_prompt_adapter
16-
from api.utils.constants import ErrorCode
1716
from api.generation.baichuan import build_baichuan_chat_input, check_is_baichuan
1817
from api.generation.chatglm import generate_stream_chatglm, check_is_chatglm
1918
from api.generation.qwen import build_qwen_chat_input, check_is_qwen
19+
from api.utils.constants import ErrorCode
2020
from api.utils.protocol import ChatMessage
2121

2222
server_error_msg = (
@@ -298,11 +298,7 @@ def __init__(
298298
self.model_name = model_name.lower()
299299
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
300300
self.stream_interval = stream_interval
301-
302-
if context_len is None:
303-
self.context_len = get_context_length(self.model.config)
304-
else:
305-
self.context_len = context_len
301+
self.context_len = context_len
306302

307303
self.construct_prompt = True
308304
if check_is_chatglm(self.model):
@@ -316,10 +312,13 @@ def __init__(
316312
logger.info("Using Qwen Model for Chat!")
317313
self.construct_prompt = False
318314
self.generate_stream_func = generate_stream
315+
self.context_len = 8192
319316
else:
320317
self.generate_stream_func = generate_stream
321318

322319
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
320+
if self.context_len is None:
321+
self.context_len = get_context_length(self.model.config)
323322

324323
def count_token(self, params):
325324
prompt = params["prompt"]

api/models.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def get_generate_model():
4747
)
4848

4949

50+
def get_context_len(model_config):
51+
if "qwen" in config.MODEL_NAME.lower():
52+
max_model_len = config.CONTEXT_LEN or 8192
53+
else:
54+
max_model_len = config.CONTEXT_LEN or model_config.get_max_model_len()
55+
return max_model_len
56+
57+
5058
def get_vllm_engine():
5159
try:
5260
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -76,7 +84,7 @@ def get_vllm_engine():
7684
)
7785

7886
engine_model_config = asyncio.run(engine.get_model_config())
79-
engine.max_model_len = engine_model_config.get_max_model_len()
87+
engine.max_model_len = get_context_len(engine_model_config)
8088

8189
return engine
8290

0 commit comments

Comments
 (0)