Skip to content

Commit af30443

Browse files
matthewchan-gcopybara-github
authored andcommitted
Fix Conversation::GetSingleTurnText when message is an array and history is non-empty.
LiteRT-LM-PiperOrigin-RevId: 821983876
1 parent 24974da commit af30443

File tree

2 files changed

+251
-17
lines changed

2 files changed

+251
-17
lines changed

runtime/conversation/conversation.cc

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -156,26 +156,23 @@ absl::StatusOr<std::string> Conversation::GetSingleTurnText(
156156
ASSIGN_OR_RETURN(const std::string old_string,
157157
prompt_template_.Apply(old_tmpl_input));
158158

159-
if (std::holds_alternative<nlohmann::ordered_json>(message)) {
160-
PromptTemplateInput new_tmpl_input = std::move(old_tmpl_input);
159+
PromptTemplateInput new_tmpl_input = std::move(old_tmpl_input);
160+
for (const auto& message : messages) {
161161
ASSIGN_OR_RETURN(nlohmann::ordered_json message_tmpl_input,
162-
model_data_processor_->MessageToTemplateInput(
163-
std::get<nlohmann::ordered_json>(message)));
162+
model_data_processor_->MessageToTemplateInput(message));
164163
new_tmpl_input.messages.push_back(message_tmpl_input);
165-
new_tmpl_input.add_generation_prompt = true;
166-
ASSIGN_OR_RETURN(const std::string& new_string,
167-
prompt_template_.Apply(new_tmpl_input));
168-
if (new_string.substr(0, old_string.size()) != old_string) {
169-
return absl::InternalError(absl::StrCat(
170-
"The new rendered template string does not start with the previous "
171-
"rendered template string. \nold_string: ",
172-
old_string, "\nnew_string: ", new_string));
173-
}
174-
return {new_string.substr(old_string.size(),
175-
new_string.size() - old_string.size())};
176-
} else {
177-
return absl::InvalidArgumentError("Json message is required for now.");
178164
}
165+
new_tmpl_input.add_generation_prompt = true;
166+
ASSIGN_OR_RETURN(const std::string& new_string,
167+
prompt_template_.Apply(new_tmpl_input));
168+
if (new_string.substr(0, old_string.size()) != old_string) {
169+
return absl::InternalError(absl::StrCat(
170+
"The new rendered template string does not start with the previous "
171+
"rendered template string. \nold_string: ",
172+
old_string, "\nnew_string: ", new_string));
173+
}
174+
return {new_string.substr(old_string.size(),
175+
new_string.size() - old_string.size())};
179176
}
180177

181178
absl::StatusOr<DecodeConfig> Conversation::CreateDecodeConfig() {

runtime/conversation/conversation_test.cc

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,117 @@ TEST(ConversationTest, SendMultipleMessages) {
391391
assistant_message));
392392
}
393393

