Skip to content

Qualcomm AI Engine Direct - Support min.dim, max.dim and argmax ops #12669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .decompose_einsum import DecomposeEinsum
from .decompose_expm1 import DecomposeExpM1
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
from .decompose_minmaxdim import DecomposeMinMaxDim
from .decompose_roll import DecomposeRoll
from .decompose_silu import DecomposeSilu
from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast
Expand Down Expand Up @@ -54,6 +55,7 @@
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
DecomposeMinMaxDim,
DecomposeRoll,
DecomposeSilu,
DecomposeWrapWithAutocast,
Expand Down
109 changes: 109 additions & 0 deletions backends/qualcomm/_passes/decompose_minmaxdim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import operator
from collections import Counter

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class DecomposeMinMaxDim(ExportPass):
"""
Since QNN does not support multi-output ops, this pass decomposes
`torch.min(dim=...)` and `torch.max(dim=...)` into two separate operations:
- `aten.min.dim` / `aten.max.dim` for the value
- `aten.argmin` / `aten.argmax` for the index

Example transformation in the exported FX graph:

Python source:
val, idx = torch.min(x, dim=1)

Before:
%min = aten.min(%x, dim=1)
%val = getitem(%min, 0)
%idx = getitem(%min, 1)

After:
%min = aten.min(%x, dim=1)
%val = getitem(%min, 0)
%idx = aten.argmin(%x, dim=1)

This pass preserves the value output if used, and transforms only the index path.
"""

def __init__(self):
super().__init__()
self.min_dim = exir_ops.edge.aten.min.dim
self.max_dim = exir_ops.edge.aten.max.dim
self.argmin = exir_ops.edge.aten.argmin.default
self.argmax = exir_ops.edge.aten.argmax.default
self.getitem = operator.getitem

# index-only op
self.replace_table = {
self.min_dim: self.argmin,
self.max_dim: self.argmax,
}

self.patterns = [
# Only index is used (e.g., _, idx = torch.min(x, dim=1))
{self.min_dim: 1, self.getitem: 1},
{self.max_dim: 1, self.getitem: 1},
# Both value and index are used (e.g., val, idx = torch.max(x, dim=1))
{self.min_dim: 1, self.getitem: 2},
{self.max_dim: 1, self.getitem: 2},
]

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
partitions = get_source_partitions(graph, [torch.min, torch.max])
for _, src_partitions in partitions.items():
for partition in src_partitions:
if Counter([n.target for n in partition.nodes]) not in self.patterns:
continue
binary_output_node = partition.nodes[0]

# Ensure the binary-output node has exactly 2 outputs
if len(binary_output_node.meta["val"]) != 2:
continue

input_tensor = binary_output_node.args[0]
dim = binary_output_node.args[1]
keepdim = (
binary_output_node.args[2]
if len(binary_output_node.args) > 2
else False
)

idx_node = next(
(
output_node
for output_node in partition.output_nodes
if output_node.meta["val"].dtype == torch.int64
),
None,
)

if idx_node:
with graph.inserting_before(idx_node):
argmin_node = graph.create_node(
"call_function",
self.replace_table[binary_output_node.target],
(input_tensor, dim, keepdim),
)
argmin_node.meta = idx_node.meta

for user in list(idx_node.users):
user.replace_input_with(idx_node, argmin_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/i64_to_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class I64toI32(ExportPass):
"""

I64_OPS = {
exir_ops.edge.aten.argmax.default,
exir_ops.edge.aten.argmin.default,
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.cumsum.default,
Expand Down
8 changes: 6 additions & 2 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.logical_not.default,
exir_ops.edge.aten.lt.Scalar,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.max.dim,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.min.dim,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.ne.Scalar,
Expand Down Expand Up @@ -167,10 +169,12 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
return node.target in self.layout_sensitive_ops

def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
if node.target in [
if node.target in {
exir_ops.edge.aten.max.dim,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.min.dim,
exir_ops.edge.aten.sum.dim_IntList,
]:
}:
# if dimemsion is not kept, we'll have no clue how to do layout transform
if len(node.args) < 3 or not node.args[2]:
return False
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DecomposeEinsum,
DecomposeExpM1,
DecomposeLinalgVectorNorm,
DecomposeMinMaxDim,
DecomposeRoll,
DecomposeSilu,
DecomposeWrapWithAutocast,
Expand Down Expand Up @@ -84,6 +85,7 @@ def get_capture_program_passes():
(ConvertConv1dToConv2d, True),
(DecomposeAny, True),
(DecomposeColIm, True),
(DecomposeMinMaxDim, True),
(ExpandBroadcastTensorShape, False),
(FixedLinearKeepDim, True),
(FoldQDQ, True),
Expand Down
6 changes: 3 additions & 3 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ The operator now should be functional for Qualcomm backends. For operator to wor
## Operator Support Status
Please help update following table if you are contributing new operators:

| Operators | HTP - 80/116 Enabled |
| Operators | HTP - 82/116 Enabled |
|-----------|---------|
| Argmax | &cross; |
| Argmax | &check; |
| Argmin | &check; |
| BatchNorm | &check; |
| BatchToSpace | &cross; |
Expand Down Expand Up @@ -449,7 +449,7 @@ Please help update following table if you are contributing new operators:
| Quantize | &check; |
| ReduceMax | &check; |
| ReduceMean | &check; |
| ReduceMin | &cross; |
| ReduceMin | &check; |
| ReduceSum | &check; |
| Relu | &check; |
| Relu1 | &cross; |
Expand Down
6 changes: 6 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
op_amax,
op_and,
op_arange,
op_argmax,
op_argmin,
op_atan,
op_avg_pool2d,
Expand Down Expand Up @@ -54,9 +55,11 @@
op_lt,
op_matmul,
op_max,
op_max_dim,
op_max_pool2d,
op_mean_dim,
op_min,
op_min_dim,
op_mul,
op_ne,
op_neg,
Expand Down Expand Up @@ -105,6 +108,7 @@
op_amax,
op_and,
op_arange,
op_argmax,
op_argmin,
op_atan,
op_avg_pool2d,
Expand Down Expand Up @@ -147,9 +151,11 @@
op_lt,
op_matmul,
op_max,
op_max_dim,
op_max_pool2d,
op_mean_dim,
op_min,
op_min_dim,
op_mul,
op_neg,
op_ne,
Expand Down
78 changes: 78 additions & 0 deletions backends/qualcomm/builders/op_argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpArgmax, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class argmax(NodeVisitor):
target = ["aten.argmax.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
output_tensor = self.get_tensor(node, node)
argmax_inp_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
argmax_input_tensors = [argmax_inp_tensor_wrapper]
argmax_out_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor.to(torch.int32),
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
argmax_output_tensors = [argmax_out_tensor_wrapper]

dim = cast(int, node.args[1])
if dim < 0:
dim = dim % len(input_tensor.shape)
if QCOM_AXIS_ORDER in node.meta:
dim = node.meta[QCOM_AXIS_ORDER].index(dim)

argmax_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpArgmax.op_name,
)
argmax_op.AddInputTensors(argmax_input_tensors)
argmax_op.AddOutputTensors(argmax_output_tensors)

argmax_op.AddScalarParam(
OpArgmax.param_axis,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: np.uint32(dim)},
)

if len(node.args) > 2:
keep_dims = cast(bool, node.args[2])
argmax_op.AddScalarParam(
OpArgmax.param_keep_dims,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: keep_dims},
)

return argmax_op
Loading
Loading