Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
import types
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from inspect import isfunction
Expand Down Expand Up @@ -503,10 +504,27 @@ def render_jinja_template(

rendered = []
all_generation_indices = []
continue_final_message_tag = "CONTINUE_FINAL_MESSAGE_TAG "
for chat in conversations:
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
if continue_final_message:
chat = deepcopy(chat)
final_message = chat[-1]["content"]
if isinstance(final_message, (list, tuple)):
for content_block in reversed(final_message):
if "text" in content_block:
# Pick the last text block in the message (the first one we hit while iterating in reverse)
final_message = content_block["text"]
content_block["text"] = content_block["text"] + continue_final_message_tag
break
else:
raise ValueError(
"continue_final_message is set but we could not find any text to continue in the final message!"
)
else:
chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag
if return_assistant_tokens_mask:
rendered_chat, generation_indices = _render_with_assistant_indices(
compiled_template=compiled_template,
Expand All @@ -526,31 +544,22 @@ def render_jinja_template(
**kwargs,
)
if continue_final_message:
final_message = chat[-1]["content"]
if isinstance(final_message, (list, tuple)):
for content_block in reversed(final_message):
if "text" in content_block:
# Pick the last text block in the message (the first one we hit while iterating in reverse)
final_message = content_block["text"]
break
else:
raise ValueError(
"continue_final_message is set but we could not find any text to continuein the final message!"
)
if final_message.strip() not in rendered_chat:
if (final_message.strip() not in rendered_chat) or (
continue_final_message_tag.strip() not in rendered_chat
):
raise ValueError(
"continue_final_message is set but the final message does not appear in the chat after "
"applying the chat template! This can happen if the chat template deletes portions of "
"the final message. Please verify the chat template and final message in your chat to "
"ensure they are compatible."
)
final_msg_loc = rendered_chat.rindex(final_message.strip())
if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message:
# The template preserves spacing or the message doesn't have trailing spacing, so things are simple
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())]
tag_loc = rendered_chat.rindex(continue_final_message_tag.strip())
if rendered_chat[tag_loc : tag_loc + len(continue_final_message_tag)] == continue_final_message_tag:
# The template preserves spacing, so things are simple
rendered_chat = rendered_chat[:tag_loc]
else:
# The message has trailing spacing that was trimmed, so we must be more cautious
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())]
rendered_chat = rendered_chat[:tag_loc].rstrip()
rendered.append(rendered_chat)

return rendered, all_generation_indices