Skip to content

Commit ca7cb28

Browse files
committed
[TOSA] Defer input transpose until guards pass
Lazily create the NHWC input transpose so we emit it only once the failure guards in the transposed and depthwise convolution rewrite succeed. Change-Id: Ia362deda898794397107f6da3c44cd89f219f58f
1 parent 3b21d94 commit ca7cb28

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,6 +2337,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23372337
auto bias = adaptor.getBias();
23382338

23392339
if (isa<Torch::NoneType>(bias.getType())) {
2340+
// ConvTranspose weights use IOHW; the helper expects OIHW, so swap
2341+
// dims 0/1 before we synthesize the bias.
23402342
SmallVector<int64_t, 4> biasWeightShape =
23412343
transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0],
23422344
weightShape[2], weightShape[3]}
@@ -2407,12 +2409,13 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24072409
transposedInputShape.push_back(inputShape[dim]);
24082410
auto transposedInputType = RankedTensorType::get(
24092411
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
2410-
auto transposedInput =
2411-
tosa::TransposeOp::create(
2412-
rewriter, op->getLoc(),
2413-
getTypeConverter()->convertType(transposedInputType), input,
2414-
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
2415-
.getResult();
2412+
auto createTransposedInput = [&]() {
2413+
return rewriter
2414+
.create<tosa::TransposeOp>(
2415+
op->getLoc(), getTypeConverter()->convertType(transposedInputType),
2416+
input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
2417+
.getResult();
2418+
};
24162419

24172420
if (transposed) {
24182421
if (groups != 1)
@@ -2425,17 +2428,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24252428
"TOSA");
24262429

24272430
SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});
2428-
SmallVector<int64_t, 4> ohwiWeightShape;
2429-
for (int32_t dim : iohwToOhwi)
2430-
ohwiWeightShape.push_back(weightShape[dim]);
2431-
auto ohwiWeightType = RankedTensorType::get(
2432-
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
2433-
Value transformedWeight =
2434-
rewriter
2435-
.create<tosa::TransposeOp>(
2436-
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
2437-
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2438-
.getResult();
24392431

24402432
// TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
24412433
// Map from PyTorch's (padding, output_padding):
@@ -2457,6 +2449,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24572449
SmallVector<int64_t, 4> outPad(
24582450
{outPadTop, outPadBottom, outPadLeft, outPadRight});
24592451

2452+
Value nhwcInput = createTransposedInput();
2453+
SmallVector<int64_t, 4> ohwiWeightShape;
2454+
for (int32_t dim : iohwToOhwi)
2455+
ohwiWeightShape.push_back(weightShape[dim]);
2456+
auto ohwiWeightType = RankedTensorType::get(
2457+
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
2458+
Value transformedWeight =
2459+
rewriter
2460+
.create<tosa::TransposeOp>(
2461+
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
2462+
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2463+
.getResult();
2464+
24602465
// Result type is NHWC (we'll transpose back).
24612466
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
24622467
SmallVector<int64_t, 4> outNHWC;
@@ -2480,7 +2485,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24802485
rewriter
24812486
.create<tosa::TransposeConv2DOp>(
24822487
op->getLoc(), getTypeConverter()->convertType(transConvOpTy),
2483-
transposedInput, transformedWeight, bias, inputZp, weightZp,
2488+
nhwcInput, transformedWeight, bias, inputZp, weightZp,
24842489
rewriter.getDenseI64ArrayAttr(outPad),
24852490
rewriter.getDenseI64ArrayAttr(stride), accType)
24862491
.getResult();
@@ -2535,6 +2540,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25352540
SmallVector<int32_t> transposedDims({2, 3, 0, 1});
25362541
SmallVector<int64_t> transposedWeightShape = {
25372542
weightShape[2], weightShape[3], weightShape[0], weightShape[1]};
2543+
2544+
// reshape: HWO(I/G) -> HWIM
2545+
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
2546+
if (outputCDim == kUnknownSize) {
2547+
return rewriter.notifyMatchFailure(
2548+
op, "number of output channels must be statically known for "
2549+
"depthwise convolutions");
2550+
}
2551+
25382552
auto transposedWeightType = RankedTensorType::get(
25392553
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
25402554
auto transposedWeight =
@@ -2544,13 +2558,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25442558
rewriter.getDenseI32ArrayAttr(transposedDims))
25452559
.getResult();
25462560

2547-
// reshape: HWO(I/G) -> HWIM
2548-
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
2549-
if (outputCDim == kUnknownSize) {
2550-
return rewriter.notifyMatchFailure(
2551-
op, "number of output channels must be statically known for "
2552-
"depthwise convolutions");
2553-
}
25542561
transformedWeightShape = {
25552562
transposedWeightShape[0],
25562563
transposedWeightShape[1],
@@ -2571,6 +2578,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25712578
llvm_unreachable("Unhandled convolution type");
25722579
}
25732580

2581+
Value transposedInput = createTransposedInput();
2582+
25742583
int64_t outputHDim, outputWDim;
25752584
int64_t inputHDim = inputShape[2];
25762585
int64_t inputWDim = inputShape[3];

0 commit comments

Comments
 (0)