Skip to content

Commit 41fb026

Browse files
committed
[CIR][ThroughMLIR] Lower WhileOp with break
1 parent b7dffc3 commit 41fb026

File tree

3 files changed

+94
-1
lines changed

3 files changed

+94
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
528528
for (auto continueOp : continues) {
529529
bool nested = false;
530530
// When there is another loop between this WhileOp and the ContinueOp,
531-
// we shouldn't change that loop instead.
531+
// we should change that loop instead.
532532
for (mlir::Operation *parent = continueOp->getParentOp();
533533
parent != whileOp; parent = parent->getParentOp()) {
534534
if (isa<WhileOp>(parent)) {
@@ -570,6 +570,73 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
570570
}
571571
}
572572

573+
void rewriteBreak(mlir::scf::WhileOp whileOp,
574+
mlir::ConversionPatternRewriter &rewriter) const {
575+
// Collect all BreakOp inside this while.
576+
llvm::SmallVector<cir::BreakOp> breaks;
577+
whileOp->walk([&](mlir::Operation *op) {
578+
if (auto breakOp = dyn_cast<BreakOp>(op))
579+
breaks.push_back(breakOp);
580+
});
581+
582+
if (breaks.empty())
583+
return;
584+
585+
for (auto breakOp : breaks) {
586+
// When there is another loop between this WhileOp and the BreakOp,
587+
// we should change that loop instead.
588+
if (breakOp->getParentOfType<mlir::scf::WhileOp>() != whileOp)
589+
continue;
590+
591+
// Similar to the case of ContinueOp, when there is an `IfOp`,
592+
// we need to take special care.
593+
for (mlir::Operation *parent = breakOp->getParentOp(); parent != whileOp;
594+
parent = parent->getParentOp()) {
595+
if (auto ifOp = dyn_cast<cir::IfOp>(parent))
596+
llvm_unreachable("NYI");
597+
}
598+
599+
// Operations after this BreakOp has to be removed.
600+
for (mlir::Operation *runner = breakOp->getNextNode(); runner;) {
601+
mlir::Operation *next = runner->getNextNode();
602+
runner->erase();
603+
runner = next;
604+
}
605+
606+
// Blocks after this BreakOp also has to be removed.
607+
for (mlir::Block *block = breakOp->getBlock()->getNextNode(); block;) {
608+
mlir::Block *next = block->getNextNode();
609+
block->erase();
610+
block = next;
611+
}
612+
613+
// We know this BreakOp isn't nested in any IfOp.
614+
// Therefore, the loop is executed only once.
615+
// We pull everything out of the loop.
616+
617+
auto &beforeOps = whileOp.getBeforeBody()->getOperations();
618+
for (mlir::Operation *op = &*beforeOps.begin(); op;) {
619+
if (isa<ConditionOp>(op))
620+
break;
621+
auto *next = op->getNextNode();
622+
op->moveBefore(whileOp);
623+
op = next;
624+
}
625+
626+
auto &afterOps = whileOp.getAfterBody()->getOperations();
627+
for (mlir::Operation *op = &*afterOps.begin(); op;) {
628+
if (isa<YieldOp>(op))
629+
break;
630+
auto *next = op->getNextNode();
631+
op->moveBefore(whileOp);
632+
op = next;
633+
}
634+
635+
// The loop itself should now be removed.
636+
rewriter.eraseOp(whileOp);
637+
}
638+
}
639+
573640
public:
574641
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
575642

@@ -579,6 +646,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
579646
SCFWhileLoop loop(op, adaptor, &rewriter);
580647
auto whileOp = loop.transferToSCFWhileOp();
581648
rewriteContinue(whileOp, rewriter);
649+
rewriteBreak(whileOp, rewriter);
582650
rewriter.eraseOp(op);
583651
return mlir::success();
584652
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void while_break() {
5+
int i = 0;
6+
while (i < 100) {
7+
i++;
8+
break;
9+
i++;
10+
}
11+
// This should be compiled into the condition `i < 100` and a single `i++`,
12+
// without the while-loop.
13+
14+
// CHECK: memref.alloca_scope {
15+
// CHECK: %[[IV:.+]] = memref.load %alloca[]
16+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
17+
// CHECK: %[[_:.+]] = arith.cmpi slt, %[[IV]], %[[HUNDRED]]
18+
// CHECK: memref.alloca_scope {
19+
// CHECK: %[[IV2:.+]] = memref.load %alloca[]
20+
// CHECK: %[[ONE:.+]] = arith.constant 1
21+
// CHECK: %[[INCR:.+]] = arith.addi %[[IV2]], %[[ONE]]
22+
// CHECK: memref.store %[[INCR]], %alloca[]
23+
// CHECK: }
24+
// CHECK: }
25+
}

0 commit comments

Comments
 (0)