Skip to content

Commit 8b9af40

Browse files
ai-edge-botcopybara-github
authored andcommitted
Multimodal embedding lookup support decode.
LiteRT-LM-PiperOrigin-RevId: 817293972
1 parent 79d46dd commit 8b9af40

8 files changed

+100
-83
lines changed

runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@
3939

4040
namespace litert::lm {
4141

42+
absl::Status EndOfMultiModalEmbedding::LookupDecode(
43+
int token, std::vector<float>& output_vector) {
44+
return LookupPrefill(token, output_vector);
45+
}
46+
47+
absl::Status EndOfMultiModalEmbedding::LookupDecode(
48+
int token, litert::TensorBuffer* output_tensor) {
49+
return LookupPrefill({token}, output_tensor, 0);
50+
}
51+
4252
absl::Status EndOfMultiModalEmbedding::LookupPrefill(
4353
int token, std::vector<float>& output_vector) {
4454
if (token != special_token_) {

runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,15 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup {
4747
static absl::StatusOr<std::unique_ptr<EndOfMultiModalEmbedding>> Create(
4848
const litert::Model* absl_nonnull model, int special_token);
4949

50-
// Multimodal embeddings are not supported during decode.
50+
// For a given token, looks up the end of multi-modal embedding and stores it
51+
// in the provided vector.
5152
absl::Status LookupDecode(int token,
52-
std::vector<float>& output_vector) override {
53-
return absl::UnimplementedError(
54-
"LookupDecode is not implemented for EndOfMultiModalEmbedding.");
55-
}
53+
std::vector<float>& output_vector) override;
5654

57-
// Multimodal embeddings are not supported during decode.
55+
// For a given token, looks up the end of multi-modal embedding and stores it
56+
// in the output tensor.
5857
absl::Status LookupDecode(int token,
59-
litert::TensorBuffer* output_tensor) override {
60-
return absl::UnimplementedError(
61-
"LookupDecode is not implemented for EndOfMultiModalEmbedding.");
62-
}
58+
litert::TensorBuffer* output_tensor) override;
6359

6460
// If the token is the special token, looks up the end of multimodal
6561
// embedding and stores it in the provided vector.

runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal_test.cc

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,41 @@ class EndOfMultiModalEmbeddingTest : public testing::Test {
108108
std::optional<Model> model_;
109109
};
110110

111-
TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVector) {
111+
TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVectorBadOutputVector) {
112112
std::unique_ptr<EndOfMultiModalEmbedding> embedding =
113113
GetEndOfMultiModalEmbedding();
114114
ASSERT_NE(embedding, nullptr);
115115

116-
std::vector<float> output_vector(4 * 32);
116+
std::vector<float> output_vector(4 * 32 + 1);
117117

118118
int32_t token = -3;
119119
ASSERT_THAT(embedding->LookupDecode(token, output_vector),
120120
testing::status::StatusIs(
121-
absl::StatusCode::kUnimplemented,
122-
testing::HasSubstr("LookupDecode is not implemented for "
123-
"EndOfMultiModalEmbedding.")));
121+
absl::StatusCode::kInvalidArgument,
122+
testing::HasSubstr("The output vector is not the correct "
123+
"size for the end of multi-modal")));
124+
}
125+
126+
TEST_F(EndOfMultiModalEmbeddingTest, LookupDecodeVector) {
127+
std::unique_ptr<EndOfMultiModalEmbedding> embedding =
128+
GetEndOfMultiModalEmbedding();
129+
ASSERT_NE(embedding, nullptr);
130+
131+
std::vector<float> output_vector(128);
132+
133+
int32_t token = -3;
134+
ASSERT_OK(embedding->LookupDecode(token, output_vector));
135+
136+
size_t offset = 0;
137+
// Dimensions 0 and 1 both have size 1.
138+
for (int idx2 = 0; idx2 < 4; ++idx2) {
139+
for (int idx3 = 0; idx3 < 32; ++idx3) {
140+
// Dimensions 0 and 1 both have size 1 so offset and expected value can
141+
// ignore them.
142+
float expected_value = (100.0 * idx2 + idx3) * 2;
143+
ASSERT_NEAR(output_vector[offset++], expected_value, 1e-5);
144+
}
145+
}
124146
}
125147

