@@ -2306,9 +2306,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23062306 if (!matchPattern (op.getTransposed (), m_TorchConstantBool (&transposed)))
23072307 return rewriter.notifyMatchFailure (
23082308 op, " Unimplemented: non-constant value for transposed not supported" );
2309- if (transposed)
2310- return rewriter.notifyMatchFailure (
2311- op, " Unimplemented: transposed convolution not supported" );
23122309
23132310 auto input = adaptor.getInput ();
23142311 auto weight = adaptor.getWeight ();
@@ -2340,12 +2337,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23402337 auto bias = adaptor.getBias ();
23412338
23422339 if (isa<Torch::NoneType>(bias.getType ())) {
2343- auto bias_result = tosa::getConvBiasForNoneType (op, rewriter, inputElemTy,
2344- outputElemTy, weightShape);
2345- if (failed (bias_result))
2340+ SmallVector<int64_t , 4 > biasWeightShape =
2341+ transposed ? SmallVector<int64_t , 4 >{weightShape[1 ], weightShape[0 ],
2342+ weightShape[2 ], weightShape[3 ]}
2343+ : weightShape;
2344+
2345+ auto biasResult = tosa::getConvBiasForNoneType (
2346+ op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
2347+ if (failed (biasResult))
23462348 return rewriter.notifyMatchFailure (
23472349 op, " Failed to create bias tensor for none type." );
2348- bias = bias_result .value ();
2350+ bias = biasResult .value ();
23492351 } else {
23502352 if (!isa<RankedTensorType>(bias.getType ()))
23512353 return rewriter.notifyMatchFailure (
@@ -2372,8 +2374,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23722374 m_TorchListOfConstantInts (padding_2d)))
23732375 return rewriter.notifyMatchFailure (op,
23742376 " non-const padding list unsupported" );
2375- // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
2376- // padding {height, width}. The Torch OFM computation uses 2*pad in each
2377+ // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
2378+ // padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23772379 // spatial direction, implying the same top=bottom=height and left=right=width
23782380 // values for TOSA.
23792381 SmallVector<int64_t > padding (
@@ -2390,11 +2392,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23902392 return rewriter.notifyMatchFailure (
23912393 op, " failed to get accumulator type for convolution ops" );
23922394
2395+ // Weight layout reference:
2396+ // Conv : PyTorch OIHW -> TOSA OHWI
2397+ // Depthwise : PyTorch OIHW* -> TOSA HWIM
2398+ // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier)
2399+ // Grouped : PyTorch O(I/G)HW -> N/A
2400+ // Transposed : PyTorch IOHW -> TOSA OHWI
23932401 // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
23942402 // Perform the necessary transformations.
23952403 SmallVector<int32_t > nchwToNhwcDims ({0 , 2 , 3 , 1 });
2396- SmallVector<int64_t > transposedInputShape (
2397- {inputShape[0 ], inputShape[2 ], inputShape[3 ], inputShape[1 ]});
2404+ SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
2405+ SmallVector<int64_t , 4 > transposedInputShape;
2406+ for (int32_t dim : nchwToNhwcDims)
2407+ transposedInputShape.push_back (inputShape[dim]);
23982408 auto transposedInputType = RankedTensorType::get (
23992409 makeShapeLLVMCompatible (transposedInputShape), inputElemTy);
24002410 auto transposedInput =
@@ -2404,6 +2414,104 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24042414 rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
24052415 .getResult ();
24062416
2417+ if (transposed) {
2418+ if (groups != 1 )
2419+ return rewriter.notifyMatchFailure (
2420+ op, " Unimplemented: grouped transposed convolution not supported by "
2421+ " TOSA" );
2422+ if (dilation[0 ] != 1 || dilation[1 ] != 1 )
2423+ return rewriter.notifyMatchFailure (
2424+ op, " Unimplemented: dilated transposed convolution not supported by "
2425+ " TOSA" );
2426+
2427+ SmallVector<int32_t > iohwToOhwi ({1 , 2 , 3 , 0 });
2428+ SmallVector<int64_t , 4 > ohwiWeightShape;
2429+ for (int32_t dim : iohwToOhwi)
2430+ ohwiWeightShape.push_back (weightShape[dim]);
2431+ auto ohwiWeightType = RankedTensorType::get (
2432+ makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2433+ Value transformedWeight =
2434+ rewriter
2435+ .create <tosa::TransposeOp>(
2436+ op->getLoc (), getTypeConverter ()->convertType (ohwiWeightType),
2437+ weight, rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2438+ .getResult ();
2439+
2440+ // TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
2441+ // Map from PyTorch's (padding, output_padding):
2442+ // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
2443+ // Negative values are allowed and will be handled by the TOSA
2444+ // decomposition.
2445+ SmallVector<int64_t , 2 > outPadding2D;
2446+ if (!matchPattern (adaptor.getOutputPadding (),
2447+ m_TorchListOfConstantInts (outPadding2D)))
2448+ return rewriter.notifyMatchFailure (
2449+ op, " non-const output_padding list unsupported for transposed conv" );
2450+
2451+ int64_t outPadH = outPadding2D[0 ] - 2 * padding_2d[0 ];
2452+ int64_t outPadW = outPadding2D[1 ] - 2 * padding_2d[1 ];
2453+ int64_t outPadTop = outPadH / 2 ;
2454+ int64_t outPadBottom = outPadH - outPadTop;
2455+ int64_t outPadLeft = outPadW / 2 ;
2456+ int64_t outPadRight = outPadW - outPadLeft;
2457+ SmallVector<int64_t , 4 > outPad (
2458+ {outPadTop, outPadBottom, outPadLeft, outPadRight});
2459+
2460+ // Result type is NHWC (we'll transpose back).
2461+ auto outNCHW = makeShapeTorchCompatible (outputTy.getShape ());
2462+ SmallVector<int64_t , 4 > outNHWC;
2463+ for (int32_t dim : nchwToNhwcDims)
2464+ outNHWC.push_back (outNCHW[dim]);
2465+ auto transConvOpTy =
2466+ RankedTensorType::get (makeShapeLLVMCompatible (outNHWC), biasElemTy);
2467+
2468+ // Zero-points.
2469+ auto zps = tosa::createZPsAsConst (rewriter, input, weight);
2470+ Value inputZp = zps.first ? zps.first
2471+ : tosa::createZeroPointTensor (
2472+ rewriter, op->getLoc (), inputElemTy, 0 )
2473+ .value ();
2474+ Value weightZp = zps.second ? zps.second
2475+ : tosa::createZeroPointTensor (
2476+ rewriter, op->getLoc (), weightElemTy, 0 )
2477+ .value ();
2478+
2479+ Value convTOut =
2480+ rewriter
2481+ .create <tosa::TransposeConv2DOp>(
2482+ op->getLoc (), getTypeConverter ()->convertType (transConvOpTy),
2483+ transposedInput, transformedWeight, bias, inputZp, weightZp,
2484+ rewriter.getDenseI64ArrayAttr (outPad),
2485+ rewriter.getDenseI64ArrayAttr (stride), accType)
2486+ .getResult ();
2487+
2488+ SmallVector<int64_t , 4 > transposedOutputShape;
2489+ for (int32_t dim : nhwcToNchwDims)
2490+ transposedOutputShape.push_back (outNHWC[dim]);
2491+ auto transposedOutputType = RankedTensorType::get (
2492+ makeShapeLLVMCompatible (transposedOutputShape), biasElemTy);
2493+ Value transposedOutput =
2494+ rewriter
2495+ .create <tosa::TransposeOp>(
2496+ op->getLoc (),
2497+ getTypeConverter ()->convertType (transposedOutputType), convTOut,
2498+ rewriter.getDenseI32ArrayAttr (nhwcToNchwDims))
2499+ .getResult ();
2500+
2501+ // Quantized rescale.
2502+ Value rescaledResult = transposedOutput;
2503+ if (isa<quant::QuantizedType>(inputElemTy)) {
2504+ rescaledResult = tosa::buildRescaleOpConvOutput (
2505+ rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2506+ }
2507+
2508+ // Final cast to requested output type.
2509+ rewriter.replaceOp (
2510+ op, {tosa::tosaCastTensorToType (rewriter, rescaledResult, outputTy)
2511+ .value ()});
2512+ return success ();
2513+ }
2514+
24072515 SmallVector<int64_t > transformedWeightShape;
24082516 RankedTensorType transformedWeightType;
24092517 Value transformedWeight;
@@ -2485,7 +2593,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24852593 if (remainderHDim != 0 ) {
24862594 if (remainderHDim > padding[1 ]) {
24872595 SmallVector<int64_t > startHSlice (inputTy.getRank (), 0 );
2488- SmallVector<int64_t > sizeHSlice (transposedInputShape);
2596+ SmallVector<int64_t , 4 > sizeHSlice (transposedInputShape);
24892597 // TOSA uses NHWC, so we will slice dim 1 for Height value
24902598 sizeHSlice[1 ] = inputHDim - (remainderHDim - padding[1 ]);
24912599 transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
@@ -2579,7 +2687,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25792687 llvm_unreachable (" Unhandled convolution type" );
25802688 }
25812689
2582- SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
25832690 SmallVector<int64_t > transposedOutputShape (
25842691 {outputShape[0 ], outputShape[3 ], outputShape[1 ], outputShape[2 ]});
25852692 auto transposedOutputType = RankedTensorType::get (
0 commit comments