From 2ba36dcccc583c75ffc31a09ef69ce5af06c13dd Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Wed, 30 Jul 2025 15:12:57 -0700 Subject: [PATCH] Update optimization barrier output to match the input structure. PiperOrigin-RevId: 789062186 --- .../fx_passes/build_aten_composite_pass.py | 2 + .../odml_torch/optimization_barrier.py | 71 ++++++++++++++++ .../test/test_optimization_barrier.py | 80 +++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 ai_edge_torch/odml_torch/optimization_barrier.py create mode 100644 ai_edge_torch/odml_torch/test/test_optimization_barrier.py diff --git a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py index 3e92d9cd..fcc45513 100644 --- a/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +++ b/ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py @@ -16,6 +16,7 @@ from typing import Any, Callable from ai_edge_torch import fx_infra from ai_edge_torch import lowertools +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib import torch import torch.utils._pytree as pytree @@ -276,6 +277,7 @@ def embedding(*args, **kwargs): # Explicitly reshape back to the original shape. This places the ReshapeOp # outside of the HLFB. output = torch.reshape(output, (*(original_idx_shape), embedding_dim)) + output, _ = optimization_barrier_lib.optimization_barrier(output, idx) return output node.target = embedding diff --git a/ai_edge_torch/odml_torch/optimization_barrier.py b/ai_edge_torch/odml_torch/optimization_barrier.py new file mode 100644 index 00000000..88778b37 --- /dev/null +++ b/ai_edge_torch/odml_torch/optimization_barrier.py @@ -0,0 +1,71 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Optimization barrier op definition and lowering.""" + +from ai_edge_torch.odml_torch import _torch_library +from ai_edge_torch.odml_torch.lowerings import registry +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo as stablehlo +import torch +import torch.utils._pytree as pytree + +_torch_library.ODML_TORCH_LIB.define( + "optimization_barrier(Tensor[] inputs) -> Tensor[]" +) + +optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default + + +def optimization_barrier(*inputs: pytree.PyTree): + """Apply optimization barrier to the tensors nested within arbitrary pytrees. + + Args: + *inputs: A list of tensors or tensor pytrees. + + Returns: + The tensors after optimization barrier in the same pytrees structures. + """ + if len(inputs) == 1: + inputs = inputs[0] + tensors, spec = pytree.tree_flatten(inputs) + tensors = optimization_barrier_op(tuple(tensors)) + outputs = pytree.tree_unflatten(tensors, spec) + return outputs + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "CompositeExplicitAutograd", +) +def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]): + return tuple(inputs) + + +@torch.library.impl( + _torch_library.ODML_TORCH_LIB, + "optimization_barrier", + "Meta", +) +def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]): + return tuple([torch.empty_like(x) for x in inputs]) + + +@registry.lower(torch.ops.odml_torch.optimization_barrier.default) +def _optimization_barrier_lowering( + lctx, inputs: tuple[ir.Value, ...] +) -> ir.Value: + del lctx + return stablehlo.optimization_barrier(inputs) diff --git a/ai_edge_torch/odml_torch/test/test_optimization_barrier.py b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py new file mode 100644 index 00000000..d25d8c8e --- /dev/null +++ b/ai_edge_torch/odml_torch/test/test_optimization_barrier.py @@ -0,0 +1,80 @@ +# Copyright 2025 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from ai_edge_torch import odml_torch +from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op. +import torch + +from absl.testing import absltest as googletest + +optimization_barrier = optimization_barrier_lib.optimization_barrier + + +class TestOptimizationBarrier(googletest.TestCase): + """Test optimization barrier op implementation and lowering.""" + + def test_applied_optimization_barrier_op(self): + """Test optimization barrier op application and lowering.""" + + class TestModel(torch.nn.Module): + + def forward(self, x, y): + x, _ = optimization_barrier(x, y) + return x + + x = torch.randn(1, 5) + ep = torch.export.export(TestModel().eval(), (x, x)) + mlir = odml_torch.export.exported_program_to_mlir(ep) + mlir_text = mlir.get_text() + self.assertEqual( + mlir_text.count( + "stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>," + " tensor<1x5xf32>" + ), + 1, + ) + + def test_input_single_tensor(self): + """Test optimization barrier with single tensor input.""" + x = torch.randn(1, 5) + y = optimization_barrier(x) + self.assertIsInstance(y, torch.Tensor) + self.assertEqual(y.shape, (1, 5)) + + def test_input_multiple_tensors(self): + """Test optimization barrier with multiple tensors input.""" + x = torch.randn(1, 5) + y = torch.randn(1, 6) + z = optimization_barrier(x, y) + self.assertIsInstance(z, tuple) + self.assertLen(z, 2) + self.assertIsInstance(z[0], torch.Tensor) + self.assertIsInstance(z[1], torch.Tensor) + self.assertEqual(z[0].shape, (1, 5)) + self.assertEqual(z[1].shape, (1, 6)) + + def test_input_nested_tensors(self): + """Test optimization barrier with nested tensor inputs.""" + x = {"foo": torch.randn(1, 5), "bar": torch.randn(1, 6)} + z = optimization_barrier(x) + self.assertIsInstance(z, dict) + self.assertLen(z, 2) + self.assertIsInstance(z["foo"], torch.Tensor) + self.assertIsInstance(z["bar"], torch.Tensor) + self.assertEqual(z["foo"].shape, (1, 5)) + self.assertEqual(z["bar"].shape, (1, 6)) + + +if __name__ == "__main__": + googletest.main()