Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions runtime/executor/litert_compiled_model_executor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,7 @@ GetOptimizedPrefillWorkGroups(
return work_groups;
}

absl::Status InitializeAttentionMask(litert::TensorBuffer& mask,
bool is_f16) {
absl::Status InitializeAttentionMask(litert::TensorBuffer& mask, bool is_f16) {
auto mask_size = mask.PackedSize();
RET_CHECK(mask_size) << "Failed to get attention mask buffer size.";
auto mask_tensor_type = mask.TensorType();
Expand Down
51 changes: 29 additions & 22 deletions runtime/executor/magic_number_configs_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ constexpr absl::string_view kInputPosSubstr = "pos";
constexpr absl::string_view kOutputLogitsSubstr = "logits";
constexpr int64_t kDefaultTargetNumberBase = 256;

Expected<int64_t> GetLastDimensionOfInput(const Subgraph& subgraph,
absl::string_view input_name) {
Expected<int64_t> GetLastDimensionOfInput(const Subgraph& subgraph,
absl::string_view input_name) {
LITERT_ASSIGN_OR_RETURN(auto tensor, subgraph.Input(input_name));
LITERT_ASSIGN_OR_RETURN(auto type, tensor.RankedTensorType());
return type.Layout().Dimensions()[type.Layout().Rank() - 1];
Expand Down Expand Up @@ -82,9 +82,8 @@ bool IsMagicNumber(int64_t number) {
return true;
}

Expected<void> SetMagicNumberIfPrime(const Subgraph& subgraph,
absl::string_view tensor_name, bool input,
int64_t& magic_number) {
Expected<void> SetMagicNumberIfPrime(const Subgraph& subgraph,
absl::string_view tensor_name, bool input, int64_t& magic_number) {
auto expected_dim = input ? GetLastDimensionOfInput(subgraph, tensor_name)
: GetFirstDimensionOfOutput(subgraph, tensor_name);
LITERT_ASSIGN_OR_RETURN(auto dim, expected_dim);
Expand All @@ -106,17 +105,17 @@ Expected<MagicNumbers> GetMagicNumbersFromModel(const Model& litert_model) {
for (int i = 0; i < num_signatures; ++i) {
LITERT_ASSIGN_OR_RETURN(auto signature, litert_model.GetSignature(i));
LITERT_ASSIGN_OR_RETURN(auto subgraph,
litert_model.Subgraph(signature.Key()));
litert_model.Subgraph(signature.Key()));
if (signature.Key().starts_with(kPrefillSignaturePrefix)) {
for (const auto& input_name : signature.InputNames()) {
if (absl::StrContains(input_name, kMaskSubstr)) {
LITERT_RETURN_IF_ERROR(
SetMagicNumberIfPrime(subgraph, input_name, /*input=*/true,
magic_numbers.context_length));
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
subgraph, input_name, true,
magic_numbers.context_length));
} else if (absl::StrContains(input_name, kInputPosSubstr)) {
int64_t prefill_length = 0;
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
subgraph, input_name, /*input=*/true, prefill_length));
LITERT_RETURN_IF_ERROR(
SetMagicNumberIfPrime(subgraph, input_name, true, prefill_length));
if (prefill_length > 0) {
magic_numbers.prefill_lengths.push_back(prefill_length);
}
Expand All @@ -125,16 +124,22 @@ Expected<MagicNumbers> GetMagicNumbersFromModel(const Model& litert_model) {
} else if (signature.Key().starts_with(kDecodeSignaturePrefix)) {
for (const auto& input_name : signature.InputNames()) {
if (absl::StrContains(input_name, kMaskSubstr)) {
LITERT_RETURN_IF_ERROR(
SetMagicNumberIfPrime(subgraph, input_name, /*input=*/true,
magic_numbers.context_length));
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
subgraph, input_name, true,
magic_numbers.context_length));
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
subgraph
,
input_name, /*input=*/true, magic_numbers.context_length));
}
}
for (const auto& output_name : signature.OutputNames()) {
if (absl::StrContains(output_name, kOutputLogitsSubstr)) {
LITERT_RETURN_IF_ERROR(
SetMagicNumberIfPrime(subgraph, output_name, /*input=*/false,
magic_numbers.num_output_candidates));
LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime(
subgraph
,
output_name, /*input=*/false,
magic_numbers.num_output_candidates));
}
}
}
Expand All @@ -158,8 +163,8 @@ GetVerificationPairs(const Model& litert_model,
continue;
}

LITERT_ASSIGN_OR_RETURN(auto subgraph,
litert_model.Subgraph(signature.Key()));
LITERT_ASSIGN_OR_RETURN(auto subgraph,
litert_model.Subgraph(signature.Key()));
for (int j = 0; j < num_signatures; ++j) {
LITERT_ASSIGN_OR_RETURN(auto test_signature,
litert_model.GetSignature(j));
Expand All @@ -169,16 +174,18 @@ GetVerificationPairs(const Model& litert_model,
}

LITERT_ASSIGN_OR_RETURN(auto test_subgraph,
litert_model.Subgraph(test_signature.Key()));
litert_model.Subgraph(test_signature.Key()));

bool is_same_shape = true;
for (const auto& input_name : signature.InputNames()) {
if (absl::StrContains(input_name, kMaskSubstr) ||
absl::StrContains(input_name, kInputPosSubstr)) {
LITERT_ASSIGN_OR_RETURN(
auto dim, GetLastDimensionOfInput(subgraph, input_name));
auto dim,
GetLastDimensionOfInput(subgraph, input_name));
LITERT_ASSIGN_OR_RETURN(
auto test_dim,
GetLastDimensionOfInput(test_subgraph, input_name));
GetLastDimensionOfInput(test_subgraph, input_name));
// Check if dim is same as test_dim, or as a magic number when
// test_dim is target number corresponding to the magic number.
// Otherwise, the shapes are not same.
Expand Down