Skip to content

Commit 276c188

Browse files
committed
Add multimodal support to optimum-executorch
This commit adds comprehensive support for image-text-to-text models to optimum-executorch, extending the existing recipe system to handle multimodal vision-language models. Key changes: - Added new image-text-to-text task to task registry - Created ImageTextToTextExportableModule for multimodal model export - Extended integrations to support both vision encoder and text decoder export - Added comprehensive tests for multimodal functionality - CLI now supports --task image-text-to-text for multimodal models This enables users to export models like Gemma-3, LLaVA, and other vision-language models using the familiar optimum-executorch workflow: optimum-cli export executorch --model google/gemma-3-4b-it --task image-text-to-text --recipe xnnpack
1 parent ab6261d commit 276c188

File tree

4 files changed

+492
-1
lines changed

4 files changed

+492
-1
lines changed

optimum/exporters/executorch/integrations.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,146 @@ def generate(self, prompt_token_ids, max_new_tokens):
496496
break
497497

498498
return generated_ids
499+
500+
501+
class ImageTextToTextExportableModule(torch.nn.Module):
502+
"""
503+
A wrapper module designed to make an image-text-to-text model exportable with `torch.export`.
504+
This module ensures that the exported model is compatible with ExecuTorch.
505+
"""
506+
507+
def __init__(self, model, use_custom_kv_cache=False, use_custom_sdpa=False):
508+
super().__init__()
509+
self.model = model
510+
self.config = model.config
511+
self.use_custom_kv_cache = use_custom_kv_cache
512+
self.use_custom_sdpa = use_custom_sdpa
513+
from .utils import save_config_to_constant_methods
514+
self.metadata = save_config_to_constant_methods(
515+
model.config.text_config, model.generation_config
516+
)
517+
logging.info(f"Metadata to be recorded in PTE: {self.metadata}")
518+
519+
def _prepare_vision_embedding_export_inputs(self):
520+
"""
521+
Prepare example inputs and configurations for export.
522+
523+
Returns:
524+
pixel_values (torch.Tensor): Example pixel values tensor.
525+
dynamic_shapes (dict or None): Dynamic shape specifications for export.
526+
strict (bool): Whether to use strict export mode.
527+
"""
528+
image_size = self.config.vision_config.image_size
529+
pixel_values = torch.rand((1, 3, image_size, image_size))
530+
dynamic_shapes = None
531+
strict = False
532+
533+
return pixel_values, dynamic_shapes, strict
534+
535+
def _prepare_text_embedding_export_inputs(self):
536+
"""
537+
Prepare example inputs and configurations for export.
538+
539+
Returns:
540+
inputs_embeds (torch.Tensor): Example inputs embeddings tensor.
541+
cache_position (torch.Tensor): Example cache position tensor.
542+
dynamic_shapes (dict or None): Dynamic shape specifications for export.
543+
strict (bool): Whether to use strict export mode.
544+
"""
545+
# Prepare inputs with dynamic shapes
546+
seq_length = 3 # Sequence length > 1 to avoid specialization issues
547+
hidden_size = self.config.text_config.hidden_size
548+
example_inputs_embeds = torch.zeros((1, seq_length, hidden_size), dtype=torch.float32)
549+
example_cache_position = torch.arange(seq_length, dtype=torch.long)
550+
max_seq_len = self.metadata.get("get_max_seq_len")
551+
sliding_window = self.metadata.get("sliding_window", float("inf"))
552+
max_dim = min(max_seq_len, sliding_window) - 1
553+
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_dim)
554+
dynamic_shapes = {
555+
"inputs_embeds": {1: seq_len_dim},
556+
"cache_position": {0: seq_len_dim},
557+
}
558+
strict = False
559+
560+
return example_inputs_embeds, example_cache_position, dynamic_shapes, strict
561+
562+
def export(
563+
self,
564+
) -> Dict[str, ExportedProgram]:
565+
"""
566+
Export both the vision encoder and text decoder components.
567+
568+
Returns:
569+
Dict[str, ExportedProgram]: Dictionary containing exported programs for vision and text components.
570+
"""
571+
# Export vision encoder
572+
pixel_values, vision_dynamic_shapes, vision_strict = self._prepare_vision_embedding_export_inputs()
573+
logging.info(
574+
f"Exporting vision encoder using pixel_values({pixel_values.shape}), dynamic_shapes={vision_dynamic_shapes}, strict={vision_strict}"
575+
)
576+
577+
# Create vision encoder wrapper
578+
vision_encoder = VisionEncoderExportableModule(self.model)
579+
with torch.no_grad():
580+
vision_exported_program = vision_encoder.export(pixel_values)["model"]
581+
582+
# Export text decoder
583+
inputs_embeds, cache_position, text_dynamic_shapes, text_strict = self._prepare_text_embedding_export_inputs()
584+
logging.info(
585+
f"Exporting text decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape}), dynamic_shapes={text_dynamic_shapes}, strict={text_strict}"
586+
)
587+
588+
# Use the enhanced transformers integration for multimodal support
589+
if is_transformers_version(">", "4.52.0"):
590+
from transformers.integrations.executorch import (
591+
TorchExportableModuleForImageTextLM,
592+
)
593+
594+
exportable_module = TorchExportableModuleForImageTextLM(
595+
self.model.language_model,
596+
max_batch_size=1,
597+
max_cache_len=self.metadata.get("get_max_seq_len"),
598+
)
599+
self._register_attention_mask_for_4_53(exportable_module)
600+
601+
if self.use_custom_kv_cache:
602+
from optimum.executorch.attentions.custom_kv_cache import (
603+
replace_with_et_custom_kv_cache,
604+
)
605+
606+
replace_with_et_custom_kv_cache(
607+
exportable_module.model,
608+
self.model.language_model.config,
609+
self.model.generation_config,
610+
self.model.dtype,
611+
)
612+
613+
with torch.no_grad():
614+
text_exported_program = exportable_module.export(inputs_embeds, cache_position, text_dynamic_shapes, text_strict)
615+
else:
616+
raise ValueError("Image-text-to-text export requires transformers > 4.52.0")
617+
618+
return {
619+
"vision_encoder": vision_exported_program,
620+
"text_decoder": text_exported_program
621+
}
622+
623+
def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
624+
"""Register attention mask for transformers >= 4.53.0"""
625+
if is_transformers_version(">=", "4.53.0.dev0"):
626+
from transformers.integrations.executorch import sdpa_mask_without_vmap
627+
from transformers.masking_utils import AttentionMaskInterface
628+
from transformers.modeling_utils import AttentionInterface
629+
630+
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
631+
if self.use_custom_sdpa:
632+
if self.use_custom_kv_cache:
633+
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
634+
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
635+
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
636+
# This handles both regular sdpa and one for sliding window/local attention
637+
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
638+
else:
639+
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
640+
# This handles both regular sdpa and one for sliding window/local attention
641+
exportable_module.model.model.config._attn_implementation = "custom_sdpa"

