Skip to content

Commit 442e29a

Browse files
[MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations
1 parent 5ec6287 commit 442e29a

File tree

11 files changed

+531
-36
lines changed

11 files changed

+531
-36
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14371437
"bool", /*default=*/"false",
14381438
"Enables the use of Arm FEAT_I8MM instructions while lowering "
14391439
"the vector dialect.">,
1440+
Option<"armBF16", "enable-arm-bf16",
1441+
"bool", /*default=*/"false",
1442+
"Enables the use of Arm FEAT_BF16 instructions while lowering "
1443+
"the vector dialect.">,
14401444
Option<"x86Vector", "enable-x86vector",
14411445
"bool", /*default=*/"false",
14421446
"Enables the use of X86Vector dialect while lowering the vector "

mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
1717
"apply_patterns.arm_neon.vector_contract_to_i8mm",
1818
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
1919
let description = [{
20-
Indicates that vector.contract operations should be lowered to
21-
finer-grained vector primitives from the ArmNeon dialect.
20+
Indicates that vector contract operations should be lowered to
21+
to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
22+
}];
23+
24+
let assemblyFormat = "attr-dict";
25+
}
26+
27+
def ApplyArmNeonContractionToBFMMLAPatternsOp
28+
: Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
29+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
30+
let description = [{
31+
Indicates that vector contract operations should be lowered to
32+
to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
2233
}];
2334

2435
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/ArmNeon/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace mlir {
1313
class RewritePatternSet;
1414

1515
namespace arm_neon {
16-
void populateLowerContractionToNeonI8MMPatternPatterns(
17-
RewritePatternSet &patterns);
16+
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
17+
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
1818
} // namespace arm_neon
1919

2020
} // namespace mlir

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8484
populateVectorGatherLoweringPatterns(patterns);
8585
if (armI8MM) {
8686
if (armNeon)
87-
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
87+
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
8888
if (armSVE)
8989
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
9090
}
91+
if (armBF16 && armNeon)
92+
arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
9193
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
9294
}
9395

mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ using namespace mlir;
2020

2121
void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
2222
RewritePatternSet &patterns) {
23-
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
23+
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
24+
}
25+
26+
void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
27+
RewritePatternSet &patterns) {
28+
arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
2429
}
2530

2631
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_mlir_dialect_library(MLIRArmNeonTransforms
2-
LowerContractionToNeonI8MMPattern.cpp
2+
LowerContractToNeonPatterns.cpp
33

44
DEPENDS
55
MLIRArmNeonIncGen

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp renamed to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
1+
//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -95,15 +95,20 @@ class VectorContractRewriter {
9595
// multiplications.
9696
enum class MMLA {
9797
Nop,
98-
Signed, // smmla
99-
Unsigned, // ummla
100-
Mixed, // usmmla
101-
MixedSwapped // usmmla with LHS and RHS swapped
98+
SignedInt, // smmla
99+
UnsignedInt, // ummla
100+
MixedInt, // usmmla
101+
Bfloat // bfmmla
102102
};
103103

104104
// Lower-level operation to be emitted.
105105
MMLA mmlaOp = MMLA::Nop;
106106

107+
// Indicate if the operands for the ArmNeon dialect operation need to be
108+
// swapped. Currently this is needed in order to emulate an "summla"
109+
// operation.
110+
bool swapOperands = false;
111+
107112
// The operand tiles. These are not necessarily the operands of
108113
// `vector.contract`, for example they could be operands to `arith.extsi`
109114
// that is in turn fed into `vector.contract`.
@@ -128,21 +133,22 @@ class VectorContractRewriter {
128133
// Create the matrix multiply and accumulate operation according to `mmlaOp`.
129134
Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
130135
Value lhs, Value rhs) {
136+
137+
if (swapOperands)
138+
std::swap(lhs, rhs);
131139
switch (mmlaOp) {
132-
case MMLA::Signed:
140+
case MMLA::SignedInt:
133141
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
134142
lhs, rhs);
135-
case MMLA::Unsigned:
143+
case MMLA::UnsignedInt:
136144
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
137145
lhs, rhs);
138-
case MMLA::Mixed:
146+
case MMLA::MixedInt:
139147
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
140148
lhs, rhs);
141-
case MMLA::MixedSwapped:
142-
// The accumulator comes transposed and the result will be transposed
143-
// later, so all we have to do here is swap the operands.
144-
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
145-
rhs, lhs);
149+
case MMLA::Bfloat:
150+
return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
151+
rhs);
146152
case MMLA::Nop:
147153
llvm_unreachable("Uninitialized operation type");
148154
}
@@ -275,7 +281,7 @@ class VectorContractRewriter {
275281
// Transpose ACC if doing signed by unsigned multiplication, because we're
276282
// using the instruction for unsigned by signed multiplication with
277283
// reversed operands.
278-
if (mmlaOp == MMLA::MixedSwapped)
284+
if (swapOperands)
279285
tiledAcc = rewriter.create<vector::TransposeOp>(
280286
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
281287

@@ -304,7 +310,7 @@ class VectorContractRewriter {
304310

305311
// Because of the reversed operands the result is obtained transposed.
306312
// Transpose it back,
307-
if (mmlaOp == MMLA::MixedSwapped)
313+
if (swapOperands)
308314
tiledRes = rewriter.create<vector::TransposeOp>(
309315
loc, tiledRes, ArrayRef<int64_t>({1, 0}));
310316

@@ -341,10 +347,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
341347
// values before the extension. All four signed/unsigned combinations for
342348
// input operands are supported, but they are lowered to different
343349
// operations. Determine which is the appropriate operation to lower to.
344-
mmlaOp = MMLA::Signed;
350+
mmlaOp = MMLA::SignedInt;
345351
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
346352
if (!maybeLhs) {
347-
mmlaOp = MMLA::Unsigned;
353+
mmlaOp = MMLA::UnsignedInt;
348354
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
349355
}
350356
if (!maybeLhs)
@@ -353,11 +359,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
353359

354360
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
355361
if (maybeRhs) {
356-
if (mmlaOp == MMLA::Unsigned)
357-
mmlaOp = MMLA::Mixed;
362+
if (mmlaOp == MMLA::UnsignedInt)
363+
mmlaOp = MMLA::MixedInt;
358364
} else {
359-
if (mmlaOp == MMLA::Signed)
360-
mmlaOp = MMLA::MixedSwapped;
365+
if (mmlaOp == MMLA::SignedInt) {
366+
mmlaOp = MMLA::MixedInt;
367+
swapOperands = true;
368+
}
361369
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
362370
}
363371

@@ -374,16 +382,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
374382
auto lhsExtInType = cast<VectorType>(lhs.getType());
375383
if (lhsExtInType.getElementTypeBitWidth() < 8)
376384
lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
377-
/* signExt */ mmlaOp == MMLA::Signed ||
378-
mmlaOp == MMLA::Mixed,
385+
/* signExt */
386+
(mmlaOp == MMLA::SignedInt ||
387+
(mmlaOp == MMLA::MixedInt && !swapOperands)),
379388
rewriter);
380389

381390
auto rhsExtInType = cast<VectorType>(rhs.getType());
382391
if (rhsExtInType.getElementTypeBitWidth() < 8)
383-
384392
rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
385-
/* signExt */ mmlaOp != MMLA::Unsigned &&
386-
mmlaOp != MMLA::Mixed,
393+
/* signExt */
394+
(mmlaOp == MMLA::SignedInt ||
395+
(mmlaOp == MMLA::MixedInt && swapOperands)),
387396
rewriter);
388397

389398
// Initialize parameters for unrolling.
@@ -397,6 +406,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
397406
}
398407
};
399408

