Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@

namespace litert::lm {

absl::Status EndOfMultiModalEmbedding::LookupDecode(
int token, std::vector<float>& output_vector) {
return LookupPrefill(token, output_vector);
}

absl::Status EndOfMultiModalEmbedding::LookupDecode(
int token, litert::TensorBuffer* output_tensor) {
return LookupPrefill({token}, output_tensor, 0);
}

absl::Status EndOfMultiModalEmbedding::LookupPrefill(
int token, std::vector<float>& output_vector) {
if (token != special_token_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,15 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup {
static absl::StatusOr<std::unique_ptr<EndOfMultiModalEmbedding>> Create(
const litert::Model* absl_nonnull model, int special_token);

// Multimodal embeddings are not supported during decode.
// For a given token, looks up the end of multi-modal embedding and stores it
// in the provided vector.
absl::Status LookupDecode(int token,
std::vector<float>& output_vector) override {
return absl::UnimplementedError(
"LookupDecode is not implemented for EndOfMultiModalEmbedding.");
}
std::vector<float>& output_vector) override;

// Multimodal embeddings are not supported during decode.
// For a given token, looks up the end of multi-modal embedding and stores it
// in the output tensor.
absl::Status LookupDecode(int token,
litert::TensorBuffer* output_tensor) override {
return absl::UnimplementedError(
"LookupDecode is not implemented for EndOfMultiModalEmbedding.");
}
litert::TensorBuffer* output_tensor) override;

// If the token is the special token, looks up the end of multimodal
// embedding and stores it in the provided vector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,41 @@ class EndOfMultiModalEmbeddingTest : public testing::Test {
std::optional<Model> model_;
};

TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVector) {
TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVectorBadOutputVector) {
std::unique_ptr<EndOfMultiModalEmbedding> embedding =
GetEndOfMultiModalEmbedding();
ASSERT_NE(embedding, nullptr);

std::vector<float> output_vector(4 * 32);
std::vector<float> output_vector(4 * 32 + 1);

int32_t token = -3;
ASSERT_THAT(embedding->LookupDecode(token, output_vector),
testing::status::StatusIs(
absl::StatusCode::kUnimplemented,
testing::HasSubstr("LookupDecode is not implemented for "
"EndOfMultiModalEmbedding.")));
absl::StatusCode::kInvalidArgument,
testing::HasSubstr("The output vector is not the correct "
"size for the end of multi-modal")));
}

TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVector) {
std::unique_ptr<EndOfMultiModalEmbedding> embedding =
GetEndOfMultiModalEmbedding();
ASSERT_NE(embedding, nullptr);

std::vector<float> output_vector(128);

int32_t token = -3;
ASSERT_OK(embedding->LookupDecode(token, output_vector));

size_t offset = 0;
// Dimensions 0 and 1 both have size 1.
for (int idx2 = 0; idx2 < 4; ++idx2) {
for (int idx3 = 0; idx3 < 32; ++idx3) {
// Dimensions 0 and 1 both have size 1 so offset and expected value can
// ignore them.
float expected_value = (100.0 * idx2 + idx3) * 2;
ASSERT_NEAR(output_vector[offset++], expected_value, 1e-5);
}
}
}

TEST_F(EndOfMultiModalEmbeddingTest, LookupDecode) {
Expand All @@ -133,12 +155,23 @@ TEST_F(EndOfMultiModalEmbeddingTest, LookupDecode) {
GetTensorBuffer(dimensions));

int32_t token = -3;
ASSERT_THAT(
embedding->LookupDecode(token, &output_tensor),
testing::status::StatusIs(
absl::StatusCode::kUnimplemented,
testing::HasSubstr(
"LookupDecode is not implemented for EndOfMultiModalEmbedding")));
ASSERT_OK(embedding->LookupDecode(token, &output_tensor));

auto output_tensor_lock_and_addr = ::litert::TensorBufferScopedLock::Create(
output_tensor, ::litert::TensorBuffer::LockMode::kRead);
auto output_tensor_ptr =
reinterpret_cast<float*>(output_tensor_lock_and_addr->second);

float expected_value = 0.0;
for (int idx2 = 0; idx2 < dimensions[2]; ++idx2) {
for (int idx3 = 0; idx3 < dimensions[3]; ++idx3) {
// Since dimension 1 is of size 1, the offset and expected value can
// ignore it.
size_t offset = idx2 * dimensions[3] + idx3;
expected_value = (100.0 * idx2 + idx3) * 2;
ASSERT_NEAR(output_tensor_ptr[offset], expected_value, 1e-5);
}
}
}

