Skip to content

Commit 36e3dd5

Browse files
authored
Enable option to disable dynamic shapes in optimum-et (#114)
1 parent ab6261d commit 36e3dd5

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

optimum/commands/export/executorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def parse_args_executorch(parser):
6767
action="store_true",
6868
help="For decoder-only models to use custom kv cache for static cache that updates cache using custom op. Defaults to False.",
6969
)
70+
required_group.add_argument(
71+
"--disable_dynamic_shapes",
72+
required=False,
73+
action="store_true",
74+
help="When this flag is set on decoder-only models, dynamic shapes are disabled during export.",
75+
)
7076
required_group.add_argument(
7177
"--qlinear",
7278
type=str,
@@ -109,6 +115,8 @@ def run(self):
109115
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
110116
if self.args.use_custom_kv_cache:
111117
kwargs["use_custom_kv_cache"] = self.args.use_custom_kv_cache
118+
if self.args.disable_dynamic_shapes:
119+
kwargs["disable_dynamic_shapes"] = self.args.disable_dynamic_shapes
112120
if self.args.qlinear:
113121
kwargs["qlinear"] = self.args.qlinear
114122
if self.args.qembedding:

optimum/exporters/executorch/integrations.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ class CausalLMExportableModule(torch.nn.Module):
4040
This module ensures that the exported model is compatible with ExecuTorch.
4141
"""
4242

43-
def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False):
43+
def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False, disable_dynamic_shapes=False):
4444
super().__init__()
4545
self.model = model
4646
self.config = model.config
4747
self.use_custom_kv_cache = use_custom_kv_cache
4848
self.use_custom_sdpa = use_custom_sdpa
49+
self.disable_dynamic_shapes = disable_dynamic_shapes
4950
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
5051
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")
5152

@@ -71,7 +72,11 @@ def _prepare_export_inputs(self):
7172
and not (self.use_custom_kv_cache and self.use_custom_sdpa)
7273
)
7374

74-
if is_transformers_version(">", "4.52.0") and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache:
75+
if (
76+
not self.disable_dynamic_shapes
77+
and is_transformers_version(">", "4.52.0")
78+
and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache
79+
):
7580
# Prepare inputs with dynamic shapes
7681
seq_length = 3 # Sequence length > 1 to avoid specialization issues
7782
example_input_ids = torch.zeros((1, seq_length), dtype=torch.long)

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
5353
device = "cpu"
5454
batch_size = 1
5555
dtype = kwargs.get("dtype", "float32")
56+
disable_dynamic_shapes = kwargs.get("disable_dynamic_shapes", False)
5657
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
5758
use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False)
5859
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
@@ -133,4 +134,4 @@ def _load_eager_pretrained(
133134
qembedding_config = kwargs.get("qembedding", None)
134135
quantize_model_(eager_model, qlinear_config=qlinear_config, qembedding_config=qembedding_config)
135136

136-
return CausalLMExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa)
137+
return CausalLMExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa, disable_dynamic_shapes)

0 commit comments

Comments
 (0)