Skip to content

Commit c5ff74c

Browse files
keyprocedureGasoonjiadigantdesai
authored
[EXIR] Register _clone_dim_order op and map aten.clone (#13735)
### Summary This is PR 2 of 3 implementing a dim order aware clone op. This PR registers the new `_clone_dim_order` op and maps `aten.clone` to `dim_order_ops._clone_dim_order` in EXIR during export to preserve memory layout changes (contiguous/channels_last). It also updates the Core ML, ARM, and Qualcomm backends to handle the new clone op. Related PRs: - PR 1: [#12974](#12974) - Add `_clone_dim_order` portable kernel - PR 3: [#12976](#12976) - Update RemoveCloneOpsTransform to be dim order aware Fixes #12645 ### Test plan - Operator level tests to verify kernel behavior for layout preservation and changes. - Graph level checks to confirm that clone mapping occurs. - End to end tests to validate that functional clone behavior is unchanged. - Backend tests to ensure clone semantics are preserved. All tests pass via: `python -m unittest exir.tests.test_memory_format_ops_pass` `python -m unittest backends.apple.coreml.test.test_torch_ops` `pytest backends/arm/test/ops/test_clone.py` `pytest backends/arm/test/passes/test_remove_clone_pass.py` --------- Co-authored-by: Gasoonjia <[email protected]> Co-authored-by: Digant Desai <[email protected]>
1 parent f9ce98f commit c5ff74c

File tree

14 files changed

+212
-10
lines changed

14 files changed

+212
-10
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from coremltools.converters.mil.frontend.torch.ops import (
1616
_get_inputs,
1717
_get_kwinputs,
18+
noop,
1819
NUM_TO_NUMPY_DTYPE,
1920
NUM_TO_TORCH_DTYPE,
2021
split,
@@ -91,6 +92,28 @@ def _to_dim_order_copy(context, node):
9192
to(context, node)
9293

9394

95+
@register_torch_op(
96+
torch_alias=[
97+
"dim_order_ops::_clone_dim_order",
98+
"dim_order_ops._clone_dim_order",
99+
],
100+
override=False,
101+
)
102+
def _clone_dim_order(context, node):
103+
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
104+
node.kwinputs.pop("dim_order")
105+
106+
# In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format.
107+
dim_order = [int(d) for d in dim_order.val]
108+
memory_format = get_memory_format(dim_order)
109+
assert (
110+
memory_format == _torch.contiguous_format
111+
), "Only contiguous memory format is supported in CoreML"
112+
113+
# Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone.
114+
noop(context, node)
115+
116+
94117
# https://github.com/apple/coremltools/pull/2558
95118
@register_torch_op(
96119
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,28 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
268268
et_prog = delegated_program.to_executorch()
269269
self._compare_outputs(et_prog, model, example_inputs)
270270

271+
def test__clone_dim_order_contiguous(self):
272+
class Model(torch.nn.Module):
273+
def forward(self, x):
274+
return torch.ops.dim_order_ops._clone_dim_order(
275+
x, dim_order=[0, 1, 2, 3]
276+
)
277+
278+
model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),)
279+
ep = torch.export.export(model, example_inputs)
280+
delegated_program = executorch.exir.to_edge_transform_and_lower(
281+
ep,
282+
partitioner=[self._coreml_partitioner()],
283+
)
284+
for node in delegated_program.exported_program().graph.nodes:
285+
if node.op == "call_function":
286+
assert node.target.__name__ in [
287+
"executorch_call_delegate",
288+
"getitem",
289+
], f"Got unexpected node target after delegation: {node.target.__name__}"
290+
et_prog = delegated_program.to_executorch()
291+
self._compare_outputs(et_prog, model, example_inputs)
292+
271293

272294
if __name__ == "__main__":
273295
test_runner = TestTorchOps()
@@ -280,3 +302,4 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
280302
test_runner.test_dequantize_codebook_linear_per_grouped_row()
281303
test_runner.test_dequantize_codebook_embedding_per_grouped_col()
282304
test_runner.test_dequantize_codebook_embedding_per_grouped_row()
305+
test_runner.test__clone_dim_order_contiguous()

backends/arm/_passes/remove_clone_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class RemoveClonePass(ExportPass):
1818
"""Remove all clones from graph_module"""
1919

2020
def call_operator(self, op, args, kwargs, meta):
21-
if op != exir_ops.edge.aten.clone.default:
21+
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
2222
return super().call_operator(op, args, kwargs, meta)
2323

2424
if len(args) != 1:

backends/arm/operator_support/clone_support.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77

8+
import torch
89
import torch.fx as fx
910
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1011
register_tosa_support_check,
@@ -18,7 +19,7 @@
1819

1920
@register_tosa_support_check
2021
class CloneSupported(SupportedTOSAOperatorCheck):
21-
targets = [exir_ops.edge.aten.clone.default]
22+
targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default]
2223

