diff --git a/runtime/conversation/model_data_processor/BUILD b/runtime/conversation/model_data_processor/BUILD index ebf93ab5..1290f6d8 100644 --- a/runtime/conversation/model_data_processor/BUILD +++ b/runtime/conversation/model_data_processor/BUILD @@ -38,9 +38,9 @@ cc_library( srcs = ["data_utils.cc"], hdrs = ["data_utils.h"], deps = [ - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@nlohmann_json//:json", "//runtime/util:memory_mapped_file", ], @@ -99,14 +99,11 @@ cc_library( ":generic_data_processor", ":generic_data_processor_config", ":model_data_processor", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@nlohmann_json//:json", "//runtime/conversation:io_types", - "//runtime/engine:io_types", "//runtime/proto:llm_model_type_cc_proto", - "//runtime/util:litert_status_util", ], ) diff --git a/runtime/conversation/model_data_processor/data_utils.cc b/runtime/conversation/model_data_processor/data_utils.cc index 6d046332..acd91183 100644 --- a/runtime/conversation/model_data_processor/data_utils.cc +++ b/runtime/conversation/model_data_processor/data_utils.cc @@ -19,6 +19,7 @@ #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl +#include "absl/strings/escaping.h" // from @com_google_absl #include "nlohmann/json.hpp" // from @nlohmann_json #include "runtime/util/memory_mapped_file.h" @@ -38,7 +39,12 @@ absl::StatusOr> LoadItemData( return MemoryMappedFile::Create(item["path"].get()); } if (item.contains("blob")) { - return InMemoryFile::Create(item["blob"]); + std::string blob_b64 = item["blob"].get(); + std::string blob; + if (!absl::Base64Unescape(blob_b64, &blob)) { + return absl::InvalidArgumentError("Failed to decode base64 blob."); + } + return InMemoryFile::Create(blob); } return absl::InvalidArgumentError( "Audio or image item must contain a path or blob."); diff --git a/runtime/conversation/model_data_processor/data_utils.h b/runtime/conversation/model_data_processor/data_utils.h index 202f286b..b460c445 100644 --- a/runtime/conversation/model_data_processor/data_utils.h +++ b/runtime/conversation/model_data_processor/data_utils.h @@ -38,7 +38,7 @@ namespace litert::lm { // } // { // "type": "image", -// "blob": "raw image bytes as string", +// "blob": "base64 encoded image bytes as string", // } // // 3. Audio item @@ -48,7 +48,7 @@ namespace litert::lm { // } // { // "type": "audio", -// "blob": "raw audio bytes as string", +// "blob": "base64 encoded audio bytes as string", // } // // Note: though we support loading image and audio data from blob, this format diff --git a/runtime/conversation/model_data_processor/data_utils_test.cc b/runtime/conversation/model_data_processor/data_utils_test.cc index 4e049b54..6c5ed1c2 100644 --- a/runtime/conversation/model_data_processor/data_utils_test.cc +++ b/runtime/conversation/model_data_processor/data_utils_test.cc @@ -66,7 +66,7 @@ TEST(DataUtilsTest, LoadItemData_ImageItemWithBlob) { ASSERT_OK_AND_ASSIGN(std::unique_ptr memory_mapped_file, LoadItemData({ {"type", "image"}, - {"blob", "image_contents"}, + {"blob", "aW1hZ2VfY29udGVudHM="}, })); EXPECT_EQ(std::string(static_cast(memory_mapped_file->data()), memory_mapped_file->length()), @@ -90,7 +90,7 @@ TEST(DataUtilsTest, LoadItemData_AudioItemWithBlob) { ASSERT_OK_AND_ASSIGN(std::unique_ptr memory_mapped_file, LoadItemData({ {"type", "audio"}, - {"blob", "audio_contents"}, + {"blob", "YXVkaW9fY29udGVudHM="}, })); EXPECT_EQ(std::string(static_cast(memory_mapped_file->data()), memory_mapped_file->length()), @@ -120,5 +120,16 @@ TEST(DataUtilsTest, LoadItemData_InvalidItem) { testing::HasSubstr("Audio or image item must contain a path or blob.")); } +TEST(DataUtilsTest, LoadItemData_ImageItemWithInvalidBase64Blob) { + auto result = LoadItemData({ + {"type", "image"}, + {"blob", "invalid base64"}, + }); + EXPECT_THAT(result.status(), + testing::status::StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(result.status().message(), + testing::HasSubstr("Failed to decode base64 blob.")); +} + } // namespace } // namespace litert::lm diff --git a/runtime/conversation/model_data_processor/model_data_processor_factory.cc b/runtime/conversation/model_data_processor/model_data_processor_factory.cc index 1590f4a2..e32c9b26 100644 --- a/runtime/conversation/model_data_processor/model_data_processor_factory.cc +++ b/runtime/conversation/model_data_processor/model_data_processor_factory.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/log/absl_log.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "runtime/conversation/io_types.h" @@ -37,12 +38,16 @@ absl::StatusOr> CreateModelDataProcessor( switch (model_type.model_type_case()) { case proto::LlmModelType::kGemma3N: case proto::LlmModelType::kGemma3: + ABSL_LOG(INFO) << "Creating Gemma3DataProcessor for model type: " + << model_type.model_type_case(); return Gemma3DataProcessor::Create( std::holds_alternative(config) ? std::get(config) : Gemma3DataProcessorConfig(), preface); case proto::LlmModelType::kGenericModel: { + ABSL_LOG(INFO) << "Creating GenericDataProcessor for model type: " + << model_type.model_type_case(); if (std::holds_alternative(config)) { return GenericDataProcessor::Create( std::holds_alternative(config)