optimum/exporters/executorch/tasks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from . import causal_lm, image_classification, masked_lm, seq2seq_lm
15+
from . import causal_lm, image_classification, image_text_to_text, masked_lm, seq2seq_lm
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
17+
import torchao
18+
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
19+
20+
from ..integrations import ImageTextToTextExportableModule
21+
from ..quantization import quantize_model_
22+
from ..task_registry import register_task
23+
24+
25+
# NOTE: It's important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py.
26+
# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier.
27+
@register_task("image-text-to-text")
28+
def load_image_text_to_text_model(model_name_or_path: str, **kwargs) -> ImageTextToTextExportableModule:
29+
"""
30+
Loads a causal language model for image-to-text generation and registers it under the task
31+
'image-text-to-text' using Hugging Face's AutoModelForCausalLM.
32+
33+
Args:
34+
model_name_or_path (str):
35+
Model ID on huggingface.co or path on disk to the model repository to export. For example:
36+
`model_name_or_path="google/gemma-3-4b-it"` or `model_name_or_path="/path/to/model_folder`
37+
**kwargs:
38+
Additional configuration options for the model:
39+
- dtype (str, optional):
40+
Data type for model weights (default: "float32").
41+
Options include "float16" and "bfloat16".
42+
- attn_implementation (str, optional):
43+
Attention mechanism implementation (default: "sdpa").
44+
- cache_implementation (str, optional):
45+
Cache management strategy (default: "static").
46+
- max_length (int, optional):
47+
Maximum sequence length for generation (default: 2048).
48+
49+
Returns:
50+
ImageTextToTextExportableModule:
51+
An instance of `ImageTextToTextExportableModule` for exporting and lowering to ExecuTorch.
52+
"""
53+
device = "cpu"
54+
batch_size = 1
55+
dtype = kwargs.get("dtype", "float32")
56+
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
57+
use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False)
58+
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
59+
cache_implementation = kwargs.get("cache_implementation", "static")
60+
use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa"
61+
max_length = kwargs.get("max_length", 2048)
62+
config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path)
63+
64+
# Make sure config has text_config and vision_config:
65+
if not hasattr(config, "text_config") or not hasattr(config, "vision_config"):
66+
raise ValueError(
67+
f"The model {model_name_or_path} does not have a `text_config` or `vision_config` attribute in its config. "
68+
"This is required for image-text-to-text models."
69+
)
70+
71+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
72+
# NOTE: To make the model exportable we need to set the rope scaling to default to avoid hitting
73+
# the data-dependent control flow in _longrope_frequency_update. Alternatively, users should rewrite
74+
# that function to avoid the data-dependent control flow.
75+
config.rope_scaling["type"] = "default"
76+
77+
if hasattr(config, "use_cache") and config.use_cache is False:
78+
config.use_cache = True
79+
80+
def _load_eager_pretrained(
81+
model_name_or_path,
82+
device,
83+
dtype,
84+
config,
85+
attn_implementation,
86+
cache_implementation,
87+
batch_size,
88+
max_length,
89+
):
90+
eager_model = AutoModelForCausalLM.from_pretrained(
91+
model_name_or_path,
92+
device_map=device,
93+
torch_dtype=dtype,
94+
config=config,
95+
attn_implementation=attn_implementation,
96+
generation_config=GenerationConfig(
97+
use_cache=True,
98+
cache_implementation=cache_implementation,
99+
max_length=max_length,
100+
cache_config={
101+
"batch_size": batch_size,
102+
"max_cache_len": max_length,
103+
},
104+
),
105+
)
106+
return eager_model
107+
108+
try:
109+
eager_model = _load_eager_pretrained(
110+
model_name_or_path,
111+
device,
112+
dtype,
113+
config,
114+
attn_implementation,
115+
cache_implementation,
116+
batch_size,
117+
max_length,
118+
)
119+
except ValueError as e:
120+
if "torch.nn.functional.scaled_dot_product_attention" in str(e):
121+
logging.info("⚠ SDPA attention not supported, falling back to eager implementation")
122+
attn_implementation = "eager"
123+
eager_model = _load_eager_pretrained(
124+
model_name_or_path,
125+
device,
126+
dtype,
127+
config,
128+
attn_implementation,
129+
cache_implementation,
130+
batch_size,
131+
max_length,
132+
)
133+
else:
134+
raise
135+
136+
# Make sure model has language_model as well as vision_tower:
137+
if not hasattr(eager_model, "language_model") or not hasattr(eager_model, "vision_tower"):
138+
raise ValueError(
139+
f"The model {model_name_or_path} does not have a `language_model` or `vision_tower` attribute. "
140+
"This is required for image-text-to-text models."
141+
)
142+
143+
for param in eager_model.parameters():
144+
# Must disable gradient for quantized checkpoint
145+
if isinstance(param, torchao.utils.TorchAOBaseTensor):
146+
param.requires_grad = False
147+
148+
qlinear_config = kwargs.get("qlinear", None)
149+
qembedding_config = kwargs.get("qembedding", None)
150+
quantize_model_(eager_model, qlinear_config=qlinear_config, qembedding_config=qembedding_config)
151+
152+
return ImageTextToTextExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa)

0 commit comments

Comments
 (0)