Skip to content

[MLIR][AArch64] Refactor lowering of vector.contract to Neon I8MM #149810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

momchil-velikov
Copy link
Collaborator

This patch refactors the pattern in Transforms/LowerContractionToNeonI8MMPattern.cpp using similar approach as in #147052 to prepare for adding BF16 support.

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir-neon

Author: Momchil Velikov (momchil-velikov)

Changes

This patch refactors the pattern in Transforms/LowerContractionToNeonI8MMPattern.cpp using similar approach as in #147052 to prepare for adding BF16 support.


Patch is 22.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149810.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp (+247-184)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index a07be7801869f..0738de6c7788c 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -31,23 +31,15 @@ using namespace mlir;
 using namespace mlir::arm_neon;
 
 namespace {
-
-/// Return the shaped type with new element type.
-static Type matchContainerType(Type element, Type container) {
-  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
-    return shapedTy.clone(element);
-  }
-  return element;
-}
-
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
 // Return success only for extensions from `iN` (N <= 8) to `i32`.
 template <typename Op>
 std::optional<Value> getExtOperand(Value v) {
@@ -85,202 +77,186 @@ std::optional<Value> getExtOperand(Value v) {
   return inOp;
 }
 
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
-  Signed,      // smmla
-  Unsigned,    // ummla
-  Mixed,       // usmmla
-  MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+                           bool signExt, PatternRewriter &rewriter) {
+  Type targetTy = srcTy.clone(rewriter.getI8Type());
+  return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+                 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
 
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
-                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
-  switch (op) {
-  case MMLA::Signed:
-    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Unsigned:
-    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Mixed:
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
-                                                     rhs);
-  case MMLA::MixedSwapped:
-    // The accumulator comes transposed and the result will be transposed
-    // later, so all we have to do here is swap the operands.
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
-                                                     lhs);
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    Signed,      // smmla
+    Unsigned,    // ummla
+    Mixed,       // usmmla
+    MixedSwapped // usmmla with LHS and RHS swapped
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // The operand tiles. These are not necessarily the operands of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // The dimensions logically corresponding to matrix multiplication of
+  // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+  // shapes, for example RHS could be NxK with a transposing indexing map.
+  int64_t dimM = 0;
+  int64_t dimN = 0;
+  int64_t dimK = 0;
+
+  // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+  SmallVector<int64_t> iterationBounds;
+
+  // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+  // of this shape.
+  SmallVector<int64_t> subTileShape;
+
+  // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs) {
+    switch (mmlaOp) {
+    case MMLA::Signed:
+      return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Unsigned:
+      return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Mixed:
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       lhs, rhs);
+    case MMLA::MixedSwapped:
+      // The accumulator comes transposed and the result will be transposed
+      // later, so all we have to do here is swap the operands.
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       rhs, lhs);
+    case MMLA::Nop:
+      llvm_unreachable("Uninitialized operation type");
+    }
   }
-}
 
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
-    // Note: RHS is not transposed.
-    mlir::VectorType lhsType = op.getLhsType();
-    mlir::VectorType rhsType = op.getRhsType();
+  // Check common preconditions for applying the patterns and initialize
+  // logical dimensions.
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    // Check iterator types for matrix multiplication.
+    SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+    if (!((itTypes.size() == 3 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::parallel &&
+            itTypes[2] == vector::IteratorType::reduction)) ||
+          (itTypes.size() == 2 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::reduction))))
+      return rewriter.notifyMatchFailure(
+          op, "iterator types do not correspond to matrix multiplication");
+
     // Avoid 0-D vectors and 1-D rhs:
-    if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
-      return failure();
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+        rhsType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
     // This codegen does not work for scalable vectors. Return failure so this
     // pattern is not accidentally chosen over patterns that lower to ArmSVE.
     if (lhsType.isScalable() || rhsType.isScalable())
