1
- // ===- LowerContractionToNeonI8MMPattern .cpp - Contract to I8MM -*- C++ -*-===//
1
+ // ===- LowerContractToNeonPatterns .cpp - Contract to I8MM/BF16 - -*- C++ -*-===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
@@ -95,15 +95,20 @@ class VectorContractRewriter {
95
95
// multiplications.
96
96
enum class MMLA {
97
97
Nop,
98
- Signed, // smmla
99
- Unsigned, // ummla
100
- Mixed, // usmmla
101
- MixedSwapped // usmmla with LHS and RHS swapped
98
+ SignedInt, // smmla
99
+ UnsignedInt, // ummla
100
+ MixedInt, // usmmla
101
+ Bfloat // bfmmla
102
102
};
103
103
104
104
// Lower-level operation to be emitted.
105
105
MMLA mmlaOp = MMLA::Nop;
106
106
107
+ // Indicate if the operands for the ArmNeon dialect operation need to be
108
+ // swapped. Currently this is needed in order to emulate an "summla"
109
+ // operation.
110
+ bool swapOperands = false ;
111
+
107
112
// The operand tiles. These are not necessarily the operands of
108
113
// `vector.contract`, for example they could be operands to `arith.extsi`
109
114
// that is in turn fed into `vector.contract`.
@@ -128,21 +133,22 @@ class VectorContractRewriter {
128
133
// Create the matrix multiply and accumulate operation according to `mmlaOp`.
129
134
Value createMMLA (PatternRewriter &rewriter, Location loc, Value acc,
130
135
Value lhs, Value rhs) {
136
+
137
+ if (swapOperands)
138
+ std::swap (lhs, rhs);
131
139
switch (mmlaOp) {
132
- case MMLA::Signed :
140
+ case MMLA::SignedInt :
133
141
return rewriter.createOrFold <arm_neon::SmmlaOp>(loc, acc.getType (), acc,
134
142
lhs, rhs);
135
- case MMLA::Unsigned :
143
+ case MMLA::UnsignedInt :
136
144
return rewriter.createOrFold <arm_neon::UmmlaOp>(loc, acc.getType (), acc,
137
145
lhs, rhs);
138
- case MMLA::Mixed :
146
+ case MMLA::MixedInt :
139
147
return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, acc.getType (), acc,
140
148
lhs, rhs);
141
- case MMLA::MixedSwapped:
142
- // The accumulator comes transposed and the result will be transposed
143
- // later, so all we have to do here is swap the operands.
144
- return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, acc.getType (), acc,
145
- rhs, lhs);
149
+ case MMLA::Bfloat:
150
+ return rewriter.create <arm_neon::BfmmlaOp>(loc, acc.getType (), acc, lhs,
151
+ rhs);
146
152
case MMLA::Nop:
147
153
llvm_unreachable (" Uninitialized operation type" );
148
154
}
@@ -275,7 +281,7 @@ class VectorContractRewriter {
275
281
// Transpose ACC if doing signed by unsigned multiplication, because we're
276
282
// using the instruction for unsigned by signed multiplication with
277
283
// reversed operands.
278
- if (mmlaOp == MMLA::MixedSwapped )
284
+ if (swapOperands )
279
285
tiledAcc = rewriter.create <vector::TransposeOp>(
280
286
loc, tiledAcc, ArrayRef<int64_t >({1 , 0 }));
281
287
@@ -304,7 +310,7 @@ class VectorContractRewriter {
304
310
305
311
// Because of the reversed operands the result is obtained transposed.
306
312
// Transpose it back,
307
- if (mmlaOp == MMLA::MixedSwapped )
313
+ if (swapOperands )
308
314
tiledRes = rewriter.create <vector::TransposeOp>(
309
315
loc, tiledRes, ArrayRef<int64_t >({1 , 0 }));
310
316
@@ -341,10 +347,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
341
347
// values before the extension. All four signed/unsigned combinations for
342
348
// input operands are supported, but they are lowered to different
343
349
// operations. Determine which is the appropriate operation to lower to.
344
- mmlaOp = MMLA::Signed ;
350
+ mmlaOp = MMLA::SignedInt ;
345
351
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
346
352
if (!maybeLhs) {
347
- mmlaOp = MMLA::Unsigned ;
353
+ mmlaOp = MMLA::UnsignedInt ;
348
354
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
349
355
}
350
356
if (!maybeLhs)
@@ -353,11 +359,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
353
359
354
360
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
355
361
if (maybeRhs) {
356
- if (mmlaOp == MMLA::Unsigned )
357
- mmlaOp = MMLA::Mixed ;
362
+ if (mmlaOp == MMLA::UnsignedInt )
363
+ mmlaOp = MMLA::MixedInt ;
358
364
} else {
359
- if (mmlaOp == MMLA::Signed)
360
- mmlaOp = MMLA::MixedSwapped;
365
+ if (mmlaOp == MMLA::SignedInt) {
366
+ mmlaOp = MMLA::MixedInt;
367
+ swapOperands = true ;
368
+ }
361
369
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
362
370
}
363
371
@@ -374,16 +382,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
374
382
auto lhsExtInType = cast<VectorType>(lhs.getType ());
375
383
if (lhsExtInType.getElementTypeBitWidth () < 8 )
376
384
lhs = extendSmallIntVector (loc, lhsExtInType, lhs,
377
- /* signExt */ mmlaOp == MMLA::Signed ||
378
- mmlaOp == MMLA::Mixed,
385
+ /* signExt */
386
+ (mmlaOp == MMLA::SignedInt ||
387
+ (mmlaOp == MMLA::MixedInt && !swapOperands)),
379
388
rewriter);
380
389
381
390
auto rhsExtInType = cast<VectorType>(rhs.getType ());
382
391
if (rhsExtInType.getElementTypeBitWidth () < 8 )
383
-
384
392
rhs = extendSmallIntVector (loc, rhsExtInType, rhs,
385
- /* signExt */ mmlaOp != MMLA::Unsigned &&
386
- mmlaOp != MMLA::Mixed,
393
+ /* signExt */
394
+ (mmlaOp == MMLA::SignedInt ||
395
+ (mmlaOp == MMLA::MixedInt && swapOperands)),
387
396
rewriter);
388
397
389
398
// Initialize parameters for unrolling.
@@ -397,6 +406,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
397
406
}
398
407
};
399
408
409
+ class VectorContractRewriterBFMMLA : public VectorContractRewriter {
410
+ public:
411
+ LogicalResult matchAndInit (vector::ContractionOp op,
412
+ PatternRewriter &rewriter) {
413
+
414
+ if (failed (VectorContractRewriter::matchAndInit (op, rewriter)))
415
+ return failure ();
416
+
417
+ // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
418
+ // tiling.
419
+ if ((dimM != 1 && dimM % 2 != 0 ) || dimN % 2 != 0 || dimK % 4 != 0 )
420
+ return rewriter.notifyMatchFailure (op, " Unsupported operand shapes" );
421
+
422
+ // Check the output is a vector of Float32 elements.
423
+ auto outTy = dyn_cast<VectorType>(op.getResultType ());
424
+ if (!outTy || outTy.getElementType () != rewriter.getF32Type ())
425
+ return rewriter.notifyMatchFailure (op,
426
+ " output type is not a vector of f32" );
427
+
428
+ // Check the inputs are vectors of BFloat16 elements.
429
+ if (op.getLhsType ().getElementType () != rewriter.getBF16Type ())
430
+ return rewriter.notifyMatchFailure (op,
431
+ " input type is not a vector of bf16" );
432
+
433
+ mmlaOp = MMLA::Bfloat;
434
+ swapOperands = false ;
435
+ lhs = op.getLhs ();
436
+ rhs = op.getRhs ();
437
+ acc = op.getAcc ();
438
+
439
+ // Initialize parameters for unrolling.
440
+ iterationBounds = *op.getShapeForUnroll ();
441
+ if (iterationBounds.size () == 3 )
442
+ subTileShape = SmallVector<int64_t >({dimM == 1 ? 1 : 2 , 2 , 4 });
443
+ else
444
+ subTileShape = SmallVector<int64_t >({2 , 4 });
445
+
446
+ return success ();
447
+ }
448
+ };
449
+
400
450
// / Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
401
451
// / any vector.contract into multiple smmla instructions with unrolling so long
402
452
// / as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -418,10 +468,32 @@ class LowerContractionToNeonI8MMPattern
418
468
}
419
469
};
420
470
471
+ class LowerContractionToNeonBFMMLAPattern
472
+ : public OpRewritePattern<vector::ContractionOp> {
473
+ public:
474
+ using OpRewritePattern::OpRewritePattern;
475
+ LogicalResult matchAndRewrite (vector::ContractionOp op,
476
+ PatternRewriter &rewriter) const override {
477
+
478
+ VectorContractRewriterBFMMLA vcr;
479
+ if (failed (vcr.matchAndInit (op, rewriter)))
480
+ return failure ();
481
+ vcr.rewrite (op, rewriter);
482
+
483
+ return success ();
484
+ }
485
+ };
486
+
421
487
} // namespace
422
488
423
- void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns (
489
+ void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns (
424
490
RewritePatternSet &patterns) {
425
491
MLIRContext *context = patterns.getContext ();
426
492
patterns.add <LowerContractionToNeonI8MMPattern>(context, /* benefit=*/ 2 );
427
493
}
494
+
495
+ void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns (
496
+ RewritePatternSet &patterns) {
497
+ MLIRContext *context = patterns.getContext ();
498
+ patterns.add <LowerContractionToNeonBFMMLAPattern>(context, /* benefit=*/ 2 );
499
+ }
0 commit comments