25
25
#include " mlir/IR/AffineMap.h"
26
26
#include " mlir/IR/PatternMatch.h"
27
27
28
+ #include < cassert>
28
29
#include < numeric>
29
30
30
31
#define DEBUG_TYPE " lower-contract-to-arm-sve-i8mm"
@@ -169,6 +170,11 @@ class VectorContractRewriter {
169
170
// Lower-level operation to be emitted.
170
171
MMLA mmlaOp = MMLA::Nop;
171
172
173
+ // Indicate if the operands for the ArmSVE dialect operation need to be
174
+ // swapped. Currently this is needed in order to emulate an "summla"
175
+ // operation.
176
+ bool swapOperands = false ;
177
+
172
178
// The operand tiles. These are not necessarily the operends of
173
179
// `vector.contract`, for example they could be operands to `arith.extsi`
174
180
// that is in turn fed into `vector.contract`.
@@ -181,34 +187,6 @@ class VectorContractRewriter {
181
187
int64_t N = 0 ;
182
188
int64_t K = 0 ;
183
189
184
- // Single-dimensional vector types for the operands of the ArmSVE dialect
185
- // op.
186
- VectorType flatLhsType;
187
- VectorType flatRhsType;
188
- VectorType flatAccType;
189
-
190
- // Single-dimension vector type for the entire RHS tile.
191
- VectorType flatRhsTileType;
192
-
193
- // Vector type having the same number of elements as a row in the
194
- // accumulator/output tile and the same element type.
195
- VectorType accRowTy;
196
-
197
- // Vector type having twice the number of elements as a row in the
198
- // accumulator/output tile the same element type.
199
- VectorType accRowX2Ty;
200
-
201
- // Vector type having half the number of elements as a row in the
202
- // accumulator/output tile and an integer element type with twice the bit
203
- // width.
204
- VectorType accRow64Ty;
205
- VectorType accRowX264Ty;
206
-
207
- // Indicate if the operands for the ArmSVE dialect operation need to be
208
- // swapped. Currently this is needed in order to emulate an "summla"
209
- // operation.
210
- bool swapOperands = false ;
211
-
212
190
// Create the matrix mulitply and accumulate operation according to
213
191
// `mmlaOp`.
214
192
Value createMMLA (PatternRewriter &rewriter, Location loc, Value acc,
@@ -229,18 +207,20 @@ class VectorContractRewriter {
229
207
Value VectorContractRewriter::createMMLA (PatternRewriter &rewriter,
230
208
Location loc, Value acc, Value lhs,
231
209
Value rhs) {
210
+
211
+ Type resTy = acc.getType ();
232
212
if (swapOperands)
233
213
std::swap (lhs, rhs);
234
214
235
215
switch (mmlaOp) {
236
216
case MMLA::SignedInt:
237
- return rewriter.create <arm_sve::SmmlaOp>(loc, flatAccType , acc, lhs, rhs);
217
+ return rewriter.create <arm_sve::SmmlaOp>(loc, resTy , acc, lhs, rhs);
238
218
case MMLA::UnsignedInt:
239
- return rewriter.create <arm_sve::UmmlaOp>(loc, flatAccType , acc, lhs, rhs);
219
+ return rewriter.create <arm_sve::UmmlaOp>(loc, resTy , acc, lhs, rhs);
240
220
case MMLA::MixedInt:
241
- return rewriter.create <arm_sve::UsmmlaOp>(loc, flatAccType , acc, lhs, rhs);
221
+ return rewriter.create <arm_sve::UsmmlaOp>(loc, resTy , acc, lhs, rhs);
242
222
case MMLA::Bfloat:
243
- return rewriter.create <arm_sve::BfmmlaOp>(loc, flatAccType , acc, lhs, rhs);
223
+ return rewriter.create <arm_sve::BfmmlaOp>(loc, resTy , acc, lhs, rhs);
244
224
default :
245
225
llvm_unreachable (" Uninitialized operation kind" );
246
226
}
@@ -280,6 +260,55 @@ LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
280
260
281
261
Value VectorContractRewriter::rewrite (vector::ContractionOp op,
282
262
PatternRewriter &rewriter) {
263
+
264
+ // Initialize some helper types.
265
+ Type operandEltType = cast<VectorType>(lhs.getType ()).getElementType ();
266
+ Type resultEltType = cast<VectorType>(op.getResultType ()).getElementType ();
267
+
268
+ const int64_t numOperandSubTileElts =
269
+ 128 / operandEltType.getIntOrFloatBitWidth ();
270
+
271
+ assert (resultEltType.getIntOrFloatBitWidth () == 32 &&
272
+ " Only implemented for i32 or f32 output" );
273
+ const int64_t numResultSubTileElts = 4 ;
274
+
275
+ // Single-dimensional vector types for the operands of the ArmSVE dialect
276
+ // op.
277
+ auto flatLhsType =
278
+ VectorType::get (/* shape=*/ numOperandSubTileElts, operandEltType,
279
+ /* scalableDims=*/ {true });
280
+ auto flatRhsType =
281
+ VectorType::get (/* shape=*/ numOperandSubTileElts, operandEltType,
282
+ /* scalableDims=*/ {true });
283
+ auto flatAccType =
284
+ VectorType::get (/* shape=*/ numResultSubTileElts, resultEltType,
285
+ /* scalableDims=*/ {true });
286
+
287
+ // Single-dimension vector type for the entire RHS tile.
288
+
289
+ auto flatRhsTileType = VectorType::get (/* shape=*/ K * N, operandEltType,
290
+ /* scalableDims=*/ {true });
291
+
292
+ // Vector type having the same number of elements as a row in the
293
+ // accumulator/output tile and the same element type.
294
+ auto accRowTy = VectorType::get (/* shape=*/ N, resultEltType,
295
+ /* scalableDims=*/ {true });
296
+
297
+ // Vector type having twice the number of elements as a row in the
298
+ // accumulator/output tile the same element type.
299
+ auto accRowX2Ty = VectorType::get (/* shape=*/ 2 * N, resultEltType,
300
+ /* scalableDims=*/ {true });
301
+ // Vector type having half the number of elements as a row in the
302
+ // accumulator/output tile and an integer element type with twice the bit
303
+ // width.
304
+ auto accRow64Ty = VectorType::get (/* shape=*/ N / 2 , rewriter.getI64Type (),
305
+ /* scalableDims=*/ {true });
306
+ // Vector type having the same the number of elements as a row in the
307
+ // accumulator/output tile and an integer element type with twice the bit
308
+ // width.
309
+ auto accRowX264Ty = VectorType::get (/* shape=*/ N, rewriter.getI64Type (),
310
+ /* scalableDims=*/ {true });
311
+
283
312
Location loc = op.getLoc ();
284
313
285
314
// Extract LHS sub-tiles with logical shape <2xK>.
@@ -394,9 +423,9 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
394
423
public:
395
424
// Check the specific preconditions for the integer case. Initialise
396
425
// parametrisation types and dimensions.
397
- LogicalResult match (vector::ContractionOp op, PatternRewriter &rewriter) {
398
-
399
- if (failed (VectorContractRewriter:: match (op, rewriter)))
426
+ LogicalResult matchAndInit (vector::ContractionOp op,
427
+ PatternRewriter &rewriter) {
428
+ if (failed (match (op, rewriter)))
400
429
return failure ();
401
430
402
431
VectorType lhsType = op.getLhsType ();
@@ -458,26 +487,6 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
458
487
rhs = *maybeRhs;
459
488
acc = op.getAcc ();
460
489
461
- flatLhsType = VectorType::get (/* shape=*/ 16 , rewriter.getI8Type (),
462
- /* scalableDims=*/ {true });
463
- flatRhsType = VectorType::get (/* shape=*/ 16 , rewriter.getI8Type (),
464
- /* scalableDims=*/ {true });
465
-
466
- flatAccType = VectorType::get (/* shape=*/ 4 , rewriter.getI32Type (),
467
- /* scalableDims=*/ {true });
468
-
469
- flatRhsTileType = VectorType::get (/* shape=*/ 8 * N, rewriter.getI8Type (),
470
- /* scalableDims=*/ {true });
471
-
472
- accRowTy = VectorType::get (/* shape=*/ N, rewriter.getI32Type (),
473
- /* scalableDims=*/ {true });
474
- accRowX2Ty = VectorType::get (/* shape=*/ 2 * N, rewriter.getI32Type (),
475
- /* scalableDims=*/ {true });
476
- accRow64Ty = VectorType::get (/* shape=*/ N / 2 , rewriter.getI64Type (),
477
- /* scalableDims=*/ {true });
478
- accRowX264Ty = VectorType::get (/* shape=*/ N, rewriter.getI64Type (),
479
- /* scalableDims=*/ {true });
480
-
481
490
return success ();
482
491
}
483
492
};
@@ -486,9 +495,9 @@ class VectorContractRewriterBfloat : public VectorContractRewriter {
486
495
public:
487
496
// Check the specific preconditions for the bfloat16 case. Initialise
488
497
// parametrisation types and dimensions.
489
- LogicalResult match (vector::ContractionOp op, PatternRewriter &rewriter) {
490
-
491
- if (failed (VectorContractRewriter:: match (op, rewriter)))
498
+ LogicalResult matchAndInit (vector::ContractionOp op,
499
+ PatternRewriter &rewriter) {
500
+ if (failed (match (op, rewriter)))
492
501
return failure ();
493
502
494
503
VectorType lhsType = op.getLhsType ();
@@ -527,26 +536,6 @@ class VectorContractRewriterBfloat : public VectorContractRewriter {
527
536
rhs = op.getRhs ();
528
537
acc = op.getAcc ();
529
538
530
- flatLhsType = VectorType::get (/* shape=*/ 8 , rewriter.getBF16Type (),
531
- /* scalableDims=*/ {true });
532
- flatRhsType = VectorType::get (/* shape=*/ 8 , rewriter.getBF16Type (),
533
- /* scalableDims=*/ {true });
534
-
535
- flatAccType = VectorType::get (/* shape=*/ 4 , rewriter.getF32Type (),
536
- /* scalableDims=*/ {true });
537
-
538
- flatRhsTileType = VectorType::get (/* shape=*/ 4 * N, rewriter.getBF16Type (),
539
- /* scalableDims=*/ {true });
540
-
541
- accRowTy = VectorType::get (/* shape=*/ N, rewriter.getF32Type (),
542
- /* scalableDims=*/ {true });
543
- accRowX2Ty = VectorType::get (/* shape=*/ 2 * N, rewriter.getF32Type (),
544
- /* scalableDims=*/ {true });
545
- accRow64Ty = VectorType::get (/* shape=*/ N / 2 , rewriter.getI64Type (),
546
- /* scalableDims=*/ {true });
547
- accRowX264Ty = VectorType::get (/* shape=*/ N, rewriter.getI64Type (),
548
- /* scalableDims=*/ {true });
549
-
550
539
return success ();
551
540
}
552
541
};
@@ -560,7 +549,7 @@ class LowerContractionToSVEI8MMPattern
560
549
561
550
// Match i8xi8 -> i32 matrix multiply and accumulate.
562
551
VectorContractRewriterI8MM vcr;
563
- if (failed (vcr.match (op, rewriter)))
552
+ if (failed (vcr.matchAndInit (op, rewriter)))
564
553
return failure ();
565
554
566
555
Value result = vcr.rewrite (op, rewriter);
@@ -579,7 +568,7 @@ class LowerContractionToSVEBFMMLAPattern
579
568
580
569
// Match bf16xbf16 -> f32 matrix multiply and accumulate.
581
570
VectorContractRewriterBfloat vcr;
582
- if (failed (vcr.match (op, rewriter)))
571
+ if (failed (vcr.matchAndInit (op, rewriter)))
583
572
return failure ();
584
573
585
574
Value result = vcr.rewrite (op, rewriter);
0 commit comments