Skip to content

[mlir][vector] vector.splat deprecation: folding/canonicalizing parity with broadcast #150284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2780,6 +2780,10 @@ def Vector_SplatOp : Vector_Op<"splat", [
let assemblyFormat = "$input attr-dict `:` type($aggregate)";

let hasFolder = 1;

// vector.splat is deprecated, and vector.broadcast should be used instead.
// Canonicalize vector.splat to vector.broadcast.
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
131 changes: 84 additions & 47 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2476,17 +2476,19 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
return {};
}

/// Rewrite a vector.from_elements into a vector.splat if all elements are the
/// same SSA value. E.g.:
///
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
/// ==> rewrite to vector.splat %a : vector<3xf32>
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
PatternRewriter &rewriter) {
/// Rewrite vector.from_elements as vector.broadcast if the elements are the
/// same. Example:
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
/// =>
/// %0 = vector.broadcast %a : f32 to vector<3xf32>
static LogicalResult
rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
PatternRewriter &rewriter) {
if (!llvm::all_equal(fromElementsOp.getElements()))
return failure();
rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
fromElementsOp.getElements().front());
rewriter.replaceOpWithNewOp<BroadcastOp>(
fromElementsOp, fromElementsOp.getType(),
fromElementsOp.getElements().front());
return success();
}

Expand Down Expand Up @@ -2517,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {

// Handled by `rewriteFromElementsAsSplat`
// Handled by `rewriteFromElementsAsBroadcast`.
if (fromElements.getType().getNumElements() == 1)
return failure();

Expand Down Expand Up @@ -2610,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
results.add(rewriteFromElementsAsBroadcast);
results.add<FromElementsToShapeCast>(context);
}

Expand Down Expand Up @@ -3058,23 +3060,50 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
}
};

/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
/// value that is splatted. Otherwise return null.
///
/// Examples:
///
/// scalar_source --> vector.splat --> value - return scalar_source
/// scalar_source --> vector.broadcast --> value - return scalar_source
static Value getScalarSplatSource(Value value) {
// Block argument:
Operation *defOp = value.getDefiningOp();
if (!defOp)
return {};

// Splat:
if (auto splat = dyn_cast<vector::SplatOp>(defOp))
return splat.getInput();

auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);

// Not broadcast (and not splat):
if (!broadcast)
return {};

// Broadcast of a vector:
if (isa<VectorType>(broadcast.getSourceType()))
return {};

// Broadcast of a scalar:
return broadcast.getSource();
}

/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
auto v2Splat = op.getV2().getDefiningOp<SplatOp>();

if (!v1Splat || !v2Splat)
Value splat = getScalarSplatSource(op.getV1());
if (!splat || getScalarSplatSource(op.getV2()) != splat)
return failure();

if (v1Splat.getInput() != v2Splat.getInput())
return failure();

rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
Expand Down Expand Up @@ -3230,23 +3259,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
}
};

/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();

if (!srcSplat || !dstSplat)
return failure();

if (srcSplat.getInput() != dstSplat.getInput())
Value splat = getScalarSplatSource(op.getValueToStore());
if (!splat || getScalarSplatSource(op.getDest()) != splat)
return failure();

rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
Expand Down Expand Up @@ -3514,27 +3539,21 @@ LogicalResult InsertStridedSliceOp::verify() {
}

namespace {
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
auto srcSplatOp =
insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
auto destSplatOp =
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();

if (!srcSplatOp || !destSplatOp)
auto dst = insertStridedSliceOp.getDest();
auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
if (!splat || getScalarSplatSource(dst) != splat)
return failure();

if (srcSplatOp.getInput() != destSplatOp.getInput())
return failure();

rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
rewriter.replaceOp(insertStridedSliceOp, dst);
return success();
}
};
Expand Down Expand Up @@ -4189,17 +4208,18 @@ class StridedSliceBroadcast final
}
};

/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto splat = op.getVector().getDefiningOp<SplatOp>();

Value splat = getScalarSplatSource(op.getVector());
if (!splat)
return failure();
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
Expand Down Expand Up @@ -6354,19 +6374,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};

// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
/// Replace transpose(splat-like(v)) with broadcast(v)
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
if (!splatOp)
Value splat = getScalarSplatSource(transposeOp.getVector());
if (!splat)
return failure();

rewriter.replaceOpWithNewOp<vector::SplatOp>(
transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
transposeOp, transposeOp.getResultVectorType(), splat);
return success();
}
};
Expand Down Expand Up @@ -7117,6 +7137,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}

// Canonicalizer for vector.splat. It always gets canonicalized to a
// vector.broadcast.
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
public:
using OpRewritePattern<SplatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SplatOp splatOp,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
splatOp.getOperand());
return success();
}
};
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SplatToBroadcastPattern>(context);
}

void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
Expand Down
Loading