Skip to content

Commit 8cef853

Browse files
committed
[TOSA] Update lit tests
Change-Id: I7eb03f6173a7779248a4ed14e1d9d1016b94ec79
1 parent ca7cb28 commit 8cef853

File tree

2 files changed

+45
-47
lines changed

2 files changed

+45
-47
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,10 +2410,10 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24102410
auto transposedInputType = RankedTensorType::get(
24112411
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
24122412
auto createTransposedInput = [&]() {
2413-
return rewriter
2414-
.create<tosa::TransposeOp>(
2415-
op->getLoc(), getTypeConverter()->convertType(transposedInputType),
2416-
input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
2413+
return tosa::TransposeOp::create(
2414+
rewriter, op->getLoc(),
2415+
getTypeConverter()->convertType(transposedInputType), input,
2416+
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
24172417
.getResult();
24182418
};
24192419

@@ -2456,10 +2456,10 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24562456
auto ohwiWeightType = RankedTensorType::get(
24572457
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
24582458
Value transformedWeight =
2459-
rewriter
2460-
.create<tosa::TransposeOp>(
2461-
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
2462-
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2459+
tosa::TransposeOp::create(
2460+
rewriter, op->getLoc(),
2461+
getTypeConverter()->convertType(ohwiWeightType), weight,
2462+
rewriter.getDenseI32ArrayAttr(iohwToOhwi))
24632463
.getResult();
24642464

24652465
// Result type is NHWC (we'll transpose back).
@@ -2481,26 +2481,24 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24812481
rewriter, op->getLoc(), weightElemTy, 0)
24822482
.value();
24832483

2484-
Value convTOut =
2485-
rewriter
2486-
.create<tosa::TransposeConv2DOp>(
2487-
op->getLoc(), getTypeConverter()->convertType(transConvOpTy),
2488-
nhwcInput, transformedWeight, bias, inputZp, weightZp,
2489-
rewriter.getDenseI64ArrayAttr(outPad),
2490-
rewriter.getDenseI64ArrayAttr(stride), accType)
2491-
.getResult();
2484+
Value convTOut = tosa::TransposeConv2DOp::create(
2485+
rewriter, op->getLoc(),
2486+
getTypeConverter()->convertType(transConvOpTy),
2487+
nhwcInput, transformedWeight, bias, inputZp, weightZp,
2488+
rewriter.getDenseI64ArrayAttr(outPad),
2489+
rewriter.getDenseI64ArrayAttr(stride), accType)
2490+
.getResult();
24922491

24932492
SmallVector<int64_t, 4> transposedOutputShape;
24942493
for (int32_t dim : nhwcToNchwDims)
24952494
transposedOutputShape.push_back(outNHWC[dim]);
24962495
auto transposedOutputType = RankedTensorType::get(
24972496
makeShapeLLVMCompatible(transposedOutputShape), biasElemTy);
24982497
Value transposedOutput =
2499-
rewriter
2500-
.create<tosa::TransposeOp>(
2501-
op->getLoc(),
2502-
getTypeConverter()->convertType(transposedOutputType), convTOut,
2503-
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
2498+
tosa::TransposeOp::create(
2499+
rewriter, op->getLoc(),
2500+
getTypeConverter()->convertType(transposedOutputType), convTOut,
2501+
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
25042502
.getResult();
25052503

25062504
// Quantized rescale.

0 commit comments

Comments
 (0)