Skip to content

Commit 335f227

Browse files
olegshyshkovcopybara-github
authored andcommitted
[GmlSt] Do not transform ops that had been already tiles by another pass.
PiperOrigin-RevId: 495372204
1 parent a267389 commit 335f227

File tree

6 files changed

+26
-0
lines changed

6 files changed

+26
-0
lines changed

xla/mlir_hlo/gml_st/transforms/transform_map_for_cpu/transform_map_for_cpu.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ struct TileMapPattern : public OpRewritePattern<linalg::MapOp> {
4747

4848
LogicalResult matchAndRewrite(linalg::MapOp op,
4949
PatternRewriter &rewriter) const override {
50+
if (hasLabel(op, kMapTransformedLabel)) return failure();
51+
52+
if (isa<gml_st::ParallelOp, gml_st::ForOp>(op->getParentOp()))
53+
return rewriter.notifyMatchFailure(
54+
op, "has already been tiled by another pass.");
55+
5056
auto fuseFilterFn = [](Operation *op) {
5157
return isa<linalg::BroadcastOp, linalg::MapOp>(op);
5258
};

xla/mlir_hlo/gml_st/transforms/transform_matmul_for_cpu/transform_matmul_for_cpu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,10 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
519519
return rewriter.notifyMatchFailure(matmulOp,
520520
"has already been transformed.");
521521

522+
if (isa<gml_st::ParallelOp, gml_st::ForOp>(matmulOp->getParentOp()))
523+
return rewriter.notifyMatchFailure(
524+
matmulOp, "has already been tiled by another pass.");
525+
522526
SmallVector<Operation *> fusionCluster = getFusionCluster(matmulOp);
523527

524528
// First element of the cluster is always the root for tiling.

xla/mlir_hlo/gml_st/transforms/transform_matmul_for_triton/transform_matmul_for_triton.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ struct MatmulTransformPattern : public OpRewritePattern<linalg::MatmulOp> {
7676
return rewriter.notifyMatchFailure(matmulOp,
7777
"has already been transformed.");
7878

79+
if (isa<gml_st::ParallelOp, gml_st::ForOp>(matmulOp->getParentOp()))
80+
return rewriter.notifyMatchFailure(
81+
matmulOp, "has already been tiled by another pass.");
82+
7983
// First level tiling: parallel dimensions.
8084
SmallVector<int64_t> parallelDimsTileSizes{lhsParallelDimTileSize,
8185
rhsParallelDimTileSize, 0};

xla/mlir_hlo/gml_st/transforms/transform_reduce_for_cpu/transform_reduce_for_cpu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ struct Reduce1DTransformPattern : public OpRewritePattern<linalg::ReduceOp> {
102102
return rewriter.notifyMatchFailure(reduceOp,
103103
"has already been transformed.");
104104

105+
if (isa<gml_st::ParallelOp, gml_st::ForOp>(reduceOp->getParentOp()))
106+
return rewriter.notifyMatchFailure(
107+
reduceOp, "has already been tiled by another pass.");
108+
105109
if (failed(validateOp(reduceOp, rewriter, /*expectedRank=*/1)))
106110
return failure();
107111

xla/mlir_hlo/gml_st/transforms/transform_sort_for_cpu/transform_sort_for_cpu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ struct TileSortPattern : public OpRewritePattern<SortOp> {
4949
PatternRewriter &rewriter) const override {
5050
if (hasLabel(op, kSortTransformedLabel)) return failure();
5151

52+
if (isa<gml_st::ParallelOp, gml_st::ForOp>(op->getParentOp()))
53+
return rewriter.notifyMatchFailure(
54+
op, "has already been tiled by another pass.");
55+
5256
auto tilingResult =
5357
tile(options, rewriter, cast<TilingInterface>(op.getOperation()));
5458
if (failed(tilingResult)) return failure();

xla/mlir_hlo/gml_st/transforms/transform_transpose_for_cpu/transform_transpose_for_cpu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ struct TileTransposePattern : public OpRewritePattern<linalg::TransposeOp> {
5050
PatternRewriter &rewriter) const override {
5151
if (hasLabel(op, kTransposeTransformedLabel)) return failure();
5252

53+
if (isa<gml_st::ParallelOp, gml_st::ForOp>(op->getParentOp()))
54+
return rewriter.notifyMatchFailure(
55+
op, "has already been tiled by another pass.");
56+
5357
auto tilingResult =
5458
tile(options, rewriter, cast<TilingInterface>(op.getOperation()));
5559
if (failed(tilingResult)) return failure();

0 commit comments

Comments
 (0)