394+
TEST(ConversationTest, SendMultipleMessagesWithHistory) {
395+
// Set up mock Session.
396+
auto mock_session = std::make_unique<MockSession>();
397+
MockSession* mock_session_ptr = mock_session.get();
398+
SessionConfig session_config = SessionConfig::CreateDefault();
399+
session_config.SetStartTokenId(0);
400+
session_config.GetMutableStopTokenIds().push_back({1});
401+
*session_config.GetMutableLlmModelType().mutable_gemma3() = {};
402+
session_config.GetMutableJinjaPromptTemplate() = kTestJinjaPromptTemplate;
403+
EXPECT_CALL(*mock_session_ptr, GetSessionConfig())
404+
.WillRepeatedly(testing::ReturnRef(session_config));
405+
auto mock_tokenizer = std::make_unique<MockTokenizer>();
406+
EXPECT_CALL(*mock_session_ptr, GetTokenizer())
407+
.WillRepeatedly(testing::ReturnRef(*mock_tokenizer));
408+
409+
// Set up mock Engine.
410+
auto mock_engine = std::make_unique<MockEngine>();
411+
EXPECT_CALL(*mock_engine, CreateSession(testing::_))
412+
.WillOnce(testing::Return(std::move(mock_session)));
413+
ASSERT_OK_AND_ASSIGN(auto model_assets,
414+
ModelAssets::Create(GetTestdataPath(kTestLlmPath)));
415+
ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault(
416+
model_assets, Backend::CPU));
417+
EXPECT_CALL(*mock_engine, GetEngineSettings())
418+
.WillRepeatedly(testing::ReturnRef(engine_settings));
419+
420+
// Create Conversation.
421+
ASSERT_OK_AND_ASSIGN(auto conversation_config,
422+
ConversationConfig::CreateFromSessionConfig(
423+
*mock_engine, session_config));
424+
ASSERT_OK_AND_ASSIGN(auto conversation,
425+
Conversation::Create(*mock_engine, conversation_config));
426+
427+
// The first user message.
428+
JsonMessage user_message_1 = nlohmann::ordered_json::parse(R"json(
429+
{
430+
"role": "user",
431+
"content": "How are you?"
432+
}
433+
)json");
434+
EXPECT_CALL(*mock_session_ptr, RunPrefill(testing::_))
435+
.WillOnce(testing::Return(absl::OkStatus()));
436+
437+
// The first assistant response.
438+
Responses responses_1(1);
439+
responses_1.GetMutableResponseTexts()[0] = "I am good.";
440+
EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_))
441+
.WillOnce(testing::Return(responses_1));
442+
443+
// Send the first user message to fill the history.
444+
ASSERT_OK(conversation->SendMessage(user_message_1));
445+
ASSERT_THAT(conversation->GetHistory().size(), testing::Eq(2));
446+
447+
// We will send two consecutive messages when the history is not empty.
448+
JsonMessage user_messages = nlohmann::ordered_json::parse(R"json(
449+
[
450+
{
451+
"role": "user",
452+
"content": "foo"
453+
},
454+
{
455+
"role": "user",
456+
"content": "bar"
457+
}
458+
]
459+
)json");
460+
absl::string_view expected_input_text =
461+
"<start_of_turn>user\n"
462+
"foo<end_of_turn>\n"
463+
"<start_of_turn>user\n"
464+
"bar<end_of_turn>\n";
465+
EXPECT_CALL(*mock_session_ptr,
466+
RunPrefill(testing::ElementsAre(
467+
testing::VariantWith<InputText>(testing::Property(
468+
&InputText::GetRawTextString, expected_input_text)))))
469+
.WillOnce(testing::Return(absl::OkStatus()));
470+
471+
// The second assistant response.
472+
Responses responses_2(1);
473+
responses_2.GetMutableResponseTexts()[0] = "baz";
474+
EXPECT_CALL(*mock_session_ptr, RunDecode(testing::_))
475+
.WillOnce(testing::Return(responses_2));
476+
477+
// Send the user messages.
478+
ASSERT_OK(conversation->SendMessage(user_messages));
479+
480+
// Check the history.
481+
JsonMessage assistant_message_1 = nlohmann::ordered_json::parse(R"({
482+
"role": "assistant",
483+
"content": [
484+
{
485+
"type": "text",
486+
"text": "I am good."
487+
}
488+
]
489+
})");
490+
JsonMessage assistant_message_2 = nlohmann::ordered_json::parse(R"({
491+
"role": "assistant",
492+
"content": [
493+
{
494+
"type": "text",
495+
"text": "baz"
496+
}
497+
]
498+
})");
499+
EXPECT_THAT(conversation->GetHistory(),
500+
testing::ElementsAre(user_message_1, assistant_message_1,
501+
user_messages[0], user_messages[1],
502+
assistant_message_2));
503+
}
504+
394505
TEST(ConversationTest, SendMessageAsync) {
395506
ASSERT_OK_AND_ASSIGN(auto model_assets,
396507
ModelAssets::Create(GetTestdataPath(kTestLlmPath)));
@@ -587,6 +698,132 @@ TEST(ConversationTest, SendMultipleMessagesAsync) {
587698
assistant_message));
588699
}
589700

