Skip to content

Commit 5ec4385

Browse files
committed
Working, encoder not delegated to XNNPack
1 parent 4b7520f commit 5ec4385

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

optimum/executorch/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,8 +1346,8 @@ def text_generation(
13461346

13471347
# Sanity check
13481348
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id:
1349-
raise ValueError(
1350-
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
1349+
logging.warning(
1350+
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} is not the same as the model's bos_token_id={self.bos_token_id}."
13511351
)
13521352
if isinstance(self.tokenizer, PreTrainedTokenizer) and not verify_eos_tokens_in_pretrained_tokenizer(
13531353
self.eos_token_id, self.tokenizer

optimum/exporters/executorch/integrations.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,56 @@
3737

3838
from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods
3939

40+
def _patch_idefics3_vision_embeddings_for_export(vision_model):
41+
"""
42+
Patch Idefics3VisionEmbeddings to make it export-friendly by removing data-dependent operations.
43+
This assumes batch_size=1 and a full attention mask (all 1s).
44+
"""
45+
import types
46+
47+
def export_friendly_forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
48+
batch_size, _, max_im_h, max_im_w = pixel_values.shape
49+
50+
patch_embeds = self.patch_embedding(pixel_values)
51+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
52+
53+
nb_patches_h = max_im_h // self.patch_size
54+
nb_patches_w = max_im_w // self.patch_size
55+
N = self.num_patches_per_side
56+
57+
# For export, we assume full attention mask and compute position IDs statically.
58+
# This avoids the data-dependent loop over batch dimension.
59+
h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=torch.long)
60+
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=torch.long)
61+
62+
# This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N), which
63+
# we don't have a kernel for at the moment.
64+
bucket_coords_h = (h_indices * N) // nb_patches_h
65+
bucket_coords_w = (w_indices * N) // nb_patches_w
66+
67+
bucket_coords_h = torch.clamp(bucket_coords_h, max=N - 1)
68+
bucket_coords_w = torch.clamp(bucket_coords_w, max=N - 1)
69+
70+
pos_ids = (bucket_coords_h[:, None] * N + bucket_coords_w[None, :]).reshape(-1)
71+
position_ids = pos_ids.unsqueeze(0).expand(batch_size, -1)
72+
embeddings = embeddings + self.position_embedding(position_ids)
73+
return embeddings
74+
75+
# Patch the forward method.
76+
vision_model.embeddings.forward = types.MethodType(export_friendly_forward, vision_model.embeddings)
77+
4078

4179
class VisionExportableModule(torch.nn.Module):
4280
def __init__(self, model: torch.nn.Module):
4381
super().__init__()
4482
self.model = model
4583

84+
# Patch Idefics3 vision embeddings if needed
85+
if hasattr(model, 'model') and hasattr(model.model, 'vision_model'):
86+
model_type = getattr(model.config, 'model_type', '')
87+
if 'idefics3' in model_type.lower():
88+
_patch_idefics3_vision_embeddings_for_export(model.model.vision_model)
89+
4690
def prepare_export_inputs(self):
4791
# 1. Get export inputs
4892
model_id = self.model.config.name_or_path
@@ -83,7 +127,9 @@ def forward(
83127
self,
84128
input_features: torch.FloatTensor,
85129
):
86-
image_embeds = self.model.get_image_features(input_features)
130+
# Pass pixel_attention_mask=None to avoid data-dependent operations during export.
131+
# The model will create a mask full of 1s internally if None is passed.
132+
image_embeds = self.model.get_image_features(input_features, pixel_attention_mask=None)
87133
if isinstance(image_embeds, list):
88134
image_embeds = torch.stack(image_embeds)
89135
return image_embeds

0 commit comments

Comments
 (0)