|
3 | 3 | # This source code is licensed under the BSD-style license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
| 6 | +from copy import copy |
6 | 7 | from math import prod
|
7 | 8 |
|
8 | 9 | import torch
|
@@ -75,35 +76,47 @@ def call_operator(self, op, args, kwargs, meta):
|
75 | 76 | return super().call_operator(op, args, kwargs, meta)
|
76 | 77 |
|
77 | 78 | x = get_node_arg(args, 0)
|
78 |
| - input_shape = x.data.size() |
79 |
| - output_shape = meta["val"].size() |
| 79 | + input_shape = list(x.data.shape) |
| 80 | + output_shape = list(meta["val"].shape) |
80 | 81 | dims_to_reduce = get_node_arg(args, 1)
|
81 | 82 | dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
|
| 83 | + dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] |
82 | 84 |
|
83 | 85 | dtype = meta["val"].dtype
|
84 | 86 | view_op = get_view(op)
|
85 | 87 |
|
86 |
| - if len(input_shape) > 4: |
87 |
| - raise NotImplementedError( |
88 |
| - f"{op} with rank > 4 is currently not supported for the TOSA backend." |
89 |
| - ) |
| 88 | + # Reshape to 4D |
| 89 | + if len(input_shape) != 4: |
| 90 | + new_shape = copy(input_shape) |
| 91 | + |
| 92 | + while len(new_shape) < 4: |
| 93 | + new_shape.insert(0, 1) |
| 94 | + dims_to_reduce = [dim + 1 for dim in dims_to_reduce] |
90 | 95 |
|
91 |
| - # Unsqueeze to 4D |
92 |
| - if len(input_shape) < 4: |
93 |
| - pad_n = 4 - len(input_shape) |
94 |
| - new_shape = [1] * pad_n + list(input_shape) |
95 |
| - dims_to_reduce = [dim + pad_n for dim in dims_to_reduce] |
| 96 | + while len(new_shape) > 4: |
| 97 | + i = new_shape.pop(0) |
| 98 | + new_shape[0] = new_shape[0] * i |
| 99 | + dims_to_reduce = [dim - 1 for dim in dims_to_reduce] |
96 | 100 |
|
97 | 101 | x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
|
98 | 102 |
|
99 | 103 | # Reduce (h,w) dims by avg pool if possible
|
100 | 104 | x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
|
101 | 105 |
|
| 106 | + # Reshape back to 5D if necessary |
| 107 | + if len(input_shape) > 4: |
| 108 | + original_dims = input_shape[0:-4] |
| 109 | + temp_shape = list(x.data.shape)[1:] |
| 110 | + temp_shape = original_dims + temp_shape |
| 111 | + dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce] |
| 112 | + |
| 113 | + x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) |
| 114 | + |
102 | 115 | # Reduce remaining dims by sum
|
103 | 116 | x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
|
104 | 117 |
|
105 | 118 | # Reshape to correct output shape if necessary
|
106 |
| - if x.data.size() != output_shape: |
| 119 | + if list(x.data.shape) != output_shape: |
107 | 120 | x = super().call_operator(view_op, (x, output_shape), {}, meta, True)
|
108 | 121 |
|
109 | 122 | return x
|
|
0 commit comments