Skip to content

Commit 3b21d94

Browse files
committed
[TOSA] Add transposed conv support
Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer TOSA xfails to drop the transpose-conv cases that now pass, and document the weight layout mapping. Change-Id: I709579e40a1ccaf9b9188392c7c78fcb653109ce
1 parent 8d563af commit 3b21d94

File tree

4 files changed

+140
-23
lines changed

4 files changed

+140
-23
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 120 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,9 +2306,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23062306
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
23072307
return rewriter.notifyMatchFailure(
23082308
op, "Unimplemented: non-constant value for transposed not supported");
2309-
if (transposed)
2310-
return rewriter.notifyMatchFailure(
2311-
op, "Unimplemented: transposed convolution not supported");
23122309

23132310
auto input = adaptor.getInput();
23142311
auto weight = adaptor.getWeight();
@@ -2340,12 +2337,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23402337
auto bias = adaptor.getBias();
23412338

23422339
if (isa<Torch::NoneType>(bias.getType())) {
2343-
auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy,
2344-
outputElemTy, weightShape);
2345-
if (failed(bias_result))
2340+
SmallVector<int64_t, 4> biasWeightShape =
2341+
transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0],
2342+
weightShape[2], weightShape[3]}
2343+
: weightShape;
2344+
2345+
auto biasResult = tosa::getConvBiasForNoneType(
2346+
op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
2347+
if (failed(biasResult))
23462348
return rewriter.notifyMatchFailure(
23472349
op, "Failed to create bias tensor for none type.");
2348-
bias = bias_result.value();
2350+
bias = biasResult.value();
23492351
} else {
23502352
if (!isa<RankedTensorType>(bias.getType()))
23512353
return rewriter.notifyMatchFailure(
@@ -2372,8 +2374,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23722374
m_TorchListOfConstantInts(padding_2d)))
23732375
return rewriter.notifyMatchFailure(op,
23742376
"non-const padding list unsupported");
2375-
// TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
2376-
// padding {height, width}. The Torch OFM computation uses 2*pad in each
2377+
// TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
2378+
// padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23772379
// spatial direction, implying the same top=bottom=height and left=right=width
23782380
// values for TOSA.
23792381
SmallVector<int64_t> padding(
@@ -2390,11 +2392,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23902392
return rewriter.notifyMatchFailure(
23912393
op, "failed to get accumulator type for convolution ops");
23922394

2395+
// Weight layout reference:
2396+
// Conv : PyTorch OIHW -> TOSA OHWI
2397+
// Depthwise : PyTorch OIHW* -> TOSA HWIM
2398+
// (PyTorch depthwise uses out_ch=in_ch*depth_multiplier)
2399+
// Grouped : PyTorch O(I/G)HW -> N/A
2400+
// Transposed : PyTorch IOHW -> TOSA OHWI
23932401
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
23942402
// Perform the necessary transformations.
23952403
SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1});
2396-
SmallVector<int64_t> transposedInputShape(
2397-
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
2404+
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
2405+
SmallVector<int64_t, 4> transposedInputShape;
2406+
for (int32_t dim : nchwToNhwcDims)
2407+
transposedInputShape.push_back(inputShape[dim]);
23982408
auto transposedInputType = RankedTensorType::get(
23992409
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
24002410
auto transposedInput =
@@ -2404,6 +2414,104 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24042414
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
24052415
.getResult();
24062416

2417+
if (transposed) {
2418+
if (groups != 1)
2419+
return rewriter.notifyMatchFailure(
2420+
op, "Unimplemented: grouped transposed convolution not supported by "
2421+
"TOSA");
2422+
if (dilation[0] != 1 || dilation[1] != 1)
2423+
return rewriter.notifyMatchFailure(
2424+
op, "Unimplemented: dilated transposed convolution not supported by "
2425+
"TOSA");
2426+
2427+
SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});
2428+
SmallVector<int64_t, 4> ohwiWeightShape;
2429+
for (int32_t dim : iohwToOhwi)
2430+
ohwiWeightShape.push_back(weightShape[dim]);
2431+
auto ohwiWeightType = RankedTensorType::get(
2432+
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
2433+
Value transformedWeight =
2434+
rewriter
2435+
.create<tosa::TransposeOp>(
2436+
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
2437+
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2438+
.getResult();
2439+
2440+
// TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
2441+
// Map from PyTorch's (padding, output_padding):
2442+
// out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
2443+
// Negative values are allowed and will be handled by the TOSA
2444+
// decomposition.
2445+
SmallVector<int64_t, 2> outPadding2D;
2446+
if (!matchPattern(adaptor.getOutputPadding(),
2447+
m_TorchListOfConstantInts(outPadding2D)))
2448+
return rewriter.notifyMatchFailure(
2449+
op, "non-const output_padding list unsupported for transposed conv");
2450+
2451+
int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0];
2452+
int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1];
2453+
int64_t outPadTop = outPadH / 2;
2454+
int64_t outPadBottom = outPadH - outPadTop;
2455+
int64_t outPadLeft = outPadW / 2;
2456+
int64_t outPadRight = outPadW - outPadLeft;
2457+
SmallVector<int64_t, 4> outPad(
2458+
{outPadTop, outPadBottom, outPadLeft, outPadRight});
2459+
2460+
// Result type is NHWC (we'll transpose back).
2461+
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
2462+
SmallVector<int64_t, 4> outNHWC;
2463+
for (int32_t dim : nchwToNhwcDims)
2464+
outNHWC.push_back(outNCHW[dim]);
2465+
auto transConvOpTy =
2466+
RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy);
2467+
2468+
// Zero-points.
2469+
auto zps = tosa::createZPsAsConst(rewriter, input, weight);
2470+
Value inputZp = zps.first ? zps.first
2471+
: tosa::createZeroPointTensor(
2472+
rewriter, op->getLoc(), inputElemTy, 0)
2473+
.value();
2474+
Value weightZp = zps.second ? zps.second
2475+
: tosa::createZeroPointTensor(
2476+
rewriter, op->getLoc(), weightElemTy, 0)
2477+
.value();
2478+
2479+
Value convTOut =
2480+
rewriter
2481+
.create<tosa::TransposeConv2DOp>(
2482+
op->getLoc(), getTypeConverter()->convertType(transConvOpTy),
2483+
transposedInput, transformedWeight, bias, inputZp, weightZp,
2484+
rewriter.getDenseI64ArrayAttr(outPad),
2485+
rewriter.getDenseI64ArrayAttr(stride), accType)
2486+
.getResult();
2487+
2488+
SmallVector<int64_t, 4> transposedOutputShape;
2489+
for (int32_t dim : nhwcToNchwDims)
2490+
transposedOutputShape.push_back(outNHWC[dim]);
2491+
auto transposedOutputType = RankedTensorType::get(
2492+
makeShapeLLVMCompatible(transposedOutputShape), biasElemTy);
2493+
Value transposedOutput =
2494+
rewriter
2495+
.create<tosa::TransposeOp>(
2496+
op->getLoc(),
2497+
getTypeConverter()->convertType(transposedOutputType), convTOut,
2498+
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
2499+
.getResult();
2500+
2501+
// Quantized rescale.
2502+
Value rescaledResult = transposedOutput;
2503+
if (isa<quant::QuantizedType>(inputElemTy)) {
2504+
rescaledResult = tosa::buildRescaleOpConvOutput(
2505+
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2506+
}
2507+
2508+
// Final cast to requested output type.
2509+
rewriter.replaceOp(
2510+
op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy)
2511+
.value()});
2512+
return success();
2513+
}
2514+
24072515
SmallVector<int64_t> transformedWeightShape;
24082516
RankedTensorType transformedWeightType;
24092517
Value transformedWeight;
@@ -2485,7 +2593,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24852593
if (remainderHDim != 0) {
24862594
if (remainderHDim > padding[1]) {
24872595
SmallVector<int64_t> startHSlice(inputTy.getRank(), 0);
2488-
SmallVector<int64_t> sizeHSlice(transposedInputShape);
2596+
SmallVector<int64_t, 4> sizeHSlice(transposedInputShape);
24892597
// TOSA uses NHWC, so we will slice dim 1 for Height value
24902598
sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]);
24912599
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
@@ -2579,7 +2687,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25792687
llvm_unreachable("Unhandled convolution type");
25802688
}
25812689