-      return failure();
-    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
-    auto dimN = rhsType.getDimSize(0);
-    auto dimK = rhsType.getDimSize(1);
-    bool isVecmat = dimM == 1 ? true : false;
-    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
-        rhsType.getDimSize(rhsType.getRank() - 1)) {
-      return failure(); // dimK mismatch
-    }
-    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
-    // tiling.
-    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
-      return failure();
-    }
-
-    // Check iterator types for contract. All iterators except inner-most
-    // dimension must be parallel.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
-                                        vector::IteratorType::reduction) {
-      return failure();
-    }
-    if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
-                     [](vector::IteratorType iteratorType) {
-                       return iteratorType != vector::IteratorType::parallel;
-                     })) {
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "Not applicable to scalable vectors");
+
+    // Initialize dimensions and check for a matching K dimension.
+    dimM = lhsType.getDimSize(0);
+    dimN = rhsType.getDimSize(0);
+    dimK = rhsType.getDimSize(1);
+
+    int64_t lhsDimK;
+    if (lhsType.getRank() == 1) {
+      dimM = 1;
+      lhsDimK = lhsType.getDimSize(0);
+    } else {
+      lhsDimK = lhsType.getDimSize(1);
     }
 
-    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
-    // values before the extension. All four signed/unsigned combinations for
-    // input operands are supported, but they are lowered to different
-    // operations. Determine which is the appropriate operation to lower to.
-    MMLA mmlaOp = MMLA::Signed;
-    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
-    if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
-      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
-    }
-    if (!maybeLhs)
-      return failure();
+    if (lhsDimK != dimK)
+      return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
 
-    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
-    if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
-    } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
-      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
-    }
-    if (!maybeRhs)
-      return failure();
+    return success();
+  }
 
-    Value origLhs = *maybeLhs;
-    Value origRhs = *maybeRhs;
-
-    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
-    // following neon instruction. Check inputs for extsi are <=i8
-    Value extLhs;
-    Value extRhs;
-    if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
-      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetLhsExtTy =
-            matchContainerType(rewriter.getI8Type(), lhsExtInType);
-        if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
-          extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-        else
-          extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-      }
-    }
-    if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
-      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetRhsExtTy =
-            matchContainerType(rewriter.getI8Type(), rhsExtInType);
-        if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
-          extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-        else
-          extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-      }
-    }
+public:
+  void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+    // Create some convenience types.
+    auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+    auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+    auto inputExpandedType =
+        VectorType::get({2, subTileShape.back()}, inputElementType);
+    auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+    // One-dimensional representation of logical sub-tiles as required by the
+    // ArmNeon ops.
+    auto collapsedInputType =
+        VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+    auto collapsedOutputType =
+        VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+    // Get indexing maps for a more concise/convenient access.
+    auto indexingMaps = op.getIndexingMapsArray();
+    AffineMap &lhsPermutationMap = indexingMaps[0];
+    AffineMap &rhsPermutationMap = indexingMaps[1];
+    AffineMap &accPermutationMap = indexingMaps[2];
 
-    if (!extLhs || !extRhs) {
-      return failure();
-    }
+    Location loc = op.getLoc();
 
     // Initial accumulator for the final result. This is the un-tiled result if
     // tiling is done.
     Value result = rewriter.create<arith::ConstantOp>(
         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
 
-    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
-    SmallVector<int64_t> smmlaShape = {2, 8};
-    SmallVector<int64_t> loopOrder = {0, 1};
-    if (unrolledSize.size() == 3) {
-      smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+    SmallVector<int64_t, 3> loopOrder = {0, 1};
+    if (iterationBounds.size() == 3)
       loopOrder.push_back(2);
-    }
 
     // Keep track of the previous accumulator when tiling over K.
     Value kAcc;
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+         StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
-        SmallVector<int64_t> operandShape =
-            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandShape = applyPermutationMap(
+            permutationMap, ArrayRef<int64_t>(subTileShape));
         SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
         return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, operand, operandOffsets, operandShape, operandStrides);
       };
 
       // Extract tiled lhs, rhs, and acc
-      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
       SmallVector<int64_t> lhsOffsets =
           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
-      AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+      Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
       SmallVector<int64_t> rhsOffsets =
           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
-      AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+      Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
       SmallVector<int64_t> accOffsets =
           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledAcc =
