Skip to content

Commit cdbee58

Browse files
[fixup] Some refactoring
1 parent 01da5c0 commit cdbee58

File tree

3 files changed

+79
-90
lines changed

3 files changed

+79
-90
lines changed

mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def ApplyArmSVELowerContractionToI8MMPatternsOp
1616
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
1717
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
1818
let description = [{
19-
Indicates that vector contraction-like operations should be lowered to
20-
finer-grained vector primitives using the ArmSVE dialect.
19+
Indicates that vector contract operations should be lowered to
20+
to ArmSVE dialect operations mapping to instructions from FEAT_I8MM.
2121
}];
2222

2323
let assemblyFormat = "attr-dict";
@@ -27,8 +27,8 @@ def ApplyArmSVELowerContractionToBFMMLAPatternsOp
2727
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_bfmmla",
2828
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
2929
let description = [{
30-
Indicates that vector contraction-like operations should be lowered to
31-
finer-grained vector primitives using the ArmSVE dialect.
30+
Indicates that vector contract operations should be lowered to
31+
ArmSVE dialect operations mapping to instructions from FEAT_BF16.
3232
}];
3333

3434
let assemblyFormat = "attr-dict";

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp

Lines changed: 69 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/AffineMap.h"
2626
#include "mlir/IR/PatternMatch.h"
2727

28+
#include <cassert>
2829
#include <numeric>
2930

3031
#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
@@ -169,6 +170,11 @@ class VectorContractRewriter {
169170
// Lower-level operation to be emitted.
170171
MMLA mmlaOp = MMLA::Nop;
171172

173+
// Indicate if the operands for the ArmSVE dialect operation need to be
174+
// swapped. Currently this is needed in order to emulate an "summla"
175+
// operation.
176+
bool swapOperands = false;
177+
172178
// The operand tiles. These are not necessarily the operends of
173179
// `vector.contract`, for example they could be operands to `arith.extsi`
174180
// that is in turn fed into `vector.contract`.
@@ -181,34 +187,6 @@ class VectorContractRewriter {
181187
int64_t N = 0;
182188
int64_t K = 0;
183189

184-
// Single-dimensional vector types for the operands of the ArmSVE dialect
185-
// op.
186-
VectorType flatLhsType;
187-
VectorType flatRhsType;
188-
VectorType flatAccType;
189-
190-
// Single-dimension vector type for the entire RHS tile.
191-
VectorType flatRhsTileType;
192-
193-
// Vector type having the same number of elements as a row in the
194-
// accumulator/output tile and the same element type.
195-
VectorType accRowTy;
196-
197-
// Vector type having twice the number of elements as a row in the
198-
// accumulator/output tile the same element type.
199-
VectorType accRowX2Ty;
200-
201-
// Vector type having half the number of elements as a row in the
202-
// accumulator/output tile and an integer element type with twice the bit
203-
// width.
204-
VectorType accRow64Ty;
205-
VectorType accRowX264Ty;
206-
207-
// Indicate if the operands for the ArmSVE dialect operation need to be
208-
// swapped. Currently this is needed in order to emulate an "summla"
209-
// operation.
210-
bool swapOperands = false;
211-
212190
// Create the matrix mulitply and accumulate operation according to
213191
// `mmlaOp`.
214192
Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
@@ -229,18 +207,20 @@ class VectorContractRewriter {
229207
Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
230208
Location loc, Value acc, Value lhs,
231209
Value rhs) {
210+
211+
Type resTy = acc.getType();
232212
if (swapOperands)
233213
std::swap(lhs, rhs);
234214

235215
switch (mmlaOp) {
236216
case MMLA::SignedInt:
237-
return rewriter.create<arm_sve::SmmlaOp>(loc, flatAccType, acc, lhs, rhs);
217+
return rewriter.create<arm_sve::SmmlaOp>(loc, resTy, acc, lhs, rhs);
238218
case MMLA::UnsignedInt:
239-
return rewriter.create<arm_sve::UmmlaOp>(loc, flatAccType, acc, lhs, rhs);
219+
return rewriter.create<arm_sve::UmmlaOp>(loc, resTy, acc, lhs, rhs);
240220
case MMLA::MixedInt:
241-
return rewriter.create<arm_sve::UsmmlaOp>(loc, flatAccType, acc, lhs, rhs);
221+
return rewriter.create<arm_sve::UsmmlaOp>(loc, resTy, acc, lhs, rhs);
242222
case MMLA::Bfloat:
243-
return rewriter.create<arm_sve::BfmmlaOp>(loc, flatAccType, acc, lhs, rhs);
223+
return rewriter.create<arm_sve::BfmmlaOp>(loc, resTy, acc, lhs, rhs);
244224
default:
245225
llvm_unreachable("Uninitialized operation kind");
246226
}
@@ -280,6 +260,55 @@ LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
280260

281261
Value VectorContractRewriter::rewrite(vector::ContractionOp op,
282262
PatternRewriter &rewriter) {
263+
264+
// Initialize some helper types.
265+
Type operandEltType = cast<VectorType>(lhs.getType()).getElementType();
266+
Type resultEltType = cast<VectorType>(op.getResultType()).getElementType();
267+
268+
const int64_t numOperandSubTileElts =
269+
128 / operandEltType.getIntOrFloatBitWidth();
270+
271+
assert(resultEltType.getIntOrFloatBitWidth() == 32 &&
272+
"Only implemented for i32 or f32 output");
273+
const int64_t numResultSubTileElts = 4;
274+
275+
// Single-dimensional vector types for the operands of the ArmSVE dialect
276+
// op.
277+
auto flatLhsType =
278+
VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType,
279+
/*scalableDims=*/{true});
280+
auto flatRhsType =
281+
VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType,
282+
/*scalableDims=*/{true});
283+
auto flatAccType =
284+
VectorType::get(/*shape=*/numResultSubTileElts, resultEltType,
285+
/*scalableDims=*/{true});
286+
287+
// Single-dimension vector type for the entire RHS tile.
288+
289+
auto flatRhsTileType = VectorType::get(/*shape=*/K * N, operandEltType,
290+
/*scalableDims=*/{true});
291+
292+
// Vector type having the same number of elements as a row in the
293+
// accumulator/output tile and the same element type.
294+
auto accRowTy = VectorType::get(/*shape=*/N, resultEltType,
295+
/*scalableDims=*/{true});
296+
297+
// Vector type having twice the number of elements as a row in the
298+
// accumulator/output tile the same element type.
299+
auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, resultEltType,
300+
/*scalableDims=*/{true});
301+
// Vector type having half the number of elements as a row in the
302+
// accumulator/output tile and an integer element type with twice the bit
303+
// width.
304+
auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
305+
/*scalableDims=*/{true});
306+
// Vector type having the same the number of elements as a row in the
307+
// accumulator/output tile and an integer element type with twice the bit
308+
// width.
309+
auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
310+
/*scalableDims=*/{true});
311+
283312
Location loc = op.getLoc();
284313

