From 8851f6b8363bbb75d62c332ed0e9c7d6c649af84 Mon Sep 17 00:00:00 2001 From: Teo Bergkvist Date: Mon, 30 Jun 2025 14:37:16 +0200 Subject: [PATCH] Arm backend: Add addmm decomposition pass and test Decomposes addmm into matmul and add operators. Change-Id: I9f67f6a6e6b63b22065ed2bc918bb4deaaaffd74 Signed-off-by: Teo Bergkvist --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 3 + backends/arm/_passes/decompose_addmm_pass.py | 60 +++++++ .../tosa_supported_operators.py | 2 + backends/arm/test/ops/test_addmm.py | 157 ++++++++++++++++++ 5 files changed, 223 insertions(+) create mode 100644 backends/arm/_passes/decompose_addmm_pass.py create mode 100644 backends/arm/test/ops/test_addmm.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1d6c34a0d35..286a5f4499b 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -24,6 +24,7 @@ from .convert_to_clamp import ConvertToClampPass # noqa from .decompose_acosh_pass import DecomposeAcoshPass # noqa from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa +from .decompose_addmm_pass import DecomposeAddmmPass # noqa from .decompose_asin_pass import DecomposeAsinPass # noqa from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_atanh_pass import DecomposeAtanhPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index bfae8f1b017..58a41f4d573 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -29,6 +29,7 @@ ConvertToClampPass, DecomposeAcoshPass, DecomposeAdaptiveAvgPool2dPass, + DecomposeAddmmPass, DecomposeAsinPass, DecomposeAtanhPass, DecomposeAtanPass, @@ -165,6 +166,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeSqrtPass()) self.add_pass(DecomposeAtanPass()) self.add_pass(DecomposeAtanhPass()) + self.add_pass(DecomposeAddmmPass()) self.add_pass(ConvertIntPowToMuls()) self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) @@ -258,6 +260,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeRoundPass()) self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSignPass()) + self.add_pass(DecomposeAddmmPass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) diff --git a/backends/arm/_passes/decompose_addmm_pass.py b/backends/arm/_passes/decompose_addmm_pass.py new file mode 100644 index 00000000000..b59a8cb02d3 --- /dev/null +++ b/backends/arm/_passes/decompose_addmm_pass.py @@ -0,0 +1,60 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + + +# For MI case +edge_addmm = exir_ops.edge.aten.addmm.default +# For BI case +aten_addmm = torch.ops.aten.addmm.default + + +def get_ops(op): + """Returns the appropriate operator functions based on the input operator.""" + if op == edge_addmm: + return ( + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.add.Tensor, + ) + elif op == aten_addmm: + return ( + torch.ops.aten.mm.default, + torch.ops.aten.mul.Scalar, + torch.ops.aten.add.Tensor, + ) + else: + raise ValueError(f"Unsupported operator: {op}") + + +class DecomposeAddmmPass(ArmPass): + """Decomposes the addmm operator into tensor multiplication and addition.""" + + def call_operator(self, op, args, kwargs, meta): + if op not in [edge_addmm, aten_addmm]: + return super().call_operator(op, args, kwargs, meta) + + input, mat1, mat2 = args + beta = kwargs.get("beta", 1.0) + alpha = kwargs.get("alpha", 1.0) + + mul_op, mul_scalar_op, add_op = get_ops(op) + + mul = super().call_operator(mul_op, (mat1, mat2), {}, meta, updated=True) + mul_alpha = super().call_operator( + mul_scalar_op, (mul, alpha), {}, meta, updated=True + ) + + input_beta = super().call_operator( + mul_scalar_op, (input, beta), {}, meta, updated=True + ) + + return super().call_operator( + add_op, (mul_alpha, input_beta), {}, meta, updated=True + ) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 1189bd969a9..6d1b8e66c2f 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -253,6 +253,7 @@ def is_node_supported( exir_ops.edge.aten.sign.default, exir_ops.edge.aten.asin.default, exir_ops.edge.aten.atanh.default, + exir_ops.edge.aten.addmm.default, ] return supported @@ -293,6 +294,7 @@ def is_node_supported( exir_ops.edge.aten.div.Scalar: None, exir_ops.edge.aten.leaky_relu.default: None, exir_ops.edge.aten.round.default: None, + exir_ops.edge.aten.addmm.default: None, } if node.target in needs_decomp_dict: diff --git a/backends/arm/test/ops/test_addmm.py b/backends/arm/test/ops/test_addmm.py new file mode 100644 index 00000000000..7da5596ab00 --- /dev/null +++ b/backends/arm/test/ops/test_addmm.py @@ -0,0 +1,157 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.addmm.default" + +exir_op = "executorch_exir_dialects_edge__ops_aten__addmm_default" + +input_t1 = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] # Input x1, x2, x3 + + +test_data_suite = { + "basic": [ + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + torch.tensor([[1.0, 0.0], [0.0, 1.0]]), + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + 1.0, + 1.0, + ], + "zeros": [torch.zeros(2, 2), torch.zeros(2, 3), torch.zeros(3, 2), 1.0, 1.0], + "beta_only": [ + torch.tensor([[10.0, 20.0], [30.0, 40.0]]), + torch.randn(2, 3), + torch.randn(3, 2), + 0.0, + 1.0, + ], + "alpha_only": [ + torch.tensor([[10.0, 20.0], [30.0, 40.0]]), + torch.randn(2, 3), + torch.randn(3, 2), + 1.0, + 0.0, + ], + "scaled": [ + torch.ones(2, 2), + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + torch.tensor([[5.0, 6.0], [7.0, 8.0]]), + 0.5, + 2.0, + ], + "negative_scalars": [ + torch.tensor([[1.0, -1.0], [-1.0, 1.0]]), + torch.tensor([[2.0, 0.0], [0.0, 2.0]]), + torch.tensor([[1.0, 1.0], [1.0, 1.0]]), + -1.0, + -1.0, + ], + "non_square": [torch.ones(3, 4), torch.rand(3, 2), torch.rand(2, 4), 1.0, 1.0], + "large_values": [ + torch.full((2, 2), 1e6), + torch.full((2, 3), 1e3), + torch.full((3, 2), 1e3), + 1.0, + 1.0, + ], + "small_values": [ + torch.full((2, 2), 1e-6), + torch.full((2, 3), 1e-3), + torch.full((3, 2), 1e-3), + 1.0, + 1.0, + ], + "random": [torch.randn(4, 5), torch.randn(4, 3), torch.randn(3, 5), 1.0, 1.0], + "broadcast_bias_row": [ + torch.randn(1, 2), + torch.randn(3, 4), + torch.randn(4, 2), + 1.0, + 1.0, + ], + "row_bias": [ + torch.randn(3, 1), + torch.randn(3, 4), + torch.randn(4, 4), + 1.0, + 1.0, + ], + "scalar_bias": [ + torch.tensor(2.0), + torch.randn(5, 3), + torch.randn(3, 6), + 1.0, + 1.0, + ], +} + + +class Addmm(torch.nn.Module): + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + x3: torch.Tensor, + alpha: float, + beta: float, + ) -> torch.Tensor: + return torch.addmm(x1, x2, x3, alpha=alpha, beta=beta) + + +@common.parametrize("test_data", test_data_suite) +def test_addmm_tosa_MI(test_data: Tuple): + pipeline = TosaPipelineMI[input_t1]( + Addmm(), + (*test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_addmm_tosa_BI(test_data: Tuple): + pipeline = TosaPipelineBI[input_t1]( + Addmm(), + (*test_data,), + aten_op=[], + exir_op=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_addmm_u55_BI(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t1]( + Addmm(), + (*test_data,), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_addmm_u85_BI(test_data: Tuple): + pipeline = EthosU85PipelineBI[input_t1]( + Addmm(), + (*test_data,), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.run()