Skip to content

Commit 05acab0

Browse files
DannyYuyang-quiccccclai
authored andcommitted
Qualcomm AI Engine Direct - Support min.dim, max.dim and argmax ops (pytorch#12669)
### Summary: - Support aten.min.dim, aten.max.dim and aten.argmax ### Test plan aten.min.dim ``` python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperators.test_qnn_backend_min_dim -s ${device_id} -H ${host_id} -m ${soc} -b build-android python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_min_dim -s ${device_id} -H ${host_id} -m ${soc} -b build-android ``` aten.max.dim ``` python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperators.test_qnn_backend_max_dim -s ${device_id} -H ${host_id} -m ${soc} -b build-android python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_max_dim -s ${device_id} -H ${host_id} -m ${soc} -b build-android ``` argmax ``` python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperators.test_qnn_backend_argmax -s ${device_id} -H ${host_id} -m ${soc} -b build-android python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_argmax -s ${device_id} -H ${host_id} -m ${soc} -b build-android ``` cc: @haowhsu-quic @cccclai --------- Co-authored-by: Chen Lai <[email protected]>
1 parent 888bc4d commit 05acab0

File tree

14 files changed

+474
-7
lines changed

14 files changed

+474
-7
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .decompose_einsum import DecomposeEinsum
1818
from .decompose_expm1 import DecomposeExpM1
1919
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
20+
from .decompose_minmaxdim import DecomposeMinMaxDim
2021
from .decompose_roll import DecomposeRoll
2122
from .decompose_silu import DecomposeSilu
2223
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
@@ -54,6 +55,7 @@
5455
DecomposeEinsum,
5556
DecomposeExpM1,
5657
DecomposeLinalgVectorNorm,
58+
DecomposeMinMaxDim,
5759
DecomposeRoll,
5860
DecomposeSilu,
5961
DecomposeWrapWithAutocast,
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import operator
9+
from collections import Counter
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
15+
16+
17+
class DecomposeMinMaxDim(ExportPass):
18+
"""
19+
Since QNN does not support multi-output ops, this pass decomposes
20+
`torch.min(dim=...)` and `torch.max(dim=...)` into two separate operations:
21+
- `aten.min.dim` / `aten.max.dim` for the value
22+
- `aten.argmin` / `aten.argmax` for the index
23+
24+
Example transformation in the exported FX graph:
25+
26+
Python source:
27+
val, idx = torch.min(x, dim=1)
28+
29+
Before:
30+
%min = aten.min(%x, dim=1)
31+
%val = getitem(%min, 0)
32+
%idx = getitem(%min, 1)
33+
34+
After:
35+
%min = aten.min(%x, dim=1)
36+
%val = getitem(%min, 0)
37+
%idx = aten.argmin(%x, dim=1)
38+
39+
This pass preserves the value output if used, and transforms only the index path.
40+
"""
41+
42+
def __init__(self):
43+
super().__init__()
44+
self.min_dim = exir_ops.edge.aten.min.dim
45+
self.max_dim = exir_ops.edge.aten.max.dim
46+
self.argmin = exir_ops.edge.aten.argmin.default
47+
self.argmax = exir_ops.edge.aten.argmax.default
48+
self.getitem = operator.getitem
49+
50+
# index-only op
51+
self.replace_table = {
52+
self.min_dim: self.argmin,
53+
self.max_dim: self.argmax,
54+
}
55+
56+
self.patterns = [
57+
# Only index is used (e.g., _, idx = torch.min(x, dim=1))
58+
{self.min_dim: 1, self.getitem: 1},
59+
{self.max_dim: 1, self.getitem: 1},
60+
# Both value and index are used (e.g., val, idx = torch.max(x, dim=1))
61+
{self.min_dim: 1, self.getitem: 2},
62+
{self.max_dim: 1, self.getitem: 2},
63+
]
64+
65+
def call(self, graph_module: torch.fx.GraphModule):
66+
graph = graph_module.graph
67+
partitions = get_source_partitions(graph, [torch.min, torch.max])
68+
for _, src_partitions in partitions.items():
69+
for partition in src_partitions:
70+
if Counter([n.target for n in partition.nodes]) not in self.patterns:
71+
continue
72+
binary_output_node = partition.nodes[0]
73+
74+
# Ensure the binary-output node has exactly 2 outputs
75+
if len(binary_output_node.meta["val"]) != 2:
76+
continue
77+
78+
input_tensor = binary_output_node.args[0]
79+
dim = binary_output_node.args[1]
80+
keepdim = (
81+
binary_output_node.args[2]
82+
if len(binary_output_node.args) > 2
83+
else False
84+
)
85+
86+
idx_node = next(
87+
(
88+
output_node
89+
for output_node in partition.output_nodes
90+
if output_node.meta["val"].dtype == torch.int64
91+
),
92+
None,
93+
)
94+
95+
if idx_node:
96+
with graph.inserting_before(idx_node):
97+
argmin_node = graph.create_node(
98+
"call_function",
99+
self.replace_table[binary_output_node.target],
100+
(input_tensor, dim, keepdim),
101+
)
102+
argmin_node.meta = idx_node.meta
103+
104+
for user in list(idx_node.users):
105+
user.replace_input_with(idx_node, argmin_node)
106+
107+
graph.eliminate_dead_code()
108+
graph_module.recompile()
109+
return PassResult(graph_module, True)

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class I64toI32(ExportPass):
2626
"""
2727

2828
I64_OPS = {
29+
exir_ops.edge.aten.argmax.default,
2930
exir_ops.edge.aten.argmin.default,
3031
exir_ops.edge.aten.arange.start_step,
3132
exir_ops.edge.aten.cumsum.default,

backends/qualcomm/_passes/layout_transform.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ class LayoutTransform(ExportPass):
9191
exir_ops.edge.aten.logical_not.default,
9292
exir_ops.edge.aten.lt.Scalar,
9393
exir_ops.edge.aten.lt.Tensor,
94+
exir_ops.edge.aten.max.dim,
9495
exir_ops.edge.aten.maximum.default,
9596
exir_ops.edge.aten.mean.dim,
97+
exir_ops.edge.aten.min.dim,
9698
exir_ops.edge.aten.minimum.default,
9799
exir_ops.edge.aten.mul.Tensor,
98100
exir_ops.edge.aten.ne.Scalar,
@@ -167,10 +169,12 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
167169
return node.target in self.layout_sensitive_ops
168170

169171
def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
170-
if node.target in [
172+
if node.target in {
173+
exir_ops.edge.aten.max.dim,
171174
exir_ops.edge.aten.mean.dim,
175+
exir_ops.edge.aten.min.dim,
172176
exir_ops.edge.aten.sum.dim_IntList,
173-
]:
177+
}:
174178
# if dimemsion is not kept, we'll have no clue how to do layout transform
175179
if len(node.args) < 3 or not node.args[2]:
176180
return False

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DecomposeEinsum,
2323
DecomposeExpM1,
2424
DecomposeLinalgVectorNorm,
25+
DecomposeMinMaxDim,
2526
DecomposeRoll,
2627
DecomposeSilu,
2728
DecomposeWrapWithAutocast,
@@ -84,6 +85,7 @@ def get_capture_program_passes():
8485
(ConvertConv1dToConv2d, True),
8586
(DecomposeAny, True),
8687
(DecomposeColIm, True),
88+
(DecomposeMinMaxDim, True),
8789
(ExpandBroadcastTensorShape, False),
8890
(FixedLinearKeepDim, True),
8991
(FoldQDQ, True),

backends/qualcomm/builders/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,9 @@ The operator now should be functional for Qualcomm backends. For operator to wor
360360
## Operator Support Status
361361
Please help update following table if you are contributing new operators:
362362

363-
| Operators | HTP - 80/116 Enabled |
363+
| Operators | HTP - 82/116 Enabled |
364364
|-----------|---------|
365-
| Argmax | &cross; |
365+
| Argmax | &check; |
366366
| Argmin | &check; |
367367
| BatchNorm | &check; |
368368
| BatchToSpace | &cross; |
@@ -449,7 +449,7 @@ Please help update following table if you are contributing new operators:
449449
| Quantize | &check; |
450450
| ReduceMax | &check; |
451451
| ReduceMean | &check; |
452-
| ReduceMin | &cross; |
452+
| ReduceMin | &check; |
453453
| ReduceSum | &check; |
454454
| Relu | &check; |
455455
| Relu1 | &cross; |

backends/qualcomm/builders/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
op_amax,
1313
op_and,
1414
op_arange,
15+
op_argmax,
1516
op_argmin,
1617
op_atan,
1718
op_avg_pool2d,
@@ -54,9 +55,11 @@
5455
op_lt,
5556
op_matmul,
5657
op_max,
58+
op_max_dim,
5759
op_max_pool2d,
5860
op_mean_dim,
5961
op_min,
62+
op_min_dim,
6063
op_mul,
6164
op_ne,
6265
op_neg,
@@ -105,6 +108,7 @@
105108
op_amax,
106109
op_and,
107110
op_arange,
111+
op_argmax,
108112
op_argmin,
109113
op_atan,
110114
op_avg_pool2d,
@@ -147,9 +151,11 @@
147151
op_lt,
148152
op_matmul,
149153
op_max,
154+
op_max_dim,
150155
op_max_pool2d,
151156
op_mean_dim,
152157
op_min,
158+
op_min_dim,
153159
op_mul,
154160
op_neg,
155161
op_ne,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
import numpy as np
10+
import torch
11+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
12+
13+
from .node_visitor import NodeVisitor
14+
from .node_visitor_manager import register_node_visitor
15+
from .qnn_constants import OpArgmax, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class argmax(NodeVisitor):
20+
target = ["aten.argmax.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
input_node = self.get_node(node.args[0])
31+
input_tensor = self.get_tensor(input_node, node)
32+
output_tensor = self.get_tensor(node, node)
33+
argmax_inp_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
argmax_input_tensors = [argmax_inp_tensor_wrapper]
41+
argmax_out_tensor_wrapper = self.define_tensor(
42+
node,
43+
node,
44+
output_tensor.to(torch.int32),
45+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
46+
nodes_to_wrappers,
47+
)
48+
argmax_output_tensors = [argmax_out_tensor_wrapper]
49+
50+
dim = cast(int, node.args[1])
51+
if dim < 0:
52+
dim = dim % len(input_tensor.shape)
53+
if QCOM_AXIS_ORDER in node.meta:
54+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
55+
56+
argmax_op = PyQnnWrapper.PyQnnOpWrapper(
57+
node.name,
58+
QNN_OP_PACKAGE_NAME_QTI_AISW,
59+
OpArgmax.op_name,
60+
)
61+
argmax_op.AddInputTensors(argmax_input_tensors)
62+
argmax_op.AddOutputTensors(argmax_output_tensors)
63+
64+
argmax_op.AddScalarParam(
65+
OpArgmax.param_axis,
66+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
67+
{QCOM_DATA: np.uint32(dim)},
68+
)
69+
70+
if len(node.args) > 2:
71+
keep_dims = cast(bool, node.args[2])
72+
argmax_op.AddScalarParam(
73+
OpArgmax.param_keep_dims,
74+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
75+
{QCOM_DATA: keep_dims},
76+
)
77+
78+
return argmax_op

0 commit comments

Comments
 (0)