diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 36e57c0713e3..8f91ca6fcddf 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1694,6 +1694,62 @@ def apply_chat_template( else: return rendered_chat + def encode_message_with_chat_template( + self, + message: dict[str, str], + conversation_history: Optional[list[dict[str, str]]] = None, + **kwargs, + ) -> list[int]: + """ + Tokenize a single message. This method is a convenience wrapper around `apply_chat_template` that allows you + to tokenize messages one by one. This is useful for things like token-by-token streaming. + This method is not guaranteed to be perfect. For some models, it may be impossible to robustly tokenize + single messages. For example, if the chat template adds tokens after each message, but also has a prefix that + is added to the entire chat, it will be impossible to distinguish a chat-start-token from a message-start-token. + In these cases, this method will do its best to find the correct tokenization, but it may not be perfect. + **Note:** This method does not support `add_generation_prompt`. If you want to add a generation prompt, + you should do it separately after tokenizing the conversation. + Args: + message (`dict`): + A dictionary with "role" and "content" keys, representing the message to tokenize. + conversation_history (`list[dict]`, *optional*): + A list of dicts with "role" and "content" keys, representing the chat history so far. If you are + tokenizing messages one by one, you should pass the previous messages in the conversation here. + **kwargs: + Additional kwargs to pass to the `apply_chat_template` method. + Returns: + `list[int]`: A list of token ids representing the tokenized message. + """ + if "add_generation_prompt" in kwargs: + raise ValueError( + "`encode_message_with_chat_template` does not support `add_generation_prompt`. Please add the generation prompt " + "separately." + ) + + if conversation_history is None or len(conversation_history) == 0: + return self.apply_chat_template([message], add_generation_prompt=False, tokenize=True, **kwargs) + + conversation = conversation_history + [message] + tokens = self.apply_chat_template(conversation, add_generation_prompt=False, tokenize=True, **kwargs) + + prefix_tokens = self.apply_chat_template( + conversation_history, add_generation_prompt=False, tokenize=True, **kwargs + ) + # It's possible that the prefix tokens are not a prefix of the full list of tokens. + # For example, if the prefix is `User: Hi` and the full conversation is `User: HiAssistant: Hello`. + # In this case, we can't simply find the prefix, so we have to do something a bit more subtle. + # We look for the first place where the tokens differ, and that's our split point. + # This is not perfect, but it's the best we can do without a token-level API. + # To make this more robust, we could do a diff and find the longest common subsequence, but this is + # a good first approximation. + # This is particularly important for models like Llama3 that have changed their chat template to include + # EOS tokens after user messages. + min_len = min(len(prefix_tokens), len(tokens)) + for i in range(min_len): + if prefix_tokens[i] != tokens[i]: + return tokens[i:] + return tokens[min_len:] + def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str: """ Retrieve the chat template string used for tokenizing chat messages. This template is used diff --git a/tests/tokenization/test_tokenization_utils.py b/tests/tokenization/test_tokenization_utils.py index dd1aae486d13..fc74223110f8 100644 --- a/tests/tokenization/test_tokenization_utils.py +++ b/tests/tokenization/test_tokenization_utils.py @@ -24,6 +24,7 @@ import numpy as np from transformers import ( + AutoTokenizer, BatchEncoding, BertTokenizer, BertTokenizerFast, @@ -375,3 +376,32 @@ def test_training_new_tokenizer_edge_cases(self): tokenizer = PreTrainedTokenizerFast(tokenizer_object=_tokenizer) toy_text_iterator = ("a" for _ in range(1000)) tokenizer.train_new_from_iterator(text_iterator=toy_text_iterator, length=1000, vocab_size=50) + + def test_encode_message(self): + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + conversation = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hey there, how are you?"}, + {"role": "assistant", "content": "Thank you for asking, I am doing well"}, + {"role": "user", "content": "What's the weather like today?"}, + {"role": "assistant", "content": "Today the weather is nice"}, + ] + + # First, test the default case, where we encode the whole conversation at once + whole_conversation_tokens = tokenizer.apply_chat_template(conversation, tokenize=True) + + # Now, test the message-by-message encoding + tokens = [] + for i, message in enumerate(conversation): + tokens += tokenizer.encode_message_with_chat_template(message, conversation_history=conversation[:i]) + + self.assertEqual(whole_conversation_tokens, tokens) + + def test_encode_message_raises_on_add_generation_prompt(self): + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + conversation = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hey there, how are you?"}, + ] + with self.assertRaises(ValueError): + tokenizer.encode_message_with_chat_template(conversation[0], add_generation_prompt=True)