-          extractOperand(op.getAcc(), accPermutationMap, accOffsets);
-
-      auto inputElementType =
-          cast<ShapedType>(tiledLhs.getType()).getElementType();
-      auto accElementType =
-          cast<ShapedType>(tiledAcc.getType()).getElementType();
-      auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
-      auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+      Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
 
       // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
-      // rows along dimM. Expand their shapes to match the smmla op.
-      if (isVecmat) {
-        auto expandForSMMLA = [&](Value tiledOperand,
-                                  VectorType expandedTypeType) {
+      // rows along dimM. Expand their shapes to match the ArmNeon op.
+      if (dimM == 1) {
+        auto expandRowVector = [&](Value tiledOperand,
+                                   VectorType expandedTypeType) {
           auto emptyOperand = rewriter.create<arith::ConstantOp>(
               loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
           SmallVector<int64_t> offsets(
@@ -290,8 +266,8 @@ class LowerContractionToNeonI8MMPattern
           return rewriter.createOrFold<vector::InsertStridedSliceOp>(
               loc, tiledOperand, emptyOperand, offsets, strides);
         };
-        tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
-        tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+        tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+        tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
       }
 
       // Transpose ACC if doing signed by unsigned multiplication, because we're
@@ -301,15 +277,11 @@ class LowerContractionToNeonI8MMPattern
         tiledAcc = rewriter.create<vector::TransposeOp>(
             loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
 
-      // Collapse tiled operands to 1D vectors required by smmla intrinsic
-      auto collapsedInputType =
-          VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+      // Collapse tiled operands to 1D vectors required by the ArmNeon ops
       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
-      auto collapsedOutputType =
-          VectorType::get(outputExpandedType.getNumElements(), accElementType);
 
       bool initialKAcc = offsets.back() == 0;
       Value collapsedRes;
@@ -321,8 +293,8 @@ class LowerContractionToNeonI8MMPattern
       }
 
       // Insert contract op
-      kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
-                        collapsedRes, collapsedLhs, collapsedRhs);
+      kAcc =
+          createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -336,9 +308,8 @@ class LowerContractionToNeonI8MMPattern
 
       // With vecmat, only one row of tiled ACC can be inserted into the final
       // result
-      if (isVecmat) {
+      if (dimM == 1)
         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
-      }
 
       // Insert the tiled result back into the non tiled result of the
       // contract op.
@@ -349,6 +320,98 @@ class LowerContractionToNeonI8MMPattern
     }
 
     rewriter.replaceOp(op, result);
+  }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+      return failure();
+
+    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+    // tiling.
+    if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+      return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+    // Check inputs are sign-/zero- extensions from iN (N <= 8...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

This patch refactors the pattern in Transforms/LowerContractionToNeonI8MMPattern.cpp using similar approach as in #147052 to prepare for adding BF16 support.


Patch is 22.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149810.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp (+247-184)
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index a07be7801869f..0738de6c7788c 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -31,23 +31,15 @@ using namespace mlir;
 using namespace mlir::arm_neon;
 
 namespace {
-
-/// Return the shaped type with new element type.
-static Type matchContainerType(Type element, Type container) {
-  if (auto shapedTy = dyn_cast<ShapedType>(container)) {
-    return shapedTy.clone(element);
-  }
-  return element;
-}
-
-// Get the operand of a `vector.contract`. This function is intended to abstract
-// away from the particular way a value is extended before feeding it into the
-// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
-// (for implicit sign-extension see `vector.contract` documentation).
-//
-// The template parameter `Op` indicates the extension operation (explicit or
-// implicit) for which we are checking.
-//
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
 // Return success only for extensions from `iN` (N <= 8) to `i32`.
 template <typename Op>
 std::optional<Value> getExtOperand(Value v) {
@@ -85,202 +77,186 @@ std::optional<Value> getExtOperand(Value v) {
   return inOp;
 }
 
-// Designate the operation (resp. instruction) used to do sub-tile matrix
-// multiplications.
-enum class MMLA {
-  Signed,      // smmla
-  Unsigned,    // ummla
-  Mixed,       // usmmla
-  MixedSwapped // usmmla with LHS and RHS swapped
-};
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+                           bool signExt, PatternRewriter &rewriter) {
+  Type targetTy = srcTy.clone(rewriter.getI8Type());
+  return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+                 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
 
-// Create the matrix mulitply and accumulate operation according to `op`.
-Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
-                 mlir::Type accType, Value acc, Value lhs, Value rhs) {
-  switch (op) {
-  case MMLA::Signed:
-    return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Unsigned:
-    return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
-                                                    rhs);
-  case MMLA::Mixed:
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
-                                                     rhs);
-  case MMLA::MixedSwapped:
-    // The accumulator comes transposed and the result will be transposed
-    // later, so all we have to do here is swap the operands.
-    return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
-                                                     lhs);
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    Signed,      // smmla
+    Unsigned,    // ummla
+    Mixed,       // usmmla
+    MixedSwapped // usmmla with LHS and RHS swapped
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // The operand tiles. These are not necessarily the operands of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // The dimensions logically corresponding to matrix multiplication of
+  // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+  // shapes, for example RHS could be NxK with a transposing indexing map.
+  int64_t dimM = 0;
+  int64_t dimN = 0;
+  int64_t dimK = 0;
+
+  // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+  SmallVector<int64_t> iterationBounds;
+
+  // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+  // of this shape.
+  SmallVector<int64_t> subTileShape;
+
+  // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs) {
+    switch (mmlaOp) {
+    case MMLA::Signed:
+      return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Unsigned:
+      return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::Mixed:
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       lhs, rhs);
+    case MMLA::MixedSwapped:
+      // The accumulator comes transposed and the result will be transposed
+      // later, so all we have to do here is swap the operands.
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       rhs, lhs);
+    case MMLA::Nop:
+      llvm_unreachable("Uninitialized operation type");
+    }
   }
