@@ -391,6 +391,117 @@ TEST(ConversationTest, SendMultipleMessages) {
391391 assistant_message));
392392}
393393
394+ TEST (ConversationTest, SendMultipleMessagesWithHistory) {
395+ // Set up mock Session.
396+ auto mock_session = std::make_unique<MockSession>();
397+ MockSession* mock_session_ptr = mock_session.get ();
398+ SessionConfig session_config = SessionConfig::CreateDefault ();
399+ session_config.SetStartTokenId (0 );
400+ session_config.GetMutableStopTokenIds ().push_back ({1 });
401+ *session_config.GetMutableLlmModelType ().mutable_gemma3 () = {};
402+ session_config.GetMutableJinjaPromptTemplate () = kTestJinjaPromptTemplate ;
403+ EXPECT_CALL (*mock_session_ptr, GetSessionConfig ())
404+ .WillRepeatedly (testing::ReturnRef (session_config));
405+ auto mock_tokenizer = std::make_unique<MockTokenizer>();
406+ EXPECT_CALL (*mock_session_ptr, GetTokenizer ())
407+ .WillRepeatedly (testing::ReturnRef (*mock_tokenizer));
408+
409+ // Set up mock Engine.
410+ auto mock_engine = std::make_unique<MockEngine>();
411+ EXPECT_CALL (*mock_engine, CreateSession (testing::_))
412+ .WillOnce (testing::Return (std::move (mock_session)));
413+ ASSERT_OK_AND_ASSIGN (auto model_assets,
414+ ModelAssets::Create (GetTestdataPath (kTestLlmPath )));
415+ ASSERT_OK_AND_ASSIGN (auto engine_settings, EngineSettings::CreateDefault (
416+ model_assets, Backend::CPU));
417+ EXPECT_CALL (*mock_engine, GetEngineSettings ())
418+ .WillRepeatedly (testing::ReturnRef (engine_settings));
419+
420+ // Create Conversation.
421+ ASSERT_OK_AND_ASSIGN (auto conversation_config,
422+ ConversationConfig::CreateFromSessionConfig (
423+ *mock_engine, session_config));
424+ ASSERT_OK_AND_ASSIGN (auto conversation,
425+ Conversation::Create (*mock_engine, conversation_config));
426+
427+ // The first user message.
428+ JsonMessage user_message_1 = nlohmann::ordered_json::parse (R"json(
429+ {
430+ "role": "user",
431+ "content": "How are you?"
432+ }
433+ )json" );
434+ EXPECT_CALL (*mock_session_ptr, RunPrefill (testing::_))
435+ .WillOnce (testing::Return (absl::OkStatus ()));
436+
437+ // The first assistant response.
438+ Responses responses_1 (1 );
439+ responses_1.GetMutableResponseTexts ()[0 ] = " I am good." ;
440+ EXPECT_CALL (*mock_session_ptr, RunDecode (testing::_))
441+ .WillOnce (testing::Return (responses_1));
442+
443+ // Send the first user message to fill the history.
444+ ASSERT_OK (conversation->SendMessage (user_message_1));
445+ ASSERT_THAT (conversation->GetHistory ().size (), testing::Eq (2 ));
446+
447+ // We will send two consecutive messages when the history is not empty.
448+ JsonMessage user_messages = nlohmann::ordered_json::parse (R"json(
449+ [
450+ {
451+ "role": "user",
452+ "content": "foo"
453+ },
454+ {
455+ "role": "user",
456+ "content": "bar"
457+ }
458+ ]
459+ )json" );
460+ absl::string_view expected_input_text =
461+ " <start_of_turn>user\n "
462+ " foo<end_of_turn>\n "
463+ " <start_of_turn>user\n "
464+ " bar<end_of_turn>\n " ;
465+ EXPECT_CALL (*mock_session_ptr,
466+ RunPrefill (testing::ElementsAre (
467+ testing::VariantWith<InputText>(testing::Property (
468+ &InputText::GetRawTextString, expected_input_text)))))
469+ .WillOnce (testing::Return (absl::OkStatus ()));
470+
471+ // The second assistant response.
472+ Responses responses_2 (1 );
473+ responses_2.GetMutableResponseTexts ()[0 ] = " baz" ;
474+ EXPECT_CALL (*mock_session_ptr, RunDecode (testing::_))
475+ .WillOnce (testing::Return (responses_2));
476+
477+ // Send the user messages.
478+ ASSERT_OK (conversation->SendMessage (user_messages));
479+
480+ // Check the history.
481+ JsonMessage assistant_message_1 = nlohmann::ordered_json::parse (R"( {
482+ "role": "assistant",
483+ "content": [
484+ {
485+ "type": "text",
486+ "text": "I am good."
487+ }
488+ ]
489+ })" );
490+ JsonMessage assistant_message_2 = nlohmann::ordered_json::parse (R"( {
491+ "role": "assistant",
492+ "content": [
493+ {
494+ "type": "text",
495+ "text": "baz"
496+ }
497+ ]
498+ })" );
499+ EXPECT_THAT (conversation->GetHistory (),
500+ testing::ElementsAre (user_message_1, assistant_message_1,
501+ user_messages[0 ], user_messages[1 ],
502+ assistant_message_2));
503+ }
504+
394505TEST (ConversationTest, SendMessageAsync) {
395506 ASSERT_OK_AND_ASSIGN (auto model_assets,
396507 ModelAssets::Create (GetTestdataPath (kTestLlmPath )));
@@ -587,6 +698,132 @@ TEST(ConversationTest, SendMultipleMessagesAsync) {
587698 assistant_message));
588699}
589700
701+ TEST (ConversationTest, SendMultipleMessagesAsyncWithHistory) {
702+ // Set up mock Session.
703+ auto mock_session = std::make_unique<MockSession>();
704+ MockSession* mock_session_ptr = mock_session.get ();
705+ SessionConfig session_config = SessionConfig::CreateDefault ();
706+ session_config.SetStartTokenId (0 );
707+ session_config.GetMutableStopTokenIds ().push_back ({1 });
708+ *session_config.GetMutableLlmModelType ().mutable_gemma3 () = {};
709+ session_config.GetMutableJinjaPromptTemplate () = kTestJinjaPromptTemplate ;
710+ EXPECT_CALL (*mock_session_ptr, GetSessionConfig ())
711+ .WillRepeatedly (testing::ReturnRef (session_config));
712+ auto mock_tokenizer = std::make_unique<MockTokenizer>();
713+ EXPECT_CALL (*mock_session_ptr, GetTokenizer ())
714+ .WillRepeatedly (testing::ReturnRef (*mock_tokenizer));
715+
716+ // Set up mock Engine.
717+ auto mock_engine = std::make_unique<MockEngine>();
718+ EXPECT_CALL (*mock_engine, CreateSession (testing::_))
719+ .WillOnce (testing::Return (std::move (mock_session)));
720+ ASSERT_OK_AND_ASSIGN (auto model_assets,
721+ ModelAssets::Create (GetTestdataPath (kTestLlmPath )));
722+ ASSERT_OK_AND_ASSIGN (auto engine_settings, EngineSettings::CreateDefault (
723+ model_assets, Backend::CPU));
724+ EXPECT_CALL (*mock_engine, GetEngineSettings ())
725+ .WillRepeatedly (testing::ReturnRef (engine_settings));
726+
727+ // Create Conversation.
728+ ASSERT_OK_AND_ASSIGN (auto conversation_config,
729+ ConversationConfig::CreateFromSessionConfig (
730+ *mock_engine, session_config));
731+ ASSERT_OK_AND_ASSIGN (auto conversation,
732+ Conversation::Create (*mock_engine, conversation_config));
733+
734+ // The first user message.
735+ JsonMessage user_message_1 = nlohmann::ordered_json::parse (R"json(
736+ {
737+ "role": "user",
738+ "content": "How are you?"
739+ }
740+ )json" );
741+ EXPECT_CALL (*mock_session_ptr,
742+ GenerateContentStream (testing::_, testing::_, testing::_))
743+ .WillOnce ([](const std::vector<InputData>& contents,
744+ std::unique_ptr<InferenceCallbacks> callbacks,
745+ const DecodeConfig& decode_config) {
746+ Responses responses (1 );
747+ responses.GetMutableResponseTexts ()[0 ] = " I am good." ;
748+ callbacks->OnNext (responses);
749+ callbacks->OnDone ();
750+ return absl::OkStatus ();
751+ });
752+
753+ JsonMessage assistant_message_1 = nlohmann::ordered_json::parse (R"json( {
754+ "role": "assistant",
755+ "content": [
756+ {
757+ "type": "text",
758+ "text": "I am good."
759+ }
760+ ]
761+ })json" );
762+ absl::Notification done_1;
763+ auto message_callbacks_1 =
764+ std::make_unique<TestMessageCallbacks>(assistant_message_1, done_1);
765+ EXPECT_OK (conversation->SendMessageAsync (user_message_1,
766+ std::move (message_callbacks_1)));
767+ done_1.WaitForNotification ();
768+ ASSERT_THAT (conversation->GetHistory ().size (), testing::Eq (2 ));
769+
770+ // We will send two consecutive messages when the history is not empty.
771+ JsonMessage user_messages = nlohmann::ordered_json::parse (R"json(
772+ [
773+ {
774+ "role": "user",
775+ "content": "foo"
776+ },
777+ {
778+ "role": "user",
779+ "content": "bar"
780+ }
781+ ]
782+ )json" );
783+
784+ absl::string_view expected_input_text =
785+ " <start_of_turn>user\n "
786+ " foo<end_of_turn>\n "
787+ " <start_of_turn>user\n "
788+ " bar<end_of_turn>\n " ;
789+ EXPECT_CALL (*mock_session_ptr,
790+ GenerateContentStream (
791+ testing::ElementsAre (
792+ testing::VariantWith<InputText>(testing::Property (
793+ &InputText::GetRawTextString, expected_input_text))),
794+ testing::_, testing::_))
795+ .WillOnce ([](const std::vector<InputData>& contents,
796+ std::unique_ptr<InferenceCallbacks> callbacks,
797+ const DecodeConfig& decode_config) {
798+ Responses responses (1 );
799+ responses.GetMutableResponseTexts ()[0 ] = " baz" ;
800+ callbacks->OnNext (responses);
801+ callbacks->OnDone ();
802+ return absl::OkStatus ();
803+ });
804+
805+ JsonMessage assistant_message_2 = nlohmann::ordered_json::parse (R"json( {
806+ "role": "assistant",
807+ "content": [
808+ {
809+ "type": "text",
810+ "text": "baz"
811+ }
812+ ]
813+ })json" );
814+ absl::Notification done_2;
815+ auto message_callbacks_2 =
816+ std::make_unique<TestMessageCallbacks>(assistant_message_2, done_2);
817+ EXPECT_OK (conversation->SendMessageAsync (user_messages,
818+ std::move (message_callbacks_2)));
819+ done_2.WaitForNotification ();
820+
821+ EXPECT_THAT (conversation->GetHistory (),
822+ testing::ElementsAre (user_message_1, assistant_message_1,
823+ user_messages[0 ], user_messages[1 ],
824+ assistant_message_2));
825+ }
826+
590827TEST (ConversationTest, SendMessageWithPreface) {
591828 ASSERT_OK_AND_ASSIGN (auto model_assets,
592829 ModelAssets::Create (GetTestdataPath (kTestLlmPath )));
0 commit comments