126148
TEST_F(EndOfMultiModalEmbeddingTest, LookupDecode) {
@@ -133,12 +155,23 @@ TEST_F(EndOfMultiModalEmbeddingTest, LookupDecode) {
133155
GetTensorBuffer(dimensions));
134156

135157
int32_t token = -3;
136-
ASSERT_THAT(
137-
embedding->LookupDecode(token, &output_tensor),
138-
testing::status::StatusIs(
139-
absl::StatusCode::kUnimplemented,
140-
testing::HasSubstr(
141-
"LookupDecode is not implemented for EndOfMultiModalEmbedding")));
158+
ASSERT_OK(embedding->LookupDecode(token, &output_tensor));
159+
160+
auto output_tensor_lock_and_addr = ::litert::TensorBufferScopedLock::Create(
161+
output_tensor, ::litert::TensorBuffer::LockMode::kRead);
162+
auto output_tensor_ptr =
163+
reinterpret_cast<float*>(output_tensor_lock_and_addr->second);
164+
165+
float expected_value = 0.0;
166+
for (int idx2 = 0; idx2 < dimensions[2]; ++idx2) {
167+
for (int idx3 = 0; idx3 < dimensions[3]; ++idx3) {
168+
// Since dimension 1 is of size 1, the offset and expected value can
169+
// ignore it.
170+
size_t offset = idx2 * dimensions[3] + idx3;
171+
expected_value = (100.0 * idx2 + idx3) * 2;
172+
ASSERT_NEAR(output_tensor_ptr[offset], expected_value, 1e-5);
173+
}
174+
}
142175
}
143176

144177
TEST_F(EndOfMultiModalEmbeddingTest, LookupPrefillVector) {

runtime/components/embedding_lookup/embedding_lookup_manager.cc

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,29 +115,15 @@ absl::Status EmbeddingLookupManager::LookupDecode(
115115
}
116116
const size_t floats_per_token = text_embedding_lookup_->GetFloatsPerToken();
117117
output_vector.resize(floats_per_token);
118-
119-
if (token < 0) {
120-
return absl::InvalidArgumentError(
121-
"Multimodal embeddings are not supported during decode.");
122-
}
123-
124-
return text_embedding_lookup_->LookupDecode(token, output_vector);
118+
return LookupPrefill(token, output_vector);
125119
}
126120

127121
absl::Status EmbeddingLookupManager::LookupDecode(
128122
int token, litert::TensorBuffer* output_tensor) {
129-
if (text_embedding_lookup_ == nullptr) {
130-
return absl::InternalError(
131-
"Text embedding lookup is null. Please ensure that the "
132-
"EmbeddingLookupManager is initialized properly.");
123+
if (output_tensor == nullptr) {
124+
return absl::InvalidArgumentError("Decode output tensor buffer is null.");
133125
}
134-
135-
if (token < 0) {
136-
return absl::InvalidArgumentError(
137-
"Multimodal embeddings are not supported during decode.");
138-
}
139-
140-
return text_embedding_lookup_->LookupDecode(token, output_tensor);
126+
return LookupPrefill({token}, output_tensor, /*byte_offset=*/0);
141127
}
142128

143129
absl::Status EmbeddingLookupManager::LookupPrefill(
@@ -165,6 +151,11 @@ absl::Status EmbeddingLookupManager::LookupPrefill(
165151
memcpy(output_vector.data(), default_embedding_vector_.data(),
166152
default_embedding_vector_.size() * sizeof(float));
167153
}
154+
// Remove fully used multi modal embedding lookups.
155+
std::erase_if(multi_modal_embedding_lookups_,
156+
[](const auto& embedding_lookup) {
157+
return !embedding_lookup->HasRemainingEmbeddings();
158+
});
168159
return absl::OkStatus();
169160
}
170161

runtime/components/embedding_lookup/embedding_lookup_manager_test.cc

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -245,23 +245,21 @@ TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleTokenVectorNonEmpty) {
245245

246246
TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleNegativeTokenVector) {
247247
ASSERT_NE(embedding_lookup_manager_, nullptr);
248-
248+
ASSERT_OK(UpdateMultiModalEmbeddings());
249249
std::vector<float> output_vector;
250250
int32_t token = -1;
251-
ASSERT_THAT(
252-
embedding_lookup_manager_->LookupDecode(token, output_vector),
253-
testing::status::StatusIs(
254-
absl::StatusCode::kInvalidArgument,
255-
testing::HasSubstr(
256-
"Multimodal embeddings are not supported during decode")));
251+
ASSERT_OK(embedding_lookup_manager_->LookupDecode(token, output_vector));
252+
ASSERT_EQ(output_vector.size(), 128);
253+
for (int i = 0; i < 128; i++) {
254+
EXPECT_EQ(output_vector[i], 1.0 + i);
255+
}
257256

258257
token = -2;
259-
ASSERT_THAT(
260-
embedding_lookup_manager_->LookupDecode(token, output_vector),
261-
testing::status::StatusIs(
262-
absl::StatusCode::kInvalidArgument,
263-
testing::HasSubstr(
264-
"Multimodal embeddings are not supported during decode")));
258+
ASSERT_OK(embedding_lookup_manager_->LookupDecode(token, output_vector));
259+
ASSERT_EQ(output_vector.size(), 128);
260+
for (int i = 0; i < 128; i++) {
261+
EXPECT_EQ(output_vector[i], 257.0 + i);
262+
}
265263
}
266264

