Skip to content

Commit b30034d

Browse files
authored
[mlir][linalg] Add folder for broadcast(broadcast) -> broadcast (#150825)
Back to back `linalg.broadcast` can be rewritten to a single broadcast.
1 parent 50f3a6b commit b30034d

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2293,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
22932293
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
22942294
}
22952295

2296+
/// Fold back-to-back broadcasts together.
2297+
struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2298+
using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2299+
2300+
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2301+
PatternRewriter &rewriter) const override {
2302+
auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2303+
if (!defBroadcastOp)
2304+
return failure();
2305+
ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2306+
ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2307+
SmallVector<int64_t> foldedDims(dimensions);
2308+
Value init = broadcastOp.getInit();
2309+
int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2310+
// Mapping from input dims to init dims.
2311+
SmallVector<int64_t> dimMap;
2312+
for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2313+
if (!llvm::is_contained(dimensions, dim))
2314+
dimMap.push_back(dim);
2315+
}
2316+
for (auto dim : defDimensions)
2317+
foldedDims.push_back(dimMap[dim]);
2318+
2319+
llvm::sort(foldedDims);
2320+
rewriter.replaceOpWithNewOp<BroadcastOp>(
2321+
broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2322+
return success();
2323+
}
2324+
};
2325+
22962326
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
22972327
MLIRContext *context) {
2298-
results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2328+
results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
22992329
}
23002330

23012331
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
11761176

11771177
// -----
11781178

1179+
// CHECK-LABEL: @broadcast_broadcast_fold
1180+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1181+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
1182+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
1183+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
1184+
// CHECK-NOT: linalg.broadcast
1185+
// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
1186+
func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
1187+
%init1: tensor<2x3xf32>,
1188+
%init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
1189+
%broadcast1 = linalg.broadcast
1190+
ins(%input: tensor<2xf32>)
1191+
outs(%init1: tensor<2x3xf32>)
1192+
dimensions = [1]
1193+
%broadcast2 = linalg.broadcast
1194+
ins(%broadcast1: tensor<2x3xf32>)
1195+
outs(%init2: tensor<2x3x4xf32>)
1196+
dimensions = [2]
1197+
func.return %broadcast2 : tensor<2x3x4xf32>
1198+
}
1199+
1200+
// -----
1201+
1202+
// CHECK-LABEL: @broadcast_broadcast_fold
1203+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1204+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
1205+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
1206+
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
1207+
// CHECK-NOT: linalg.broadcast
1208+
// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
1209+
func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
1210+
%init1: tensor<2x4xf32>,
1211+
%init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
1212+
%broadcast1 = linalg.broadcast
1213+
ins(%input: tensor<2xf32>)
1214+
outs(%init1: tensor<2x4xf32>)
1215+
dimensions = [1]
1216+
%broadcast2 = linalg.broadcast
1217+
ins(%broadcast1: tensor<2x4xf32>)
1218+
outs(%init2: tensor<2x3x4xf32>)
1219+
dimensions = [1]
1220+
func.return %broadcast2 : tensor<2x3x4xf32>
1221+
}
1222+
1223+
// -----
1224+
11791225
func.func @transpose_1d(%input: tensor<16xf32>,
11801226
%init: tensor<16xf32>) -> tensor<16xf32> {
11811227
%transpose = linalg.transpose

0 commit comments

Comments
 (0)