Skip to content

Commit efd133d

Browse files
Arm backend: Add meandim support for rank > 4 (pytorch#12631)
Signed-off-by: Adrian Lundell <[email protected]>
1 parent 4d7f9ca commit efd133d

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from copy import copy
67
from math import prod
78

89
import torch
@@ -75,35 +76,47 @@ def call_operator(self, op, args, kwargs, meta):
7576
return super().call_operator(op, args, kwargs, meta)
7677

7778
x = get_node_arg(args, 0)
78-
input_shape = x.data.size()
79-
output_shape = meta["val"].size()
79+
input_shape = list(x.data.shape)
80+
output_shape = list(meta["val"].shape)
8081
dims_to_reduce = get_node_arg(args, 1)
8182
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
83+
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]
8284

8385
dtype = meta["val"].dtype
8486
view_op = get_view(op)
8587

86-
if len(input_shape) > 4:
87-
raise NotImplementedError(
88-
f"{op} with rank > 4 is currently not supported for the TOSA backend."
89-
)
88+
# Reshape to 4D
89+
if len(input_shape) != 4:
90+
new_shape = copy(input_shape)
91+
92+
while len(new_shape) < 4:
93+
new_shape.insert(0, 1)
94+
dims_to_reduce = [dim + 1 for dim in dims_to_reduce]
9095

91-
# Unsqueeze to 4D
92-
if len(input_shape) < 4:
93-
pad_n = 4 - len(input_shape)
94-
new_shape = [1] * pad_n + list(input_shape)
95-
dims_to_reduce = [dim + pad_n for dim in dims_to_reduce]
96+
while len(new_shape) > 4:
97+
i = new_shape.pop(0)
98+
new_shape[0] = new_shape[0] * i
99+
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
96100

97101
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
98102

99103
# Reduce (h,w) dims by avg pool if possible
100104
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
101105

106+
# Reshape back to 5D if necessary
107+
if len(input_shape) > 4:
108+
original_dims = input_shape[0:-4]
109+
temp_shape = list(x.data.shape)[1:]
110+
temp_shape = original_dims + temp_shape
111+
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
112+
113+
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
114+
102115
# Reduce remaining dims by sum
103116
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
104117

105118
# Reshape to correct output shape if necessary
106-
if x.data.size() != output_shape:
119+
if list(x.data.shape) != output_shape:
107120
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)
108121

109122
return x

backends/arm/test/ops/test_mean_dim.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ class MeanDim(torch.nn.Module):
195195
(-4, -3, -2, -1),
196196
False,
197197
),
198+
"rank5_01234": lambda: (
199+
torch.rand(1, 1, 7, 3, 2),
200+
(-5, -4, -3, -2, -1),
201+
False,
202+
),
203+
"rank5_234": lambda: (
204+
torch.rand(1, 1, 7, 3, 2),
205+
(-3, -2, -1),
206+
False,
207+
),
208+
"rank5_12": lambda: (
209+
torch.rand(1, 1, 7, 3, 2),
210+
(1, 2),
211+
False,
212+
),
198213
"u55_avg_pool_not_supported": lambda: (
199214
torch.rand(1, 1, 1, 257),
200215
(0, 1, 2, 3),
@@ -236,7 +251,14 @@ def test_mean_dim_tosa_BI(test_data):
236251
pipeline.run()
237252

238253

239-
@common.parametrize("test_data", MeanDim.test_data_suite)
254+
xfails = {
255+
"rank5_01234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
256+
"rank5_234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
257+
"rank5_12": "Rank 5 graph input currently not supported in EthosUBackend",
258+
}
259+
260+
261+
@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False)
240262
@common.XfailIfNoCorstone300
241263
def test_mean_dim_u55_BI(test_data):
242264
test_data, dim, keep_dim = test_data()
@@ -256,7 +278,7 @@ def test_mean_dim_u55_BI(test_data):
256278
pipeline.run()
257279

258280

259-
@common.parametrize("test_data", MeanDim.test_data_suite)
281+
@common.parametrize("test_data", MeanDim.test_data_suite, xfails=xfails, strict=False)
260282
@common.XfailIfNoCorstone320
261283
def test_mean_dim_u85_BI(test_data):
262284
test_data, dim, keep_dim = test_data()

0 commit comments

Comments
 (0)