Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions runtime/conversation/model_data_processor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ cc_library(
srcs = ["data_utils.cc"],
hdrs = ["data_utils.h"],
deps = [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@nlohmann_json//:json",
"//runtime/util:memory_mapped_file",
],
Expand Down Expand Up @@ -99,14 +99,11 @@ cc_library(
":generic_data_processor",
":generic_data_processor_config",
":model_data_processor",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@nlohmann_json//:json",
"//runtime/conversation:io_types",
"//runtime/engine:io_types",
"//runtime/proto:llm_model_type_cc_proto",
"//runtime/util:litert_status_util",
],
)

Expand Down
8 changes: 7 additions & 1 deletion runtime/conversation/model_data_processor/data_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/escaping.h" // from @com_google_absl
#include "nlohmann/json.hpp" // from @nlohmann_json
#include "runtime/util/memory_mapped_file.h"

Expand All @@ -38,7 +39,12 @@ absl::StatusOr<std::unique_ptr<MemoryMappedFile>> LoadItemData(
return MemoryMappedFile::Create(item["path"].get<std::string>());
}
if (item.contains("blob")) {
return InMemoryFile::Create(item["blob"]);
std::string blob_b64 = item["blob"].get<std::string>();
std::string blob;
if (!absl::Base64Unescape(blob_b64, &blob)) {
return absl::InvalidArgumentError("Failed to decode base64 blob.");
}
return InMemoryFile::Create(blob);
}
return absl::InvalidArgumentError(
"Audio or image item must contain a path or blob.");
Expand Down
4 changes: 2 additions & 2 deletions runtime/conversation/model_data_processor/data_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace litert::lm {
// }
// {
// "type": "image",
// "blob": "raw image bytes as string",
// "blob": "base64 encoded image bytes as string",
// }
//
// 3. Audio item
Expand All @@ -48,7 +48,7 @@ namespace litert::lm {
// }
// {
// "type": "audio",
// "blob": "raw audio bytes as string",
// "blob": "base64 encoded audio bytes as string",
// }
//
// Note: though we support loading image and audio data from blob, this format
Expand Down
15 changes: 13 additions & 2 deletions runtime/conversation/model_data_processor/data_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ TEST(DataUtilsTest, LoadItemData_ImageItemWithBlob) {
ASSERT_OK_AND_ASSIGN(std::unique_ptr<MemoryMappedFile> memory_mapped_file,
LoadItemData({
{"type", "image"},
{"blob", "image_contents"},
{"blob", "aW1hZ2VfY29udGVudHM="},
}));
EXPECT_EQ(std::string(static_cast<const char*>(memory_mapped_file->data()),
memory_mapped_file->length()),
Expand All @@ -90,7 +90,7 @@ TEST(DataUtilsTest, LoadItemData_AudioItemWithBlob) {
ASSERT_OK_AND_ASSIGN(std::unique_ptr<MemoryMappedFile> memory_mapped_file,
LoadItemData({
{"type", "audio"},
{"blob", "audio_contents"},
{"blob", "YXVkaW9fY29udGVudHM="},
}));
EXPECT_EQ(std::string(static_cast<const char*>(memory_mapped_file->data()),
memory_mapped_file->length()),
Expand Down Expand Up @@ -120,5 +120,16 @@ TEST(DataUtilsTest, LoadItemData_InvalidItem) {
testing::HasSubstr("Audio or image item must contain a path or blob."));
}

TEST(DataUtilsTest, LoadItemData_ImageItemWithInvalidBase64Blob) {
auto result = LoadItemData({
{"type", "image"},
{"blob", "invalid base64"},
});
EXPECT_THAT(result.status(),
testing::status::StatusIs(absl::StatusCode::kInvalidArgument));
EXPECT_THAT(result.status().message(),
testing::HasSubstr("Failed to decode base64 blob."));
}

} // namespace
} // namespace litert::lm
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <optional>
#include <variant>

#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "runtime/conversation/io_types.h"
Expand All @@ -37,12 +38,16 @@ absl::StatusOr<std::unique_ptr<ModelDataProcessor>> CreateModelDataProcessor(
switch (model_type.model_type_case()) {
case proto::LlmModelType::kGemma3N:
case proto::LlmModelType::kGemma3:
ABSL_LOG(INFO) << "Creating Gemma3DataProcessor for model type: "
<< model_type.model_type_case();
return Gemma3DataProcessor::Create(
std::holds_alternative<Gemma3DataProcessorConfig>(config)
? std::get<Gemma3DataProcessorConfig>(config)
: Gemma3DataProcessorConfig(),
preface);
case proto::LlmModelType::kGenericModel: {
ABSL_LOG(INFO) << "Creating GenericDataProcessor for model type: "
<< model_type.model_type_case();
if (std::holds_alternative<GenericDataProcessorConfig>(config)) {
return GenericDataProcessor::Create(
std::holds_alternative<GenericDataProcessorConfig>(config)
Expand Down