diff --git a/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.cc b/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.cc index 9e301aa3..fafda396 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.cc +++ b/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.cc @@ -39,6 +39,16 @@ namespace litert::lm { +absl::Status EndOfMultiModalEmbedding::LookupDecode( + int token, std::vector& output_vector) { + return LookupPrefill(token, output_vector); +} + +absl::Status EndOfMultiModalEmbedding::LookupDecode( + int token, litert::TensorBuffer* output_tensor) { + return LookupPrefill({token}, output_tensor, 0); +} + absl::Status EndOfMultiModalEmbedding::LookupPrefill( int token, std::vector& output_vector) { if (token != special_token_) { diff --git a/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h b/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h index 1f078126..f40de339 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h +++ b/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h @@ -47,19 +47,15 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup { static absl::StatusOr> Create( const litert::Model* absl_nonnull model, int special_token); - // Multimodal embeddings are not supported during decode. + // For a given token, looks up the end of multi-modal embedding and stores it + // in the provided vector. absl::Status LookupDecode(int token, - std::vector& output_vector) override { - return absl::UnimplementedError( - "LookupDecode is not implemented for EndOfMultiModalEmbedding."); - } + std::vector& output_vector) override; - // Multimodal embeddings are not supported during decode. + // For a given token, looks up the end of multi-modal embedding and stores it + // in the output tensor. absl::Status LookupDecode(int token, - litert::TensorBuffer* output_tensor) override { - return absl::UnimplementedError( - "LookupDecode is not implemented for EndOfMultiModalEmbedding."); - } + litert::TensorBuffer* output_tensor) override; // If the token is the special token, looks up the end of multimodal // embedding and stores it in the provided vector. diff --git a/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal_test.cc b/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal_test.cc index 59501d96..8eaa13c3 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal_test.cc +++ b/runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal_test.cc @@ -108,19 +108,41 @@ class EndOfMultiModalEmbeddingTest : public testing::Test { std::optional model_; }; -TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVector) { +TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVectorBadOutputVector) { std::unique_ptr embedding = GetEndOfMultiModalEmbedding(); ASSERT_NE(embedding, nullptr); - std::vector output_vector(4 * 32); + std::vector output_vector(4 * 32 + 1); int32_t token = -3; ASSERT_THAT(embedding->LookupDecode(token, output_vector), testing::status::StatusIs( - absl::StatusCode::kUnimplemented, - testing::HasSubstr("LookupDecode is not implemented for " - "EndOfMultiModalEmbedding."))); + absl::StatusCode::kInvalidArgument, + testing::HasSubstr("The output vector is not the correct " + "size for the end of multi-modal"))); +} + +TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVector) { + std::unique_ptr embedding = + GetEndOfMultiModalEmbedding(); + ASSERT_NE(embedding, nullptr); + + std::vector output_vector(128); + + int32_t token = -3; + ASSERT_OK(embedding->LookupDecode(token, output_vector)); + + size_t offset = 0; + // Dimensions 0 and 1 both have size 1. + for (int idx2 = 0; idx2 < 4; ++idx2) { + for (int idx3 = 0; idx3 < 32; ++idx3) { + // Dimensions 0 and 1 both have size 1 so offset and expected value can + // ignore them. + float expected_value = (100.0 * idx2 + idx3) * 2; + ASSERT_NEAR(output_vector[offset++], expected_value, 1e-5); + } + } } TEST_F(EndOfMultiModalEmbeddingTest, LookupDecode) { @@ -133,12 +155,23 @@ TEST_F(EndOfMultiModalEmbeddingTest, LookupDecode) { GetTensorBuffer(dimensions)); int32_t token = -3; - ASSERT_THAT( - embedding->LookupDecode(token, &output_tensor), - testing::status::StatusIs( - absl::StatusCode::kUnimplemented, - testing::HasSubstr( - "LookupDecode is not implemented for EndOfMultiModalEmbedding"))); + ASSERT_OK(embedding->LookupDecode(token, &output_tensor)); + + auto output_tensor_lock_and_addr = ::litert::TensorBufferScopedLock::Create( + output_tensor, ::litert::TensorBuffer::LockMode::kRead); + auto output_tensor_ptr = + reinterpret_cast(output_tensor_lock_and_addr->second); + + float expected_value = 0.0; + for (int idx2 = 0; idx2 < dimensions[2]; ++idx2) { + for (int idx3 = 0; idx3 < dimensions[3]; ++idx3) { + // Since dimension 1 is of size 1, the offset and expected value can + // ignore it. + size_t offset = idx2 * dimensions[3] + idx3; + expected_value = (100.0 * idx2 + idx3) * 2; + ASSERT_NEAR(output_tensor_ptr[offset], expected_value, 1e-5); + } + } } TEST_F(EndOfMultiModalEmbeddingTest, LookupPrefillVector) { diff --git a/runtime/components/embedding_lookup/embedding_lookup_manager.cc b/runtime/components/embedding_lookup/embedding_lookup_manager.cc index fe63033f..fd016d2c 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_manager.cc +++ b/runtime/components/embedding_lookup/embedding_lookup_manager.cc @@ -115,29 +115,15 @@ absl::Status EmbeddingLookupManager::LookupDecode( } const size_t floats_per_token = text_embedding_lookup_->GetFloatsPerToken(); output_vector.resize(floats_per_token); - - if (token < 0) { - return absl::InvalidArgumentError( - "Multimodal embeddings are not supported during decode."); - } - - return text_embedding_lookup_->LookupDecode(token, output_vector); + return LookupPrefill(token, output_vector); } absl::Status EmbeddingLookupManager::LookupDecode( int token, litert::TensorBuffer* output_tensor) { - if (text_embedding_lookup_ == nullptr) { - return absl::InternalError( - "Text embedding lookup is null. Please ensure that the " - "EmbeddingLookupManager is initialized properly."); + if (output_tensor == nullptr) { + return absl::InvalidArgumentError("Decode output tensor buffer is null."); } - - if (token < 0) { - return absl::InvalidArgumentError( - "Multimodal embeddings are not supported during decode."); - } - - return text_embedding_lookup_->LookupDecode(token, output_tensor); + return LookupPrefill({token}, output_tensor, /*byte_offset=*/0); } absl::Status EmbeddingLookupManager::LookupPrefill( @@ -165,6 +151,11 @@ absl::Status EmbeddingLookupManager::LookupPrefill( memcpy(output_vector.data(), default_embedding_vector_.data(), default_embedding_vector_.size() * sizeof(float)); } + // Remove fully used multi modal embedding lookups. + std::erase_if(multi_modal_embedding_lookups_, + [](const auto& embedding_lookup) { + return !embedding_lookup->HasRemainingEmbeddings(); + }); return absl::OkStatus(); } diff --git a/runtime/components/embedding_lookup/embedding_lookup_manager_test.cc b/runtime/components/embedding_lookup/embedding_lookup_manager_test.cc index 14684a9c..ed4bee2c 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_manager_test.cc +++ b/runtime/components/embedding_lookup/embedding_lookup_manager_test.cc @@ -245,23 +245,21 @@ TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleTokenVectorNonEmpty) { TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleNegativeTokenVector) { ASSERT_NE(embedding_lookup_manager_, nullptr); - + ASSERT_OK(UpdateMultiModalEmbeddings()); std::vector output_vector; int32_t token = -1; - ASSERT_THAT( - embedding_lookup_manager_->LookupDecode(token, output_vector), - testing::status::StatusIs( - absl::StatusCode::kInvalidArgument, - testing::HasSubstr( - "Multimodal embeddings are not supported during decode"))); + ASSERT_OK(embedding_lookup_manager_->LookupDecode(token, output_vector)); + ASSERT_EQ(output_vector.size(), 128); + for (int i = 0; i < 128; i++) { + EXPECT_EQ(output_vector[i], 1.0 + i); + } token = -2; - ASSERT_THAT( - embedding_lookup_manager_->LookupDecode(token, output_vector), - testing::status::StatusIs( - absl::StatusCode::kInvalidArgument, - testing::HasSubstr( - "Multimodal embeddings are not supported during decode"))); + ASSERT_OK(embedding_lookup_manager_->LookupDecode(token, output_vector)); + ASSERT_EQ(output_vector.size(), 128); + for (int i = 0; i < 128; i++) { + EXPECT_EQ(output_vector[i], 257.0 + i); + } } TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleToken) { diff --git a/runtime/components/embedding_lookup/embedding_lookup_multi_modal.cc b/runtime/components/embedding_lookup/embedding_lookup_multi_modal.cc index 3260f16a..dac1ac7c 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_multi_modal.cc +++ b/runtime/components/embedding_lookup/embedding_lookup_multi_modal.cc @@ -38,20 +38,12 @@ namespace litert::lm { absl::Status EmbeddingLookupMultiModal::LookupDecode( int token, std::vector& output_vector) { - // Multimodal lookup is not supported for single token case because decode - // does not use multimodal embedding lookup. - return absl::UnimplementedError( - "Multimodal embedding lookup is not supported for single token decode " - "case."); + return LookupPrefill(token, output_vector); } absl::Status EmbeddingLookupMultiModal::LookupDecode( int token, litert::TensorBuffer* output_tensor) { - // Multimodal lookup is not supported for single token case because decode - // does not use multimodal embedding lookup. - return absl::UnimplementedError( - "Multimodal embedding lookup is not supported for single token decode " - "case."); + return LookupPrefill({token}, output_tensor, 0); } absl::Status EmbeddingLookupMultiModal::LookupPrefill( @@ -178,9 +170,10 @@ absl::Status EmbeddingLookupMultiModal::Initialize( "null."); } LITERT_ASSIGN_OR_RETURN_ABSL( - embedding_, - ::litert::lm::ReferTensorBufferAsSpan(*embedding_buffer)); + embedding_buffer_, + ::litert::lm::CopyFromTensorBuffer(*embedding_buffer)); special_token_ = special_token; + embedding_ = absl::MakeSpan(embedding_buffer_); return absl::OkStatus(); } diff --git a/runtime/components/embedding_lookup/embedding_lookup_multi_modal.h b/runtime/components/embedding_lookup/embedding_lookup_multi_modal.h index c54012d9..67fdc5a1 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_multi_modal.h +++ b/runtime/components/embedding_lookup/embedding_lookup_multi_modal.h @@ -94,7 +94,7 @@ class EmbeddingLookupMultiModal : public EmbeddingLookup { protected: absl::Status Initialize(const ::litert::TensorBuffer* embedding_buffer, int special_token); - + std::vector embedding_buffer_; absl::Span embedding_; int special_token_; }; diff --git a/runtime/components/embedding_lookup/embedding_lookup_multi_modal_test.cc b/runtime/components/embedding_lookup/embedding_lookup_multi_modal_test.cc index a173d02b..c827c506 100644 --- a/runtime/components/embedding_lookup/embedding_lookup_multi_modal_test.cc +++ b/runtime/components/embedding_lookup/embedding_lookup_multi_modal_test.cc @@ -516,12 +516,14 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecode) { LITERT_ASSERT_OK_AND_ASSIGN(litert::TensorBuffer output_tensor, GetTensorBuffer(dimensions)); int token = special_token_; - ASSERT_THAT( - embedding->LookupDecode(token, &output_tensor), - testing::status::StatusIs( - absl::StatusCode::kUnimplemented, - testing::HasSubstr("Multimodal embedding lookup is not supported for " - "single token decode case."))); + ASSERT_OK(embedding->LookupDecode(token, &output_tensor)); + auto output_tensor_lock_and_addr = ::litert::TensorBufferScopedLock::Create( + output_tensor, ::litert::TensorBuffer::LockMode::kRead); + auto output_tensor_ptr = + reinterpret_cast(output_tensor_lock_and_addr->second); + // 2 * 3 = 6 floats per token. + EXPECT_THAT(absl::MakeSpan(output_tensor_ptr, 6), + testing::ElementsAre(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)); } TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorNoSpecialToken) { @@ -531,12 +533,8 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorNoSpecialToken) { std::vector output_vector(2 * 3); int token = 1; - ASSERT_THAT( - embedding->LookupDecode(token, output_vector), - testing::status::StatusIs( - absl::StatusCode::kUnimplemented, - testing::HasSubstr("Multimodal embedding lookup is not supported for " - "single token decode case."))); + ASSERT_OK(embedding->LookupDecode(token, output_vector)); + EXPECT_THAT(output_vector, testing::Each(testing::FloatEq(0.0))); } TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorSpecialToken) { @@ -546,12 +544,10 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorSpecialToken) { std::vector output_vector(2 * 3); int token = special_token_; - ASSERT_THAT( - embedding->LookupDecode(token, output_vector), - testing::status::StatusIs( - absl::StatusCode::kUnimplemented, - testing::HasSubstr("Multimodal embedding lookup is not supported for " - "single token decode case."))); + ASSERT_OK(embedding->LookupDecode(token, output_vector)); + // 2 * 3 = 6 floats per token. + EXPECT_THAT(output_vector, + testing::ElementsAre(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)); } } // namespace litert::lm