-}
 
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
-/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
-/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
-/// necessary, a single smmla instruction is emitted.
-class LowerContractionToNeonI8MMPattern
-    : public OpRewritePattern<vector::ContractionOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
-    // Note: RHS is not transposed.
-    mlir::VectorType lhsType = op.getLhsType();
-    mlir::VectorType rhsType = op.getRhsType();
+  // Check common preconditions for applying the patterns and initialize
+  // logical dimensions.
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    // Check iterator types for matrix multiplication.
+    SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+    if (!((itTypes.size() == 3 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::parallel &&
+            itTypes[2] == vector::IteratorType::reduction)) ||
+          (itTypes.size() == 2 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::reduction))))
+      return rewriter.notifyMatchFailure(
+          op, "iterator types do not correspond to matrix multiplication");
+
     // Avoid 0-D vectors and 1-D rhs:
-    if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
-      return failure();
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+        rhsType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
     // This codegen does not work for scalable vectors. Return failure so this
     // pattern is not accidentally chosen over patterns that lower to ArmSVE.
     if (lhsType.isScalable() || rhsType.isScalable())
-      return failure();
-    auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
-    auto dimN = rhsType.getDimSize(0);
-    auto dimK = rhsType.getDimSize(1);
-    bool isVecmat = dimM == 1 ? true : false;
-    if (lhsType.getDimSize(lhsType.getRank() - 1) !=
-        rhsType.getDimSize(rhsType.getRank() - 1)) {
-      return failure(); // dimK mismatch
-    }
-    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
-    // tiling.
-    if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
-      return failure();
-    }
-
-    // Check iterator types for contract. All iterators except inner-most
-    // dimension must be parallel.
-    auto iteratorTypes = op.getIteratorTypesArray();
-    if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
-                                        vector::IteratorType::reduction) {
-      return failure();
-    }
-    if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
-                     [](vector::IteratorType iteratorType) {
-                       return iteratorType != vector::IteratorType::parallel;
-                     })) {
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "Not applicable to scalable vectors");
+
+    // Initialize dimensions and check for a matching K dimension.
+    dimM = lhsType.getDimSize(0);
+    dimN = rhsType.getDimSize(0);
+    dimK = rhsType.getDimSize(1);
+
+    int64_t lhsDimK;
+    if (lhsType.getRank() == 1) {
+      dimM = 1;
+      lhsDimK = lhsType.getDimSize(0);
+    } else {
+      lhsDimK = lhsType.getDimSize(1);
     }
 
