diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 01710aa8d80..aaf65afd279 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -17,6 +17,7 @@ from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm +from .decompose_minmaxdim import DecomposeMinMaxDim from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast @@ -54,6 +55,7 @@ DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, + DecomposeMinMaxDim, DecomposeRoll, DecomposeSilu, DecomposeWrapWithAutocast, diff --git a/backends/qualcomm/_passes/decompose_minmaxdim.py b/backends/qualcomm/_passes/decompose_minmaxdim.py new file mode 100644 index 00000000000..0b79b04518e --- /dev/null +++ b/backends/qualcomm/_passes/decompose_minmaxdim.py @@ -0,0 +1,109 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import operator +from collections import Counter + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +class DecomposeMinMaxDim(ExportPass): + """ + Since QNN does not support multi-output ops, this pass decomposes + `torch.min(dim=...)` and `torch.max(dim=...)` into two separate operations: + - `aten.min.dim` / `aten.max.dim` for the value + - `aten.argmin` / `aten.argmax` for the index + + Example transformation in the exported FX graph: + + Python source: + val, idx = torch.min(x, dim=1) + + Before: + %min = aten.min(%x, dim=1) + %val = getitem(%min, 0) + %idx = getitem(%min, 1) + + After: + %min = aten.min(%x, dim=1) + %val = getitem(%min, 0) + %idx = aten.argmin(%x, dim=1) + + This pass preserves the value output if used, and transforms only the index path. + """ + + def __init__(self): + super().__init__() + self.min_dim = exir_ops.edge.aten.min.dim + self.max_dim = exir_ops.edge.aten.max.dim + self.argmin = exir_ops.edge.aten.argmin.default + self.argmax = exir_ops.edge.aten.argmax.default + self.getitem = operator.getitem + + # index-only op + self.replace_table = { + self.min_dim: self.argmin, + self.max_dim: self.argmax, + } + + self.patterns = [ + # Only index is used (e.g., _, idx = torch.min(x, dim=1)) + {self.min_dim: 1, self.getitem: 1}, + {self.max_dim: 1, self.getitem: 1}, + # Both value and index are used (e.g., val, idx = torch.max(x, dim=1)) + {self.min_dim: 1, self.getitem: 2}, + {self.max_dim: 1, self.getitem: 2}, + ] + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + partitions = get_source_partitions(graph, [torch.min, torch.max]) + for _, src_partitions in partitions.items(): + for partition in src_partitions: + if Counter([n.target for n in partition.nodes]) not in self.patterns: + continue + binary_output_node = partition.nodes[0] + + # Ensure the binary-output node has exactly 2 outputs + if len(binary_output_node.meta["val"]) != 2: + continue + + input_tensor = binary_output_node.args[0] + dim = binary_output_node.args[1] + keepdim = ( + binary_output_node.args[2] + if len(binary_output_node.args) > 2 + else False + ) + + idx_node = next( + ( + output_node + for output_node in partition.output_nodes + if output_node.meta["val"].dtype == torch.int64 + ), + None, + ) + + if idx_node: + with graph.inserting_before(idx_node): + argmin_node = graph.create_node( + "call_function", + self.replace_table[binary_output_node.target], + (input_tensor, dim, keepdim), + ) + argmin_node.meta = idx_node.meta + + for user in list(idx_node.users): + user.replace_input_with(idx_node, argmin_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py index 5c9310c3d59..986dd60543f 100644 --- a/backends/qualcomm/_passes/i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -26,6 +26,7 @@ class I64toI32(ExportPass): """ I64_OPS = { + exir_ops.edge.aten.argmax.default, exir_ops.edge.aten.argmin.default, exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.cumsum.default, diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 0c6b3152561..919c445fafa 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -91,8 +91,10 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.logical_not.default, exir_ops.edge.aten.lt.Scalar, exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.max.dim, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.min.dim, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.ne.Scalar, @@ -167,10 +169,12 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool: return node.target in self.layout_sensitive_ops def is_layout_agnostic(self, node: torch.fx.Node) -> bool: - if node.target in [ + if node.target in { + exir_ops.edge.aten.max.dim, exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.min.dim, exir_ops.edge.aten.sum.dim_IntList, - ]: + }: # if dimemsion is not kept, we'll have no clue how to do layout transform if len(node.args) < 3 or not node.args[2]: return False diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 8340fa6209e..152433195cd 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -22,6 +22,7 @@ DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, + DecomposeMinMaxDim, DecomposeRoll, DecomposeSilu, DecomposeWrapWithAutocast, @@ -84,6 +85,7 @@ def get_capture_program_passes(): (ConvertConv1dToConv2d, True), (DecomposeAny, True), (DecomposeColIm, True), + (DecomposeMinMaxDim, True), (ExpandBroadcastTensorShape, False), (FixedLinearKeepDim, True), (FoldQDQ, True), diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 4e150f1eeaa..9c62e1080fe 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -360,9 +360,9 @@ The operator now should be functional for Qualcomm backends. For operator to wor ## Operator Support Status Please help update following table if you are contributing new operators: -| Operators | HTP - 80/116 Enabled | +| Operators | HTP - 82/116 Enabled | |-----------|---------| -| Argmax | ✗ | +| Argmax | ✓ | | Argmin | ✓ | | BatchNorm | ✓ | | BatchToSpace | ✗ | @@ -449,7 +449,7 @@ Please help update following table if you are contributing new operators: | Quantize | ✓ | | ReduceMax | ✓ | | ReduceMean | ✓ | -| ReduceMin | ✗ | +| ReduceMin | ✓ | | ReduceSum | ✓ | | Relu | ✓ | | Relu1 | ✗ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index f8b2f11ff4c..709df7006f8 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -12,6 +12,7 @@ op_amax, op_and, op_arange, + op_argmax, op_argmin, op_atan, op_avg_pool2d, @@ -54,9 +55,11 @@ op_lt, op_matmul, op_max, + op_max_dim, op_max_pool2d, op_mean_dim, op_min, + op_min_dim, op_mul, op_ne, op_neg, @@ -105,6 +108,7 @@ op_amax, op_and, op_arange, + op_argmax, op_argmin, op_atan, op_avg_pool2d, @@ -147,9 +151,11 @@ op_lt, op_matmul, op_max, + op_max_dim, op_max_pool2d, op_mean_dim, op_min, + op_min_dim, op_mul, op_neg, op_ne, diff --git a/backends/qualcomm/builders/op_argmax.py b/backends/qualcomm/builders/op_argmax.py new file mode 100644 index 00000000000..e81b0dd1d95 --- /dev/null +++ b/backends/qualcomm/builders/op_argmax.py @@ -0,0 +1,78 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpArgmax, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class argmax(NodeVisitor): + target = ["aten.argmax.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + output_tensor = self.get_tensor(node, node) + argmax_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + argmax_input_tensors = [argmax_inp_tensor_wrapper] + argmax_out_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor.to(torch.int32), + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + argmax_output_tensors = [argmax_out_tensor_wrapper] + + dim = cast(int, node.args[1]) + if dim < 0: + dim = dim % len(input_tensor.shape) + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER].index(dim) + + argmax_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpArgmax.op_name, + ) + argmax_op.AddInputTensors(argmax_input_tensors) + argmax_op.AddOutputTensors(argmax_output_tensors) + + argmax_op.AddScalarParam( + OpArgmax.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(dim)}, + ) + + if len(node.args) > 2: + keep_dims = cast(bool, node.args[2]) + argmax_op.AddScalarParam( + OpArgmax.param_keep_dims, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: keep_dims}, + ) + + return argmax_op diff --git a/backends/qualcomm/builders/op_max_dim.py b/backends/qualcomm/builders/op_max_dim.py new file mode 100644 index 00000000000..354444da550 --- /dev/null +++ b/backends/qualcomm/builders/op_max_dim.py @@ -0,0 +1,89 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpReduceMax, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class MaxDim(NodeVisitor): + target = ["aten.max.dim"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> List[PyQnnWrapper.PyQnnOpWrapper]: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + # QNN does not support multiple outputs for a single op. + # Since torch.max(input, dim) returns both values and indices, + # we only support the value output for OpReduceMax. The index output will be handled + # separately by OpArgmax. + # Therefore, we update node.meta["val"] to only keep the value part. + if len(node.meta["val"]) == 2: + node.meta["val"] = node.meta["val"][0] + + output_tensor = self.get_tensor(node, node) + out_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + dims = cast(List[int], [node.args[1]]) + dims = [max_dim % len(input_node.meta["val"].shape) for max_dim in dims] + if QCOM_AXIS_ORDER in node.meta: + dims = [node.meta[QCOM_AXIS_ORDER].index(max_dim) for max_dim in dims] + dims_shape = [len(dims)] + + reduce_max_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReduceMax.op_name, + ) + reduce_max_op.AddInputTensors([input_tensor_wrapper]) + reduce_max_op.AddOutputTensors([out_tensor_wrapper]) + + reduce_max_op.AddTensorParam( + OpReduceMax.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(dims_shape), + dims_shape, + np.array(dims, dtype=np.uint32), + True, + ) + if len(node.args) > 2: + keep_dims = cast(bool, node.args[2]) + reduce_max_op.AddScalarParam( + OpReduceMax.param_keep_dims, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: keep_dims}, + ) + + return reduce_max_op diff --git a/backends/qualcomm/builders/op_min_dim.py b/backends/qualcomm/builders/op_min_dim.py new file mode 100644 index 00000000000..6425a9aa755 --- /dev/null +++ b/backends/qualcomm/builders/op_min_dim.py @@ -0,0 +1,89 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpReduceMin, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class MinDim(NodeVisitor): + target = ["aten.min.dim"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> List[PyQnnWrapper.PyQnnOpWrapper]: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + # QNN does not support multiple outputs for a single op. + # Since torch.min(input, dim) returns both values and indices, + # we only support the value output for OpReduceMin. The index output will be handled + # separately by OpArgmin. + # Therefore, we update node.meta["val"] to only keep the value part. + if len(node.meta["val"]) == 2: + node.meta["val"] = node.meta["val"][0] + + output_tensor = self.get_tensor(node, node) + out_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + dims = cast(List[int], [node.args[1]]) + dims = [min_dim % len(input_node.meta["val"].shape) for min_dim in dims] + if QCOM_AXIS_ORDER in node.meta: + dims = [node.meta[QCOM_AXIS_ORDER].index(min_dim) for min_dim in dims] + dims_shape = [len(dims)] + + reduce_min_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReduceMin.op_name, + ) + reduce_min_op.AddInputTensors([input_tensor_wrapper]) + reduce_min_op.AddOutputTensors([out_tensor_wrapper]) + + reduce_min_op.AddTensorParam( + OpReduceMin.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(dims_shape), + dims_shape, + np.array(dims, dtype=np.uint32), + True, + ) + if len(node.args) > 2: + keep_dims = cast(bool, node.args[2]) + reduce_min_op.AddScalarParam( + OpReduceMin.param_keep_dims, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: keep_dims}, + ) + + return reduce_min_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index aa245442f67..74ffe24e3c4 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -14,6 +14,13 @@ # instead of replicating them here. +@dataclass(init=False, frozen=True) +class OpArgmax: + op_name: str = "Argmax" + param_axis: str = "axis" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpArgmin: op_name: str = "Argmin" @@ -399,6 +406,13 @@ class OpReduceMean: param_keep_dims: str = "keep_dims" +@dataclass(init=False, frozen=True) +class OpReduceMin: + op_name: str = "ReduceMin" + param_axes: str = "axes" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpReduceSum: op_name: str = "ReduceSum" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index cc7e0054ebe..58b1a036955 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -213,10 +213,13 @@ def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.argmax.default]) +def annotate_argmax(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in(node, quantization_config) + + @register_annotator([torch.ops.aten.argmin.default]) def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None: - if _is_annotated([node]): - return annotate_single_in(node, quantization_config) @@ -285,11 +288,21 @@ def annotate_max(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.max.dim]) +def annotate_max_dim(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in(node, quantization_config) + + @register_annotator([torch.ops.aten.min.other, torch.ops.aten.minimum.default]) def annotate_min(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.min.dim]) +def annotate_min_dim(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in(node, quantization_config) + + @register_annotator( [torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor] ) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 6e396c69c54..b091819f0ff 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -119,6 +119,15 @@ def forward(self, y): ) +class Argmax(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.argmax(x, dim=0, keepdim=True) + return x + + class Argmin(torch.nn.Module): def __init__(self): super().__init__() @@ -1146,6 +1155,15 @@ def forward(self, attn_mask): ) +class MaxDim(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, logits): + max_logits, max_indices = torch.max(logits, dim=1) + return max_logits, max_indices + + class Maximum(torch.nn.Module): def __init__(self): super().__init__() @@ -1154,6 +1172,15 @@ def forward(self, x, y): return torch.maximum(x, y) +class MinDim(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, logits): + min_logits, min_indices = torch.min(logits, dim=1) + return min_logits, min_indices + + class Minimum(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index b697e81f2d1..a7942404d18 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -152,6 +152,11 @@ def test_qnn_backend_arange(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_argmax(self): + module = Argmax() # noqa: F405 + sample_input = (torch.randn(16, 3, 4, 4),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_argmin(self): module = Argmin() # noqa: F405 sample_input = (torch.randn(16, 3, 4, 4),) @@ -861,6 +866,11 @@ def test_qnn_backend_maximum(self): sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_max_dim(self): + module = MaxDim() # noqa: F405 + sample_input = (torch.randn(4, 10),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_max_pool2d(self): module = MaxPool2d() # noqa: F405 sample_input = (torch.randn(4, 3, 24, 24),) @@ -884,6 +894,11 @@ def test_qnn_backend_minimum(self): sample_input = (torch.randn(1, 2, 3, 4), torch.randn(2, 3, 4)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_min_dim(self): + module = MinDim() # noqa: F405 + sample_input = (torch.randn(4, 10),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_neg(self): module = Neg() # noqa: F405 sample_input = (torch.randn(1, 4, 16, 16),) @@ -1424,6 +1439,12 @@ def test_qnn_backend_arange(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_argmax(self): + module = Argmax() # noqa: F405 + sample_input = (torch.randn(16, 3, 4, 4),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_argmin(self): module = Argmin() # noqa: F405 sample_input = (torch.randn(16, 3, 4, 4),) @@ -2219,6 +2240,12 @@ def test_qnn_backend_maximum(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_max_dim(self): + module = MaxDim() # noqa: F405 + sample_input = (torch.randn(4, 10),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_max_pool2d(self): module = MaxPool2d() # noqa: F405 sample_input = (torch.randn(4, 3, 24, 24),) @@ -2245,6 +2272,12 @@ def test_qnn_backend_minimum(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_min_dim(self): + module = MinDim() # noqa: F405 + sample_input = (torch.randn(4, 10),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_neg(self): module = Neg() # noqa: F405 sample_input = (torch.randn(1, 4, 16, 16),)