13
13
)
14
14
15
15
from api .apapter import get_prompt_adapter
16
- from api .utils .constants import ErrorCode
17
16
from api .generation .baichuan import build_baichuan_chat_input , check_is_baichuan
18
17
from api .generation .chatglm import generate_stream_chatglm , check_is_chatglm
19
18
from api .generation .qwen import build_qwen_chat_input , check_is_qwen
19
+ from api .utils .constants import ErrorCode
20
20
from api .utils .protocol import ChatMessage
21
21
22
22
server_error_msg = (
@@ -298,11 +298,7 @@ def __init__(
298
298
self .model_name = model_name .lower ()
299
299
self .prompt_name = prompt_name .lower () if prompt_name is not None else None
300
300
self .stream_interval = stream_interval
301
-
302
- if context_len is None :
303
- self .context_len = get_context_length (self .model .config )
304
- else :
305
- self .context_len = context_len
301
+ self .context_len = context_len
306
302
307
303
self .construct_prompt = True
308
304
if check_is_chatglm (self .model ):
@@ -316,10 +312,13 @@ def __init__(
316
312
logger .info ("Using Qwen Model for Chat!" )
317
313
self .construct_prompt = False
318
314
self .generate_stream_func = generate_stream
315
+ self .context_len = 8192
319
316
else :
320
317
self .generate_stream_func = generate_stream
321
318
322
319
self .prompt_adapter = get_prompt_adapter (self .model_name , prompt_name = self .prompt_name )
320
+ if self .context_len is None :
321
+ self .context_len = get_context_length (self .model .config )
323
322
324
323
def count_token (self , params ):
325
324
prompt = params ["prompt" ]
0 commit comments