diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp index 15de736480c5e..59acb362191a7 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp @@ -31,23 +31,15 @@ using namespace mlir; using namespace mlir::arm_neon; namespace { - -/// Return the shaped type with new element type. -static Type matchContainerType(Type element, Type container) { - if (auto shapedTy = dyn_cast(container)) { - return shapedTy.clone(element); - } - return element; -} - -// Get the operand of a `vector.contract`. This function is intended to abstract -// away from the particular way a value is extended before feeding it into the -// `vector.contract` - via zero-extend or an explicit or implicit sign-extend -// (for implicit sign-extension see `vector.contract` documentation). -// -// The template parameter `Op` indicates the extension operation (explicit or -// implicit) for which we are checking. -// +/// Get the operand of a `vector.contract`. This function is intended to +/// abstract away from the particular way a value is extended before feeding it +/// into the `vector.contract` - via zero-extend or an explicit or implicit +/// sign-extend (for implicit sign-extension see `vector.contract` +/// documentation). +/// +/// The template parameter `Op` indicates the extension operation (explicit or +/// implicit) for which we are checking. +/// // Return success only for extensions from `iN` (N <= 8) to `i32`. template std::optional getExtOperand(Value v) { @@ -85,202 +77,186 @@ std::optional getExtOperand(Value v) { return inOp; } -// Designate the operation (resp. instruction) used to do sub-tile matrix -// multiplications. -enum class MMLA { - Signed, // smmla - Unsigned, // ummla - Mixed, // usmmla - MixedSwapped // usmmla with LHS and RHS swapped -}; +/// Helper function to extend a vector with elements iN, N < 8 to +/// a vector of i8. Do sign extension if the parameter `signExt` is true, +/// zero extension otherwise. +Value extendSmallIntVector(Location loc, VectorType srcTy, Value val, + bool signExt, PatternRewriter &rewriter) { + Type targetTy = srcTy.clone(rewriter.getI8Type()); + return signExt ? rewriter.createOrFold(loc, targetTy, val) + : rewriter.createOrFold(loc, targetTy, val); +} -// Create the matrix mulitply and accumulate operation according to `op`. -Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, - mlir::Type accType, Value acc, Value lhs, Value rhs) { - switch (op) { - case MMLA::Signed: - return rewriter.createOrFold(loc, accType, acc, lhs, - rhs); - case MMLA::Unsigned: - return rewriter.createOrFold(loc, accType, acc, lhs, - rhs); - case MMLA::Mixed: - return rewriter.createOrFold(loc, accType, acc, lhs, - rhs); - case MMLA::MixedSwapped: - // The accumulator comes transposed and the result will be transposed - // later, so all we have to do here is swap the operands. - return rewriter.createOrFold(loc, accType, acc, rhs, - lhs); +class VectorContractRewriter { +protected: + // Designate the operation (resp. instruction) used to do sub-tile matrix + // multiplications. + enum class MMLA { + Nop, + Signed, // smmla + Unsigned, // ummla + Mixed, // usmmla + MixedSwapped // usmmla with LHS and RHS swapped + }; + + // Lower-level operation to be emitted. + MMLA mmlaOp = MMLA::Nop; + + // The operand tiles. These are not necessarily the operands of + // `vector.contract`, for example they could be operands to `arith.extsi` + // that is in turn fed into `vector.contract`. + Value lhs; + Value rhs; + Value acc; + + // The dimensions logically corresponding to matrix multiplication of + // MxK * KxN -> MxN. The operands and the result do not necessarily have these + // shapes, for example RHS could be NxK with a transposing indexing map. + int64_t dimM = 0; + int64_t dimN = 0; + int64_t dimK = 0; + + // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`. + SmallVector iterationBounds; + + // Sub-tile shape. The algorithm handles operand shapes, which are multiples + // of this shape. + SmallVector subTileShape; + + // Create the matrix multiply and accumulate operation according to `mmlaOp`. + Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc, + Value lhs, Value rhs) { + switch (mmlaOp) { + case MMLA::Signed: + return rewriter.createOrFold(loc, acc.getType(), acc, + lhs, rhs); + case MMLA::Unsigned: + return rewriter.createOrFold(loc, acc.getType(), acc, + lhs, rhs); + case MMLA::Mixed: + return rewriter.createOrFold(loc, acc.getType(), acc, + lhs, rhs); + case MMLA::MixedSwapped: + // The accumulator comes transposed and the result will be transposed + // later, so all we have to do here is swap the operands. + return rewriter.createOrFold(loc, acc.getType(), acc, + rhs, lhs); + case MMLA::Nop: + llvm_unreachable("Uninitialized operation type"); + } } -} -/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile -/// any vector.contract into multiple smmla instructions with unrolling so long -/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM -/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is -/// necessary, a single smmla instruction is emitted. -class LowerContractionToNeonI8MMPattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. - // Note: RHS is not transposed. - mlir::VectorType lhsType = op.getLhsType(); - mlir::VectorType rhsType = op.getRhsType(); + // Check common preconditions for applying the patterns and initialize + // logical dimensions. + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + // Check iterator types for matrix multiplication. + SmallVector itTypes = op.getIteratorTypesArray(); + if (!((itTypes.size() == 3 && + (itTypes[0] == vector::IteratorType::parallel && + itTypes[1] == vector::IteratorType::parallel && + itTypes[2] == vector::IteratorType::reduction)) || + (itTypes.size() == 2 && + (itTypes[0] == vector::IteratorType::parallel && + itTypes[1] == vector::IteratorType::reduction)))) + return rewriter.notifyMatchFailure( + op, "iterator types do not correspond to matrix multiplication"); + // Avoid 0-D vectors and 1-D rhs: - if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) - return failure(); + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 || + rhsType.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Invalid operand rank"); + // This codegen does not work for scalable vectors. Return failure so this // pattern is not accidentally chosen over patterns that lower to ArmSVE. if (lhsType.isScalable() || rhsType.isScalable()) - return failure(); - auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); - auto dimN = rhsType.getDimSize(0); - auto dimK = rhsType.getDimSize(1); - bool isVecmat = dimM == 1 ? true : false; - if (lhsType.getDimSize(lhsType.getRank() - 1) != - rhsType.getDimSize(rhsType.getRank() - 1)) { - return failure(); // dimK mismatch - } - // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for - // tiling. - if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { - return failure(); - } - - // Check iterator types for contract. All iterators except inner-most - // dimension must be parallel. - auto iteratorTypes = op.getIteratorTypesArray(); - if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != - vector::IteratorType::reduction) { - return failure(); - } - if (llvm::any_of(ArrayRef(iteratorTypes).drop_back(1), - [](vector::IteratorType iteratorType) { - return iteratorType != vector::IteratorType::parallel; - })) { - return failure(); + return rewriter.notifyMatchFailure(op, + "Not applicable to scalable vectors"); + + // Initialize dimensions and check for a matching K dimension. + dimM = lhsType.getDimSize(0); + dimN = rhsType.getDimSize(0); + dimK = rhsType.getDimSize(1); + + int64_t lhsDimK; + if (lhsType.getRank() == 1) { + dimM = 1; + lhsDimK = lhsType.getDimSize(0); + } else { + lhsDimK = lhsType.getDimSize(1); } - // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the - // values before the extension. All four signed/unsigned combinations for - // input operands are supported, but they are lowered to different - // operations. Determine which is the appropriate operation to lower to. - MMLA mmlaOp = MMLA::Signed; - auto maybeLhs = getExtOperand(op.getLhs()); - if (!maybeLhs) { - mmlaOp = MMLA::Unsigned; - maybeLhs = getExtOperand(op.getLhs()); - } - if (!maybeLhs) - return failure(); + if (lhsDimK != dimK) + return rewriter.notifyMatchFailure(op, "Dimensions mismatch"); - auto maybeRhs = getExtOperand(op.getRhs()); - if (maybeRhs) { - if (mmlaOp == MMLA::Unsigned) - mmlaOp = MMLA::Mixed; - } else { - if (mmlaOp == MMLA::Signed) - mmlaOp = MMLA::MixedSwapped; - maybeRhs = getExtOperand(op.getRhs()); - } - if (!maybeRhs) - return failure(); + return success(); + } - Value origLhs = *maybeLhs; - Value origRhs = *maybeRhs; - - // Match any iX to i32 for X<8 then turn into an i8 output. Feed into - // following neon instruction. Check inputs for extsi are <=i8 - Value extLhs; - Value extRhs; - if (auto lhsExtInType = dyn_cast(origLhs.getType())) { - if (lhsExtInType.getElementTypeBitWidth() <= 8) { - Type targetLhsExtTy = - matchContainerType(rewriter.getI8Type(), lhsExtInType); - if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed) - extLhs = rewriter.createOrFold(loc, targetLhsExtTy, - origLhs); - else - extLhs = rewriter.createOrFold(loc, targetLhsExtTy, - origLhs); - } - } - if (auto rhsExtInType = dyn_cast(origRhs.getType())) { - if (rhsExtInType.getElementTypeBitWidth() <= 8) { - Type targetRhsExtTy = - matchContainerType(rewriter.getI8Type(), rhsExtInType); - if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed) - extRhs = rewriter.createOrFold(loc, targetRhsExtTy, - origRhs); - else - extRhs = rewriter.createOrFold(loc, targetRhsExtTy, - origRhs); - } - } +public: + void lower(vector::ContractionOp op, PatternRewriter &rewriter) { + // Create some convenience types. + auto inputElementType = cast(lhs.getType()).getElementType(); + auto accElementType = cast(acc.getType()).getElementType(); + auto inputExpandedType = + VectorType::get({2, subTileShape.back()}, inputElementType); + auto outputExpandedType = VectorType::get({2, 2}, accElementType); + + // One-dimensional representation of logical sub-tiles as required by the + // ArmNeon ops. + auto collapsedInputType = + VectorType::get(inputExpandedType.getNumElements(), inputElementType); + auto collapsedOutputType = + VectorType::get(outputExpandedType.getNumElements(), accElementType); + + // Get indexing maps for a more concise/convenient access. + auto indexingMaps = op.getIndexingMapsArray(); + AffineMap &lhsPermutationMap = indexingMaps[0]; + AffineMap &rhsPermutationMap = indexingMaps[1]; + AffineMap &accPermutationMap = indexingMaps[2]; - if (!extLhs || !extRhs) { - return failure(); - } + Location loc = op.getLoc(); // Initial accumulator for the final result. This is the un-tiled result if // tiling is done. Value result = rewriter.create( loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); - SmallVector unrolledSize = *op.getShapeForUnroll(); - SmallVector smmlaShape = {2, 8}; - SmallVector loopOrder = {0, 1}; - if (unrolledSize.size() == 3) { - smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); + SmallVector loopOrder = {0, 1}; + if (iterationBounds.size() == 3) loopOrder.push_back(2); - } // Keep track of the previous accumulator when tiling over K. Value kAcc; for (SmallVector offsets : - StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { + StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) { // Helper to compute the new shape of each operand and extract the slice. auto extractOperand = [&](Value operand, AffineMap permutationMap, ArrayRef operandOffsets) { - SmallVector operandShape = - applyPermutationMap(permutationMap, ArrayRef(smmlaShape)); + SmallVector operandShape = applyPermutationMap( + permutationMap, ArrayRef(subTileShape)); SmallVector operandStrides(operandOffsets.size(), 1); return rewriter.createOrFold( loc, operand, operandOffsets, operandShape, operandStrides); }; // Extract tiled lhs, rhs, and acc - AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; SmallVector lhsOffsets = applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); - Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets); - AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; + Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets); SmallVector rhsOffsets = applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); - Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets); - AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; + Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets); SmallVector accOffsets = applyPermutationMap(accPermutationMap, ArrayRef(offsets)); - Value tiledAcc = - extractOperand(op.getAcc(), accPermutationMap, accOffsets); - - auto inputElementType = - cast(tiledLhs.getType()).getElementType(); - auto accElementType = - cast(tiledAcc.getType()).getElementType(); - auto inputExpandedType = VectorType::get({2, 8}, inputElementType); - auto outputExpandedType = VectorType::get({2, 2}, accElementType); + Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets); // With vecmat, tiled LHS and ACC will contain only one of 2 necessary - // rows along dimM. Expand their shapes to match the smmla op. - if (isVecmat) { - auto expandForSMMLA = [&](Value tiledOperand, - VectorType expandedTypeType) { + // rows along dimM. Expand their shapes to match the ArmNeon op. + if (dimM == 1) { + auto expandRowVector = [&](Value tiledOperand, + VectorType expandedTypeType) { auto emptyOperand = rewriter.create( loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); SmallVector offsets( @@ -290,8 +266,8 @@ class LowerContractionToNeonI8MMPattern return rewriter.createOrFold( loc, tiledOperand, emptyOperand, offsets, strides); }; - tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); - tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); + tiledLhs = expandRowVector(tiledLhs, inputExpandedType); + tiledAcc = expandRowVector(tiledAcc, outputExpandedType); } // Transpose ACC if doing signed by unsigned multiplication, because we're @@ -301,15 +277,11 @@ class LowerContractionToNeonI8MMPattern tiledAcc = rewriter.create( loc, tiledAcc, ArrayRef({1, 0})); - // Collapse tiled operands to 1D vectors required by smmla intrinsic - auto collapsedInputType = - VectorType::get(inputExpandedType.getNumElements(), inputElementType); + // Collapse tiled operands to 1D vectors required by the ArmNeon ops auto collapsedLhs = rewriter.createOrFold( tiledLhs.getLoc(), collapsedInputType, tiledLhs); auto collapsedRhs = rewriter.createOrFold( tiledRhs.getLoc(), collapsedInputType, tiledRhs); - auto collapsedOutputType = - VectorType::get(outputExpandedType.getNumElements(), accElementType); bool initialKAcc = offsets.back() == 0; Value collapsedRes; @@ -321,8 +293,8 @@ class LowerContractionToNeonI8MMPattern } // Insert contract op - kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(), - collapsedRes, collapsedLhs, collapsedRhs); + kAcc = + createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs); // Reshape output back to 2D Value tiledRes = rewriter.createOrFold( @@ -336,9 +308,8 @@ class LowerContractionToNeonI8MMPattern // With vecmat, only one row of tiled ACC can be inserted into the final // result - if (isVecmat) { + if (dimM == 1) tiledRes = rewriter.createOrFold(loc, tiledRes, 0); - } // Insert the tiled result back into the non tiled result of the // contract op. @@ -349,6 +320,98 @@ class LowerContractionToNeonI8MMPattern } rewriter.replaceOp(op, result); + } +}; + +class VectorContractRewriterI8MM : public VectorContractRewriter { +public: + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + if (failed(VectorContractRewriter::matchAndInit(op, rewriter))) + return failure(); + + // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for + // tiling. + if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0) + return rewriter.notifyMatchFailure(op, "Unsupported operand shapes"); + + // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the + // values before the extension. All four signed/unsigned combinations for + // input operands are supported, but they are lowered to different + // operations. Determine which is the appropriate operation to lower to. + mmlaOp = MMLA::Signed; + auto maybeLhs = getExtOperand(op.getLhs()); + if (!maybeLhs) { + mmlaOp = MMLA::Unsigned; + maybeLhs = getExtOperand(op.getLhs()); + } + if (!maybeLhs) + return rewriter.notifyMatchFailure( + op, "LHS is not a sign- or zero- extended iN, N <= 8"); + + auto maybeRhs = getExtOperand(op.getRhs()); + if (maybeRhs) { + if (mmlaOp == MMLA::Unsigned) + mmlaOp = MMLA::Mixed; + } else { + if (mmlaOp == MMLA::Signed) + mmlaOp = MMLA::MixedSwapped; + maybeRhs = getExtOperand(op.getRhs()); + } + + if (!maybeRhs) + return rewriter.notifyMatchFailure( + op, "RHS is not a sign- or zero- extended iN, N <= 8"); + + lhs = *maybeLhs; + rhs = *maybeRhs; + acc = op.getAcc(); + + // Extend inputs from iN, N < 8 to i8. + Location loc = op.getLoc(); + auto lhsExtInType = cast(lhs.getType()); + if (lhsExtInType.getElementTypeBitWidth() < 8) + lhs = extendSmallIntVector(loc, lhsExtInType, lhs, + /* signExt */ mmlaOp == MMLA::Signed || + mmlaOp == MMLA::Mixed, + rewriter); + + auto rhsExtInType = cast(rhs.getType()); + if (rhsExtInType.getElementTypeBitWidth() < 8) + + rhs = extendSmallIntVector(loc, rhsExtInType, rhs, + /* signExt */ mmlaOp != MMLA::Unsigned && + mmlaOp != MMLA::Mixed, + rewriter); + + // Initialize parameters for unrolling. + iterationBounds = *op.getShapeForUnroll(); + if (iterationBounds.size() == 3) + subTileShape = SmallVector({dimM == 1 ? 1 : 2, 2, 8}); + else + subTileShape = SmallVector({2, 8}); + + return success(); + } +}; + +/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile +/// any vector.contract into multiple smmla instructions with unrolling so long +/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM +/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is +/// necessary, a single smmla instruction is emitted. +class LowerContractionToNeonI8MMPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + VectorContractRewriterI8MM vcr; + if (failed(vcr.matchAndInit(op, rewriter))) + return failure(); + vcr.lower(op, rewriter); + return success(); } };