1717import re
1818import types
1919from contextlib import contextmanager
20+ from copy import deepcopy
2021from datetime import datetime
2122from functools import lru_cache
2223from inspect import isfunction
@@ -503,10 +504,27 @@ def render_jinja_template(
503504
504505 rendered = []
505506 all_generation_indices = []
507+ continue_final_message_tag = "CONTINUE_FINAL_MESSAGE_TAG "
506508 for chat in conversations :
507509 if hasattr (chat , "messages" ):
508510 # Indicates it's a Conversation object
509511 chat = chat .messages
512+ if continue_final_message :
513+ chat = deepcopy (chat )
514+ final_message = chat [- 1 ]["content" ]
515+ if isinstance (final_message , (list , tuple )):
516+ for content_block in reversed (final_message ):
517+ if "text" in content_block :
518+ # Pick the last text block in the message (the first one we hit while iterating in reverse)
519+ final_message = content_block ["text" ]
520+ content_block ["text" ] = content_block ["text" ] + continue_final_message_tag
521+ break
522+ else :
523+ raise ValueError (
524+ "continue_final_message is set but we could not find any text to continue in the final message!"
525+ )
526+ else :
527+ chat [- 1 ]["content" ] = chat [- 1 ]["content" ] + continue_final_message_tag
510528 if return_assistant_tokens_mask :
511529 rendered_chat , generation_indices = _render_with_assistant_indices (
512530 compiled_template = compiled_template ,
@@ -526,31 +544,22 @@ def render_jinja_template(
526544 ** kwargs ,
527545 )
528546 if continue_final_message :
529- final_message = chat [- 1 ]["content" ]
530- if isinstance (final_message , (list , tuple )):
531- for content_block in reversed (final_message ):
532- if "text" in content_block :
533- # Pick the last text block in the message (the first one we hit while iterating in reverse)
534- final_message = content_block ["text" ]
535- break
536- else :
537- raise ValueError (
538- "continue_final_message is set but we could not find any text to continuein the final message!"
539- )
540- if final_message .strip () not in rendered_chat :
547+ if (final_message .strip () not in rendered_chat ) or (
548+ continue_final_message_tag .strip () not in rendered_chat
549+ ):
541550 raise ValueError (
542551 "continue_final_message is set but the final message does not appear in the chat after "
543552 "applying the chat template! This can happen if the chat template deletes portions of "
544553 "the final message. Please verify the chat template and final message in your chat to "
545554 "ensure they are compatible."
546555 )
547- final_msg_loc = rendered_chat .rindex (final_message .strip ())
548- if rendered_chat [final_msg_loc : final_msg_loc + len (final_message . lstrip ()) ] == final_message :
549- # The template preserves spacing or the message doesn't have trailing spacing , so things are simple
550- rendered_chat = rendered_chat [: final_msg_loc + len ( final_message . lstrip ()) ]
556+ tag_loc = rendered_chat .rindex (continue_final_message_tag .strip ())
557+ if rendered_chat [tag_loc : tag_loc + len (continue_final_message_tag ) ] == continue_final_message_tag :
558+ # The template preserves spacing, so things are simple
559+ rendered_chat = rendered_chat [:tag_loc ]
551560 else :
552561 # The message has trailing spacing that was trimmed, so we must be more cautious
553- rendered_chat = rendered_chat [: final_msg_loc + len ( final_message . strip ())]
562+ rendered_chat = rendered_chat [:tag_loc ]. rstrip ()
554563 rendered .append (rendered_chat )
555564
556565 return rendered , all_generation_indices
0 commit comments