Skip to content

Commit c896da5

Browse files
sirakiincopybara-github
authored andcommitted
Add EmbeddingGemma tp gemeratove
PiperOrigin-RevId: 812513918
1 parent af81648 commit c896da5

File tree

6 files changed

+481
-22
lines changed

6 files changed

+481
-22
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Example of converting a Gemma3 model to multi-signature tflite model."""
17+
18+
from absl import app
19+
from ai_edge_torch.generative.examples.embedding_gemma import embedding_gemma
20+
from ai_edge_torch.generative.utilities import converter
21+
from ai_edge_torch.generative.utilities import loader
22+
23+
flags = converter.define_conversion_flags(
24+
'embedding_gemma',
25+
default_mask_as_input=False,
26+
default_transpose_kv_cache=False,
27+
)
28+
29+
_NORMALIZE_OUTPUT = flags.DEFINE_bool(
30+
'normalize_output', True, 'Whether to normalize the output with L2 norm.'
31+
)
32+
33+
_SEQ_LEN = flags.DEFINE_integer(
34+
'seq_len', 2048, 'The sequence length of the model.'
35+
)
36+
37+
38+
def main(_):
39+
checkpoint_path = flags.FLAGS.checkpoint_path
40+
pytorch_model = embedding_gemma.build_embedding_gemma(
41+
checkpoint_path,
42+
normalize_output=_NORMALIZE_OUTPUT.value,
43+
custom_loader=loader.maybe_get_custom_loader(
44+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45+
),
46+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
47+
)
48+
embedding_gemma.convert_to_litert(
49+
pytorch_model,
50+
output_path=flags.FLAGS.output_path,
51+
output_name_prefix=flags.FLAGS.output_name_prefix,
52+
prefill_seq_len=_SEQ_LEN.value,
53+
quantize=flags.FLAGS.quantize,
54+
)
55+
56+
57+
if __name__ == '__main__':
58+
app.run(main)

0 commit comments

Comments
 (0)