|
4 | 4 | import torch
|
5 | 5 | import torch.nn.functional as F
|
6 | 6 | from loguru import logger
|
7 |
| -from transformers.generation.logits_process import ( |
8 |
| - LogitsProcessorList, |
9 |
| - RepetitionPenaltyLogitsProcessor, |
10 |
| - TemperatureLogitsWarper, |
11 |
| - TopKLogitsWarper, |
12 |
| - TopPLogitsWarper, |
13 |
| -) |
14 | 7 |
|
15 | 8 | from api.apapter import get_prompt_adapter
|
16 | 9 | from api.generation.baichuan import build_baichuan_chat_input, check_is_baichuan
|
17 | 10 | from api.generation.chatglm import generate_stream_chatglm, check_is_chatglm
|
18 | 11 | from api.generation.qwen import build_qwen_chat_input, check_is_qwen
|
| 12 | +from api.generation.utils import prepare_logits_processor, is_partial_stop, get_context_length |
19 | 13 | from api.utils.constants import ErrorCode
|
20 | 14 | from api.utils.protocol import ChatMessage
|
21 | 15 |
|
|
24 | 18 | )
|
25 | 19 |
|
26 | 20 |
|
27 |
| -def prepare_logits_processor( |
28 |
| - temperature: float, repetition_penalty: float, top_p: float, top_k: int |
29 |
| -) -> LogitsProcessorList: |
30 |
| - processor_list = LogitsProcessorList() |
31 |
| - # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. |
32 |
| - if temperature >= 1e-5 and temperature != 1.0: |
33 |
| - processor_list.append(TemperatureLogitsWarper(temperature)) |
34 |
| - if repetition_penalty > 1.0: |
35 |
| - processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) |
36 |
| - if 1e-8 <= top_p < 1.0: |
37 |
| - processor_list.append(TopPLogitsWarper(top_p)) |
38 |
| - if top_k > 0: |
39 |
| - processor_list.append(TopKLogitsWarper(top_k)) |
40 |
| - return processor_list |
41 |
| - |
42 |
| - |
43 |
| -def is_partial_stop(output: str, stop_str: str): |
44 |
| - """Check whether the output contains a partial stop str.""" |
45 |
| - for i in range(0, min(len(output), len(stop_str))): |
46 |
| - if stop_str.startswith(output[-i:]): |
47 |
| - return True |
48 |
| - return False |
49 |
| - |
50 |
| - |
51 | 21 | @torch.inference_mode()
|
52 | 22 | def generate_stream(
|
53 | 23 | model,
|
@@ -76,9 +46,9 @@ def generate_stream(
|
76 | 46 | )
|
77 | 47 |
|
78 | 48 | if isinstance(prompt, list) and check_is_baichuan(model):
|
79 |
| - input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len) |
| 49 | + input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len, max_new_tokens) |
80 | 50 | elif isinstance(prompt, list) and check_is_qwen(model):
|
81 |
| - input_ids = build_qwen_chat_input(tokenizer, prompt) |
| 51 | + input_ids = build_qwen_chat_input(tokenizer, prompt, context_len, max_new_tokens) |
82 | 52 | stop_token_ids.extend([tokenizer.im_end_id, tokenizer.im_start_id])
|
83 | 53 | else:
|
84 | 54 | input_ids = tokenizer(prompt).input_ids
|
@@ -262,25 +232,6 @@ def generate_stream(
|
262 | 232 | torch.cuda.empty_cache()
|
263 | 233 |
|
264 | 234 |
|
265 |
| -SEQUENCE_LENGTH_KEYS = [ |
266 |
| - "max_sequence_length", |
267 |
| - "seq_length", |
268 |
| - "max_position_embeddings", |
269 |
| - "max_seq_len", |
270 |
| - "model_max_length", |
271 |
| -] |
272 |
| - |
273 |
| - |
274 |
| -def get_context_length(config): |
275 |
| - """Get the context length of a model from a huggingface model config.""" |
276 |
| - for key in SEQUENCE_LENGTH_KEYS: |
277 |
| - if hasattr(config, key): |
278 |
| - val = getattr(config, key) |
279 |
| - if val is not None: |
280 |
| - return val |
281 |
| - return 2048 |
282 |
| - |
283 |
| - |
284 | 235 | class ModelServer:
|
285 | 236 | def __init__(
|
286 | 237 | self,
|
@@ -312,7 +263,7 @@ def __init__(
|
312 | 263 | logger.info("Using Qwen Model for Chat!")
|
313 | 264 | self.construct_prompt = False
|
314 | 265 | self.generate_stream_func = generate_stream
|
315 |
| - self.context_len = 8192 |
| 266 | + self.context_len = 8192 if self.context_len is None else self.context_len |
316 | 267 | else:
|
317 | 268 | self.generate_stream_func = generate_stream
|
318 | 269 |
|
|
0 commit comments