Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions runtime/conversation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,46 +31,11 @@ cc_library(
],
)

cc_library(
name = "internal_callbacks_adapter",
srcs = ["internal_callbacks_adapter.cc"],
hdrs = ["internal_callbacks_adapter.h"],
deps = [
":io_types",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"//runtime/conversation/model_data_processor",
"//runtime/conversation/model_data_processor:config_registry",
"//runtime/engine:io_types",
],
)

cc_test(
name = "internal_callbacks_adapter_test",
srcs = ["internal_callbacks_adapter_test.cc"],
deps = [
":internal_callbacks_adapter",
":io_types",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@nlohmann_json//:json",
"//runtime/conversation/model_data_processor:config_registry",
"//runtime/conversation/model_data_processor:gemma3_data_processor",
"//runtime/conversation/model_data_processor:gemma3_data_processor_config",
"//runtime/engine:io_types",
"//runtime/util:test_utils",
],
)

cc_library(
name = "conversation",
srcs = ["conversation.cc"],
hdrs = ["conversation.h"],
deps = [
":internal_callbacks_adapter",
":io_types",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:absl_log",
Expand Down
151 changes: 14 additions & 137 deletions runtime/conversation/conversation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "absl/synchronization/mutex.h" // from @com_google_absl
#include "nlohmann/json.hpp" // from @nlohmann_json
#include "runtime/components/prompt_template.h"
#include "runtime/conversation/internal_callbacks_adapter.h"
#include "runtime/conversation/io_types.h"
#include "runtime/conversation/model_data_processor/config_registry.h"
#include "runtime/conversation/model_data_processor/model_data_processor.h"
Expand All @@ -43,93 +42,14 @@

namespace litert::lm {

absl::StatusOr<std::string> Conversation::GetSingleTurnText(
const Message& message) const {
PromptTemplateInput old_tmpl_input;
if (std::holds_alternative<JsonPreface>(preface_)) {
auto json_preface = std::get<JsonPreface>(preface_);

for (auto& message : json_preface.messages) {
ASSIGN_OR_RETURN(nlohmann::ordered_json message_tmpl_input,
model_data_processor_->MessageToTemplateInput(message));
old_tmpl_input.messages.push_back(message_tmpl_input);
}

if (json_preface.tools.is_null()) {
old_tmpl_input.tools = nullptr;
} else {
ASSIGN_OR_RETURN(old_tmpl_input.tools,
model_data_processor_->FormatTools(json_preface.tools));
}
old_tmpl_input.extra_context = json_preface.extra_context;
} else {
return absl::UnimplementedError("Preface type is not supported yet");
}
absl::MutexLock lock(&history_mutex_); // NOLINT
for (const auto& history_msg : history_) {
if (std::holds_alternative<nlohmann::ordered_json>(history_msg)) {
ASSIGN_OR_RETURN(nlohmann::ordered_json message_tmpl_input,
model_data_processor_->MessageToTemplateInput(
std::get<nlohmann::ordered_json>(history_msg)));
old_tmpl_input.messages.push_back(message_tmpl_input);
} else {
return absl::UnimplementedError("Message type is not supported yet");
}
}

if (history_.empty()) {
PromptTemplateInput new_tmpl_input = std::move(old_tmpl_input);
if (std::holds_alternative<nlohmann::ordered_json>(message)) {
ASSIGN_OR_RETURN(nlohmann::ordered_json message_tmpl_input,
model_data_processor_->MessageToTemplateInput(
std::get<nlohmann::ordered_json>(message)));
new_tmpl_input.messages.push_back(message_tmpl_input);
} else {
return absl::UnimplementedError("Message type is not supported yet");
}
new_tmpl_input.add_generation_prompt = true;
return prompt_template_.Apply(new_tmpl_input);
}

old_tmpl_input.add_generation_prompt = false;
ASSIGN_OR_RETURN(const std::string old_string,
prompt_template_.Apply(old_tmpl_input));

if (std::holds_alternative<nlohmann::ordered_json>(message)) {
PromptTemplateInput new_tmpl_input = std::move(old_tmpl_input);
ASSIGN_OR_RETURN(nlohmann::ordered_json message_tmpl_input,
model_data_processor_->MessageToTemplateInput(
std::get<nlohmann::ordered_json>(message)));
new_tmpl_input.messages.push_back(message_tmpl_input);
new_tmpl_input.add_generation_prompt = true;
ASSIGN_OR_RETURN(const std::string& new_string,
prompt_template_.Apply(new_tmpl_input));
if (new_string.substr(0, old_string.size()) != old_string) {
return absl::InternalError(absl::StrCat(
"The new rendered template string does not start with the previous "
"rendered template string. \nold_string: ",
old_string, "\nnew_string: ", new_string));
}
return {new_string.substr(old_string.size(),
new_string.size() - old_string.size())};
} else {
return absl::InvalidArgumentError("Json message is required for now.");
}
}

absl::StatusOr<std::unique_ptr<Conversation>> Conversation::Create(
std::unique_ptr<Engine::Session> session, std::optional<Preface> preface,
std::optional<PromptTemplate> prompt_template,
std::optional<DataProcessorConfig> processor_config) {
if (!preface.has_value()) {
preface = JsonPreface();
}
const proto::LlmModelType& llm_model_type =
session->GetSessionConfig().GetLlmModelType();
processor_config = processor_config.value_or(std::monostate());
ASSIGN_OR_RETURN(
std::unique_ptr<ModelDataProcessor> model_data_processor,
CreateModelDataProcessor(llm_model_type, *processor_config, *preface));

if (!prompt_template.has_value()) {
// Get template from the session or model file when the template is not
// provided by the user.
Expand All @@ -138,72 +58,29 @@ absl::StatusOr<std::unique_ptr<Conversation>> Conversation::Create(
prompt_template =
PromptTemplate(session->GetSessionConfig().GetJinjaPromptTemplate());
}
auto conversation = absl::WrapUnique(
new Conversation(std::move(session), std::move(model_data_processor),
*preface, *prompt_template));

const proto::LlmModelType& llm_model_type =
session->GetSessionConfig().GetLlmModelType();
processor_config = processor_config.value_or(std::monostate());
ASSIGN_OR_RETURN(
std::unique_ptr<ModelDataProcessor> model_data_processor,
CreateModelDataProcessor(llm_model_type, std::move(session),
*prompt_template, *preface, *processor_config));
auto conversation =
absl::WrapUnique(new Conversation(std::move(model_data_processor)));
return conversation;
}

absl::StatusOr<Message> Conversation::SendMessage(
const Message& message, std::optional<DataProcessorArguments> args) {
if (!std::holds_alternative<nlohmann::ordered_json>(message)) {
return absl::InvalidArgumentError("Json message is required for now.");
}
auto json_message = std::get<nlohmann::ordered_json>(message);
ASSIGN_OR_RETURN(const std::string& single_turn_text,
GetSingleTurnText(message));
absl::MutexLock lock(&history_mutex_); // NOLINT
history_.push_back(json_message);
ASSIGN_OR_RETURN(
const auto session_inputs,
model_data_processor_->ToInputDataVector(
single_turn_text, nlohmann::ordered_json::array({json_message}),
args.value_or(std::monostate())));
ASSIGN_OR_RETURN(const Responses& responses,
session_->GenerateContent(session_inputs));
ASSIGN_OR_RETURN(const Message assistant_message,
model_data_processor_->ToMessage(
responses, args.value_or(std::monostate())));
history_.push_back(assistant_message);
return assistant_message;
return model_data_processor_->SendMessage(message, args);
}

absl::Status Conversation::SendMessageStream(
const Message& message, std::unique_ptr<MessageCallbacks> callbacks,
std::optional<DataProcessorArguments> args) {
if (!std::holds_alternative<nlohmann::ordered_json>(message)) {
return absl::InvalidArgumentError("Json message is required for now.");
}
auto json_message = std::get<nlohmann::ordered_json>(message);
ASSIGN_OR_RETURN(const std::string& single_turn_text,
GetSingleTurnText(message));
{
absl::MutexLock lock(&history_mutex_); // NOLINT
history_.push_back(message);
}

ASSIGN_OR_RETURN(
const auto session_inputs,
model_data_processor_->ToInputDataVector(
single_turn_text, nlohmann::ordered_json::array({json_message}),
args.value_or(std::monostate())));

auto internal_callbacks_adapter = InternalCallbacksAdapter::Create(
model_data_processor_.get(), std::move(callbacks),
args.value_or(std::monostate()));

InternalCallbacksAdapter::CompleteMessageCallback complete_message_callback =
[this](const Message& complete_message) {
absl::MutexLock lock(&this->history_mutex_); // NOLINT
this->history_.push_back(complete_message);
};
internal_callbacks_adapter->SetCompleteMessageCallback(
std::move(complete_message_callback));

RETURN_IF_ERROR(session_->RunPrefill(session_inputs));
RETURN_IF_ERROR(
session_->RunDecodeAsync(std::move(internal_callbacks_adapter)));
return absl::OkStatus();
return model_data_processor_->SendMessageStream(message, std::move(callbacks),
args);
};

} // namespace litert::lm
16 changes: 2 additions & 14 deletions runtime/conversation/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,10 @@ class Conversation {

private:
explicit Conversation(
std::unique_ptr<Engine::Session> session,
std::unique_ptr<ModelDataProcessor> model_data_processor, Preface preface,
PromptTemplate prompt_template)
: session_(std::move(session)),
model_data_processor_(std::move(model_data_processor)),
preface_(preface),
prompt_template_(std::move(prompt_template)) {}

absl::StatusOr<std::string> GetSingleTurnText(const Message& message) const;
std::unique_ptr<ModelDataProcessor> model_data_processor)
: model_data_processor_(std::move(model_data_processor)) {}

std::unique_ptr<Engine::Session> session_;
std::unique_ptr<ModelDataProcessor> model_data_processor_;
Preface preface_;
PromptTemplate prompt_template_;
mutable absl::Mutex history_mutex_;
std::vector<Message> history_ ABSL_GUARDED_BY(history_mutex_);
};
} // namespace litert::lm

Expand Down
94 changes: 0 additions & 94 deletions runtime/conversation/internal_callbacks_adapter.h

This file was deleted.

Loading