Skip to content

Commit e67f323

Browse files
authored
[mlir][armsme][vector] Replace splat with broadcast (#148024)
Part of deprecation of vector.splat RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
1 parent 8ef0c50 commit e67f323

File tree

4 files changed

+20
-64
lines changed

4 files changed

+20
-64
lines changed

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,8 @@ struct InsertTileSliceConversion
607607
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
608608
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
609609
/*scalableDims=*/{true});
610-
auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
610+
auto allActiveMask =
611+
vector::BroadcastOp::create(rewriter, loc, predTy, one);
611612

612613
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
613614
switch (insertTileSliceOp.getLayout()) {

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
327327

328328
// Splat pad into 1-D vector matching type of tile slice.
329329
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
330-
auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
330+
auto pad1DOp =
331+
vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
331332

332333
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
333334
tileLoadOp.getBase(),

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
255255
}
256256
};
257257

258-
/// Conversion pattern for vector.splat.
259-
///
260-
/// Example:
261-
///
262-
/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
263-
///
264-
/// is converted to:
265-
///
266-
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267-
/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268-
/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269-
/// {
270-
/// %tile_update = arm_sme.insert_tile_slice
271-
/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272-
/// vector<[4]xi32> into vector<[4]x[4]xi32>
273-
/// scf.yield %tile_update : vector<[4]x[4]xi32>
274-
/// }
275-
///
276-
/// This is identical to vector.broadcast of a scalar.
277-
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
278-
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
279-
280-
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
281-
PatternRewriter &rewriter) const final {
282-
auto tileType = splatOp.getResult().getType();
283-
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
284-
return failure();
285-
286-
auto loc = splatOp.getLoc();
287-
auto srcType = splatOp.getOperand().getType();
288-
289-
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
290-
// Avoid unused-variable warning when building without assertions.
291-
(void)srcType;
292-
293-
// First, broadcast the scalar to a 1-d vector.
294-
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
295-
Value broadcastOp1D = vector::BroadcastOp::create(
296-
rewriter, loc, tileSliceType, splatOp.getInput());
297-
298-
auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
299-
300-
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301-
Value currentTile) {
302-
auto nextTile = arm_sme::InsertTileSliceOp::create(
303-
b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304-
return nextTile.getResult();
305-
};
306-
307-
// Next, create a loop over ZA tile slices and "move" the generated 1-d
308-
// vector to each slice.
309-
auto forOp =
310-
createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
311-
312-
rewriter.replaceOp(splatOp, forOp.getResult(0));
313-
314-
return success();
315-
}
316-
};
317-
318258
/// Conversion pattern for vector.transpose.
319259
///
320260
/// Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
791731
}
792732
};
793733

734+
// Convert all `vector.splat` to `vector.broadcast`. There is a path from
735+
// `vector.broadcast` to ArmSME via another pattern.
736+
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
737+
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
738+
739+
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
740+
PatternRewriter &rewriter) const final {
741+
742+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
743+
splatOp.getInput());
744+
return success();
745+
}
746+
};
747+
794748
} // namespace
795749

796750
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
797751
MLIRContext &ctx) {
798-
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
752+
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
799753
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
800754
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
801755
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
8787
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
8888
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
8989
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
90-
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
90+
// CHECK: %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
9191
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
9292
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
9393
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>

0 commit comments

Comments
 (0)