Skip to content

Commit 73b85f8

Browse files
[MLIR][AArch64] Lower vector.contract to SVE FEAT_BF16 operations (#147052)
This patch adds lowering of Bfloat16 widening matrix multiply and accumulate `vector.contract`, by parametrising and refactoring the pattern for 8-bit integers.
1 parent 81651e9 commit 73b85f8

File tree

12 files changed

+932
-372
lines changed

12 files changed

+932
-372
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14491449
"bool", /*default=*/"false",
14501450
"Enables the use of Arm FEAT_I8MM instructions while lowering "
14511451
"the vector dialect.">,
1452+
Option<"armBF16", "enable-arm-bf16",
1453+
"bool", /*default=*/"false",
1454+
"Enables the use of Arm FEAT_BF16 instructions while lowering "
1455+
"the vector dialect.">,
14521456
Option<"x86Vector", "enable-x86vector",
14531457
"bool", /*default=*/"false",
14541458
"Enables the use of X86Vector dialect while lowering the vector "

mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,25 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
1212
include "mlir/Dialect/Transform/IR/TransformDialect.td"
1313
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
1414

15-
def ApplyArmSVELowerContractionPatternsOp
15+
def ApplyArmSVELowerContractionToI8MMPatternsOp
1616
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
1717
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
1818
let description = [{
19-
Indicates that vector contraction-like operations should be lowered to
20-
finer-grained vector primitives using the ArmSVE dialect.
19+
Indicates that vector contract operations should be lowered to
20+
to ArmSVE dialect operations mapping to instructions from FEAT_I8MM.
2121
}];
2222

2323
let assemblyFormat = "attr-dict";
2424
}
2525

26+
def ApplyArmSVELowerContractionToBFMMLAPatternsOp
27+
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_bfmmla",
28+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
29+
let description = [{
30+
Indicates that vector contract operations should be lowered to
31+
ArmSVE dialect operations mapping to instructions from FEAT_BF16.
32+
}];
33+
34+
let assemblyFormat = "attr-dict";
35+
}
2636
#endif // ARMSVE_VECTOR_TRANSFORM_OPS

mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ void populateArmSVELegalizeForLLVMExportPatterns(
2323
void populateLowerContractionToSVEI8MMPatternPatterns(
2424
RewritePatternSet &patterns);
2525

26+
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
27+
2628
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
2729
/// intrinsics.
2830
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
100100
if (armSVE)
101101
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
102102
}
103+
if (armBF16)
104+
populateLowerContractionToSVEBFMMLAPatterns(patterns);
105+
103106
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
104107
}
105108

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// TODO: There may be opportunities to unify this with a similar pattern
1313
// for SVE. See:
1414
// https://github.com/llvm/llvm-project/issues/145559
15-
// LowerContractionToSVEI8MMPattern.cpp
15+
// LowerContractToSVEPatterns.cpp
1616
//
1717
//===----------------------------------------------------------------------===//
1818

mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@ using namespace mlir;
1818
// Apply...PatternsOp
1919
//===----------------------------------------------------------------------===//
2020

21-
void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
21+
void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
2222
RewritePatternSet &patterns) {
2323
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
2424
}
2525

26+
void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
27+
RewritePatternSet &patterns) {
28+
mlir::populateLowerContractionToSVEBFMMLAPatterns(patterns);
29+
}
30+
2631
//===----------------------------------------------------------------------===//
2732
// Transform op registration
2833
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
add_mlir_dialect_library(MLIRArmSVETransforms
22
LegalizeForLLVMExport.cpp
33
LegalizeVectorStorage.cpp
4-
LowerContractionToSVEI8MMPattern.cpp
4+
LowerContractToSVEPatterns.cpp
55

66
DEPENDS
77
MLIRArmSVEConversionsIncGen

0 commit comments

Comments
 (0)