-    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
-    // values before the extension. All four signed/unsigned combinations for
-    // input operands are supported, but they are lowered to different
-    // operations. Determine which is the appropriate operation to lower to.
-    MMLA mmlaOp = MMLA::Signed;
-    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
-    if (!maybeLhs) {
-      mmlaOp = MMLA::Unsigned;
-      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
-    }
-    if (!maybeLhs)
-      return failure();
+    if (lhsDimK != dimK)
+      return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
 
-    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
-    if (maybeRhs) {
-      if (mmlaOp == MMLA::Unsigned)
-        mmlaOp = MMLA::Mixed;
-    } else {
-      if (mmlaOp == MMLA::Signed)
-        mmlaOp = MMLA::MixedSwapped;
-      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
-    }
-    if (!maybeRhs)
-      return failure();
+    return success();
+  }
 
-    Value origLhs = *maybeLhs;
-    Value origRhs = *maybeRhs;
-
-    // Match any iX to i32 for X<8 then turn into an i8 output. Feed into
-    // following neon instruction. Check inputs for extsi are <=i8
-    Value extLhs;
-    Value extRhs;
-    if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
-      if (lhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetLhsExtTy =
-            matchContainerType(rewriter.getI8Type(), lhsExtInType);
-        if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
-          extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-        else
-          extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
-                                                         origLhs);
-      }
-    }
-    if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
-      if (rhsExtInType.getElementTypeBitWidth() <= 8) {
-        Type targetRhsExtTy =
-            matchContainerType(rewriter.getI8Type(), rhsExtInType);
-        if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
-          extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-        else
-          extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
-                                                         origRhs);
-      }
-    }
+public:
+  void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+    // Create some convenience types.
+    auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+    auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+    auto inputExpandedType =
+        VectorType::get({2, subTileShape.back()}, inputElementType);
+    auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+    // One-dimensional representation of logical sub-tiles as required by the
+    // ArmNeon ops.
+    auto collapsedInputType =
+        VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+    auto collapsedOutputType =
+        VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+    // Get indexing maps for a more concise/convenient access.
+    auto indexingMaps = op.getIndexingMapsArray();
+    AffineMap &lhsPermutationMap = indexingMaps[0];
+    AffineMap &rhsPermutationMap = indexingMaps[1];
+    AffineMap &accPermutationMap = indexingMaps[2];
 
-    if (!extLhs || !extRhs) {
-      return failure();
-    }
+    Location loc = op.getLoc();
 
     // Initial accumulator for the final result. This is the un-tiled result if
     // tiling is done.
     Value result = rewriter.create<arith::ConstantOp>(
         loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
 
-    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
-    SmallVector<int64_t> smmlaShape = {2, 8};
-    SmallVector<int64_t> loopOrder = {0, 1};
-    if (unrolledSize.size() == 3) {
-      smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+    SmallVector<int64_t, 3> loopOrder = {0, 1};
+    if (iterationBounds.size() == 3)
       loopOrder.push_back(2);
-    }
 
     // Keep track of the previous accumulator when tiling over K.
     Value kAcc;
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+         StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
       // Helper to compute the new shape of each operand and extract the slice.
       auto extractOperand = [&](Value operand, AffineMap permutationMap,
                                 ArrayRef<int64_t> operandOffsets) {
-        SmallVector<int64_t> operandShape =
-            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandShape = applyPermutationMap(
+            permutationMap, ArrayRef<int64_t>(subTileShape));
         SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
         return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, operand, operandOffsets, operandShape, operandStrides);
       };
 
       // Extract tiled lhs, rhs, and acc
-      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
       SmallVector<int64_t> lhsOffsets =
           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
-      AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+      Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
       SmallVector<int64_t> rhsOffsets =
           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
-      AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+      Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
       SmallVector<int64_t> accOffsets =
           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
