|
25 | 25 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
26 | 26 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
27 | 27 | #include "mlir/IR/AffineExpr.h"
|
| 28 | +#include "mlir/IR/AffineMap.h" |
28 | 29 | #include "mlir/IR/Builders.h"
|
29 | 30 | #include "mlir/IR/BuiltinTypeInterfaces.h"
|
30 | 31 | #include "mlir/IR/BuiltinTypes.h"
|
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
|
1681 | 1682 | return write;
|
1682 | 1683 |
|
1683 | 1684 | // Compute the mask and mask the write Op.
|
1684 |
| - auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type()); |
| 1685 | + auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(), |
| 1686 | + vecToStoreType.getScalableDims()); |
1685 | 1687 |
|
1686 | 1688 | SmallVector<OpFoldResult> destSizes =
|
1687 |
| - tensor::getMixedSizes(builder, loc, dest); |
| 1689 | + isa<MemRefType>(dest.getType()) |
| 1690 | + ? memref::getMixedSizes(builder, loc, dest) |
| 1691 | + : tensor::getMixedSizes(builder, loc, dest); |
1688 | 1692 | SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
|
1689 | 1693 | destSizes.end());
|
1690 | 1694 |
|
@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
|
2093 | 2097 | return success();
|
2094 | 2098 | }
|
2095 | 2099 |
|
| 2100 | +/// Vectorize a named linalg contraction op into: |
| 2101 | +/// vector::TransferReadOp - Reads vectors from the operands |
| 2102 | +/// vector::ContractionOp - Performs contraction |
| 2103 | +/// vector::TransferWriteOp - Write the result vector back to the |
| 2104 | +/// destination |
| 2105 | +/// The operands shapes are preserved and loaded directly into vectors. |
| 2106 | +/// Any further permutations or numerical casting remain within contraction. |
| 2107 | +static LogicalResult |
| 2108 | +vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, |
| 2109 | + LinalgOp linalgOp, |
| 2110 | + SmallVectorImpl<Value> &newResults) { |
| 2111 | + Location loc = linalgOp.getLoc(); |
| 2112 | + MLIRContext *ctx = linalgOp.getContext(); |
| 2113 | + |
| 2114 | + if (!isa<ContractionOpInterface>(linalgOp.getOperation())) |
| 2115 | + return failure(); |
| 2116 | + |
| 2117 | + OpOperand *outOperand = linalgOp.getDpsInitOperand(0); |
| 2118 | + Operation *reduceOp = matchLinalgReduction(outOperand); |
| 2119 | + auto maybeKind = getCombinerOpKind(reduceOp); |
| 2120 | + if (!maybeKind) |
| 2121 | + return failure(); |
| 2122 | + |
| 2123 | + // Check that all dimensions are present in the input operands. |
| 2124 | + // Arbitrary broadcasts are not supported by the vector contraction. |
| 2125 | + // Broadcasts are expected to be materialized before vectorization. |
| 2126 | + AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0]; |
| 2127 | + AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1]; |
| 2128 | + if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) |
| 2129 | + return failure(); |
| 2130 | + |
| 2131 | + // Load operands. |
| 2132 | + SmallVector<Value> vecOperands; |
| 2133 | + for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
| 2134 | + // The operand vector shape is computed by mapping the canonical vector |
| 2135 | + // shape to the operand's domain. Further permutations are left as a part of |
| 2136 | + // the contraction. |
| 2137 | + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); |
| 2138 | + AffineMap readMap = AffineMap::getMultiDimIdentityMap( |
| 2139 | + indexingMap.getNumResults(), rewriter.getContext()); |
| 2140 | + Type elemType = getElementTypeOrSelf(opOperand.get()); |
| 2141 | + VectorType readType = |
| 2142 | + state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); |
| 2143 | + |
| 2144 | + Value read = mlir::vector::createReadOrMaskedRead( |
| 2145 | + rewriter, loc, opOperand.get(), readType.getShape(), |
| 2146 | + /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), |
| 2147 | + /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); |
| 2148 | + vecOperands.push_back(read); |
| 2149 | + } |
| 2150 | + |
| 2151 | + // Remap iterators from linalg to vector. |
| 2152 | + SmallVector<Attribute> iterAttrs; |
| 2153 | + auto iterators = linalgOp.getIteratorTypesArray(); |
| 2154 | + for (utils::IteratorType iter : iterators) { |
| 2155 | + auto vecIter = iter == utils::IteratorType::parallel |
| 2156 | + ? vector::IteratorType::parallel |
| 2157 | + : vector::IteratorType::reduction; |
| 2158 | + iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter)); |
| 2159 | + } |
| 2160 | + |
| 2161 | + // Create contraction. |
| 2162 | + Value contractOp = rewriter.create<vector::ContractionOp>( |
| 2163 | + loc, /*lhs=*/vecOperands[0], |
| 2164 | + /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2], |
| 2165 | + linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind); |
| 2166 | + |
| 2167 | + // Store result. |
| 2168 | + Operation *write = |
| 2169 | + createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get()); |
| 2170 | + |
| 2171 | + // Finalize. |
| 2172 | + if (!write->getResults().empty()) |
| 2173 | + newResults.push_back(write->getResult(0)); |
| 2174 | + |
| 2175 | + return success(); |
| 2176 | +} |
| 2177 | + |
2096 | 2178 | namespace {
|
2097 | 2179 | enum class ConvOperationKind { Conv, Pool };
|
2098 | 2180 | } // namespace
|
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
|
2528 | 2610 | tensor::InsertSliceOp>(op);
|
2529 | 2611 | }
|
2530 | 2612 |
|
2531 |
| -FailureOr<VectorizationResult> |
2532 |
| -mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, |
2533 |
| - ArrayRef<int64_t> inputVectorSizes, |
2534 |
| - ArrayRef<bool> inputScalableVecDims, |
2535 |
| - bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { |
| 2613 | +FailureOr<VectorizationResult> mlir::linalg::vectorize( |
| 2614 | + RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes, |
| 2615 | + ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract, |
| 2616 | + bool flatten1DDepthwiseConv, bool createNamedContraction) { |
2536 | 2617 | LDBG("Attempting to vectorize:\n" << *op << "\n");
|
2537 | 2618 | LDBG("Input vector sizes: ");
|
2538 | 2619 | LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
|
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
|
2578 | 2659 | return failure();
|
2579 | 2660 | }
|
2580 | 2661 |
|
| 2662 | + // For simplicity, contraction vectorization is limited to linalg |
| 2663 | + // named ops. Generic op is ignored as not every arbitrary |
| 2664 | + // contraction body can be expressed by a vector.contract. |
| 2665 | + if (createNamedContraction && |
| 2666 | + isa<ContractionOpInterface>(linalgOp.getOperation())) { |
| 2667 | + // Attempt vectorizing directly into a named contraction. |
| 2668 | + // In case of failure, fall back to the generic path. |
| 2669 | + LogicalResult res = vectorizeAsLinalgContraction( |
| 2670 | + rewriter, state, linalgOp, results); |
| 2671 | + if (succeeded(res)) |
| 2672 | + return success(); |
| 2673 | + |
| 2674 | + LDBG("Failed to vectorize as a named contraction.\n"); |
| 2675 | + } |
| 2676 | + |
2581 | 2677 | LDBG("Vectorize generic by broadcasting to the canonical vector "
|
2582 | 2678 | "shape\n");
|
2583 | 2679 |
|
|
0 commit comments