2582-
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
25832690
SmallVector<int64_t> transposedOutputShape(
25842691
{outputShape[0], outputShape[3], outputShape[1], outputShape[2]});
25852692
auto transposedOutputType = RankedTensorType::get(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3581,7 +3581,6 @@
35813581
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
35823582
"Conv_Transpose1dModule_basic",
35833583
"Conv_Transpose1dStaticModule_basic",
3584-
"Conv_Transpose2dStaticModule_basic",
35853584
"Conv_Transpose3dModule_basic",
35863585
"Conv_Transpose3dStaticModule_basic",
35873586
"IndexPutWithNoneAndBroadcastModule_basic",
@@ -3706,16 +3705,11 @@
37063705
"Conv3dWithValidPaddingModule_basic",
37073706
"ConvTbcModule_basic",
37083707
"ConvTranspose2DQInt8_basic",
3709-
"Conv_Transpose2dModule_basic",
37103708
"ConvolutionBackwardModule2DPadded_basic",
3711-
"ConvolutionBackwardModule2DStatic_basic",
37123709
"ConvolutionBackwardModule2DStrided_basic",
37133710
"ConvolutionBackwardModule2D_basic",
37143711
"ConvolutionModule2DGroups_basic",
37153712
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
3716-
"ConvolutionModule2DTransposeStridedStatic_basic",
3717-
"ConvolutionModule2DTransposeStrided_basic",
3718-
"ConvolutionModule2DTranspose_basic",
37193713
"ConvolutionModule2DGroupedTranspose_basic",
37203714
"ConvolutionModule3DGroups_basic",
37213715
"ConvolutionModule3DGroupsStrided_basic",

projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
# that depend on TOSA as well as TOSA-to-Standard.
3030
"tosa-to-arith",
3131
"tosa-to-scf",
32+
# Required for transposed convolution support (decomposes to conv ops).
33+
"tosa-optional-decompositions",
3234
# Named ops must be legalized prior to general tosa-to-linalg
3335
"tosa-to-linalg-named",
3436
# TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them
Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
1-
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file | FileCheck %s
22

3-
// The following test ensures that a tranposed convolution op is not
4-
// lowered in the torch-to-tosa conversion pass.
3+
// The lowering now legalizes transpose convolutions into the TOSA dialect.
4+
// Verify that we emit tosa.transpose_conv2d with the expected reshapes/
5+
// permutations.
56

7+
// CHECK-LABEL: func.func @forward
8+
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
9+
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32>
10+
// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32>
11+
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32>
12+
// CHECK: %[[TRANS_IN:.*]] = tosa.transpose %[[IN_TENSOR]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x64x1x100xf32>) -> tensor<1x1x100x64xf32>
13+
// CHECK: %[[W_OHWI:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 1, 2, 3, 0>} : (tensor<64x64x3x3xf32>) -> tensor<64x3x3x64xf32>
14+
// CHECK: %[[ZP0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
15+
// CHECK: %[[ZP1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
16+
// CHECK: %[[TCONV:.*]] = tosa.transpose_conv2d %[[TRANS_IN]], %[[W_OHWI]], %[[BIAS]], %[[ZP0]], %[[ZP1]] {acc_type = f32, out_pad = array<i64: 0, -1, 0, -1>, stride = array<i64: 2, 2>} : (tensor<1x1x100x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x200x64xf32>
17+
// CHECK: %[[TRANS_OUT:.*]] = tosa.transpose %[[TCONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x2x200x64xf32>) -> tensor<1x64x2x200xf32>
18+
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[TRANS_OUT]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32>
19+
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,64,2,200],f32>
20+
// CHECK: }
621
func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
722
%true = torch.constant.bool true
823
%int1 = torch.constant.int 1
@@ -11,7 +26,6 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[
1126
%bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
1227
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
1328
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
14-
// expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}}
1529
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>
1630
return %output : !torch.vtensor<[1,64,2,200],f32>
1731
}

0 commit comments

Comments
 (0)