-      Value tiledAcc =
-          extractOperand(op.getAcc(), accPermutationMap, accOffsets);
-
-      auto inputElementType =
-          cast<ShapedType>(tiledLhs.getType()).getElementType();
-      auto accElementType =
-          cast<ShapedType>(tiledAcc.getType()).getElementType();
-      auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
-      auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+      Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
 
       // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
-      // rows along dimM. Expand their shapes to match the smmla op.
-      if (isVecmat) {
-        auto expandForSMMLA = [&](Value tiledOperand,
-                                  VectorType expandedTypeType) {
+      // rows along dimM. Expand their shapes to match the ArmNeon op.
+      if (dimM == 1) {
+        auto expandRowVector = [&](Value tiledOperand,
+                                   VectorType expandedTypeType) {
           auto emptyOperand = rewriter.create<arith::ConstantOp>(
               loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
           SmallVector<int64_t> offsets(
@@ -290,8 +266,8 @@ class LowerContractionToNeonI8MMPattern
           return rewriter.createOrFold<vector::InsertStridedSliceOp>(
               loc, tiledOperand, emptyOperand, offsets, strides);
         };
-        tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
-        tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+        tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+        tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
       }
 
       // Transpose ACC if doing signed by unsigned multiplication, because we're
@@ -301,15 +277,11 @@ class LowerContractionToNeonI8MMPattern
         tiledAcc = rewriter.create<vector::TransposeOp>(
             loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
 
-      // Collapse tiled operands to 1D vectors required by smmla intrinsic
-      auto collapsedInputType =
-          VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+      // Collapse tiled operands to 1D vectors required by the ArmNeon ops
       auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledLhs.getLoc(), collapsedInputType, tiledLhs);
       auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
           tiledRhs.getLoc(), collapsedInputType, tiledRhs);
-      auto collapsedOutputType =
-          VectorType::get(outputExpandedType.getNumElements(), accElementType);
 
       bool initialKAcc = offsets.back() == 0;
       Value collapsedRes;
@@ -321,8 +293,8 @@ class LowerContractionToNeonI8MMPattern
       }
 
       // Insert contract op
-      kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
-                        collapsedRes, collapsedLhs, collapsedRhs);
+      kAcc =
+          createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
 
       // Reshape output back to 2D
       Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
@@ -336,9 +308,8 @@ class LowerContractionToNeonI8MMPattern
 
       // With vecmat, only one row of tiled ACC can be inserted into the final
       // result
-      if (isVecmat) {
+      if (dimM == 1)
         tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
-      }
 
       // Insert the tiled result back into the non tiled result of the
       // contract op.
@@ -349,6 +320,98 @@ class LowerContractionToNeonI8MMPattern
     }
 
     rewriter.replaceOp(op, result);
+  }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+      return failure();
+
+    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+    // tiling.
+    if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+      return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+    // Check inputs are sign-/zero- extensions from iN (N <= 8...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

Note to other reviewers - this was extracted from #148198, which is where I reviewed this change originally. Also, this update makes the NEON patterns more aligned with https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp, so it's a step towards #145559.

Thanks @momchil-velikov , just one small nit inline.

}
}
public:
void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
void lower(vector::ContractionOp op, PatternRewriter &rewriter) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.

This patch refactors the pattern in `Transforms/LowerContractionToNeonI8MMPattern.cpp`
using similar approach as in #147052
to prepare for adding BF16 support.
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-refactor branch from 7acb647 to 0f64f86 Compare July 22, 2025 14:22
@momchil-velikov momchil-velikov merged commit 90e632e into main Jul 22, 2025
8 checks passed
@momchil-velikov momchil-velikov deleted the users/momchil-velikov/vector-contract-i8mm-refactor branch July 22, 2025 15:31
momchil-velikov added a commit that referenced this pull request Jul 23, 2025
…148198)

This builds upon the framework established by
#149810
to add lowering to `bfmmla`.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jul 23, 2025
…erations (#148198)

This builds upon the framework established by
llvm/llvm-project#149810
to add lowering to `bfmmla`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants