Skip to content

Commit 1a4f6eb

Browse files
protobird-gitcopybara-github
authored andcommitted
Export models with magic numbers if gpu_dynamic_shapes=true
- Remove duplicated code to convert pytorch model to tflite - Promote decode_batch_size as a common flag, not only for smollm - Exporting test signatures will be in a following CL PiperOrigin-RevId: 811525306
1 parent af81648 commit 1a4f6eb

18 files changed

+99
-349
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,12 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags("amd-llama-135m")
2523

2624

2725
def main(_):
28-
checkpoint_path = flags.FLAGS.checkpoint_path
29-
pytorch_model = amd_llama_135m.build_model(
30-
checkpoint_path,
31-
custom_loader=loader.maybe_get_custom_loader(
32-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33-
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
35-
)
36-
converter.convert_to_tflite(
37-
pytorch_model,
38-
output_path=flags.FLAGS.output_path,
39-
output_name_prefix=flags.FLAGS.output_name_prefix,
40-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42-
quantize=flags.FLAGS.quantize,
43-
lora_ranks=flags.FLAGS.lora_ranks,
44-
export_config=export_config.get_from_flags(),
45-
)
26+
converter.build_and_convert_to_tflite_from_flags(amd_llama_135m.build_model)
4627

4728

4829
if __name__ == '__main__':

ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,14 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.deepseek import deepseek
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags(
2523
'deepseek', default_mask_as_input=True, default_transpose_kv_cache=True
2624
)
2725

2826

2927
def main(_):
30-
checkpoint_path = flags.FLAGS.checkpoint_path
31-
pytorch_model = deepseek.build_model(
32-
checkpoint_path,
33-
custom_loader=loader.maybe_get_custom_loader(
34-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35-
),
36-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
37-
)
38-
converter.convert_to_tflite(
39-
pytorch_model,
40-
output_path=flags.FLAGS.output_path,
41-
output_name_prefix=flags.FLAGS.output_name_prefix,
42-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
44-
quantize=flags.FLAGS.quantize,
45-
lora_ranks=flags.FLAGS.lora_ranks,
46-
export_config=export_config.get_from_flags(),
47-
)
28+
converter.build_and_convert_to_tflite_from_flags(deepseek.build_model)
4829

4930

5031
if __name__ == '__main__':

ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,12 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.gemma import gemma1
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags("gemma-2b")
2523

2624

2725
def main(_):
28-
checkpoint_path = flags.FLAGS.checkpoint_path
29-
pytorch_model = gemma1.build_2b_model(
30-
checkpoint_path,
31-
custom_loader=loader.maybe_get_custom_loader(
32-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33-
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
35-
)
36-
converter.convert_to_tflite(
37-
pytorch_model,
38-
output_path=flags.FLAGS.output_path,
39-
output_name_prefix=flags.FLAGS.output_name_prefix,
40-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42-
quantize=flags.FLAGS.quantize,
43-
lora_ranks=flags.FLAGS.lora_ranks,
44-
export_config=export_config.get_from_flags(),
45-
)
26+
converter.build_and_convert_to_tflite_from_flags(gemma1.build_2b_model)
4627

4728

4829
if __name__ == '__main__':

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,14 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.gemma import gemma2
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags(
2523
"gemma2-2b", default_mask_as_input=True, default_transpose_kv_cache=True
2624
)
2725

2826

2927
def main(_):
30-
checkpoint_path = flags.FLAGS.checkpoint_path
31-
pytorch_model = gemma2.build_2b_model(
32-
checkpoint_path,
33-
custom_loader=loader.maybe_get_custom_loader(
34-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
35-
),
36-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
37-
)
38-
converter.convert_to_tflite(
39-
pytorch_model,
40-
output_path=flags.FLAGS.output_path,
41-
output_name_prefix=flags.FLAGS.output_name_prefix,
42-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
44-
quantize=flags.FLAGS.quantize,
45-
lora_ranks=flags.FLAGS.lora_ranks,
46-
export_config=export_config.get_from_flags(),
47-
)
28+
converter.build_and_convert_to_tflite_from_flags(gemma2.build_2b_model)
4829

4930

5031
if __name__ == '__main__':

ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.gemma3 import gemma3
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags(
2523
'gemma3-1b', default_mask_as_input=True, default_transpose_kv_cache=True
@@ -33,36 +31,14 @@
3331

3432

3533
def main(_):
36-
checkpoint_path = flags.FLAGS.checkpoint_path
3734
if _MODEL_SIZE.value == '1b':
38-
pytorch_model = gemma3.build_model_1b(
39-
checkpoint_path,
40-
custom_loader=loader.maybe_get_custom_loader(
41-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
42-
),
43-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
44-
)
35+
model_builder = gemma3.build_model_1b
4536
elif _MODEL_SIZE.value == '270m':
46-
pytorch_model = gemma3.build_model_270m(
47-
checkpoint_path,
48-
custom_loader=loader.maybe_get_custom_loader(
49-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
50-
),
51-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
52-
)
37+
model_builder = gemma3.build_model_270m
5338
else:
5439
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
5540

56-
converter.convert_to_tflite(
57-
pytorch_model,
58-
output_path=flags.FLAGS.output_path,
59-
output_name_prefix=flags.FLAGS.output_name_prefix,
60-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
61-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
62-
quantize=flags.FLAGS.quantize,
63-
lora_ranks=flags.FLAGS.lora_ranks,
64-
export_config=export_config.get_from_flags(),
65-
)
41+
converter.build_and_convert_to_tflite_from_flags(model_builder)
6642

6743

6844
if __name__ == '__main__':

ai_edge_torch/generative/examples/hammer/convert_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.hammer import hammer
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags('hammer')
2523

@@ -37,24 +35,7 @@
3735

3836

3937
def main(_):
40-
checkpoint_path = flags.FLAGS.checkpoint_path
41-
pytorch_model = _BUILDER[_MODEL_SIZE.value](
42-
checkpoint_path,
43-
custom_loader=loader.maybe_get_custom_loader(
44-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45-
),
46-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
47-
)
48-
converter.convert_to_tflite(
49-
pytorch_model,
50-
output_path=flags.FLAGS.output_path,
51-
output_name_prefix=flags.FLAGS.output_name_prefix,
52-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
54-
quantize=flags.FLAGS.quantize,
55-
lora_ranks=flags.FLAGS.lora_ranks,
56-
export_config=export_config.get_from_flags(),
57-
)
38+
converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
5839

5940

6041
if __name__ == '__main__':

ai_edge_torch/generative/examples/llama/convert_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.llama import llama
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags('llama')
2523

@@ -37,24 +35,7 @@
3735

3836

3937
def main(_):
40-
checkpoint_path = flags.FLAGS.checkpoint_path
41-
pytorch_model = _BUILDER[_MODEL_SIZE.value](
42-
checkpoint_path,
43-
custom_loader=loader.maybe_get_custom_loader(
44-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
45-
),
46-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
47-
)
48-
converter.convert_to_tflite(
49-
pytorch_model,
50-
output_path=flags.FLAGS.output_path,
51-
output_name_prefix=flags.FLAGS.output_name_prefix,
52-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
54-
quantize=flags.FLAGS.quantize,
55-
lora_ranks=flags.FLAGS.lora_ranks,
56-
export_config=export_config.get_from_flags(),
57-
)
38+
converter.build_and_convert_to_tflite_from_flags(_BUILDER[_MODEL_SIZE.value])
5839

5940

6041
if __name__ == '__main__':

ai_edge_torch/generative/examples/openelm/convert_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,12 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.openelm import openelm
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags("openelm")
2523

2624

2725
def main(_):
28-
checkpoint_path = flags.FLAGS.checkpoint_path
29-
pytorch_model = openelm.build_model(
30-
checkpoint_path,
31-
custom_loader=loader.maybe_get_custom_loader(
32-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33-
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
35-
)
36-
converter.convert_to_tflite(
37-
pytorch_model,
38-
output_path=flags.FLAGS.output_path,
39-
output_name_prefix=flags.FLAGS.output_name_prefix,
40-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42-
quantize=flags.FLAGS.quantize,
43-
lora_ranks=flags.FLAGS.lora_ranks,
44-
export_config=export_config.get_from_flags(),
45-
)
26+
converter.build_and_convert_to_tflite_from_flags(openelm.build_model)
4627

4728

4829
if __name__ == '__main__':

ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,12 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.phi import phi3
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags("phi3")
2523

2624

2725
def main(_):
28-
checkpoint_path = flags.FLAGS.checkpoint_path
29-
pytorch_model = phi3.build_model(
30-
checkpoint_path,
31-
custom_loader=loader.maybe_get_custom_loader(
32-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33-
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
35-
)
36-
converter.convert_to_tflite(
37-
pytorch_model,
38-
output_path=flags.FLAGS.output_path,
39-
output_name_prefix=flags.FLAGS.output_name_prefix,
40-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42-
quantize=flags.FLAGS.quantize,
43-
lora_ranks=flags.FLAGS.lora_ranks,
44-
export_config=export_config.get_from_flags(),
45-
)
26+
converter.build_and_convert_to_tflite_from_flags(phi3.build_model)
4627

4728

4829
if __name__ == '__main__':

ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,12 @@
1818
from absl import app
1919
from ai_edge_torch.generative.examples.phi import phi4
2020
from ai_edge_torch.generative.utilities import converter
21-
from ai_edge_torch.generative.utilities import export_config
22-
from ai_edge_torch.generative.utilities import loader
2321

2422
flags = converter.define_conversion_flags("phi4")
2523

2624

2725
def main(_):
28-
checkpoint_path = flags.FLAGS.checkpoint_path
29-
pytorch_model = phi4.build_model(
30-
checkpoint_path,
31-
custom_loader=loader.maybe_get_custom_loader(
32-
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
33-
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
35-
)
36-
converter.convert_to_tflite(
37-
pytorch_model,
38-
output_path=flags.FLAGS.output_path,
39-
output_name_prefix=flags.FLAGS.output_name_prefix,
40-
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
42-
quantize=flags.FLAGS.quantize,
43-
lora_ranks=flags.FLAGS.lora_ranks,
44-
export_config=export_config.get_from_flags(),
45-
)
26+
converter.build_and_convert_to_tflite_from_flags(phi4.build_model)
4627

4728

4829
if __name__ == '__main__':

0 commit comments

Comments
 (0)