@@ -24,19 +24,7 @@ static diopiError_t transpose(diopiContextHandle_t& ctx, DiopiTensor& in, DiopiT
24
24
return diopiSuccess;
25
25
}
26
26
27
- // static diopiError_t calTensordiopiMemoryFormat_t(const DiopiTensor& tensor, diopiMemoryFormat_t& memoryFormatOut) {
28
- // if (tensor.isContiguous(diopiMemoryFormat_t::ChannelsLast)) {
29
- // memoryFormatOut = diopiMemoryFormat_t::ChannelsLast;
30
- // } else if (tensor.isContiguous(diopiMemoryFormat_t::ChannelsLast3d)) {
31
- // memoryFormatOut = diopiMemoryFormat_t::ChannelsLast3d;
32
- // } else if (tensor.isContiguous(diopiMemoryFormat_t::Contiguous)) {
33
- // memoryFormatOut = diopiMemoryFormat_t::Contiguous;
34
- // } else {
35
- // return diopiNoImplement;
36
- // }
37
- // return diopiSuccess;
38
- // }
39
- static diopiError_t getPermuteOrder (const DiopiTensor& src, std::vector<int32_t >& orderOut, std::vector<int32_t >& reverseOrder) {
27
+ diopiError_t getPermuteOrder (const DiopiTensor& src, std::vector<int32_t >& orderOut, std::vector<int32_t >& reverseOrder) {
40
28
if (src.isContiguous ()) {
41
29
orderOut.resize (src.dim ());
42
30
for (int i = 0 ; i < src.dim (); ++i) {
@@ -59,6 +47,7 @@ static diopiError_t getPermuteOrder(const DiopiTensor& src, std::vector<int32_t>
59
47
stridesSizes[i] = std::pair<int , int >(inputStrides[i], inputSizes[i]);
60
48
}
61
49
50
+ // shape:2,3,4,5 stride:60,1,15,3 -> orderOut: 0,3,1,2, reverseOrder: 0,2,3,1
62
51
sort (stridesSizes.begin (), stridesSizes.end (), [](std::pair<int , int > a, std::pair<int , int > b) { return a.first > b.first ; });
63
52
for (int i = 0 ; i < dim; ++i) {
64
53
auto pair = stridesSizes[i];
@@ -83,73 +72,6 @@ static diopiError_t getPermuteOrder(const DiopiTensor& src, std::vector<int32_t>
83
72
return diopiSuccess;
84
73
}
85
74
86
- static diopiError_t calOrderAndSrcMemoryFormat (const DiopiTensor& src, diopiMemoryFormat_t destMemoryFormat, diopiMemoryFormat_t& srcMemoryFormatOut,
87
- std::vector<int32_t >& orderOut, std::vector<int32_t >& reverseOrder) {
88
- if (src.isContiguous (destMemoryFormat)) {
89
- srcMemoryFormatOut = destMemoryFormat;
90
- orderOut.resize (src.dim ());
91
- for (int i = 0 ; i < src.dim (); ++i) {
92
- orderOut[i] = i;
93
- }
94
- reverseOrder = orderOut;
95
- return diopiSuccess;
96
- }
97
- if (src.isContiguous (diopiMemoryFormat_t::ChannelsLast1d) && destMemoryFormat == diopiMemoryFormat_t::Contiguous) {
98
- if (src.dim () != 3 ) {
99
- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
100
- return diopiNoImplement;
101
- }
102
- srcMemoryFormatOut = diopiMemoryFormat_t::ChannelsLast1d;
103
- orderOut = {0 , 2 , 1 };
104
- reverseOrder = {0 , 2 , 1 };
105
- } else if (src.isContiguous (diopiMemoryFormat_t::Contiguous) && destMemoryFormat == diopiMemoryFormat_t::ChannelsLast1d) {
106
- if (src.dim () != 3 ) {
107
- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
108
- return diopiNoImplement;
109
- }
110
- srcMemoryFormatOut = diopiMemoryFormat_t::Contiguous;
111
- orderOut = {0 , 2 , 1 };
112
- reverseOrder = {0 , 2 , 1 };
113
- } else if (src.isContiguous (diopiMemoryFormat_t::ChannelsLast) && destMemoryFormat == diopiMemoryFormat_t::Contiguous) {
114
- if (src.dim () != 4 ) {
115
- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
116
- return diopiNoImplement;
117
- }
118
- srcMemoryFormatOut = diopiMemoryFormat_t::ChannelsLast;
119
- orderOut = {0 , 3 , 1 , 2 };
120
- reverseOrder = {0 , 2 , 3 , 1 };
121
- } else if (src.isContiguous (diopiMemoryFormat_t::Contiguous) && destMemoryFormat == diopiMemoryFormat_t::ChannelsLast) {
122
- if (src.dim () != 4 ) {
123
- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
124
- return diopiNoImplement;
125
- }
126
- srcMemoryFormatOut = diopiMemoryFormat_t::Contiguous;
127
- orderOut = {0 , 2 , 3 , 1 };
128
- reverseOrder = {0 , 3 , 1 , 2 };
129
- } else if (src.isContiguous (diopiMemoryFormat_t::Contiguous) && destMemoryFormat == diopiMemoryFormat_t::ChannelsLast3d) {
130
- if (src.dim () != 5 ) {
131
- setLastErrorString (" the dim of the tensor should be 5, but now is %d." , src.dim ());
132
- return diopiNoImplement;
133
- }
134
- srcMemoryFormatOut = diopiMemoryFormat_t::Contiguous;
135
- orderOut = {0 , 2 , 3 , 4 , 1 };
136
- reverseOrder = {0 , 4 , 1 , 2 , 3 };
137
- } else if (src.isContiguous (diopiMemoryFormat_t::ChannelsLast3d) && destMemoryFormat == diopiMemoryFormat_t::Contiguous) {
138
- if (src.dim () != 5 ) {
139
- setLastErrorString (" the dim of the tensor should be 5, but now is %d." , src.dim ());
140
- return diopiNoImplement;
141
- }
142
- srcMemoryFormatOut = diopiMemoryFormat_t::ChannelsLast3d;
143
- orderOut = {0 , 4 , 1 , 2 , 3 };
144
- reverseOrder = {0 , 2 , 3 , 4 , 1 };
145
- } else {
146
- // convert to contiguous format
147
- srcMemoryFormatOut = diopiMemoryFormat_t::Preserve;
148
- return diopiSuccess;
149
- }
150
- return diopiSuccess;
151
- }
152
-
153
75
diopiError_t calCnnlLayout (diopiMemoryFormat_t memoryFormat, int64_t dim, cnnlTensorLayout_t& cnnlLayout) {
154
76
switch (memoryFormat) {
155
77
case diopiMemoryFormat_t::ChannelsLast1d:
@@ -234,68 +156,61 @@ diopiError_t contiguous(diopiContextHandle_t ctx, DiopiTensor& src, diopiMemoryF
234
156
235
157
int64_t dim = src.dim ();
236
158
DIOPI_CHECK (dim <= 8 , " only support less than 8d tensor currently" );
237
- diopiMemoryFormat_t srcMemoryFormat;
238
- std::vector<int32_t > order;
239
- std::vector<int32_t > reverseOrder;
240
159
DiopiTensor dest;
241
- DIOPI_CALL (calOrderAndSrcMemoryFormat (src, memoryFormat, srcMemoryFormat, order, reverseOrder));
242
- if (srcMemoryFormat == diopiMemoryFormat_t::Preserve) {
243
- DIOPI_CALL (clone (ctx, src, dest, memoryFormat));
244
- src = dest;
245
- return diopiSuccess;
246
- }
247
- dest = requiresTensor (ctx, src.shape (), src.dtype (), memoryFormat);
248
- // set CNNL_LAYOUT_ARRAY because NLC->NCL failed ( no layout NCL);
249
- cnnlTensorLayout_t srcLayout = CNNL_LAYOUT_ARRAY;
250
- cnnlTensorLayout_t destLayout = CNNL_LAYOUT_ARRAY;
251
-
252
- std::vector<int64_t > olderDestStride = dest.stride ();
253
- std::vector<int64_t > olderDestShape = dest.shape ();
254
- if (memoryFormat != diopiMemoryFormat_t::Contiguous) {
255
- DIOPI_CALL (permuteTensor (dest, order));
256
- } else {
257
- DIOPI_CALL (permuteTensor (src, reverseOrder));
258
- }
259
- DIOPI_CALL (transpose (ctx, src, dest, srcLayout, destLayout, order));
260
- // recovery the shape
261
- dest.asStrided (olderDestShape, olderDestStride);
160
+ DIOPI_CALL (clone (ctx, src, dest, memoryFormat));
262
161
src = dest;
263
162
return diopiSuccess;
264
163
}
265
164
266
- // inplace contiguous
267
- diopiError_t contiguousOut (diopiContextHandle_t ctx, DiopiTensor& src, DiopiTensor& dest) {
165
+ diopiError_t permuteCopy (diopiContextHandle_t ctx, DiopiTensor& src, DiopiTensor& dest) {
166
+ // using input permute + output permute + cnnltranspose to copy
268
167
DIOPI_CHECK (src.shape () == dest.shape (), " src's shape should be the same as dest's" );
269
168
int64_t dim = src.dim ();
270
169
DIOPI_CHECK (dim <= 8 , " only support less than 8d tensor currently" );
271
- std::vector<int32_t > order (dim, 0 );
272
- std::vector<int32_t > reverseOrder (dim, 0 );
170
+ bool srcIsContiguous = src.isContiguous ();
171
+ bool destIsContiguous = dest.isContiguous ();
172
+ std::vector<int32_t > inputOrder (dim, 0 );
173
+ std::vector<int32_t > inputBackOrder (dim, 0 ); // permuteTensor(input,inputBackOrder)->contiguous
174
+ std::vector<int32_t > outputOrder (dim, 0 );
175
+ std::vector<int32_t > outputBackOrder (dim, 0 ); // permuteTensor(output,outputBackOrder)->contiguous
176
+ std::vector<int32_t > inputToOutputOrder (dim, 0 ); // into cnnltranspose
177
+
178
+ // input shape:2,3,4,5 stride:60,1,15,3 -> inputBackOrder: 0,2,3,1, inputOrder: 0,3,1,2
179
+ // output shape:2,3,4,5 stride:60,20,1,4 -> outputBackOrder: 0,1,3,2, outputOrder: 0,1,3,2
180
+ // inputToOutputOrder: 0,2,1,3
181
+
182
+ getPermuteOrder (src, inputOrder, inputBackOrder);
183
+ getPermuteOrder (dest, outputOrder, outputBackOrder);
273
184
274
- if (src.isContiguous ()) {
275
- getPermuteOrder (dest, reverseOrder, order);
276
- } else {
277
- getPermuteOrder (src, order, reverseOrder);
278
- }
279
- // set CNNL_LAYOUT_ARRAY because NLC->NCL failed ( no layout NCL);
280
185
cnnlTensorLayout_t srcLayout = CNNL_LAYOUT_ARRAY;
281
186
cnnlTensorLayout_t destLayout = CNNL_LAYOUT_ARRAY;
282
187
283
188
std::vector<int64_t > olderDestStride = dest.stride ();
284
189
std::vector<int64_t > olderDestShape = dest.shape ();
285
190
std::vector<int64_t > olderSrcStride = src.stride ();
286
191
std::vector<int64_t > olderSrcShape = src.shape ();
287
- // if (destMemoryFormat != diopiMemoryFormat_t::Contiguous) {
288
- if (src.isContiguous ()) {
289
- DIOPI_CALL (permuteTensor (dest, order));
290
- } else {
291
- DIOPI_CALL (permuteTensor (src, reverseOrder));
192
+
193
+ // permute to get contiguous tensor
194
+ if (!destIsContiguous) {
195
+ DIOPI_CALL (permuteTensor (dest, outputBackOrder));
196
+ }
197
+
198
+ if (!srcIsContiguous) {
199
+ DIOPI_CALL (permuteTensor (src, inputBackOrder));
200
+ }
201
+
202
+ for (int i = 0 ; i < dim; ++i) {
203
+ inputToOutputOrder[i] = inputOrder[outputBackOrder[i]];
292
204
}
293
- DIOPI_CALL (transpose (ctx, src, dest, srcLayout, destLayout, order));
205
+
206
+ DIOPI_CALL (transpose (ctx, src, dest, srcLayout, destLayout, inputToOutputOrder));
207
+
294
208
// recovery the shape and strides
295
- // if (destMemoryFormat != diopiMemoryFormat_t::Contiguous) {
296
- if (src.isContiguous ()) {
209
+ if (!destIsContiguous) {
297
210
dest.asStrided (olderDestShape, olderDestStride);
298
- } else {
211
+ }
212
+
213
+ if (!srcIsContiguous) {
299
214
src.asStrided (olderSrcShape, olderSrcStride);
300
215
}
301
216
return diopiSuccess;
0 commit comments