@@ -1673,6 +1673,139 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
1673
1673
1674
1674
} // namespace
1675
1675
1676
+ namespace {
1677
+ class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1678
+ public:
1679
+ using OpConversionPattern::OpConversionPattern;
1680
+ LogicalResult
1681
+ matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1682
+ ConversionPatternRewriter &rewriter) const override {
1683
+
1684
+ Location loc = op->getLoc ();
1685
+ Value lhs = adaptor.getSelf ();
1686
+ Value rhs = op->getOperand (1 );
1687
+
1688
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1689
+ return failure ();
1690
+ }
1691
+ auto lhsType = cast<RankedTensorType>(lhs.getType ());
1692
+ auto rhsType = cast<RankedTensorType>(rhs.getType ());
1693
+
1694
+ auto lhsTorchType = cast<ValueTensorType>(op.getSelf ().getType ());
1695
+ auto rhsTorchType = cast<ValueTensorType>(op.getOperand (1 ).getType ());
1696
+
1697
+ // Get the rank of both matrix.
1698
+ unsigned lhsRank = lhsType.getRank ();
1699
+ unsigned rhsRank = rhsType.getRank ();
1700
+
1701
+ Value lhsZeroPoint, rhsZeroPoint;
1702
+ getZeroPoint (op.getSelf (), lhsZeroPoint);
1703
+ getZeroPoint (op.getOperand (1 ), rhsZeroPoint);
1704
+
1705
+ if (static_cast <bool >(lhsZeroPoint) != static_cast <bool >(rhsZeroPoint)) {
1706
+ return rewriter.notifyMatchFailure (
1707
+ op, " unsupported: aten.outer with mixed quantization" );
1708
+ }
1709
+
1710
+ bool isUnsigned = torch_to_linalg::isUnsignedTorchType (lhsTorchType);
1711
+ bool isUnsignedR = torch_to_linalg::isUnsignedTorchType (rhsTorchType);
1712
+
1713
+ if (!lhsZeroPoint && lhsTorchType.getDtype () != rhsTorchType.getDtype ()) {
1714
+ // Allows quantized types to mismatch
1715
+ return rewriter.notifyMatchFailure (
1716
+ op, " unsupported: aten.outer with different input element types" );
1717
+ }
1718
+
1719
+ Type newResultType = getTypeConverter ()->convertType (op.getType ());
1720
+ auto resultType = cast<RankedTensorType>(newResultType);
1721
+ Type elementType = resultType.getElementType ();
1722
+
1723
+ // Quantized case
1724
+ if (lhsZeroPoint) {
1725
+ // get each zero point ready to pass to a quantized_matmul
1726
+ lhsZeroPoint = typeConverter->materializeTargetConversion (
1727
+ rewriter, loc,
1728
+ getTypeConverter ()->convertType (lhsZeroPoint.getType ()),
1729
+ lhsZeroPoint);
1730
+ rhsZeroPoint = typeConverter->materializeTargetConversion (
1731
+ rewriter, loc,
1732
+ getTypeConverter ()->convertType (rhsZeroPoint.getType ()),
1733
+ rhsZeroPoint);
1734
+ lhsZeroPoint = rewriter.create <arith::TruncIOp>(
1735
+ loc, rewriter.getI32Type (), lhsZeroPoint);
1736
+ rhsZeroPoint = rewriter.create <arith::TruncIOp>(
1737
+ loc, rewriter.getI32Type (), rhsZeroPoint);
1738
+
1739
+ // change uint8 quantization -> int8 quantization
1740
+ int64_t numBits =
1741
+ cast<mlir::IntegerType>(lhsType.getElementType ()).getWidth ();
1742
+ signShift (rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1743
+ numBits = cast<mlir::IntegerType>(rhsType.getElementType ()).getWidth ();
1744
+ signShift (rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1745
+
1746
+ if (lhsRank == 1 && rhsRank == 1 ) {
1747
+ int64_t lhsDim = lhsType.getShape ()[0 ];
1748
+ int64_t rhsDim = rhsType.getShape ()[0 ];
1749
+
1750
+ // Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1751
+ auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1752
+ auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1753
+ SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1754
+ lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1755
+ rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1756
+
1757
+ // Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1758
+ Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1759
+ Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1760
+ Value zeroTensor = createZeroInitTensor (rewriter, loc,
1761
+ ValueRange{lhsDimVal, rhsDimVal},
1762
+ elementType);
1763
+
1764
+ // Use the quantized version of matmul.
1765
+ Value outerProd = rewriter.create <linalg::QuantizedMatmulOp>(
1766
+ loc, zeroTensor.getType (),
1767
+ ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1768
+ zeroTensor).getResult (0 );
1769
+
1770
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1771
+ return success ();
1772
+ }
1773
+ return rewriter.notifyMatchFailure (op, " unsupported: quantized aten.outer op case" );
1774
+ }
1775
+
1776
+
1777
+ // Non Quantized Outter Product
1778
+ if (lhsRank == 1 && rhsRank == 1 ) {
1779
+ int64_t lhsDim = lhsType.getShape ()[0 ];
1780
+ int64_t rhsDim = rhsType.getShape ()[0 ];
1781
+
1782
+ // Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1783
+ auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1784
+ auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1785
+ SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1786
+ lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1787
+ rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1788
+
1789
+ // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1790
+ Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1791
+ Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1792
+ Value zeroTensor = createZeroInitTensor (rewriter, loc,
1793
+ ValueRange{lhsDimVal, rhsDimVal},
1794
+ elementType);
1795
+
1796
+ // Use linalg::MatmulOp to compute the outer product.
1797
+ Value outerProd = rewriter.create <linalg::MatmulOp>(
1798
+ loc, zeroTensor.getType (), ValueRange{lhs, rhs}, zeroTensor).getResult (0 );
1799
+
1800
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1801
+ return success ();
1802
+ }
1803
+
1804
+ return failure ();
1805
+ }
1806
+ };
1807
+ } // namespace
1808
+
1676
1809
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality (
1677
1810
TypeConverter &typeConverter, RewritePatternSet &patterns,
1678
1811
ConversionTarget &target) {
@@ -1689,4 +1822,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
1689
1822
patterns.add <ConvertAtenConvolutionOp>(typeConverter, context);
1690
1823
target.addIllegalOp <AtenFftRfftOp>();
1691
1824
patterns.add <ConvertAtenFftRfftOp>(typeConverter, context);
1825
+ target.addIllegalOp <AtenOuterOp>();
1826
+ patterns.add <ConvertAtenOuterOp>(typeConverter, context);
1692
1827
}
0 commit comments