Skip to content

Commit 141a5bd

Browse files
author
xusenlin
committed
Fix model input for chat
1 parent 625b876 commit 141a5bd

File tree

5 files changed

+168
-112
lines changed

5 files changed

+168
-112
lines changed

api/generation/baichuan.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,44 @@
11
from typing import List
22

3+
from transformers import PreTrainedTokenizer
4+
5+
from api.generation.utils import parse_messages
36
from api.utils.protocol import Role, ChatMessage
47

58

6-
def build_baichuan_chat_input(tokenizer, messages: List[ChatMessage], context_len: int = 4096):
7-
""" https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/modeling_baichuan.py """
8-
total_input, round_input = [], []
9-
for message in messages[::-1]:
10-
role, content_tokens = message.role, tokenizer.encode(message.content)
11-
if role in [Role.USER, Role.SYSTEM]:
12-
round_input = [195] + content_tokens + round_input
13-
if total_input and len(total_input) + len(round_input) > context_len:
14-
break
9+
def build_baichuan_chat_input(
10+
tokenizer: PreTrainedTokenizer,
11+
messages: List[ChatMessage],
12+
context_len: int = 4096,
13+
max_new_tokens: int = 256
14+
) -> List[int]:
15+
""" https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_utils.py """
16+
max_input_tokens = context_len - max_new_tokens
17+
system, rounds = parse_messages(messages)
18+
system_tokens = tokenizer.encode(system)
19+
max_history_tokens = max_input_tokens - len(system_tokens)
20+
21+
history_tokens = []
22+
for round in rounds[::-1]:
23+
round_tokens = []
24+
for message in round:
25+
if message.role == Role.USER:
26+
round_tokens.append(195)
1527
else:
16-
total_input = round_input + total_input
17-
round_input = []
18-
elif role == Role.ASSISTANT:
19-
round_input = [196] + content_tokens + round_input
20-
else:
21-
raise ValueError(f"message role not supported yet: {role}")
22-
total_input = total_input[-context_len:] # truncate left
23-
total_input.append(196)
24-
return total_input
28+
round_tokens.append(196)
29+
round_tokens.extend(tokenizer.encode(message.content))
30+
31+
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
32+
history_tokens = round_tokens + history_tokens # concat left
33+
if len(history_tokens) < max_history_tokens:
34+
continue
35+
break
36+
37+
input_tokens = system_tokens + history_tokens
38+
if messages[-1].role != Role.ASSISTANT:
39+
input_tokens.append(196)
40+
41+
return input_tokens[-max_input_tokens:] # truncate left
2542

2643

2744
def check_is_baichuan(model):

api/generation/core.py

Lines changed: 4 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,12 @@
44
import torch
55
import torch.nn.functional as F
66
from loguru import logger
7-
from transformers.generation.logits_process import (
8-
LogitsProcessorList,
9-
RepetitionPenaltyLogitsProcessor,
10-
TemperatureLogitsWarper,
11-
TopKLogitsWarper,
12-
TopPLogitsWarper,
13-
)
147

158
from api.apapter import get_prompt_adapter
169
from api.generation.baichuan import build_baichuan_chat_input, check_is_baichuan
1710
from api.generation.chatglm import generate_stream_chatglm, check_is_chatglm
1811
from api.generation.qwen import build_qwen_chat_input, check_is_qwen
12+
from api.generation.utils import prepare_logits_processor, is_partial_stop, get_context_length
1913
from api.utils.constants import ErrorCode
2014
from api.utils.protocol import ChatMessage
2115

@@ -24,30 +18,6 @@
2418
)
2519

2620

