@@ -528,7 +528,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
528
528
for (auto continueOp : continues) {
529
529
bool nested = false ;
530
530
// When there is another loop between this WhileOp and the ContinueOp,
531
- // we shouldn't change that loop instead.
531
+ // we should change that loop instead.
532
532
for (mlir::Operation *parent = continueOp->getParentOp ();
533
533
parent != whileOp; parent = parent->getParentOp ()) {
534
534
if (isa<WhileOp>(parent)) {
@@ -570,6 +570,73 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
570
570
}
571
571
}
572
572
573
+ void rewriteBreak (mlir::scf::WhileOp whileOp,
574
+ mlir::ConversionPatternRewriter &rewriter) const {
575
+ // Collect all BreakOp inside this while.
576
+ llvm::SmallVector<cir::BreakOp> breaks;
577
+ whileOp->walk ([&](mlir::Operation *op) {
578
+ if (auto breakOp = dyn_cast<BreakOp>(op))
579
+ breaks.push_back (breakOp);
580
+ });
581
+
582
+ if (breaks.empty ())
583
+ return ;
584
+
585
+ for (auto breakOp : breaks) {
586
+ // When there is another loop between this WhileOp and the BreakOp,
587
+ // we should change that loop instead.
588
+ if (breakOp->getParentOfType <mlir::scf::WhileOp>() != whileOp)
589
+ continue ;
590
+
591
+ // Similar to the case of ContinueOp, when there is an `IfOp`,
592
+ // we need to take special care.
593
+ for (mlir::Operation *parent = breakOp->getParentOp (); parent != whileOp;
594
+ parent = parent->getParentOp ()) {
595
+ if (auto ifOp = dyn_cast<cir::IfOp>(parent))
596
+ llvm_unreachable (" NYI" );
597
+ }
598
+
599
+ // Operations after this BreakOp has to be removed.
600
+ for (mlir::Operation *runner = breakOp->getNextNode (); runner;) {
601
+ mlir::Operation *next = runner->getNextNode ();
602
+ runner->erase ();
603
+ runner = next;
604
+ }
605
+
606
+ // Blocks after this BreakOp also has to be removed.
607
+ for (mlir::Block *block = breakOp->getBlock ()->getNextNode (); block;) {
608
+ mlir::Block *next = block->getNextNode ();
609
+ block->erase ();
610
+ block = next;
611
+ }
612
+
613
+ // We know this BreakOp isn't nested in any IfOp.
614
+ // Therefore, the loop is executed only once.
615
+ // We pull everything out of the loop.
616
+
617
+ auto &beforeOps = whileOp.getBeforeBody ()->getOperations ();
618
+ for (mlir::Operation *op = &*beforeOps.begin (); op;) {
619
+ if (isa<ConditionOp>(op))
620
+ break ;
621
+ auto *next = op->getNextNode ();
622
+ op->moveBefore (whileOp);
623
+ op = next;
624
+ }
625
+
626
+ auto &afterOps = whileOp.getAfterBody ()->getOperations ();
627
+ for (mlir::Operation *op = &*afterOps.begin (); op;) {
628
+ if (isa<YieldOp>(op))
629
+ break ;
630
+ auto *next = op->getNextNode ();
631
+ op->moveBefore (whileOp);
632
+ op = next;
633
+ }
634
+
635
+ // The loop itself should now be removed.
636
+ rewriter.eraseOp (whileOp);
637
+ }
638
+ }
639
+
573
640
public:
574
641
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
575
642
@@ -579,6 +646,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
579
646
SCFWhileLoop loop (op, adaptor, &rewriter);
580
647
auto whileOp = loop.transferToSCFWhileOp ();
581
648
rewriteContinue (whileOp, rewriter);
649
+ rewriteBreak (whileOp, rewriter);
582
650
rewriter.eraseOp (op);
583
651
return mlir::success ();
584
652
}
0 commit comments