Skip to content

Commit 2ba36dc

Browse files
chunnienccopybara-github
authored andcommitted
Update optimization barrier output to match the input structure.
PiperOrigin-RevId: 789062186
1 parent 46ff6d0 commit 2ba36dc

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Callable
1717
from ai_edge_torch import fx_infra
1818
from ai_edge_torch import lowertools
19+
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
1920
import torch
2021
import torch.utils._pytree as pytree
2122

@@ -276,6 +277,7 @@ def embedding(*args, **kwargs):
276277
# Explicitly reshape back to the original shape. This places the ReshapeOp
277278
# outside of the HLFB.
278279
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
280+
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
279281
return output
280282

281283
node.target = embedding
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Optimization barrier op definition and lowering."""
16+
17+
from ai_edge_torch.odml_torch import _torch_library
18+
from ai_edge_torch.odml_torch.lowerings import registry
19+
from jax._src.lib.mlir import ir
20+
from jax._src.lib.mlir.dialects import hlo as stablehlo
21+
import torch
22+
import torch.utils._pytree as pytree
23+
24+
_torch_library.ODML_TORCH_LIB.define(
25+
"optimization_barrier(Tensor[] inputs) -> Tensor[]"
26+
)
27+
28+
optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default
29+
30+
31+
def optimization_barrier(*inputs: pytree.PyTree):
32+
"""Apply optimization barrier to the tensors nested within arbitrary pytrees.
33+
34+
Args:
35+
*inputs: A list of tensors or tensor pytrees.
36+
37+
Returns:
38+
The tensors after optimization barrier in the same pytrees structures.
39+
"""
40+
if len(inputs) == 1:
41+
inputs = inputs[0]
42+
tensors, spec = pytree.tree_flatten(inputs)
43+
tensors = optimization_barrier_op(tuple(tensors))
44+
outputs = pytree.tree_unflatten(tensors, spec)
45+
return outputs
46+
47+
48+
@torch.library.impl(
49+
_torch_library.ODML_TORCH_LIB,
50+
"optimization_barrier",
51+
"CompositeExplicitAutograd",
52+
)
53+
def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]):
54+
return tuple(inputs)
55+
56+
57+
@torch.library.impl(
58+
_torch_library.ODML_TORCH_LIB,
59+
"optimization_barrier",
60+
"Meta",
61+
)
62+
def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]):
63+
return tuple([torch.empty_like(x) for x in inputs])
64+
65+
66+
@registry.lower(torch.ops.odml_torch.optimization_barrier.default)
67+
def _optimization_barrier_lowering(
68+
lctx, inputs: tuple[ir.Value, ...]
69+
) -> ir.Value:
70+
del lctx
71+
return stablehlo.optimization_barrier(inputs)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from ai_edge_torch import odml_torch
16+
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib # Import to register the op.
17+
import torch
18+
19+
from absl.testing import absltest as googletest
20+
21+
optimization_barrier = optimization_barrier_lib.optimization_barrier
22+
23+
24+
class TestOptimizationBarrier(googletest.TestCase):
25+
"""Test optimization barrier op implementation and lowering."""
26+
27+
def test_applied_optimization_barrier_op(self):
28+
"""Test optimization barrier op application and lowering."""
29+
30+
class TestModel(torch.nn.Module):
31+
32+
def forward(self, x, y):
33+
x, _ = optimization_barrier(x, y)
34+
return x
35+
36+
x = torch.randn(1, 5)
37+
ep = torch.export.export(TestModel().eval(), (x, x))
38+
mlir = odml_torch.export.exported_program_to_mlir(ep)
39+
mlir_text = mlir.get_text()
40+
self.assertEqual(
41+
mlir_text.count(
42+
"stablehlo.optimization_barrier %arg1, %arg1 : tensor<1x5xf32>,"
43+
" tensor<1x5xf32>"
44+
),
45+
1,
46+
)
47+
48+
def test_input_single_tensor(self):
49+
"""Test optimization barrier with single tensor input."""
50+
x = torch.randn(1, 5)
51+
y = optimization_barrier(x)
52+
self.assertIsInstance(y, torch.Tensor)
53+
self.assertEqual(y.shape, (1, 5))
54+
55+
def test_input_multiple_tensors(self):
56+
"""Test optimization barrier with multiple tensors input."""
57+
x = torch.randn(1, 5)
58+
y = torch.randn(1, 6)
59+
z = optimization_barrier(x, y)
60+
self.assertIsInstance(z, tuple)
61+
self.assertLen(z, 2)
62+
self.assertIsInstance(z[0], torch.Tensor)
63+
self.assertIsInstance(z[1], torch.Tensor)
64+
self.assertEqual(z[0].shape, (1, 5))
65+
self.assertEqual(z[1].shape, (1, 6))
66+
67+
def test_input_nested_tensors(self):
68+
"""Test optimization barrier with nested tensor inputs."""
69+
x = {"foo": torch.randn(1, 5), "bar": torch.randn(1, 6)}
70+
z = optimization_barrier(x)
71+
self.assertIsInstance(z, dict)
72+
self.assertLen(z, 2)
73+
self.assertIsInstance(z["foo"], torch.Tensor)
74+
self.assertIsInstance(z["bar"], torch.Tensor)
75+
self.assertEqual(z["foo"].shape, (1, 5))
76+
self.assertEqual(z["bar"].shape, (1, 6))
77+
78+
79+
if __name__ == "__main__":
80+
googletest.main()

0 commit comments

Comments
 (0)