@@ -51,8 +51,8 @@ constexpr absl::string_view kInputPosSubstr = "pos";
5151constexpr  absl::string_view kOutputLogitsSubstr  = " logits" 
5252constexpr  int64_t  kDefaultTargetNumberBase  = 256 ;
5353
54- Expected<int64_t > GetLastDimensionOfInput (const  Subgraph& subgraph,
55-                                            absl::string_view input_name) {
54+      Expected<int64_t > GetLastDimensionOfInput (const  Subgraph& subgraph,
55+     absl::string_view input_name) {
5656  LITERT_ASSIGN_OR_RETURN (auto  tensor, subgraph.Input (input_name));
5757  LITERT_ASSIGN_OR_RETURN (auto  type, tensor.RankedTensorType ());
5858  return  type.Layout ().Dimensions ()[type.Layout ().Rank () - 1 ];
@@ -82,9 +82,8 @@ bool IsMagicNumber(int64_t number) {
8282  return  true ;
8383}
8484
85- Expected<void > SetMagicNumberIfPrime (const  Subgraph& subgraph,
86-                                      absl::string_view tensor_name, bool  input,
87-                                      int64_t & magic_number) {
85+     Expected<void > SetMagicNumberIfPrime (const  Subgraph& subgraph,
86+     absl::string_view tensor_name, bool  input, int64_t & magic_number) {
8887  auto  expected_dim = input ? GetLastDimensionOfInput (subgraph, tensor_name)
8988                            : GetFirstDimensionOfOutput (subgraph, tensor_name);
9089  LITERT_ASSIGN_OR_RETURN (auto  dim, expected_dim);
@@ -106,17 +105,17 @@ Expected<MagicNumbers> GetMagicNumbersFromModel(const Model& litert_model) {
106105  for  (int  i = 0 ; i < num_signatures; ++i) {
107106    LITERT_ASSIGN_OR_RETURN (auto  signature, litert_model.GetSignature (i));
108107    LITERT_ASSIGN_OR_RETURN (auto  subgraph,
109-                             litert_model.Subgraph (signature.Key ()));
108+                                  litert_model.Subgraph (signature.Key ()));
110109    if  (signature.Key ().starts_with (kPrefillSignaturePrefix )) {
111110      for  (const  auto & input_name : signature.InputNames ()) {
112111        if  (absl::StrContains (input_name, kMaskSubstr )) {
113-           LITERT_RETURN_IF_ERROR (
114-               SetMagicNumberIfPrime ( subgraph, input_name, /* input= */ true ,
115-                                      magic_numbers.context_length ));
112+           LITERT_RETURN_IF_ERROR (SetMagicNumberIfPrime ( 
113+               subgraph, input_name, true ,
114+               magic_numbers.context_length ));
116115        } else  if  (absl::StrContains (input_name, kInputPosSubstr )) {
117116          int64_t  prefill_length = 0 ;
118-           LITERT_RETURN_IF_ERROR (SetMagicNumberIfPrime ( 
119-                subgraph, input_name, /* input= */ true , prefill_length));
117+           LITERT_RETURN_IF_ERROR (
118+           SetMagicNumberIfPrime ( subgraph, input_name, true , prefill_length));
120119          if  (prefill_length > 0 ) {
121120            magic_numbers.prefill_lengths .push_back (prefill_length);
122121          }
@@ -125,16 +124,22 @@ Expected<MagicNumbers> GetMagicNumbersFromModel(const Model& litert_model) {
125124    } else  if  (signature.Key ().starts_with (kDecodeSignaturePrefix )) {
126125      for  (const  auto & input_name : signature.InputNames ()) {
127126        if  (absl::StrContains (input_name, kMaskSubstr )) {
128-           LITERT_RETURN_IF_ERROR (
129-               SetMagicNumberIfPrime (subgraph, input_name, /* input=*/ true ,
130-                                     magic_numbers.context_length ));
127+           LITERT_RETURN_IF_ERROR (SetMagicNumberIfPrime (
128+               subgraph, input_name, true ,
129+               magic_numbers.context_length ));
130+           LITERT_RETURN_IF_ERROR (SetMagicNumberIfPrime (
131+               subgraph
132+               ,
133+               input_name, /* input=*/ true , magic_numbers.context_length ));
131134        }
132135      }
133136      for  (const  auto & output_name : signature.OutputNames ()) {
134137        if  (absl::StrContains (output_name, kOutputLogitsSubstr )) {
135-           LITERT_RETURN_IF_ERROR (
136-               SetMagicNumberIfPrime (subgraph, output_name, /* input=*/ false ,
137-                                     magic_numbers.num_output_candidates ));
138+           LITERT_RETURN_IF_ERROR (SetMagicNumberIfPrime (
139+               subgraph
140+               ,
141+               output_name, /* input=*/ false ,
142+               magic_numbers.num_output_candidates ));
138143        }
139144      }
140145    }
@@ -158,8 +163,8 @@ GetVerificationPairs(const Model& litert_model,
158163      continue ;
159164    }
160165
161-     LITERT_ASSIGN_OR_RETURN (auto  subgraph,
162-                             litert_model.Subgraph (signature.Key ()));
166+          LITERT_ASSIGN_OR_RETURN (auto  subgraph,
167+                                  litert_model.Subgraph (signature.Key ()));
163168    for  (int  j = 0 ; j < num_signatures; ++j) {
164169      LITERT_ASSIGN_OR_RETURN (auto  test_signature,
165170                              litert_model.GetSignature (j));
@@ -169,16 +174,18 @@ GetVerificationPairs(const Model& litert_model,
169174      }
170175
171176      LITERT_ASSIGN_OR_RETURN (auto  test_subgraph,
172-                               litert_model.Subgraph (test_signature.Key ()));
177+                                 litert_model.Subgraph (test_signature.Key ()));
178+ 
173179      bool  is_same_shape = true ;
174180      for  (const  auto & input_name : signature.InputNames ()) {
175181        if  (absl::StrContains (input_name, kMaskSubstr ) ||
176182            absl::StrContains (input_name, kInputPosSubstr )) {
177183          LITERT_ASSIGN_OR_RETURN (
178-               auto  dim, GetLastDimensionOfInput (subgraph, input_name));
184+               auto  dim,
185+           GetLastDimensionOfInput (subgraph, input_name));
179186          LITERT_ASSIGN_OR_RETURN (
180187              auto  test_dim,
181-                GetLastDimensionOfInput (test_subgraph, input_name));
188+           GetLastDimensionOfInput (test_subgraph, input_name));
182189          //  Check if dim is same as test_dim, or as a magic number when
183190          //  test_dim is target number corresponding to the magic number.
184191          //  Otherwise, the shapes are not same.
0 commit comments