Skip to content

Commit 2348344

Browse files
rootroot
authored andcommitted
Initial implementation of AtenOuterOp
- Defined the op in Linear.cpp TODO: - Testing, and perhaps add some test(-s) inside torch-mlir?
1 parent b6c4e87 commit 2348344

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,139 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
16731673

16741674
} // namespace
16751675

1676+
namespace {
1677+
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1678+
public:
1679+
using OpConversionPattern::OpConversionPattern;
1680+
LogicalResult
1681+
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1682+
ConversionPatternRewriter &rewriter) const override {
1683+
1684+
Location loc = op->getLoc();
1685+
Value lhs = adaptor.getSelf();
1686+
Value rhs = op->getOperand(1);
1687+
1688+
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
1689+
return failure();
1690+
}
1691+
auto lhsType = cast<RankedTensorType>(lhs.getType());
1692+
auto rhsType = cast<RankedTensorType>(rhs.getType());
1693+
1694+
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
1695+
auto rhsTorchType = cast<ValueTensorType>(op.getOperand(1).getType());
1696+
1697+
// Get the rank of both matrix.
1698+
unsigned lhsRank = lhsType.getRank();
1699+
unsigned rhsRank = rhsType.getRank();
1700+
1701+
Value lhsZeroPoint, rhsZeroPoint;
1702+
getZeroPoint(op.getSelf(), lhsZeroPoint);
1703+
getZeroPoint(op.getOperand(1), rhsZeroPoint);
1704+
1705+
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
1706+
return rewriter.notifyMatchFailure(
1707+
op, "unsupported: aten.outer with mixed quantization");
1708+
}
1709+
1710+
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
1711+
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);
1712+
1713+
if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
1714+
// Allows quantized types to mismatch
1715+
return rewriter.notifyMatchFailure(
1716+
op, "unsupported: aten.outer with different input element types");
1717+
}
1718+
1719+
Type newResultType = getTypeConverter()->convertType(op.getType());
1720+
auto resultType = cast<RankedTensorType>(newResultType);
1721+
Type elementType = resultType.getElementType();
1722+
1723+
// Quantized case
1724+
if (lhsZeroPoint) {
1725+
// get each zero point ready to pass to a quantized_matmul
1726+
lhsZeroPoint = typeConverter->materializeTargetConversion(
1727+
rewriter, loc,
1728+
getTypeConverter()->convertType(lhsZeroPoint.getType()),
1729+
lhsZeroPoint);
1730+
rhsZeroPoint = typeConverter->materializeTargetConversion(
1731+
rewriter, loc,
1732+
getTypeConverter()->convertType(rhsZeroPoint.getType()),
1733+
rhsZeroPoint);
1734+
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
1735+
loc, rewriter.getI32Type(), lhsZeroPoint);
1736+
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
1737+
loc, rewriter.getI32Type(), rhsZeroPoint);
1738+
1739+
// change uint8 quantization -> int8 quantization
1740+
int64_t numBits =
1741+
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
1742+
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1743+
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
1744+
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1745+
1746+
if (lhsRank == 1 && rhsRank == 1) {
1747+
int64_t lhsDim = lhsType.getShape()[0];
1748+
int64_t rhsDim = rhsType.getShape()[0];
1749+
1750+
// Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1751+
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1752+
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1753+
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1754+
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1755+
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1756+
1757+
// Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1758+
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1759+
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1760+
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1761+
ValueRange{lhsDimVal, rhsDimVal},
1762+
elementType);
1763+
1764+
// Use the quantized version of matmul.
1765+
Value outerProd = rewriter.create<linalg::QuantizedMatmulOp>(
1766+
loc, zeroTensor.getType(),
1767+
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1768+
zeroTensor).getResult(0);
1769+
1770+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1771+
return success();
1772+
}
1773+
return rewriter.notifyMatchFailure(op, "unsupported: quantized aten.outer op case");
1774+
}
1775+
1776+
1777+
// Non Quantized Outter Product
1778+
if (lhsRank == 1 && rhsRank == 1) {
1779+
int64_t lhsDim = lhsType.getShape()[0];
1780+
int64_t rhsDim = rhsType.getShape()[0];
1781+
1782+
// Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1783+
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1784+
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1785+
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1786+
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1787+
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1788+
1789+
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1790+
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1791+
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1792+
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1793+
ValueRange{lhsDimVal, rhsDimVal},
1794+
elementType);
1795+
1796+
// Use linalg::MatmulOp to compute the outer product.
1797+
Value outerProd = rewriter.create<linalg::MatmulOp>(
1798+
loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor).getResult(0);
1799+
1800+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1801+
return success();
1802+
}
1803+
1804+
return failure();
1805+
}
1806+
};
1807+
} // namespace
1808+
16761809
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
16771810
TypeConverter &typeConverter, RewritePatternSet &patterns,
16781811
ConversionTarget &target) {
@@ -1689,4 +1822,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
16891822
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
16901823
target.addIllegalOp<AtenFftRfftOp>();
16911824
patterns.add<ConvertAtenFftRfftOp>(typeConverter, context);
1825+
target.addIllegalOp<AtenOuterOp>();
1826+
patterns.add<ConvertAtenOuterOp>(typeConverter, context);
16921827
}

0 commit comments

Comments
 (0)