27-
def prepare_logits_processor(
28-
temperature: float, repetition_penalty: float, top_p: float, top_k: int
29-
) -> LogitsProcessorList:
30-
processor_list = LogitsProcessorList()
31-
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
32-
if temperature >= 1e-5 and temperature != 1.0:
33-
processor_list.append(TemperatureLogitsWarper(temperature))
34-
if repetition_penalty > 1.0:
35-
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
36-
if 1e-8 <= top_p < 1.0:
37-
processor_list.append(TopPLogitsWarper(top_p))
38-
if top_k > 0:
39-
processor_list.append(TopKLogitsWarper(top_k))
40-
return processor_list
41-
42-
43-
def is_partial_stop(output: str, stop_str: str):
44-
"""Check whether the output contains a partial stop str."""
45-
for i in range(0, min(len(output), len(stop_str))):
46-
if stop_str.startswith(output[-i:]):
47-
return True
48-
return False
49-
50-
5121
@torch.inference_mode()
5222
def generate_stream(
5323
model,
@@ -76,9 +46,9 @@ def generate_stream(
7646
)
7747

7848
if isinstance(prompt, list) and check_is_baichuan(model):
79-
input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len)
49+
input_ids = build_baichuan_chat_input(tokenizer, prompt, context_len, max_new_tokens)
8050
elif isinstance(prompt, list) and check_is_qwen(model):
81-
input_ids = build_qwen_chat_input(tokenizer, prompt)
51+
input_ids = build_qwen_chat_input(tokenizer, prompt, context_len, max_new_tokens)
8252
stop_token_ids.extend([tokenizer.im_end_id, tokenizer.im_start_id])
8353
else:
8454
input_ids = tokenizer(prompt).input_ids
@@ -262,25 +232,6 @@ def generate_stream(
262232
torch.cuda.empty_cache()
263233

264234

265-
SEQUENCE_LENGTH_KEYS = [
266-
"max_sequence_length",
267-
"seq_length",
268-
"max_position_embeddings",
269-
"max_seq_len",
270-
"model_max_length",
271-
]
272-
273-
274-
def get_context_length(config):
275-
"""Get the context length of a model from a huggingface model config."""
276-
for key in SEQUENCE_LENGTH_KEYS:
277-
if hasattr(config, key):
278-
val = getattr(config, key)
279-
if val is not None:
280-
return val
281-
return 2048
282-
283-
284235
class ModelServer:
285236
def __init__(
286237
self,
@@ -312,7 +263,7 @@ def __init__(
312263
logger.info("Using Qwen Model for Chat!")
313264
self.construct_prompt = False
314265
self.generate_stream_func = generate_stream
315-
self.context_len = 8192
266+
self.context_len = 8192 if self.context_len is None else self.context_len
316267
else:
317268
self.generate_stream_func = generate_stream
318269

api/generation/qwen.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22

33
from transformers import PreTrainedTokenizer
44

5+
from api.generation.baichuan import parse_messages
56
from api.utils.protocol import Role, ChatMessage
67

78

89
def build_qwen_chat_input(
910
tokenizer: PreTrainedTokenizer,
1011
messages: List[ChatMessage],
11-
max_window_size: int = 6144,
12-
):
12+
context_len: int = 8192,
13+
max_new_tokens: int = 256
14+
) -> List[int]:
1315
""" https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py """
16+
max_input_tokens = context_len - max_new_tokens
17+
system, rounds = parse_messages(messages)
18+
system = "You are a helpful assistant." + system # fix system prompt
19+
1420
im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id]
1521
nl_tokens = tokenizer.encode("\n")
1622

@@ -19,31 +25,37 @@ def _tokenize_str(role, content):
1925
role, allowed_special=set()
2026
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
2127

22-
system_tokens_part = _tokenize_str("system", "You are a helpful assistant.")
28+
system_tokens_part = _tokenize_str("system", system)
2329
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
24-
25-
context_tokens = []
26-
for i, message in enumerate(messages[::-1]):
27-
role, content = message.role, message.content
28-
if context_tokens:
29-
context_tokens = nl_tokens + context_tokens
30-
31-
if role == Role.USER:
32-
content_tokens = _tokenize_str("user", content)
33-
elif role == Role.SYSTEM:
34-
content_tokens = _tokenize_str("system", content)
35-
elif role == Role.ASSISTANT:
36-
content_tokens = _tokenize_str("assistant", content)
37-
else:
38-
raise ValueError(f"message role not supported yet: {role}")
39-
40-
if len(im_start_tokens + content_tokens + im_end_tokens + context_tokens) > max_window_size:
41-
break
42-
else:
43-
context_tokens = im_start_tokens + content_tokens + im_end_tokens + context_tokens
44-
45-
context_tokens = system_tokens + nl_tokens + context_tokens
46-
return context_tokens + nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens
30+
max_history_tokens = max_input_tokens - len(system_tokens)
31+
32+
history_tokens = []
33+
for round in rounds[::-1]:
34+
round_tokens = []
35+
for message in round:
36+
if round_tokens:
37+
round_tokens += nl_tokens
38+
39+
if message.role == Role.USER:
40+
content_tokens = im_start_tokens + _tokenize_str("user", message.content) + im_end_tokens
41+
else:
42+
content_tokens = im_start_tokens + _tokenize_str("assistant", message.content) + im_end_tokens
43+
44+
round_tokens.extend(content_tokens)
45+
46+
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
47+
if history_tokens:
48+
history_tokens = nl_tokens + history_tokens
49+
50+
history_tokens = round_tokens + history_tokens # concat left
51+
if len(history_tokens) < max_history_tokens:
52+
continue
53+
break
54+
55+
input_tokens = system_tokens + nl_tokens + history_tokens
56+
if messages[-1].role != Role.ASSISTANT:
57+
input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens
58+
return input_tokens[-max_input_tokens:] # truncate left
4759

4860

4961
def check_is_qwen(model):

api/generation/utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import List
2+
from typing import Tuple
3+
4+
from transformers.generation.logits_process import (
5+
LogitsProcessorList,
6+
RepetitionPenaltyLogitsProcessor,
7+
TemperatureLogitsWarper,
8+
TopKLogitsWarper,
9+
TopPLogitsWarper,
10+
)
11+
12+
from api.utils.protocol import ChatMessage, Role
13+
14+
15+
def parse_messages(messages: List[ChatMessage], split_role=Role.USER) -> Tuple[str, List[List[ChatMessage]]]:
16+
system, rounds = "", []
17+
round = []
18+
for i, message in enumerate(messages):
19+
if message.role == Role.SYSTEM:
20+
system = message.content
21+
continue
22+
if message.role == split_role and round:
23+
rounds.append(round)
24+
round = []
25+
round.append(message)
26+
if round:
27+
rounds.append(round)
28+
return system, rounds
29+
30+
31+
def prepare_logits_processor(
32+
temperature: float, repetition_penalty: float, top_p: float, top_k: int
33+
) -> LogitsProcessorList:
34+
processor_list = LogitsProcessorList()
35+
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
36+
if temperature >= 1e-5 and temperature != 1.0:
37+
processor_list.append(TemperatureLogitsWarper(temperature))
38+
if repetition_penalty > 1.0:
39+
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
40+
if 1e-8 <= top_p < 1.0:
41+
processor_list.append(TopPLogitsWarper(top_p))
42+
if top_k > 0:
43+
processor_list.append(TopKLogitsWarper(top_k))
44+
return processor_list
45+
46+
47+
def is_partial_stop(output: str, stop_str: str):
48+
"""Check whether the output contains a partial stop str."""
49+
for i in range(0, min(len(output), len(stop_str))):
50+
if stop_str.startswith(output[-i:]):
51+
return True
52+
return False
53+
54+
55+
# Models don't use the same configuration key for determining the maximum
56+
# sequence length. Store them here so we can sanely check them.
57+
# NOTE: The ordering here is important. Some models have two of these and we
58+
# have a preference for which value gets used.
59+
SEQUENCE_LENGTH_KEYS = [
60+
"max_sequence_length",
61+
"seq_length",
62+
"max_position_embeddings",
63+
"max_seq_len",
64+
"model_max_length",
65+
]
66+
67+
68+
def get_context_length(config):
69+
"""Get the context length of a model from a huggingface model config."""
70+
rope_scaling = getattr(config, "rope_scaling", None)
71+
if rope_scaling:
72+
rope_scaling_factor = config.rope_scaling["factor"]
73+
else:
74+
rope_scaling_factor = 1
75+
76+
for key in SEQUENCE_LENGTH_KEYS:
77+
val = getattr(config, key, None)
78+
if val is not None:
79+
return int(rope_scaling_factor * val)
80+
return 2048

api/vllm_routes/utils.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,17 @@ async def get_model_inputs(request, prompt, model_name):
2727
input_ids = prompt
2828
else:
2929
if "baichuan-13b" in model_name:
30-
input_ids = build_baichuan_chat_input(VLLM_ENGINE.encode_tokenizer, prompt)
30+
input_ids = build_baichuan_chat_input(
31+
VLLM_ENGINE.encode_tokenizer,
32+
prompt,
33+
max_new_tokens=request.max_tokens,
34+
)
3135
elif "qwen" in model_name:
32-
input_ids = build_qwen_chat_input(VLLM_ENGINE.encode_tokenizer, prompt)
36+
input_ids = build_qwen_chat_input(
37+
VLLM_ENGINE.encode_tokenizer,
38+
prompt,
39+
max_new_tokens=request.max_tokens,
40+
)
3341
else:
3442
raise ValueError(f"Model not supported yet: {model_name}")
35-
36-
token_num = len(input_ids)
37-
if token_num + request.max_tokens > VLLM_ENGINE.max_model_len:
38-
return input_ids, create_error_response(
39-
HTTPStatus.BAD_REQUEST,
40-
f"This model's maximum context length is {VLLM_ENGINE.max_model_len} tokens. "
41-
f"However, you requested {request.max_tokens + token_num} tokens "
42-
f"({token_num} in the messages, "
43-
f"{request.max_tokens} in the completion). "
44-
f"Please reduce the length of the messages or completion.",
45-
)
46-
else:
47-
return input_ids, None
43+
return input_ids, None

0 commit comments

Comments
 (0)