Skip to content

Commit 5a468e5

Browse files
Fix continue_final_message in apply_chat_template to prevent substring matching issues (#40732)
* Fix continue_final_message parameter in apply_chat_template * after run fixup * Handle trim in the template * after fixup * Update src/transformers/utils/chat_template_utils.py --------- Co-authored-by: Matt <[email protected]>
1 parent e8db153 commit 5a468e5

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

src/transformers/utils/chat_template_utils.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import re
1818
import types
1919
from contextlib import contextmanager
20+
from copy import deepcopy
2021
from datetime import datetime
2122
from functools import lru_cache
2223
from inspect import isfunction
@@ -503,10 +504,27 @@ def render_jinja_template(
503504

504505
rendered = []
505506
all_generation_indices = []
507+
continue_final_message_tag = "CONTINUE_FINAL_MESSAGE_TAG "
506508
for chat in conversations:
507509
if hasattr(chat, "messages"):
508510
# Indicates it's a Conversation object
509511
chat = chat.messages
512+
if continue_final_message:
513+
chat = deepcopy(chat)
514+
final_message = chat[-1]["content"]
515+
if isinstance(final_message, (list, tuple)):
516+
for content_block in reversed(final_message):
517+
if "text" in content_block:
518+
# Pick the last text block in the message (the first one we hit while iterating in reverse)
519+
final_message = content_block["text"]
520+
content_block["text"] = content_block["text"] + continue_final_message_tag
521+
break
522+
else:
523+
raise ValueError(
524+
"continue_final_message is set but we could not find any text to continue in the final message!"
525+
)
526+
else:
527+
chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag
510528
if return_assistant_tokens_mask:
511529
rendered_chat, generation_indices = _render_with_assistant_indices(
512530
compiled_template=compiled_template,
@@ -526,31 +544,22 @@ def render_jinja_template(
526544
**kwargs,
527545
)
528546
if continue_final_message:
529-
final_message = chat[-1]["content"]
530-
if isinstance(final_message, (list, tuple)):
531-
for content_block in reversed(final_message):
532-
if "text" in content_block:
533-
# Pick the last text block in the message (the first one we hit while iterating in reverse)
534-
final_message = content_block["text"]
535-
break
536-
else:
537-
raise ValueError(
538-
"continue_final_message is set but we could not find any text to continuein the final message!"
539-
)
540-
if final_message.strip() not in rendered_chat:
547+
if (final_message.strip() not in rendered_chat) or (
548+
continue_final_message_tag.strip() not in rendered_chat
549+
):
541550
raise ValueError(
542551
"continue_final_message is set but the final message does not appear in the chat after "
543552
"applying the chat template! This can happen if the chat template deletes portions of "
544553
"the final message. Please verify the chat template and final message in your chat to "
545554
"ensure they are compatible."
546555
)
547-
final_msg_loc = rendered_chat.rindex(final_message.strip())
548-
if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message:
549-
# The template preserves spacing or the message doesn't have trailing spacing, so things are simple
550-
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())]
556+
tag_loc = rendered_chat.rindex(continue_final_message_tag.strip())
557+
if rendered_chat[tag_loc : tag_loc + len(continue_final_message_tag)] == continue_final_message_tag:
558+
# The template preserves spacing, so things are simple
559+
rendered_chat = rendered_chat[:tag_loc]
551560
else:
552561
# The message has trailing spacing that was trimmed, so we must be more cautious
553-
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())]
562+
rendered_chat = rendered_chat[:tag_loc].rstrip()
554563
rendered.append(rendered_chat)
555564

556565
return rendered, all_generation_indices

0 commit comments

Comments
 (0)