@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
255
255
}
256
256
};
257
257
258
- // / Conversion pattern for vector.splat.
259
- // /
260
- // / Example:
261
- // /
262
- // / %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
263
- // /
264
- // / is converted to:
265
- // /
266
- // / %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267
- // / %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268
- // / step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269
- // / {
270
- // / %tile_update = arm_sme.insert_tile_slice
271
- // / %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272
- // / vector<[4]xi32> into vector<[4]x[4]xi32>
273
- // / scf.yield %tile_update : vector<[4]x[4]xi32>
274
- // / }
275
- // /
276
- // / This is identical to vector.broadcast of a scalar.
277
- struct SplatOpToArmSMELowering : public OpRewritePattern <vector::SplatOp> {
278
- using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
279
-
280
- LogicalResult matchAndRewrite (vector::SplatOp splatOp,
281
- PatternRewriter &rewriter) const final {
282
- auto tileType = splatOp.getResult ().getType ();
283
- if (!tileType || !arm_sme::isValidSMETileVectorType (tileType))
284
- return failure ();
285
-
286
- auto loc = splatOp.getLoc ();
287
- auto srcType = splatOp.getOperand ().getType ();
288
-
289
- assert (srcType.isIntOrFloat () && " Invalid source type for vector.splat" );
290
- // Avoid unused-variable warning when building without assertions.
291
- (void )srcType;
292
-
293
- // First, broadcast the scalar to a 1-d vector.
294
- VectorType tileSliceType = VectorType::Builder (tileType).dropDim (0 );
295
- Value broadcastOp1D = vector::BroadcastOp::create (
296
- rewriter, loc, tileSliceType, splatOp.getInput ());
297
-
298
- auto initTile = arm_sme::GetTileOp::create (rewriter, loc, tileType);
299
-
300
- auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301
- Value currentTile) {
302
- auto nextTile = arm_sme::InsertTileSliceOp::create (
303
- b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304
- return nextTile.getResult ();
305
- };
306
-
307
- // Next, create a loop over ZA tile slices and "move" the generated 1-d
308
- // vector to each slice.
309
- auto forOp =
310
- createLoopOverTileSlices (rewriter, loc, initTile, makeLoopBody);
311
-
312
- rewriter.replaceOp (splatOp, forOp.getResult (0 ));
313
-
314
- return success ();
315
- }
316
- };
317
-
318
258
// / Conversion pattern for vector.transpose.
319
259
// /
320
260
// / Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
791
731
}
792
732
};
793
733
734
+ // Convert all `vector.splat` to `vector.broadcast`. There is a path from
735
+ // `vector.broadcast` to ArmSME via another pattern.
736
+ struct ConvertSplatToBroadcast : public OpRewritePattern <vector::SplatOp> {
737
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
738
+
739
+ LogicalResult matchAndRewrite (vector::SplatOp splatOp,
740
+ PatternRewriter &rewriter) const final {
741
+
742
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splatOp, splatOp.getType (),
743
+ splatOp.getInput ());
744
+ return success ();
745
+ }
746
+ };
747
+
794
748
} // namespace
795
749
796
750
void mlir::populateVectorToArmSMEPatterns (RewritePatternSet &patterns,
797
751
MLIRContext &ctx) {
798
- patterns.add <BroadcastOpToArmSMELowering, SplatOpToArmSMELowering ,
752
+ patterns.add <BroadcastOpToArmSMELowering, ConvertSplatToBroadcast ,
799
753
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
800
754
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
801
755
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
0 commit comments