Skip to content

Commit b7eff08

Browse files
authored
Arm backend: Add sub tensor to match scalar tensors to input data types. (pytorch#12541)
Updates existing pass to ensure that x1, x2 are of the same type. This comes from an issue where scalars are not automatically casted to the input dtype unlike mul and add operators.
1 parent efd133d commit b7eff08

File tree

5 files changed

+19
-25
lines changed

5 files changed

+19
-25
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@
6767
)
6868
from .insert_rescales_pass import InsertRescalePass # noqa
6969
from .insert_table_ops import InsertTableOpsPass # noqa
70+
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
7071
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
71-
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
7272
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
7373
from .remove_clone_pass import RemoveClonePass # noqa
7474
from .replace_scalar_with_tensor_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@
6666
InsertCastForOpsWithInt64InputPass,
6767
InsertRescalePass,
6868
InsertTableOpsPass,
69+
MatchArgDtypePass,
6970
MatchArgRanksPass,
70-
MatchWhereSelfDtypePass,
7171
QuantizeOperatorArguments,
7272
RemoveClonePass,
7373
ReplaceInfValues,
@@ -116,7 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
116116
self.add_pass(ConvertToClampPass())
117117
self.add_pass(ConvertMinMaxPass())
118118
self.add_pass(ConvertAnyDefaultDimDimsPass())
119-
self.add_pass(MatchWhereSelfDtypePass())
119+
self.add_pass(MatchArgDtypePass())
120120
if self.tosa_spec.is_U55_subset:
121121
self.add_pass(CastToInt32Pass())
122122

@@ -193,8 +193,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
193193
self.add_pass(ConvertToClampPass())
194194
self.add_pass(ConvertMinMaxPass())
195195
self.add_pass(ConvertAnyDefaultDimDimsPass())
196-
self.add_pass(MatchWhereSelfDtypePass())
197-
196+
self.add_pass(MatchArgDtypePass())
198197
self.add_pass(AnnotateDecomposedMatmulPass())
199198
self.add_pass(QuantizeOperatorArguments())
200199
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]

backends/arm/_passes/match_where_self_arg_dtype_pass.py renamed to backends/arm/_passes/match_arg_dtype_pass.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7-
from executorch.backends.arm._passes.arm_pass_utils import create_node
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node, get_node_arg
88
from executorch.exir.dialects._ops import ops as exir_ops
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

@@ -26,7 +26,7 @@ def get_largest_dtype(dtype_1, dtype_2):
2626
return dtype_1 if DTYPE_RANK[dtype_1] > DTYPE_RANK[dtype_2] else dtype_2
2727

2828

29-
class MatchWhereSelfDtypePass(ExportPass):
29+
class MatchArgDtypePass(ExportPass):
3030
"""Pass to match data types of non-condition input tensors.
3131
3232
Edge dialect allows different data types for non-condition tensors, while TOSA
@@ -38,14 +38,18 @@ class MatchWhereSelfDtypePass(ExportPass):
3838
3939
"""
4040

41+
targeted_ops = {exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.where.self}
42+
4143
def call(self, graph_module: torch.fx.GraphModule):
4244
modified_graph = False
4345
graph = graph_module.graph
44-
node_list = graph.find_nodes(
45-
op="call_function", target=exir_ops.edge.aten.where.self
46-
)
47-
for node in node_list:
48-
cond, input_, other_ = node.args
46+
47+
for node in list(graph.nodes):
48+
if node.op != "call_function" or node.target not in self.targeted_ops:
49+
continue
50+
51+
input_ = get_node_arg(node.args, 0)
52+
other_ = get_node_arg(node.args, 1)
4953

5054
input_dtype = input_.meta["val"].dtype
5155
other_dtype = other_.meta["val"].dtype

backends/arm/test/ops/test_scalars.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,21 +242,16 @@ def test_add_scalar_u85_BI():
242242

243243

244244
# SUB MI ------------------------------------------------------
245-
mi_sub_xfails = {
246-
"int_r1_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
247-
"int_r4_ts": "TypeError: All IO needs to have the same data type, got input 1: 8, input 2: 6 and output: 8",
248-
**xfails,
249-
}
250245

251246

252-
@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
247+
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
253248
def test_sub_tensor_tosa_MI_scalar(test_data):
254249
"""Tests regular sub with one scalar input."""
255250
pipeline = TosaPipelineMI[input_t1](Sub(), test_data, aten_op=Sub.aten_op)
256251
pipeline.run()
257252

258253

259-
@common.parametrize("test_data", tensor_scalar_tests, xfails=mi_sub_xfails)
254+
@common.parametrize("test_data", tensor_scalar_tests, xfails=xfails)
260255
def test_sub_tensor_tosa_MI_inplace(test_data):
261256
"""Tests inplace sub with one scalar input."""
262257
pipeline = TosaPipelineMI[input_t1](SubInplace(), test_data, aten_op=[])

backends/arm/test/ops/test_sub.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
torch.randn(1, 4, 4, 1),
4343
torch.randn(1, 1, 4, 4),
4444
),
45+
"rand_3d_rand_Scalar": lambda: (torch.rand(1, 6, 2), torch.rand(1)),
46+
"rand_3d_Scalar": lambda: (torch.rand(1, 6, 2), 1),
4547
}
4648
fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"}
4749

@@ -93,7 +95,6 @@ def test_sub_tensor_tosa_BI(test_data):
9395
aten_op,
9496
exir_op,
9597
)
96-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
9798
pipeline.run()
9899

99100

@@ -106,7 +107,6 @@ def test_sub_tensor_tosa_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
106107
aten_op,
107108
exir_op,
108109
)
109-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
110110
pipeline.run()
111111

112112

@@ -121,7 +121,6 @@ def test_sub_tensor_u55_BI(test_data):
121121
exir_op,
122122
run_on_fvp=True,
123123
)
124-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
125124
pipeline.run()
126125

127126

@@ -136,7 +135,6 @@ def test_sub_tensor_u55_BI_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
136135
exir_op,
137136
run_on_fvp=True,
138137
)
139-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
140138
pipeline.run()
141139

142140

@@ -151,7 +149,6 @@ def test_sub_tensor_u85_BI_2(test_data):
151149
exir_op,
152150
run_on_fvp=True,
153151
)
154-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
155152
pipeline.run()
156153

157154

@@ -166,5 +163,4 @@ def test_sub_tensor_u85_BI(test_data: Tuple[torch.Tensor, torch.Tensor]):
166163
exir_op,
167164
run_on_fvp=True,
168165
)
169-
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
170166
pipeline.run()

0 commit comments

Comments
 (0)