Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable
from ai_edge_torch import fx_infra
from ai_edge_torch import lowertools
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
import torch
import torch.utils._pytree as pytree

Expand Down Expand Up @@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
# Explicitly reshape back to the original shape. This places the ReshapeOp
# outside of the HLFB.
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
return output

node.target = embedding
Expand Down
71 changes: 71 additions & 0 deletions ai_edge_torch/odml_torch/optimization_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimization barrier op definition and lowering."""

from ai_edge_torch.odml_torch import _torch_library
from ai_edge_torch.odml_torch.lowerings import registry
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import torch
import torch.utils._pytree as pytree

_torch_library.ODML_TORCH_LIB.define(
"optimization_barrier(Tensor[] inputs) -> Tensor[]"
)

optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default


def optimization_barrier(*inputs: pytree.PyTree):
"""Apply optimization barrier to the tensors nested within arbitrary pytrees.

Args:
*inputs: A list of tensors or tensor pytrees.

Returns:
The tensors after optimization barrier in the same pytrees structures.
"""
if len(inputs) == 1:
inputs = inputs[0]
tensors, spec = pytree.tree_flatten(inputs)
tensors = optimization_barrier_op(tuple(tensors))
outputs = pytree.tree_unflatten(tensors, spec)
return outputs


@torch.library.impl(
_torch_library.ODML_TORCH_LIB,
"optimization_barrier",
"CompositeExplicitAutograd",
)
def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]):
return tuple(inputs)


@torch.library.impl(
_torch_library.ODML_TORCH_LIB,
"optimization_barrier",
"Meta",
)
def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]):
return tuple([torch.empty_like(x) for x in inputs])


@registry.lower(torch.ops.odml_torch.optimization_barrier.default)
def _optimization_barrier_lowering(
lctx, inputs: tuple[ir.Value, ...]
) -> ir.Value:
del lctx
return stablehlo.optimization_barrier(inputs)
80 changes: 80 additions & 0 deletions ai_edge_torch/odml_torch/test/test_optimization_barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ai_edge_torch import odml_torch
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op.
import torch

from absl.testing import absltest as googletest

optimization_barrier = optimization_barrier_lib.optimization_barrier


class TestOptimizationBarrier(googletest.TestCase):
"""Test optimization barrier op implementation and lowering."""

def test_applied_optimization_barrier_op(self):
"""Test optimization barrier op application and lowering."""

class TestModel(torch.nn.Module):

def forward(self, x, y):
x, _ = optimization_barrier(x, y)
return x

x = torch.randn(1, 5)
ep = torch.export.export(TestModel().eval(), (x, x))
mlir = odml_torch.export.exported_program_to_mlir(ep)
mlir_text = mlir.get_text()
self.assertEqual(
mlir_text.count(
"stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>,"
" tensor<1x5xf32>"
),
1,
)

def test_input_single_tensor(self):
"""Test optimization barrier with single tensor input."""
x = torch.randn(1, 5)
y = optimization_barrier(x)
self.assertIsInstance(y, torch.Tensor)
self.assertEqual(y.shape, (1, 5))

def test_input_multiple_tensors(self):
"""Test optimization barrier with multiple tensors input."""
x = torch.randn(1, 5)
y = torch.randn(1, 6)
z = optimization_barrier(x, y)
self.assertIsInstance(z, tuple)
self.assertLen(z, 2)
self.assertIsInstance(z[0], torch.Tensor)
self.assertIsInstance(z[1], torch.Tensor)
self.assertEqual(z[0].shape, (1, 5))
self.assertEqual(z[1].shape, (1, 6))

def test_input_nested_tensors(self):
"""Test optimization barrier with nested tensor inputs."""
x = {"foo": torch.randn(1, 5), "bar": torch.randn(1, 6)}
z = optimization_barrier(x)
self.assertIsInstance(z, dict)
self.assertLen(z, 2)
self.assertIsInstance(z["foo"], torch.Tensor)
self.assertIsInstance(z["bar"], torch.Tensor)
self.assertEqual(z["foo"].shape, (1, 5))
self.assertEqual(z["bar"].shape, (1, 6))


if __name__ == "__main__":
googletest.main()
Loading