285314
// Extract LHS sub-tiles with logical shape <2xK>.
@@ -394,9 +423,9 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
394423
public:
395424
// Check the specific preconditions for the integer case. Initialise
396425
// parametrisation types and dimensions.
397-
LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
398-
399-
if (failed(VectorContractRewriter::match(op, rewriter)))
426+
LogicalResult matchAndInit(vector::ContractionOp op,
427+
PatternRewriter &rewriter) {
428+
if (failed(match(op, rewriter)))
400429
return failure();
401430

402431
VectorType lhsType = op.getLhsType();
@@ -458,26 +487,6 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
458487
rhs = *maybeRhs;
459488
acc = op.getAcc();
460489

461-
flatLhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
462-
/*scalableDims=*/{true});
463-
flatRhsType = VectorType::get(/*shape=*/16, rewriter.getI8Type(),
464-
/*scalableDims=*/{true});
465-
466-
flatAccType = VectorType::get(/*shape=*/4, rewriter.getI32Type(),
467-
/*scalableDims=*/{true});
468-
469-
flatRhsTileType = VectorType::get(/*shape=*/8 * N, rewriter.getI8Type(),
470-
/*scalableDims=*/{true});
471-
472-
accRowTy = VectorType::get(/*shape=*/N, rewriter.getI32Type(),
473-
/*scalableDims=*/{true});
474-
accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getI32Type(),
475-
/*scalableDims=*/{true});
476-
accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
477-
/*scalableDims=*/{true});
478-
accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
479-
/*scalableDims=*/{true});
480-
481490
return success();
482491
}
483492
};
@@ -486,9 +495,9 @@ class VectorContractRewriterBfloat : public VectorContractRewriter {
486495
public:
487496
// Check the specific preconditions for the bfloat16 case. Initialise
488497
// parametrisation types and dimensions.
489-
LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter) {
490-
491-
if (failed(VectorContractRewriter::match(op, rewriter)))
498+
LogicalResult matchAndInit(vector::ContractionOp op,
499+
PatternRewriter &rewriter) {
500+
if (failed(match(op, rewriter)))
492501
return failure();
493502

494503
VectorType lhsType = op.getLhsType();
@@ -527,26 +536,6 @@ class VectorContractRewriterBfloat : public VectorContractRewriter {
527536
rhs = op.getRhs();
528537
acc = op.getAcc();
529538

530-
flatLhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(),
531-
/*scalableDims=*/{true});
532-
flatRhsType = VectorType::get(/*shape=*/8, rewriter.getBF16Type(),
533-
/*scalableDims=*/{true});
534-
535-
flatAccType = VectorType::get(/*shape=*/4, rewriter.getF32Type(),
536-
/*scalableDims=*/{true});
537-
538-
flatRhsTileType = VectorType::get(/*shape=*/4 * N, rewriter.getBF16Type(),
539-
/*scalableDims=*/{true});
540-
541-
accRowTy = VectorType::get(/*shape=*/N, rewriter.getF32Type(),
542-
/*scalableDims=*/{true});
543-
accRowX2Ty = VectorType::get(/*shape=*/2 * N, rewriter.getF32Type(),
544-
/*scalableDims=*/{true});
545-
accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(),
546-
/*scalableDims=*/{true});
547-
accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(),
548-
/*scalableDims=*/{true});
549-
550539
return success();
551540
}
552541
};
@@ -560,7 +549,7 @@ class LowerContractionToSVEI8MMPattern
560549

