diff --git a/runtime/executor/litert_compiled_model_executor_utils.cc b/runtime/executor/litert_compiled_model_executor_utils.cc index b1e4d897..a3bb3c59 100644 --- a/runtime/executor/litert_compiled_model_executor_utils.cc +++ b/runtime/executor/litert_compiled_model_executor_utils.cc @@ -310,8 +310,7 @@ GetOptimizedPrefillWorkGroups( return work_groups; } -absl::Status InitializeAttentionMask(litert::TensorBuffer& mask, - bool is_f16) { +absl::Status InitializeAttentionMask(litert::TensorBuffer& mask, bool is_f16) { auto mask_size = mask.PackedSize(); RET_CHECK(mask_size) << "Failed to get attention mask buffer size."; auto mask_tensor_type = mask.TensorType(); diff --git a/runtime/executor/magic_number_configs_helper.cc b/runtime/executor/magic_number_configs_helper.cc index 7f5ae348..5646cd19 100644 --- a/runtime/executor/magic_number_configs_helper.cc +++ b/runtime/executor/magic_number_configs_helper.cc @@ -51,8 +51,8 @@ constexpr absl::string_view kInputPosSubstr = "pos"; constexpr absl::string_view kOutputLogitsSubstr = "logits"; constexpr int64_t kDefaultTargetNumberBase = 256; -Expected GetLastDimensionOfInput(const Subgraph& subgraph, - absl::string_view input_name) { + Expected GetLastDimensionOfInput(const Subgraph& subgraph, + absl::string_view input_name) { LITERT_ASSIGN_OR_RETURN(auto tensor, subgraph.Input(input_name)); LITERT_ASSIGN_OR_RETURN(auto type, tensor.RankedTensorType()); return type.Layout().Dimensions()[type.Layout().Rank() - 1]; @@ -82,9 +82,8 @@ bool IsMagicNumber(int64_t number) { return true; } -Expected SetMagicNumberIfPrime(const Subgraph& subgraph, - absl::string_view tensor_name, bool input, - int64_t& magic_number) { + Expected SetMagicNumberIfPrime(const Subgraph& subgraph, + absl::string_view tensor_name, bool input, int64_t& magic_number) { auto expected_dim = input ? GetLastDimensionOfInput(subgraph, tensor_name) : GetFirstDimensionOfOutput(subgraph, tensor_name); LITERT_ASSIGN_OR_RETURN(auto dim, expected_dim); @@ -106,17 +105,17 @@ Expected GetMagicNumbersFromModel(const Model& litert_model) { for (int i = 0; i < num_signatures; ++i) { LITERT_ASSIGN_OR_RETURN(auto signature, litert_model.GetSignature(i)); LITERT_ASSIGN_OR_RETURN(auto subgraph, - litert_model.Subgraph(signature.Key())); + litert_model.Subgraph(signature.Key())); if (signature.Key().starts_with(kPrefillSignaturePrefix)) { for (const auto& input_name : signature.InputNames()) { if (absl::StrContains(input_name, kMaskSubstr)) { - LITERT_RETURN_IF_ERROR( - SetMagicNumberIfPrime(subgraph, input_name, /*input=*/true, - magic_numbers.context_length)); + LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime( + subgraph, input_name, true, + magic_numbers.context_length)); } else if (absl::StrContains(input_name, kInputPosSubstr)) { int64_t prefill_length = 0; - LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime( - subgraph, input_name, /*input=*/true, prefill_length)); + LITERT_RETURN_IF_ERROR( + SetMagicNumberIfPrime(subgraph, input_name, true, prefill_length)); if (prefill_length > 0) { magic_numbers.prefill_lengths.push_back(prefill_length); } @@ -125,16 +124,22 @@ Expected GetMagicNumbersFromModel(const Model& litert_model) { } else if (signature.Key().starts_with(kDecodeSignaturePrefix)) { for (const auto& input_name : signature.InputNames()) { if (absl::StrContains(input_name, kMaskSubstr)) { - LITERT_RETURN_IF_ERROR( - SetMagicNumberIfPrime(subgraph, input_name, /*input=*/true, - magic_numbers.context_length)); + LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime( + subgraph, input_name, true, + magic_numbers.context_length)); + LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime( + subgraph + , + input_name, /*input=*/true, magic_numbers.context_length)); } } for (const auto& output_name : signature.OutputNames()) { if (absl::StrContains(output_name, kOutputLogitsSubstr)) { - LITERT_RETURN_IF_ERROR( - SetMagicNumberIfPrime(subgraph, output_name, /*input=*/false, - magic_numbers.num_output_candidates)); + LITERT_RETURN_IF_ERROR(SetMagicNumberIfPrime( + subgraph + , + output_name, /*input=*/false, + magic_numbers.num_output_candidates)); } } } @@ -158,8 +163,8 @@ GetVerificationPairs(const Model& litert_model, continue; } - LITERT_ASSIGN_OR_RETURN(auto subgraph, - litert_model.Subgraph(signature.Key())); + LITERT_ASSIGN_OR_RETURN(auto subgraph, + litert_model.Subgraph(signature.Key())); for (int j = 0; j < num_signatures; ++j) { LITERT_ASSIGN_OR_RETURN(auto test_signature, litert_model.GetSignature(j)); @@ -169,16 +174,18 @@ GetVerificationPairs(const Model& litert_model, } LITERT_ASSIGN_OR_RETURN(auto test_subgraph, - litert_model.Subgraph(test_signature.Key())); + litert_model.Subgraph(test_signature.Key())); + bool is_same_shape = true; for (const auto& input_name : signature.InputNames()) { if (absl::StrContains(input_name, kMaskSubstr) || absl::StrContains(input_name, kInputPosSubstr)) { LITERT_ASSIGN_OR_RETURN( - auto dim, GetLastDimensionOfInput(subgraph, input_name)); + auto dim, + GetLastDimensionOfInput(subgraph, input_name)); LITERT_ASSIGN_OR_RETURN( auto test_dim, - GetLastDimensionOfInput(test_subgraph, input_name)); + GetLastDimensionOfInput(test_subgraph, input_name)); // Check if dim is same as test_dim, or as a magic number when // test_dim is target number corresponding to the magic number. // Otherwise, the shapes are not same.