Skip to content

Commit 3599379

Browse files
Qualcomm AI Engine Direct - Support min.dim, max.dim and argmax ops
Summary: - Support aten.min.dim, aten.max.dim and aten.argmax
1 parent 858102a commit 3599379

File tree

14 files changed

+443
-3
lines changed

14 files changed

+443
-3
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .decompose_einsum import DecomposeEinsum
1818
from .decompose_expm1 import DecomposeExpM1
1919
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
20+
from .decompose_minmaxdim import DecomposeMinMaxDim
2021
from .decompose_roll import DecomposeRoll
2122
from .decompose_silu import DecomposeSilu
2223
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
@@ -54,6 +55,7 @@
5455
DecomposeEinsum,
5556
DecomposeExpM1,
5657
DecomposeLinalgVectorNorm,
58+
DecomposeMinMaxDim,
5759
DecomposeRoll,
5860
DecomposeSilu,
5961
DecomposeWrapWithAutocast,
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import operator
9+
from collections import Counter
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
15+
16+
17+
class DecomposeMinMaxDim(ExportPass):
18+
"""
19+
Since QNN does not support multi-output ops, this pass decomposes
20+
`torch.min(dim=...)` and `torch.max(dim=...)` into two separate operations:
21+
- `aten.min.dim` / `aten.max.dim` for the value
22+
- `aten.argmin` / `aten.argmax` for the index
23+
24+
Example transformation in the exported FX graph:
25+
26+
Python source:
27+
val, idx = torch.min(x, dim=1)
28+
29+
Before:
30+
%min = aten.min(%x, dim=1)
31+
%val = getitem(%min, 0)
32+
%idx = getitem(%min, 1)
33+
34+
After:
35+
%min = aten.min(%x, dim=1)
36+
%val = getitem(%min, 0)
37+
%idx = aten.argmin(%x, dim=1)
38+
39+
This pass preserves the value output if used, and transforms only the index path.
40+
"""
41+
42+
def __init__(self):
43+
super().__init__()
44+
self.min_dim = exir_ops.edge.aten.min.dim
45+
self.max_dim = exir_ops.edge.aten.max.dim
46+
self.argmin = exir_ops.edge.aten.argmin.default
47+
self.argmax = exir_ops.edge.aten.argmax.default
48+
self.getitem = operator.getitem
49+
50+
# index-only op
51+
self.replace_table = {
52+
self.min_dim: self.argmin,
53+
self.max_dim: self.argmax,
54+
}
55+
56+
self.patterns = [
57+
# Only index is used (e.g., _, idx = torch.min(x, dim=1))
58+
{self.min_dim: 1, self.getitem: 1},
59+
{self.max_dim: 1, self.getitem: 1},
60+
# Both value and index are used (e.g., val, idx = torch.max(x, dim=1))
61+
{self.min_dim: 1, self.getitem: 2},
62+
{self.max_dim: 1, self.getitem: 2},
63+
]
64+
65+
def call(self, graph_module: torch.fx.GraphModule):
66+
graph = graph_module.graph
67+
partitions = get_source_partitions(graph, [torch.min, torch.max])
68+
for _, src_partitions in partitions.items():
69+
for partition in src_partitions:
70+
multi_output_node = partition.nodes[0]
71+
if Counter([n.target for n in partition.nodes]) not in self.patterns:
72+
continue
73+
input_tensor = multi_output_node.args[0]
74+
dim = multi_output_node.args[1]
75+
keepdim = (
76+
multi_output_node.args[2]
77+
if len(multi_output_node.args) > 2
78+
else False
79+
)
80+
81+
idx_node = next(
82+
(
83+
user
84+
for user in multi_output_node.users
85+
if user.meta["val"].dtype == torch.int64
86+
),
87+
None,
88+
)
89+
90+
if idx_node:
91+
with graph.inserting_before(idx_node):
92+
argmin_node = graph.create_node(
93+
"call_function",
94+
self.replace_table[multi_output_node.target],
95+
(input_tensor, dim, keepdim),
96+
)
97+
argmin_node.meta = idx_node.meta
98+
99+
for user in list(idx_node.users):
100+
user.replace_input_with(idx_node, argmin_node)
101+
102+
graph.eliminate_dead_code()
103+
graph_module.recompile()
104+
return PassResult(graph_module, True)

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class I64toI32(ExportPass):
2626
"""
2727

2828
I64_OPS = {
29+
exir_ops.edge.aten.argmax.default,
2930
exir_ops.edge.aten.argmin.default,
3031
exir_ops.edge.aten.arange.start_step,
3132
exir_ops.edge.aten.cumsum.default,

backends/qualcomm/_passes/layout_transform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class LayoutTransform(ExportPass):
9393
exir_ops.edge.aten.lt.Tensor,
9494
exir_ops.edge.aten.maximum.default,
9595
exir_ops.edge.aten.mean.dim,
96+
exir_ops.edge.aten.max.dim,
97+
exir_ops.edge.aten.min.dim,
9698
exir_ops.edge.aten.minimum.default,
9799
exir_ops.edge.aten.mul.Tensor,
98100
exir_ops.edge.aten.ne.Scalar,
@@ -169,6 +171,8 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
169171
def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
170172
if node.target in [
171173
exir_ops.edge.aten.mean.dim,
174+
exir_ops.edge.aten.max.dim,
175+
exir_ops.edge.aten.min.dim,
172176
exir_ops.edge.aten.sum.dim_IntList,
173177
]:
174178
# if dimemsion is not kept, we'll have no clue how to do layout transform

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DecomposeEinsum,
2323
DecomposeExpM1,
2424
DecomposeLinalgVectorNorm,
25+
DecomposeMinMaxDim,
2526
DecomposeRoll,
2627
DecomposeSilu,
2728
DecomposeWrapWithAutocast,
@@ -84,6 +85,7 @@ def get_capture_program_passes():
8485
(ConvertConv1dToConv2d, True),
8586
(DecomposeAny, True),
8687
(DecomposeColIm, True),
88+
(DecomposeMinMaxDim, True),
8789
(ExpandBroadcastTensorShape, False),
8890
(FixedLinearKeepDim, True),
8991
(FoldQDQ, True),

backends/qualcomm/builders/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
op_amax,
1313
op_and,
1414
op_arange,
15+
op_argmax,
1516
op_argmin,
1617
op_atan,
1718
op_avg_pool2d,
@@ -54,9 +55,11 @@
5455
op_lt,
5556
op_matmul,
5657
op_max,
58+
op_max_dim,
5759
op_max_pool2d,
5860
op_mean_dim,
5961
op_min,
62+
op_min_dim,
6063
op_mul,
6164
op_ne,
6265
op_neg,
@@ -105,6 +108,7 @@
105108
op_amax,
106109
op_and,
107110
op_arange,
111+
op_argmax,
108112
op_argmin,
109113
op_atan,
110114
op_avg_pool2d,
@@ -147,9 +151,11 @@
147151
op_lt,
148152
op_matmul,
149153
op_max,
154+
op_max_dim,
150155
op_max_pool2d,
151156
op_mean_dim,
152157
op_min,
158+
op_min_dim,
153159
op_mul,
154160
op_neg,
155161
op_ne,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import numpy as np
10+
import torch
11+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
12+
13+
from .node_visitor import NodeVisitor
14+
from .node_visitor_manager import register_node_visitor
15+
from .qnn_constants import OpArgmax, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class argmax(NodeVisitor):
20+
target = ["aten.argmax.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
input_node = self.get_node(node.args[0])
31+
input_tensor = self.get_tensor(input_node, node)
32+
output_tensor = self.get_tensor(node, node)
33+
argmax_inp_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
argmax_input_tensors = [argmax_inp_tensor_wrapper]
41+
argmax_out_tensor_wrapper = self.define_tensor(
42+
node,
43+
node,
44+
output_tensor.to(torch.int32),
45+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
46+
nodes_to_wrappers,
47+
)
48+
argmax_output_tensors = [argmax_out_tensor_wrapper]
49+
50+
dim = cast(int, node.args[1])
51+
if dim < 0:
52+
dim = dim % len(input_tensor.shape)
53+
if QCOM_AXIS_ORDER in node.meta:
54+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
55+
56+
argmax_op = PyQnnWrapper.PyQnnOpWrapper(
57+
node.name,
58+
QNN_OP_PACKAGE_NAME_QTI_AISW,
59+
OpArgmax.op_name,
60+
)
61+
argmax_op.AddInputTensors(argmax_input_tensors)
62+
argmax_op.AddOutputTensors(argmax_output_tensors)
63+
64+
argmax_op.AddScalarParam(
65+
OpArgmax.param_axis,
66+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
67+
{QCOM_DATA: np.uint32(dim)},
68+
)
69+
70+
if len(node.args) > 2:
71+
keep_dims = cast(bool, node.args[2])
72+
argmax_op.AddScalarParam(
73+
OpArgmax.param_keep_dims,
74+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
75+
{QCOM_DATA: keep_dims},
76+
)
77+
78+
return argmax_op
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast, Dict, List
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
11+
import numpy as np
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
17+
from .qnn_constants import OpReduceMax, QNN_OP_PACKAGE_NAME_QTI_AISW
18+
19+
20+
@register_node_visitor
21+
class MaxDim(NodeVisitor):
22+
target = ["aten.max.dim"]
23+
24+
def __init__(self, *args) -> None:
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
31+
) -> List[PyQnnWrapper.PyQnnOpWrapper]:
32+
input_node = self.get_node(node.args[0])
33+
input_tensor = self.get_tensor(input_node, node)
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
node,
37+
input_tensor,
38+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
39+
nodes_to_wrappers,
40+
)
41+
42+
# QNN does not support multiple outputs for a single op.
43+
# Since torch.max(input, dim) returns both values and indices,
44+
# we only support the value output for OpReduceMax. The index output will be handled
45+
# separately by OpArgmax.
46+
# Therefore, we update node.meta["val"] to only keep the value part.
47+
if len(node.meta["val"]) == 2:
48+
node.meta["val"] = node.meta["val"][0]
49+
50+
output_tensor = self.get_tensor(node, node)
51+
out_tensor_wrapper = self.define_tensor(
52+
node,
53+
node,
54+
output_tensor,
55+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
56+
nodes_to_wrappers,
57+
)
58+
59+
dims = cast(List[int], [node.args[1]])
60+
dims = [mean_dim % len(input_node.meta["val"].shape) for mean_dim in dims]
61+
if QCOM_AXIS_ORDER in node.meta:
62+
dims = [node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in dims]
63+
dims_shape = [len(dims)]
64+
65+
reduce_max_op = PyQnnWrapper.PyQnnOpWrapper(
66+
node.name,
67+
QNN_OP_PACKAGE_NAME_QTI_AISW,
68+
OpReduceMax.op_name,
69+
)
70+
reduce_max_op.AddInputTensors([input_tensor_wrapper])
71+
reduce_max_op.AddOutputTensors([out_tensor_wrapper])
72+
73+
reduce_max_op.AddTensorParam(
74+
OpReduceMax.param_axes,
75+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
76+
len(dims_shape),
77+
dims_shape,
78+
np.array(dims, dtype=np.uint32),
79+
True,
80+
)
81+
if len(node.args) > 2:
82+
keep_dims = cast(bool, node.args[2])
83+
reduce_max_op.AddScalarParam(
84+
OpReduceMax.param_keep_dims,
85+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
86+
{QCOM_DATA: keep_dims},
87+
)
88+
89+
return reduce_max_op

backends/qualcomm/builders/op_min.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Min(NodeVisitor):
19-
target = ["aten.minimum.default", "aten.min.dim"]
19+
target = ["aten.minimum.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

0 commit comments

Comments
 (0)