Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable
from ai_edge_torch import fx_infra
from ai_edge_torch import lowertools
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
import torch
import torch.utils._pytree as pytree

Expand Down Expand Up @@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
# Explicitly reshape back to the original shape. This places the ReshapeOp
# outside of the HLFB.
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
return output

node.target = embedding
Expand Down
88 changes: 88 additions & 0 deletions ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Verifies the reauthored SmolVLM2 Image Encoder model."""

import logging

from absl import app
from absl import flags
from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2
from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder
from PIL import Image
import requests
import torch
import transformers

_IMAGE_URL = flags.DEFINE_string(
"image_url",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
"The image URI to encode.",
)

_CHECKPOINT = flags.DEFINE_string(
"checkpoint",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"The checkpoint to verify.",
)


def main(_):
checkpoint = _CHECKPOINT.value
logging.info("Loading the original model from: %s", checkpoint)
original_model = transformers.AutoModelForImageTextToText.from_pretrained(
checkpoint
)
original_model = original_model.eval().model

logging.info("Building the reauthored checkpoint from: %s", checkpoint)
reauthored_checkpoint = "/google/data/rw/users/ay/ayqzhang/smolvlm2_merged.safetensors"

logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint)

logging.info("Loading the tokenizer from: %s", checkpoint)
processor = transformers.AutoProcessor.from_pretrained(checkpoint)

logging.info("Loading the image from: %s", _IMAGE_URL.value)
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]

logging.info("Forwarding the original model...")
outputs_original = original_model.get_image_features(pixel_values)
# outputs_original = outputs_original.last_hidden_state
logging.info("outputs_original's shape: %s", outputs_original.shape)

pixel_values = pixel_values.reshape(
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
)
logging.info("Forwarding the reauthored model...")
outputs_reauthored = reauthored_model.forward(
pixel_values=pixel_values
)
logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape)

try:
assert torch.allclose(
outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
)
except AssertionError as e:
logging.error("*** FAILED *** verify with an image")
raise e
else:
logging.info("*** PASSED *** verify with an image")


if __name__ == "__main__":
app.run(main)
16 changes: 14 additions & 2 deletions ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,18 @@ def forward(
pixel_values: torch.Tensor,
export_config: export_cfg.ExportConfig = None,
) -> torch.Tensor:
x = self.siglip_encoder(pixel_values)
# Embed the image according to SiplipVisionEmbeddings.
x = self.siglip_encoder.tok_embedding(pixel_values)
x = x.flatten(2).transpose(1, 2)
x = x + self.siglip_encoder.tok_embedding_position

# Pass a dummy mask because SDPA attention impl expects non-None mask.
mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1])
for _, block in enumerate(self.siglip_encoder.transformer_blocks):
x = block(x, mask=mask)
x = self.siglip_encoder.final_norm(x)

# Project the image embeddings to text hidden size.
x = self.connector(x)
return x

Expand Down Expand Up @@ -166,7 +177,8 @@ def get_image_encoder_config() -> cfg.ModelConfig:
output_proj_use_bias=True,
)
norm_config = cfg.NormalizationConfig(
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6
type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6,
# enable_hlfb=False
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
Expand Down
57 changes: 57 additions & 0 deletions ai_edge_torch/odml_torch/optimization_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimization barrier op definition and lowering."""

from typing import Sequence

from ai_edge_torch.odml_torch import _torch_library
from ai_edge_torch.odml_torch.lowerings import registry
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import torch

_torch_library.ODML_TORCH_LIB.define("optimization_barrier(Tensor[] inputs) -> Tensor[]")

optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default


def optimization_barrier(*inputs: torch.Tensor):
return optimization_barrier_op(tuple(inputs))


@torch.library.impl(
_torch_library.ODML_TORCH_LIB,
"optimization_barrier",
"CompositeExplicitAutograd",
)
def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]):
return tuple(inputs)


@torch.library.impl(
_torch_library.ODML_TORCH_LIB,
"optimization_barrier",
"Meta",
)
def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]):
return tuple([torch.empty_like(x) for x in inputs])


@registry.lower(torch.ops.odml_torch.optimization_barrier.default)
def _optimization_barrier_lowering(
lctx, inputs: tuple[ir.Value, ...]
) -> ir.Value:
del lctx
return stablehlo.optimization_barrier(inputs)
49 changes: 49 additions & 0 deletions ai_edge_torch/odml_torch/test/test_optimization_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ai_edge_torch import odml_torch
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op.
import torch

from absl.testing import absltest as googletest

optimization_barrier = optimization_barrier_lib.optimization_barrier


class TestOptimizationBarrier(googletest.TestCase):
"""Test optimization barrier op implementation and lowering."""

def test_optimization_barrier_op(self):

class TestModel(torch.nn.Module):

def forward(self, x, y):
x, _ = optimization_barrier(x, y)
return x

x = torch.randn(1, 5)
ep = torch.export.export(TestModel().eval(), (x, x))
mlir = odml_torch.export.exported_program_to_mlir(ep)
mlir_text = mlir.get_text()
self.assertEqual(
mlir_text.count(
"stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>,"
" tensor<1x5xf32>"
),
1,
)


if __name__ == "__main__":
googletest.main()
Loading