diff --git a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py index 3e92d9cd..fcc45513 100644 --- a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py @@ -16,6 +16,7 @@ from typing import Any, Callable from ai_edge_torch import fx_infra from ai_edge_torch import lowertools +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib import torch import torch.utils._pytree as pytree @@ -276,6 +277,7 @@ def embedding(*args, **kwargs): # Explicitly reshape back to the original shape. This places the ReshapeOp # outside of the HLFB. output = torch.reshape(output, (*(original_idx_shape), embedding_dim)) + output, _ = optimization_barrier_lib.optimization_barrier(output, idx) return output node.target = embedding diff --git a/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py b/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py new file mode 100644 index 00000000..8b66cf6e --- /dev/null +++ b/ai_edge_torch/generative/examples/smolvlm2/verify_encoder.py @@ -0,0 +1,88 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Verifies the reauthored SmolVLM2 Image Encoder model.""" + +import logging + +from absl import app +from absl import flags +from ai_edge_torch.generative.examples.smolvlm2 import smolvlm2 +from ai_edge_torch.generative.examples.smolvlm2 import vision_encoder +from PIL import Image +import requests +import torch +import transformers + +_IMAGE_URL = flags.DEFINE_string( + "image_url", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true", + "The image URI to encode.", +) + +_CHECKPOINT = flags.DEFINE_string( + "checkpoint", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "The checkpoint to verify.", +) + + +def main(_): + checkpoint = _CHECKPOINT.value + logging.info("Loading the original model from: %s", checkpoint) + original_model = transformers.AutoModelForImageTextToText.from_pretrained( + checkpoint + ) + original_model = original_model.eval().model + + logging.info("Building the reauthored checkpoint from: %s", checkpoint) + reauthored_checkpoint = "/google/data/rw/users/ay/ayqzhang/smolvlm2_merged.safetensors" + + logging.info("Building the reauthored model from: %s", reauthored_checkpoint) + reauthored_model = vision_encoder.build_image_encoder(reauthored_checkpoint) + + logging.info("Loading the tokenizer from: %s", checkpoint) + processor = transformers.AutoProcessor.from_pretrained(checkpoint) + + logging.info("Loading the image from: %s", _IMAGE_URL.value) + image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw) + pixel_values = processor(images=image, return_tensors="pt")["pixel_values"] + + logging.info("Forwarding the original model...") + outputs_original = original_model.get_image_features(pixel_values) + # outputs_original = outputs_original.last_hidden_state + logging.info("outputs_original's shape: %s", outputs_original.shape) + + pixel_values = pixel_values.reshape( + pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:] + ) + logging.info("Forwarding the reauthored model...") + outputs_reauthored = reauthored_model.forward( + pixel_values=pixel_values + ) + logging.info("outputs_reauthored's shape: %s", outputs_reauthored.shape) + + try: + assert torch.allclose( + outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04 + ) + except AssertionError as e: + logging.error("*** FAILED *** verify with an image") + raise e + else: + logging.info("*** PASSED *** verify with an image") + + +if __name__ == "__main__": + app.run(main) diff --git a/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py b/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py index 676c5c32..99c6e762 100644 --- a/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py +++ b/ai_edge_torch/generative/examples/smolvlm2/vision_encoder.py @@ -129,7 +129,18 @@ def forward( pixel_values: torch.Tensor, export_config: export_cfg.ExportConfig = None, ) -> torch.Tensor: - x = self.siglip_encoder(pixel_values) + # Embed the image according to SiplipVisionEmbeddings. + x = self.siglip_encoder.tok_embedding(pixel_values) + x = x.flatten(2).transpose(1, 2) + x = x + self.siglip_encoder.tok_embedding_position + + # Pass a dummy mask because SDPA attention impl expects non-None mask. + mask = torch.zeros(x.shape[0], 1, x.shape[1], x.shape[1]) + for _, block in enumerate(self.siglip_encoder.transformer_blocks): + x = block(x, mask=mask) + x = self.siglip_encoder.final_norm(x) + + # Project the image embeddings to text hidden size. x = self.connector(x) return x @@ -166,7 +177,8 @@ def get_image_encoder_config() -> cfg.ModelConfig: output_proj_use_bias=True, ) norm_config = cfg.NormalizationConfig( - type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6 + type=cfg.NormalizationType.LAYER_NORM, epsilon=1e-6, + # enable_hlfb=False ) ff_config = cfg.FeedForwardConfig( type=cfg.FeedForwardType.SEQUENTIAL, diff --git a/ai_edge_torch/odml_torch/optimization_barrier.py b/ai_edge_torch/odml_torch/optimization_barrier.py new file mode 100644 index 00000000..0fe7b3f9 --- /dev/null +++ b/ai_edge_torch/odml_torch/optimization_barrier.py @@ -0,0 +1,57 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Optimization barrier op definition and lowering.""" + +from typing import Sequence + +from ai_edge_torch.odml_torch import _torch_library +from ai_edge_torch.odml_torch.lowerings import registry +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo as stablehlo +import torch + +_torch_library.ODML_TORCH_LIB.define("optimization_barrier(Tensor[] inputs) -> Tensor[]") + +optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default + + +def optimization_barrier(*inputs: torch.Tensor): + return optimization_barrier_op(tuple(inputs)) + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "CompositeExplicitAutograd", +) +def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]): + return tuple(inputs) + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "Meta", +) +def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]): + return tuple([torch.empty_like(x) for x in inputs]) + + +@registry.lower(torch.ops.odml_torch.optimization_barrier.default) +def _optimization_barrier_lowering( + lctx, inputs: tuple[ir.Value, ...] +) -> ir.Value: + del lctx + return stablehlo.optimization_barrier(inputs) diff --git a/ai_edge_torch/odml_torch/test/test_optimization_barrier.py b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py new file mode 100644 index 00000000..82150d2a --- /dev/null +++ b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py @@ -0,0 +1,49 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from ai_edge_torch import odml_torch +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op. +import torch + +from absl.testing import absltest as googletest + +optimization_barrier = optimization_barrier_lib.optimization_barrier + + +class TestOptimizationBarrier(googletest.TestCase): + """Test optimization barrier op implementation and lowering.""" + + def test_optimization_barrier_op(self): + + class TestModel(torch.nn.Module): + + def forward(self, x, y): + x, _ = optimization_barrier(x, y) + return x + + x = torch.randn(1, 5) + ep = torch.export.export(TestModel().eval(), (x, x)) + mlir = odml_torch.export.exported_program_to_mlir(ep) + mlir_text = mlir.get_text() + self.assertEqual( + mlir_text.count( + "stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>," + " tensor<1x5xf32>" + ), + 1, + ) + + +if __name__ == "__main__": + googletest.main()