267265
TEST_F(EmbeddingLookupManagerTest, LookupDecodeTextSingleToken) {

runtime/components/embedding_lookup/embedding_lookup_multi_modal.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,12 @@ namespace litert::lm {
3838

3939
absl::Status EmbeddingLookupMultiModal::LookupDecode(
4040
int token, std::vector<float>& output_vector) {
41-
// Multimodal lookup is not supported for single token case because decode
42-
// does not use multimodal embedding lookup.
43-
return absl::UnimplementedError(
44-
"Multimodal embedding lookup is not supported for single token decode "
45-
"case.");
41+
return LookupPrefill(token, output_vector);
4642
}
4743

4844
absl::Status EmbeddingLookupMultiModal::LookupDecode(
4945
int token, litert::TensorBuffer* output_tensor) {
50-
// Multimodal lookup is not supported for single token case because decode
51-
// does not use multimodal embedding lookup.
52-
return absl::UnimplementedError(
53-
"Multimodal embedding lookup is not supported for single token decode "
54-
"case.");
46+
return LookupPrefill({token}, output_tensor, 0);
5547
}
5648

5749
absl::Status EmbeddingLookupMultiModal::LookupPrefill(
@@ -178,9 +170,10 @@ absl::Status EmbeddingLookupMultiModal::Initialize(
178170
"null.");
179171
}
180172
LITERT_ASSIGN_OR_RETURN_ABSL(
181-
embedding_,
182-
::litert::lm::ReferTensorBufferAsSpan<float>(*embedding_buffer));
173+
embedding_buffer_,
174+
::litert::lm::CopyFromTensorBuffer<float>(*embedding_buffer));
183175
special_token_ = special_token;
176+
embedding_ = absl::MakeSpan(embedding_buffer_);
184177
return absl::OkStatus();
185178
}
186179

runtime/components/embedding_lookup/embedding_lookup_multi_modal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class EmbeddingLookupMultiModal : public EmbeddingLookup {
9494
protected:
9595
absl::Status Initialize(const ::litert::TensorBuffer* embedding_buffer,
9696
int special_token);
97-
97+
std::vector<float> embedding_buffer_;
9898
absl::Span<float> embedding_;
9999
int special_token_;
100100
};

runtime/components/embedding_lookup/embedding_lookup_multi_modal_test.cc

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -516,12 +516,14 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecode) {
516516
LITERT_ASSERT_OK_AND_ASSIGN(litert::TensorBuffer output_tensor,
517517
GetTensorBuffer(dimensions));
518518
int token = special_token_;
519-
ASSERT_THAT(
520-
embedding->LookupDecode(token, &output_tensor),
521-
testing::status::StatusIs(
522-
absl::StatusCode::kUnimplemented,
523-
testing::HasSubstr("Multimodal embedding lookup is not supported for "
524-
"single token decode case.")));
519+
ASSERT_OK(embedding->LookupDecode(token, &output_tensor));
520+
auto output_tensor_lock_and_addr = ::litert::TensorBufferScopedLock::Create(
521+
output_tensor, ::litert::TensorBuffer::LockMode::kRead);
522+
auto output_tensor_ptr =
523+
reinterpret_cast<float*>(output_tensor_lock_and_addr->second);
524+
// 2 * 3 = 6 floats per token.
525+
EXPECT_THAT(absl::MakeSpan(output_tensor_ptr, 6),
526+
testing::ElementsAre(1.0, 2.0, 3.0, 4.0, 5.0, 6.0));
525527
}
526528

527529
TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorNoSpecialToken) {
@@ -531,12 +533,8 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorNoSpecialToken) {
531533

532534
std::vector<float> output_vector(2 * 3);
533535
int token = 1;
534-
ASSERT_THAT(
535-
embedding->LookupDecode(token, output_vector),
536-
testing::status::StatusIs(
537-
absl::StatusCode::kUnimplemented,
538-
testing::HasSubstr("Multimodal embedding lookup is not supported for "
539-
"single token decode case.")));
536+
ASSERT_OK(embedding->LookupDecode(token, output_vector));
537+
EXPECT_THAT(output_vector, testing::Each(testing::FloatEq(0.0)));
540538
}
541539

542540
TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorSpecialToken) {
@@ -546,12 +544,10 @@ TEST_F(EmbeddingLookupMultiModalTest, LookupDecodeVectorSpecialToken) {
546544

547545
std::vector<float> output_vector(2 * 3);
548546
int token = special_token_;
549-
ASSERT_THAT(
550-
embedding->LookupDecode(token, output_vector),
551-
testing::status::StatusIs(
552-
absl::StatusCode::kUnimplemented,
553-
testing::HasSubstr("Multimodal embedding lookup is not supported for "
554-
"single token decode case.")));
547+
ASSERT_OK(embedding->LookupDecode(token, output_vector));
548+
// 2 * 3 = 6 floats per token.
549+
EXPECT_THAT(output_vector,
550+
testing::ElementsAre(1.0, 2.0, 3.0, 4.0, 5.0, 6.0));
555551
}
556552

557553
} // namespace litert::lm

0 commit comments

Comments
 (0)