@@ -2337,6 +2337,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23372337 auto bias = adaptor.getBias ();
23382338
23392339 if (isa<Torch::NoneType>(bias.getType ())) {
2340+ // ConvTranspose weights use IOHW; the helper expects OIHW, so swap
2341+ // dims 0/1 before we synthesize the bias.
23402342 SmallVector<int64_t , 4 > biasWeightShape =
23412343 transposed ? SmallVector<int64_t , 4 >{weightShape[1 ], weightShape[0 ],
23422344 weightShape[2 ], weightShape[3 ]}
@@ -2407,12 +2409,13 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24072409 transposedInputShape.push_back (inputShape[dim]);
24082410 auto transposedInputType = RankedTensorType::get (
24092411 makeShapeLLVMCompatible (transposedInputShape), inputElemTy);
2410- auto transposedInput =
2411- tosa::TransposeOp::create (
2412- rewriter, op->getLoc (),
2413- getTypeConverter ()->convertType (transposedInputType), input,
2414- rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
2415- .getResult ();
2412+ auto createTransposedInput = [&]() {
2413+ return rewriter
2414+ .create <tosa::TransposeOp>(
2415+ op->getLoc (), getTypeConverter ()->convertType (transposedInputType),
2416+ input, rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
2417+ .getResult ();
2418+ };
24162419
24172420 if (transposed) {
24182421 if (groups != 1 )
@@ -2425,17 +2428,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24252428 " TOSA" );
24262429
24272430 SmallVector<int32_t > iohwToOhwi ({1 , 2 , 3 , 0 });
2428- SmallVector<int64_t , 4 > ohwiWeightShape;
2429- for (int32_t dim : iohwToOhwi)
2430- ohwiWeightShape.push_back (weightShape[dim]);
2431- auto ohwiWeightType = RankedTensorType::get (
2432- makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2433- Value transformedWeight =
2434- rewriter
2435- .create <tosa::TransposeOp>(
2436- op->getLoc (), getTypeConverter ()->convertType (ohwiWeightType),
2437- weight, rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2438- .getResult ();
24392431
24402432 // TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
24412433 // Map from PyTorch's (padding, output_padding):
@@ -2457,6 +2449,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24572449 SmallVector<int64_t , 4 > outPad (
24582450 {outPadTop, outPadBottom, outPadLeft, outPadRight});
24592451
2452+ Value nhwcInput = createTransposedInput ();
2453+ SmallVector<int64_t , 4 > ohwiWeightShape;
2454+ for (int32_t dim : iohwToOhwi)
2455+ ohwiWeightShape.push_back (weightShape[dim]);
2456+ auto ohwiWeightType = RankedTensorType::get (
2457+ makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2458+ Value transformedWeight =
2459+ rewriter
2460+ .create <tosa::TransposeOp>(
2461+ op->getLoc (), getTypeConverter ()->convertType (ohwiWeightType),
2462+ weight, rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2463+ .getResult ();
2464+
24602465 // Result type is NHWC (we'll transpose back).
24612466 auto outNCHW = makeShapeTorchCompatible (outputTy.getShape ());
24622467 SmallVector<int64_t , 4 > outNHWC;
@@ -2480,7 +2485,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24802485 rewriter
24812486 .create <tosa::TransposeConv2DOp>(
24822487 op->getLoc (), getTypeConverter ()->convertType (transConvOpTy),
2483- transposedInput , transformedWeight, bias, inputZp, weightZp,
2488+ nhwcInput , transformedWeight, bias, inputZp, weightZp,
24842489 rewriter.getDenseI64ArrayAttr (outPad),
24852490 rewriter.getDenseI64ArrayAttr (stride), accType)
24862491 .getResult ();
@@ -2535,6 +2540,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25352540 SmallVector<int32_t > transposedDims ({2 , 3 , 0 , 1 });
25362541 SmallVector<int64_t > transposedWeightShape = {
25372542 weightShape[2 ], weightShape[3 ], weightShape[0 ], weightShape[1 ]};
2543+
2544+ // reshape: HWO(I/G) -> HWIM
2545+ outputCDim = makeShapeTorchCompatible (outputTy.getShape ())[1 ];
2546+ if (outputCDim == kUnknownSize ) {
2547+ return rewriter.notifyMatchFailure (
2548+ op, " number of output channels must be statically known for "
2549+ " depthwise convolutions" );
2550+ }
2551+
25382552 auto transposedWeightType = RankedTensorType::get (
25392553 makeShapeLLVMCompatible (transposedWeightShape), weightElemTy);
25402554 auto transposedWeight =
@@ -2544,13 +2558,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25442558 rewriter.getDenseI32ArrayAttr (transposedDims))
25452559 .getResult ();
25462560
2547- // reshape: HWO(I/G) -> HWIM
2548- outputCDim = makeShapeTorchCompatible (outputTy.getShape ())[1 ];
2549- if (outputCDim == kUnknownSize ) {
2550- return rewriter.notifyMatchFailure (
2551- op, " number of output channels must be statically known for "
2552- " depthwise convolutions" );
2553- }
25542561 transformedWeightShape = {
25552562 transposedWeightShape[0 ],
25562563 transposedWeightShape[1 ],
@@ -2571,6 +2578,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25712578 llvm_unreachable (" Unhandled convolution type" );
25722579 }
25732580
2581+ Value transposedInput = createTransposedInput ();
2582+
25742583 int64_t outputHDim, outputWDim;
25752584 int64_t inputHDim = inputShape[2 ];
25762585 int64_t inputWDim = inputShape[3 ];
0 commit comments