Skip to content

Commit 2494541

Browse files
ai-edge-botcopybara-github
authored andcommitted
This is an internal change
LiteRT-LM-PiperOrigin-RevId: 819561952
1 parent dfad938 commit 2494541

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

runtime/executor/litert_compiled_model_executor_utils.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,7 @@ GetOptimizedPrefillWorkGroups(
310310
return work_groups;
311311
}
312312

313-
absl::Status InitializeAttentionMask(litert::TensorBuffer& mask,
314-
bool is_f16) {
313+
absl::Status InitializeAttentionMask(litert::TensorBuffer& mask, bool is_f16) {
315314
auto mask_size = mask.PackedSize();
316315
RET_CHECK(mask_size) << "Failed to get attention mask buffer size.";
317316
auto mask_tensor_type = mask.TensorType();

runtime/executor/magic_number_configs_helper.cc

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ constexpr absl::string_view kInputPosSubstr = "pos";
5151
constexpr absl::string_view kOutputLogitsSubstr = "logits";
5252
constexpr int64_t kDefaultTargetNumberBase = 256;
5353

54-
Expected<int64_t> GetLastDimensionOfInput(const Subgraph& subgraph,
55-
absl::string_view input_name) {
54+
Expected<int64_t> GetLastDimensionOfInput(const Subgraph& subgraph,
55+
absl::string_view input_name) {
5656
LITERT_ASSIGN_OR_RETURN(auto tensor, subgraph.Input(input_name));
5757
LITERT_ASSIGN_OR_RETURN(auto type, tensor.RankedTensorType());
5858
return type.Layout().Dimensions()[type.Layout().Rank() - 1];
@@ -82,9 +82,8 @@ bool IsMagicNumber(int64_t number) {
8282
return true;
8383
}
8484

85-
Expected<void> SetMagicNumberIfPrime(const Subgraph& subgraph,
86-
absl::string_view tensor_name, bool input,
87-
int64_t& magic_number) {
85+
Expected<void> SetMagicNumberIfPrime(const Subgraph& subgraph,
86+
absl::string_view tensor_name, bool input, int64_t& magic_number) {
8887
auto expected_dim = input ? GetLastDimensionOfInput(subgraph, tensor_name)
8988
: GetFirstDimensionOfOutput(subgraph, tensor_name);
9089
LITERT_ASSIGN_OR_RETURN(auto dim, expected_dim);
@@ -106,17 +105,17 @@ Expected<MagicNumbers> GetMagicNumbersFromModel(const Model& litert_model) {
106105
for (int i = 0; i < num_signatures; ++i) {
107106
LITERT_ASSIGN_OR_RETURN(auto signature, litert_model.GetSignature(i));
108107
LITERT_ASSIGN_OR_RETURN(auto subgraph,
109-
litert_model.Subgraph(signature.Key()));
108+
litert_model.Subgraph(signature.Key()));
110109
if (signature.Key().starts_with(kPrefillSignaturePrefix)) {
111110
for (const auto& input_name : signature.InputNames()) {
112111
if (absl::StrContains(input_name, kMaskSubstr)) {
113-
LITERT_RETURN_IF_ERROR(
114-
SetMagicNumberIfPrime(subgraph, input_name, /*input=*/true,
115-
magic_numbers.context_length));
112+
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
113+
subgraph, input_name, true,
114+
magic_numbers.context_length));
116115
} else if (absl::StrContains(input_name, kInputPosSubstr)) {
117116
int64_t prefill_length = 0;
118-
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
119-
subgraph, input_name, /*input=*/true, prefill_length));
117+
LITERT_RETURN_IF_ERROR(
118+
SetMagicNumberIfPrime(subgraph, input_name, true, prefill_length));
120119
if (prefill_length > 0) {
121120
magic_numbers.prefill_lengths.push_back(prefill_length);
122121
}
@@ -125,16 +124,22 @@ Expected<MagicNumbers> GetMagicNumbersFromModel(const Model& litert_model) {
125124
} else if (signature.Key().starts_with(kDecodeSignaturePrefix)) {
126125
for (const auto& input_name : signature.InputNames()) {
127126
if (absl::StrContains(input_name, kMaskSubstr)) {
128-
LITERT_RETURN_IF_ERROR(
129-
SetMagicNumberIfPrime(subgraph, input_name, /*input=*/true,
130-
magic_numbers.context_length));
127+
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
128+
subgraph, input_name, true,
129+
magic_numbers.context_length));
130+
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
131+
subgraph
132+
,
133+
input_name, /*input=*/true, magic_numbers.context_length));
131134
}
132135
}
133136
for (const auto& output_name : signature.OutputNames()) {
134137
if (absl::StrContains(output_name, kOutputLogitsSubstr)) {
135-
LITERT_RETURN_IF_ERROR(
136-
SetMagicNumberIfPrime(subgraph, output_name, /*input=*/false,
137-
magic_numbers.num_output_candidates));
138+
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
139+
subgraph
140+
,
141+
output_name, /*input=*/false,
142+
magic_numbers.num_output_candidates));
138143
}
139144
}
140145
}
@@ -158,8 +163,8 @@ GetVerificationPairs(const Model& litert_model,
158163
continue;
159164
}
160165

161-
LITERT_ASSIGN_OR_RETURN(auto subgraph,
162-
litert_model.Subgraph(signature.Key()));
166+
LITERT_ASSIGN_OR_RETURN(auto subgraph,
167+
litert_model.Subgraph(signature.Key()));
163168
for (int j = 0; j < num_signatures; ++j) {
164169
LITERT_ASSIGN_OR_RETURN(auto test_signature,
165170
litert_model.GetSignature(j));
@@ -169,16 +174,18 @@ GetVerificationPairs(const Model& litert_model,
169174
}
170175

171176
LITERT_ASSIGN_OR_RETURN(auto test_subgraph,
172-
litert_model.Subgraph(test_signature.Key()));
177+
litert_model.Subgraph(test_signature.Key()));
178+
173179
bool is_same_shape = true;
174180
for (const auto& input_name : signature.InputNames()) {
175181
if (absl::StrContains(input_name, kMaskSubstr) ||
176182
absl::StrContains(input_name, kInputPosSubstr)) {
177183
LITERT_ASSIGN_OR_RETURN(
178-
auto dim, GetLastDimensionOfInput(subgraph, input_name));
184+
auto dim,
185+
GetLastDimensionOfInput(subgraph, input_name));
179186
LITERT_ASSIGN_OR_RETURN(
180187
auto test_dim,
181-
GetLastDimensionOfInput(test_subgraph, input_name));
188+
GetLastDimensionOfInput(test_subgraph, input_name));
182189
// Check if dim is same as test_dim, or as a magic number when
183190
// test_dim is target number corresponding to the magic number.
184191
// Otherwise, the shapes are not same.

0 commit comments

Comments
 (0)