409+
class VectorContractRewriterBFMMLA : public VectorContractRewriter {
410+
public:
411+
LogicalResult matchAndInit(vector::ContractionOp op,
412+
PatternRewriter &rewriter) {
413+
414+
if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
415+
return failure();
416+
417+
// Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
418+
// tiling.
419+
if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
420+
return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
421+
422+
// Check the output is a vector of Float32 elements.
423+
auto outTy = dyn_cast<VectorType>(op.getResultType());
424+
if (!outTy || outTy.getElementType() != rewriter.getF32Type())
425+
return rewriter.notifyMatchFailure(op,
426+
"output type is not a vector of f32");
427+
428+
// Check the inputs are vectors of BFloat16 elements.
429+
if (op.getLhsType().getElementType() != rewriter.getBF16Type())
430+
return rewriter.notifyMatchFailure(op,
431+
"input type is not a vector of bf16");
432+
433+
mmlaOp = MMLA::Bfloat;
434+
swapOperands = false;
435+
lhs = op.getLhs();
436+
rhs = op.getRhs();
437+
acc = op.getAcc();
438+
439+
// Initialize parameters for unrolling.
440+
iterationBounds = *op.getShapeForUnroll();
441+
if (iterationBounds.size() == 3)
442+
subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
443+
else
444+
subTileShape = SmallVector<int64_t>({2, 4});
445+
446+
return success();
447+
}
448+
};
449+
400450
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
401451
/// any vector.contract into multiple smmla instructions with unrolling so long
402452
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -418,10 +468,32 @@ class LowerContractionToNeonI8MMPattern
418468
}
419469
};
420470

471+
class LowerContractionToNeonBFMMLAPattern
472+
: public OpRewritePattern<vector::ContractionOp> {
473+
public:
474+
using OpRewritePattern::OpRewritePattern;
475+
LogicalResult matchAndRewrite(vector::ContractionOp op,
476+
PatternRewriter &rewriter) const override {
477+
478+
VectorContractRewriterBFMMLA vcr;
479+
if (failed(vcr.matchAndInit(op, rewriter)))
480+
return failure();
481+
vcr.rewrite(op, rewriter);
482+
483+
return success();
484+
}
485+
};
486+
421487
} // namespace
422488

423-
void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
489+
void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns(
424490
RewritePatternSet &patterns) {
425491
MLIRContext *context = patterns.getContext();
426492
patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
427493
}
494+
495+
void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns(
496+
RewritePatternSet &patterns) {
497+
MLIRContext *context = patterns.getContext();
498+
patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
499+
}

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// TODO: There may be opportunities to unify this with a similar pattern
1313
// for Neon. See:
1414
// https://github.com/llvm/llvm-project/issues/145559
15-
// LowerContractionToNeonI8MMPattern.cpp
15+
// LowerContracToNeonPatterns.cpp
1616
//
1717
//===----------------------------------------------------------------------===//
1818

0 commit comments

Comments
 (0)