Skip to content

Commit cb289ad

Browse files
pco111pco111ArthurZucker
authored
feat(tokenization): add encode_message to tokenize messages one by one (#39507)
* feat(tokenization): add encode_message to tokenize messages one by one * Fix the `encode_message` method, remove the `add_generation_prompt` parameter and add the corresponding error handling. Update the document to reflect this change and verify the error handling in the test. * Optimize the `encode_message` method, improve the processing logic of the empty dialogue history, and ensure that the chat template can be applied correctly when the dialogue history is empty. Update the document to reflect these changes. * The `_encode_message` method is deleted, the message coding logic is simplified, and the functional integrity of the `encode_message` method is ensured. Update the document to reflect these changes. * Docs fix * Revert changes in docstring of pad() * Revert changes in docstring * Update src/transformers/tokenization_utils_base.py Co-authored-by: Arthur <[email protected]> * Repair the call of the `encode_message` method, update it to `encode_message_with_chat_template` to support the chat template, and adjust the relevant test cases to reflect this change. * Optimize the call format of the `apply_chat_template` method, and merge multi-line calls into a single line to improve code readability. --------- Co-authored-by: pco111 <[email protected]> Co-authored-by: Arthur <[email protected]>
1 parent 4f93cc9 commit cb289ad

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,62 @@ def apply_chat_template(
16941694
else:
16951695
return rendered_chat
16961696

1697+
def encode_message_with_chat_template(
1698+
self,
1699+
message: dict[str, str],
1700+
conversation_history: Optional[list[dict[str, str]]] = None,
1701+
**kwargs,
1702+
) -> list[int]:
1703+
"""
1704+
Tokenize a single message. This method is a convenience wrapper around `apply_chat_template` that allows you
1705+
to tokenize messages one by one. This is useful for things like token-by-token streaming.
1706+
This method is not guaranteed to be perfect. For some models, it may be impossible to robustly tokenize
1707+
single messages. For example, if the chat template adds tokens after each message, but also has a prefix that
1708+
is added to the entire chat, it will be impossible to distinguish a chat-start-token from a message-start-token.
1709+
In these cases, this method will do its best to find the correct tokenization, but it may not be perfect.
1710+
**Note:** This method does not support `add_generation_prompt`. If you want to add a generation prompt,
1711+
you should do it separately after tokenizing the conversation.
1712+
Args:
1713+
message (`dict`):
1714+
A dictionary with "role" and "content" keys, representing the message to tokenize.
1715+
conversation_history (`list[dict]`, *optional*):
1716+
A list of dicts with "role" and "content" keys, representing the chat history so far. If you are
1717+
tokenizing messages one by one, you should pass the previous messages in the conversation here.
1718+
**kwargs:
1719+
Additional kwargs to pass to the `apply_chat_template` method.
1720+
Returns:
1721+
`list[int]`: A list of token ids representing the tokenized message.
1722+
"""
1723+
if "add_generation_prompt" in kwargs:
1724+
raise ValueError(
1725+
"`encode_message_with_chat_template` does not support `add_generation_prompt`. Please add the generation prompt "
1726+
"separately."
1727+
)
1728+
1729+
if conversation_history is None or len(conversation_history) == 0:
1730+
return self.apply_chat_template([message], add_generation_prompt=False, tokenize=True, **kwargs)
1731+
1732+
conversation = conversation_history + [message]
1733+
tokens = self.apply_chat_template(conversation, add_generation_prompt=False, tokenize=True, **kwargs)
1734+
1735+
prefix_tokens = self.apply_chat_template(
1736+
conversation_history, add_generation_prompt=False, tokenize=True, **kwargs
1737+
)
1738+
# It's possible that the prefix tokens are not a prefix of the full list of tokens.
1739+
# For example, if the prefix is `<s>User: Hi` and the full conversation is `<s>User: Hi</s><s>Assistant: Hello`.
1740+
# In this case, we can't simply find the prefix, so we have to do something a bit more subtle.
1741+
# We look for the first place where the tokens differ, and that's our split point.
1742+
# This is not perfect, but it's the best we can do without a token-level API.
1743+
# To make this more robust, we could do a diff and find the longest common subsequence, but this is
1744+
# a good first approximation.
1745+
# This is particularly important for models like Llama3 that have changed their chat template to include
1746+
# EOS tokens after user messages.
1747+
min_len = min(len(prefix_tokens), len(tokens))
1748+
for i in range(min_len):
1749+
if prefix_tokens[i] != tokens[i]:
1750+
return tokens[i:]
1751+
return tokens[min_len:]
1752+
16971753
def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str:
16981754
"""
16991755
Retrieve the chat template string used for tokenizing chat messages. This template is used

tests/tokenization/test_tokenization_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525

2626
from transformers import (
27+
AutoTokenizer,
2728
BatchEncoding,
2829
BertTokenizer,
2930
BertTokenizerFast,
@@ -375,3 +376,32 @@ def test_training_new_tokenizer_edge_cases(self):
375376
tokenizer = PreTrainedTokenizerFast(tokenizer_object=_tokenizer)
376377
toy_text_iterator = ("a" for _ in range(1000))
377378
tokenizer.train_new_from_iterator(text_iterator=toy_text_iterator, length=1000, vocab_size=50)
379+
380+
def test_encode_message(self):
381+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
382+
conversation = [
383+
{"role": "system", "content": "You are a helpful assistant"},
384+
{"role": "user", "content": "Hey there, how are you?"},
385+
{"role": "assistant", "content": "Thank you for asking, I am doing well"},
386+
{"role": "user", "content": "What's the weather like today?"},
387+
{"role": "assistant", "content": "Today the weather is nice"},
388+
]
389+
390+
# First, test the default case, where we encode the whole conversation at once
391+
whole_conversation_tokens = tokenizer.apply_chat_template(conversation, tokenize=True)
392+
393+
# Now, test the message-by-message encoding
394+
tokens = []
395+
for i, message in enumerate(conversation):
396+
tokens += tokenizer.encode_message_with_chat_template(message, conversation_history=conversation[:i])
397+
398+
self.assertEqual(whole_conversation_tokens, tokens)
399+
400+
def test_encode_message_raises_on_add_generation_prompt(self):
401+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
402+
conversation = [
403+
{"role": "system", "content": "You are a helpful assistant"},
404+
{"role": "user", "content": "Hey there, how are you?"},
405+
]
406+
with self.assertRaises(ValueError):
407+
tokenizer.encode_message_with_chat_template(conversation[0], add_generation_prompt=True)

0 commit comments

Comments
 (0)