701+
TEST(ConversationTest, SendMultipleMessagesAsyncWithHistory) {
702+
// Set up mock Session.
703+
auto mock_session = std::make_unique<MockSession>();
704+
MockSession* mock_session_ptr = mock_session.get();
705+
SessionConfig session_config = SessionConfig::CreateDefault();
706+
session_config.SetStartTokenId(0);
707+
session_config.GetMutableStopTokenIds().push_back({1});
708+
*session_config.GetMutableLlmModelType().mutable_gemma3() = {};
709+
session_config.GetMutableJinjaPromptTemplate() = kTestJinjaPromptTemplate;
710+
EXPECT_CALL(*mock_session_ptr, GetSessionConfig())
711+
.WillRepeatedly(testing::ReturnRef(session_config));
712+
auto mock_tokenizer = std::make_unique<MockTokenizer>();
713+
EXPECT_CALL(*mock_session_ptr, GetTokenizer())
714+
.WillRepeatedly(testing::ReturnRef(*mock_tokenizer));
715+
716+
// Set up mock Engine.
717+
auto mock_engine = std::make_unique<MockEngine>();
718+
EXPECT_CALL(*mock_engine, CreateSession(testing::_))
719+
.WillOnce(testing::Return(std::move(mock_session)));
720+
ASSERT_OK_AND_ASSIGN(auto model_assets,
721+
ModelAssets::Create(GetTestdataPath(kTestLlmPath)));
722+
ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault(
723+
model_assets, Backend::CPU));
724+
EXPECT_CALL(*mock_engine, GetEngineSettings())
725+
.WillRepeatedly(testing::ReturnRef(engine_settings));
726+
727+
// Create Conversation.
728+
ASSERT_OK_AND_ASSIGN(auto conversation_config,
729+
ConversationConfig::CreateFromSessionConfig(
730+
*mock_engine, session_config));
731+
ASSERT_OK_AND_ASSIGN(auto conversation,
732+
Conversation::Create(*mock_engine, conversation_config));
733+
734+
// The first user message.
735+
JsonMessage user_message_1 = nlohmann::ordered_json::parse(R"json(
736+
{
737+
"role": "user",
738+
"content": "How are you?"
739+
}
740+
)json");
741+
EXPECT_CALL(*mock_session_ptr,
742+
GenerateContentStream(testing::_, testing::_, testing::_))
743+
.WillOnce([](const std::vector<InputData>& contents,
744+
std::unique_ptr<InferenceCallbacks> callbacks,
745+
const DecodeConfig& decode_config) {
746+
Responses responses(1);
747+
responses.GetMutableResponseTexts()[0] = "I am good.";
748+
callbacks->OnNext(responses);
749+
callbacks->OnDone();
750+
return absl::OkStatus();
751+
});
752+
753+
JsonMessage assistant_message_1 = nlohmann::ordered_json::parse(R"json({
754+
"role": "assistant",
755+
"content": [
756+
{
757+
"type": "text",
758+
"text": "I am good."
759+
}
760+
]
761+
})json");
762+
absl::Notification done_1;
763+
auto message_callbacks_1 =
764+
std::make_unique<TestMessageCallbacks>(assistant_message_1, done_1);
765+
EXPECT_OK(conversation->SendMessageAsync(user_message_1,
766+
std::move(message_callbacks_1)));
767+
done_1.WaitForNotification();
768+
ASSERT_THAT(conversation->GetHistory().size(), testing::Eq(2));
769+
770+
// We will send two consecutive messages when the history is not empty.
771+
JsonMessage user_messages = nlohmann::ordered_json::parse(R"json(
772+
[
773+
{
774+
"role": "user",
775+
"content": "foo"
776+
},
777+
{
778+
"role": "user",
779+
"content": "bar"
780+
}
781+
]
782+
)json");
783+
784+
absl::string_view expected_input_text =
785+
"<start_of_turn>user\n"
786+
"foo<end_of_turn>\n"
787+
"<start_of_turn>user\n"
788+
"bar<end_of_turn>\n";
789+
EXPECT_CALL(*mock_session_ptr,
790+
GenerateContentStream(
791+
testing::ElementsAre(
792+
testing::VariantWith<InputText>(testing::Property(
793+
&InputText::GetRawTextString, expected_input_text))),
794+
testing::_, testing::_))
795+
.WillOnce([](const std::vector<InputData>& contents,
796+
std::unique_ptr<InferenceCallbacks> callbacks,
797+
const DecodeConfig& decode_config) {
798+
Responses responses(1);
799+
responses.GetMutableResponseTexts()[0] = "baz";
800+
callbacks->OnNext(responses);
801+
callbacks->OnDone();
802+
return absl::OkStatus();
803+
});
804+
805+
JsonMessage assistant_message_2 = nlohmann::ordered_json::parse(R"json({
806+
"role": "assistant",
807+
"content": [
808+
{
809+
"type": "text",
810+
"text": "baz"
811+
}
812+
]
813+
})json");
814+
absl::Notification done_2;
815+
auto message_callbacks_2 =
816+
std::make_unique<TestMessageCallbacks>(assistant_message_2, done_2);
817+
EXPECT_OK(conversation->SendMessageAsync(user_messages,
818+
std::move(message_callbacks_2)));
819+
done_2.WaitForNotification();
820+
821+
EXPECT_THAT(conversation->GetHistory(),
822+
testing::ElementsAre(user_message_1, assistant_message_1,
823+
user_messages[0], user_messages[1],
824+
assistant_message_2));
825+
}
826+
590827
TEST(ConversationTest, SendMessageWithPreface) {
591828
ASSERT_OK_AND_ASSIGN(auto model_assets,
592829
ModelAssets::Create(GetTestdataPath(kTestLlmPath)));

0 commit comments

Comments
 (0)