2324
tosa_specs = [
2425
TosaSpecification.create_from_string("TOSA-1.0+INT"),
@@ -28,10 +29,62 @@ class CloneSupported(SupportedTOSAOperatorCheck):
2829
def is_node_tosa_supported(
2930
self, node: fx.Node, tosa_spec: TosaSpecification
3031
) -> bool:
32+
if node.target not in self.targets:
33+
self.reporter.report_reject(node, f"Target {node.target} is not supported.")
34+
return False
3135

3236
input_node = node.args[0]
3337
if not isinstance(input_node, fx.Node):
3438
self.reporter.report_reject(node, "Non tensor clones are not supported")
3539
return False
3640

41+
# Check input node
42+
if len(node.all_input_nodes) != 1:
43+
self.reporter.report_reject(
44+
node, f"Expected 1 input node, got {len(node.all_input_nodes)}"
45+
)
46+
return False
47+
48+
input_val = node.all_input_nodes[0].meta["val"]
49+
if not isinstance(input_val, torch._subclasses.FakeTensor):
50+
self.reporter.report_reject(node, "Expected input to be a FakeTensor.")
51+
return False
52+
53+
input_dtype = input_val.dtype
54+
55+
# Check output node
56+
output_val = node.meta["val"]
57+
if not isinstance(output_val, torch._subclasses.FakeTensor):
58+
self.reporter.report_reject(node, "Expected output to be a FakeTensor.")
59+
return False
60+
61+
if output_val.dtype != input_dtype:
62+
self.reporter.report_reject(
63+
node,
64+
f"Input dtype {input_val.dtype} does not match {output_val.dtype}.",
65+
)
66+
return False
67+
68+
# Check memory format
69+
if "memory_format" in node.kwargs:
70+
if node.kwargs["memory_format"] in (torch.preserve_format,):
71+
self.reporter.report_reject(
72+
node,
73+
f"Argument 'memory_format' is not supported for "
74+
f"{node.target} right now.",
75+
)
76+
return False
77+
78+
# Check dim_order
79+
if "dim_order" in node.kwargs:
80+
dim_order = node.kwargs["dim_order"]
81+
# pyre-ignore[6]
82+
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
83+
self.reporter.report_reject(
84+
node,
85+
f"Argument {dim_order=} is not supported for "
86+
f"{node.target} right now.",
87+
)
88+
return False
89+
3790
return True

backends/arm/test/misc/test_partition_decomposed_quantized_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
]
3939
linear_residual_exir_op: list[str] = [
4040
"executorch_exir_dialects_edge__ops_aten_gelu_default",
41-
"executorch_exir_dialects_edge__ops_aten_clone_default",
41+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
4242
"executorch_exir_dialects_edge__ops_aten_linear_default",
4343
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
4444
]

backends/arm/test/ops/test_clone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020

2121
aten_op = "torch.ops.aten.clone.default"
22-
exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default"
22+
exir_op = "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
2323

2424
input_t = Tuple[torch.Tensor]
2525

backends/arm/test/passes/test_remove_clone_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def test_remove_clone_tosa_INT():
3535
module.get_inputs(),
3636
quantize=True,
3737
ops_before_pass={
38-
"executorch_exir_dialects_edge__ops_aten_clone_default": 1,
38+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1,
3939
},
40-
ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_clone_default"],
40+
ops_not_after_pass=[
41+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
42+
],
4143
pass_list=[RemoveClonePass],
4244
)
4345
pipeline.run()

backends/qualcomm/_passes/convert_bmm_to_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ConvertBmmToMatmul(ExportPass):
2121

2222
view_copy = exir_ops.edge.aten.view_copy.default
2323
expand_copy = exir_ops.edge.aten.expand_copy.default
24-
clone = exir_ops.edge.aten.clone.default
24+
clone = exir_ops.edge.dim_order_ops._clone_dim_order.default
2525
bmm = exir_ops.edge.aten.bmm.default
2626
matmul = exir_ops.edge.aten.matmul.default
2727
patterns = [

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, quantization_capture=False):
1919
self.redundant_ops_general = {
2020
torch.clone: self._default_condition,
2121
torch.ops.aten.clone.default: self._default_condition,
22-
exir_ops.edge.aten.clone.default: self._default_condition,
22+
exir_ops.edge.dim_order_ops._clone_dim_order.default: self._default_condition,
2323
torch.ops.aten.alias.default: self._default_condition,
2424
exir_ops.edge.aten.alias.default: self._default_condition,
2525
exir_ops.edge.aten.alias_copy.default: self._default_condition,

backends/qualcomm/partition/common_defs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111

1212
not_supported_operator = [
13-
exir_ops.edge.aten.clone.default,
13+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
1414
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1515
]
1616

0 commit comments

Comments
 (0)