TEST_F(EndOfMultiModalEmbeddingTest, LookupPrefillVector) {
Expand Down
27 changes: 9 additions & 18 deletions runtime/components/embedding_lookup/embedding_lookup_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,15 @@ absl::Status EmbeddingLookupManager::LookupDecode(
}
const size_t floats_per_token = text_embedding_lookup_->GetFloatsPerToken();
output_vector.resize(floats_per_token);

if (token < 0) {
return absl::InvalidArgumentError(
"Multimodal embeddings are not supported during decode.");
}

return text_embedding_lookup_->LookupDecode(token, output_vector);
return LookupPrefill(token, output_vector);
}

absl::Status EmbeddingLookupManager::LookupDecode(
int token, litert::TensorBuffer* output_tensor) {
if (text_embedding_lookup_ == nullptr) {
return absl::InternalError(
"Text embedding lookup is null. Please ensure that the "
"EmbeddingLookupManager is initialized properly.");
if (output_tensor == nullptr) {
return absl::InvalidArgumentError("Decode output tensor buffer is null.");
}

if (token < 0) {
return absl::InvalidArgumentError(
"Multimodal embeddings are not supported during decode.");
}

return text_embedding_lookup_->LookupDecode(token, output_tensor);
return LookupPrefill({token}, output_tensor, /*byte_offset=*/0);
}

absl::Status EmbeddingLookupManager::LookupPrefill(
Expand Down Expand Up @@ -165,6 +151,11 @@ absl::Status EmbeddingLookupManager::LookupPrefill(
memcpy(output_vector.data(), default_embedding_vector_.data(),
default_embedding_vector_.size() * sizeof(float));
}
// Remove fully used multi modal embedding lookups.
std::erase_if(multi_modal_embedding_lookups_,
[](const auto& embedding_lookup) {
return !embedding_lookup->HasRemainingEmbeddings();
});
return absl::OkStatus();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,23 +245,21 @@ TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleTokenVectorNonEmpty) {

TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleNegativeTokenVector) {
ASSERT_NE(embedding_lookup_manager_, nullptr);

ASSERT_OK(UpdateMultiModalEmbeddings());
std::vector<float> output_vector;
int32_t token = -1;
ASSERT_THAT(
embedding_lookup_manager_->LookupDecode(token, output_vector),
testing::status::StatusIs(
absl::StatusCode::kInvalidArgument,
testing::HasSubstr(
"Multimodal embeddings are not supported during decode")));
ASSERT_OK(embedding_lookup_manager_->LookupDecode(token, output_vector));
ASSERT_EQ(output_vector.size(), 128);
for (int i = 0; i < 128; i++) {
EXPECT_EQ(output_vector[i], 1.0 + i);
}

token = -2;
ASSERT_THAT(
embedding_lookup_manager_->LookupDecode(token, output_vector),
testing::status::StatusIs(
absl::StatusCode::kInvalidArgument,
testing::HasSubstr(
"Multimodal embeddings are not supported during decode")));
ASSERT_OK(embedding_lookup_manager_->LookupDecode(token, output_vector));
ASSERT_EQ(output_vector.size(), 128);
for (int i = 0; i < 128; i++) {
EXPECT_EQ(output_vector[i], 257.0 + i);
}
}

TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,12 @@ namespace litert::lm {

absl::Status EmbeddingLookupMultiModal::LookupDecode(
int token, std::vector<float>& output_vector) {
// Multimodal lookup is not supported for single token case because decode
// does not use multimodal embedding lookup.
return absl::UnimplementedError(
"Multimodal embedding lookup is not supported for single token decode "
"case.");
return LookupPrefill(token, output_vector);
}

absl::Status EmbeddingLookupMultiModal::LookupDecode(
int token, litert::TensorBuffer* output_tensor) {
// Multimodal lookup is not supported for single token case because decode
// does not use multimodal embedding lookup.
return absl::UnimplementedError(
"Multimodal embedding lookup is not supported for single token decode "
"case.");
return LookupPrefill({token}, output_tensor, 0);
}

absl::Status EmbeddingLookupMultiModal::LookupPrefill(
Expand Down Expand Up @@ -178,9 +170,10 @@ absl::Status EmbeddingLookupMultiModal::Initialize(
"null.");
}
LITERT_ASSIGN_OR_RETURN_ABSL(
embedding_,
::litert::lm::ReferTensorBufferAsSpan<float>(*embedding_buffer));
embedding_buffer_,
::litert::lm::CopyFromTensorBuffer<float>(*embedding_buffer));
special_token_ = special_token;
embedding_ = absl::MakeSpan(embedding_buffer_);
return absl::OkStatus();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class EmbeddingLookupMultiModal : public EmbeddingLookup {
protected:
absl::Status Initialize(const ::litert::TensorBuffer* embedding_buffer,
int special_token);

std::vector<float> embedding_buffer_;
absl::Span<float> embedding_;
int special_token_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,12 +516,14 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecode) {
LITERT_ASSERT_OK_AND_ASSIGN(litert::TensorBuffer output_tensor,
GetTensorBuffer(dimensions));
int token = special_token_;
ASSERT_THAT(
embedding->LookupDecode(token, &output_tensor),
testing::status::StatusIs(
absl::StatusCode::kUnimplemented,
testing::HasSubstr("Multimodal embedding lookup is not supported for "
"single token decode case.")));
ASSERT_OK(embedding->LookupDecode(token, &output_tensor));
auto output_tensor_lock_and_addr = ::litert::TensorBufferScopedLock::Create(
output_tensor, ::litert::TensorBuffer::LockMode::kRead);
auto output_tensor_ptr =
reinterpret_cast<float*>(output_tensor_lock_and_addr->second);
// 2 * 3 = 6 floats per token.
EXPECT_THAT(absl::MakeSpan(output_tensor_ptr, 6),
testing::ElementsAre(1.0, 2.0, 3.0, 4.0, 5.0, 6.0));
}

TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorNoSpecialToken) {
Expand All @@ -531,12 +533,8 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorNoSpecialToken) {

std::vector<float> output_vector(2 * 3);
int token = 1;
ASSERT_THAT(
embedding->LookupDecode(token, output_vector),
testing::status::StatusIs(
absl::StatusCode::kUnimplemented,
testing::HasSubstr("Multimodal embedding lookup is not supported for "
"single token decode case.")));
ASSERT_OK(embedding->LookupDecode(token, output_vector));
EXPECT_THAT(output_vector, testing::Each(testing::FloatEq(0.0)));
}

TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorSpecialToken) {
Expand All @@ -546,12 +544,10 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorSpecialToken) {

std::vector<float> output_vector(2 * 3);
int token = special_token_;
ASSERT_THAT(
embedding->LookupDecode(token, output_vector),
testing::status::StatusIs(
absl::StatusCode::kUnimplemented,
testing::HasSubstr("Multimodal embedding lookup is not supported for "
"single token decode case.")));
ASSERT_OK(embedding->LookupDecode(token, output_vector));
// 2 * 3 = 6 floats per token.
EXPECT_THAT(output_vector,
testing::ElementsAre(1.0, 2.0, 3.0, 4.0, 5.0, 6.0));
}

} // namespace litert::lm