Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,62 @@ def apply_chat_template(
else:
return rendered_chat

def encode_message_with_chat_template(
self,
message: dict[str, str],
conversation_history: Optional[list[dict[str, str]]] = None,
**kwargs,
) -> list[int]:
"""
Tokenize a single message. This method is a convenience wrapper around `apply_chat_template` that allows you
to tokenize messages one by one. This is useful for things like token-by-token streaming.
This method is not guaranteed to be perfect. For some models, it may be impossible to robustly tokenize
single messages. For example, if the chat template adds tokens after each message, but also has a prefix that
is added to the entire chat, it will be impossible to distinguish a chat-start-token from a message-start-token.
In these cases, this method will do its best to find the correct tokenization, but it may not be perfect.
**Note:** This method does not support `add_generation_prompt`. If you want to add a generation prompt,
you should do it separately after tokenizing the conversation.
Args:
message (`dict`):
A dictionary with "role" and "content" keys, representing the message to tokenize.
conversation_history (`list[dict]`, *optional*):
A list of dicts with "role" and "content" keys, representing the chat history so far. If you are
tokenizing messages one by one, you should pass the previous messages in the conversation here.
**kwargs:
Additional kwargs to pass to the `apply_chat_template` method.
Returns:
`list[int]`: A list of token ids representing the tokenized message.
"""
if "add_generation_prompt" in kwargs:
raise ValueError(
"`encode_message_with_chat_template` does not support `add_generation_prompt`. Please add the generation prompt "
"separately."
)

if conversation_history is None or len(conversation_history) == 0:
return self.apply_chat_template([message], add_generation_prompt=False, tokenize=True, **kwargs)

conversation = conversation_history + [message]
tokens = self.apply_chat_template(conversation, add_generation_prompt=False, tokenize=True, **kwargs)

prefix_tokens = self.apply_chat_template(
conversation_history, add_generation_prompt=False, tokenize=True, **kwargs
)
# It's possible that the prefix tokens are not a prefix of the full list of tokens.
# For example, if the prefix is `<s>User: Hi` and the full conversation is `<s>User: Hi</s><s>Assistant: Hello`.
# In this case, we can't simply find the prefix, so we have to do something a bit more subtle.
# We look for the first place where the tokens differ, and that's our split point.
# This is not perfect, but it's the best we can do without a token-level API.
# To make this more robust, we could do a diff and find the longest common subsequence, but this is
# a good first approximation.
# This is particularly important for models like Llama3 that have changed their chat template to include
# EOS tokens after user messages.
min_len = min(len(prefix_tokens), len(tokens))
for i in range(min_len):
if prefix_tokens[i] != tokens[i]:
return tokens[i:]
return tokens[min_len:]

def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str:
"""
Retrieve the chat template string used for tokenizing chat messages. This template is used
Expand Down
30 changes: 30 additions & 0 deletions tests/tokenization/test_tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np

from transformers import (
AutoTokenizer,
BatchEncoding,
BertTokenizer,
BertTokenizerFast,
Expand Down Expand Up @@ -375,3 +376,32 @@ def test_training_new_tokenizer_edge_cases(self):
tokenizer = PreTrainedTokenizerFast(tokenizer_object=_tokenizer)
toy_text_iterator = ("a" for _ in range(1000))
tokenizer.train_new_from_iterator(text_iterator=toy_text_iterator, length=1000, vocab_size=50)

def test_encode_message(self):
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hey there, how are you?"},
{"role": "assistant", "content": "Thank you for asking, I am doing well"},
{"role": "user", "content": "What's the weather like today?"},
{"role": "assistant", "content": "Today the weather is nice"},
]

# First, test the default case, where we encode the whole conversation at once
whole_conversation_tokens = tokenizer.apply_chat_template(conversation, tokenize=True)

# Now, test the message-by-message encoding
tokens = []
for i, message in enumerate(conversation):
tokens += tokenizer.encode_message_with_chat_template(message, conversation_history=conversation[:i])

self.assertEqual(whole_conversation_tokens, tokens)

def test_encode_message_raises_on_add_generation_prompt(self):
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hey there, how are you?"},
]
with self.assertRaises(ValueError):
tokenizer.encode_message_with_chat_template(conversation[0], add_generation_prompt=True)