561550
// Match i8xi8 -> i32 matrix multiply and accumulate.
562551
VectorContractRewriterI8MM vcr;
563-
if (failed(vcr.match(op, rewriter)))
552+
if (failed(vcr.matchAndInit(op, rewriter)))
564553
return failure();
565554

566555
Value result = vcr.rewrite(op, rewriter);
@@ -579,7 +568,7 @@ class LowerContractionToSVEBFMMLAPattern
579568

580569
// Match bf16xbf16 -> f32 matrix multiply and accumulate.
581570
VectorContractRewriterBfloat vcr;
582-
if (failed(vcr.match(op, rewriter)))
571+
if (failed(vcr.matchAndInit(op, rewriter)))
583572
return failure();
584573

585574
Value result = vcr.rewrite(op, rewriter);

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ func.func private @prepareAccTestData(%in: vector<4x4xf32>) -> memref<4x?xf32> {
5858
%c4 = arith.constant 4 : index
5959

6060
%vs = vector.vscale
61-
%d = arith.muli %c4, %vs : index
62-
%mem = memref.alloc(%d) : memref<4x?xf32>
61+
%nCols = arith.muli %c4, %vs : index
62+
%mem = memref.alloc(%nCols) : memref<4x?xf32>
6363

64-
scf.for %j = %c0 to %d step %c4 {
64+
scf.for %j = %c0 to %nCols step %c4 {
6565
vector.transfer_write %in, %mem[%c0, %j] {in_bounds = [true, true]} : vector<4x4xf32>, memref<4x?xf32>
6666
}
6767

@@ -95,10 +95,10 @@ func.func private @prepareRHSTestData(%in: vector<4x4xbf16>) -> memref<?xbf16> {
9595
%c4 = arith.constant 4 : index
9696

9797
%vs = vector.vscale
98-
%d = arith.muli %c4, %vs : index
99-
%mem = memref.alloc(%d) : memref<?x4xbf16>
98+
%nRows = arith.muli %c4, %vs : index
99+
%mem = memref.alloc(%nRows) : memref<?x4xbf16>
100100

101-
scf.for %i = %c0 to %d step %c4 {
101+
scf.for %i = %c0 to %nRows step %c4 {
102102
vector.transfer_write %in, %mem[%i, %c0] {in_bounds = [true, true]} : vector<4x4xbf16>, memref<?x4xbf16>
103103
}
104104

0 commit comments

Comments
 (0)