@@ -87,13 +87,52 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
87
87
return success ();
88
88
}
89
89
};
90
+
91
+ void copyBlockAndTransformResult (PatternRewriter &rewriter, Block &srcBlock,
92
+ Block &dstBlock) {
93
+ Operation *srcTerminator = srcBlock.getTerminator ();
94
+ auto resultOp = cast<fir::ResultOp>(srcTerminator);
95
+
96
+ dstBlock.getOperations ().splice (dstBlock.begin (), srcBlock.getOperations (),
97
+ srcBlock.begin (), std::prev (srcBlock.end ()));
98
+
99
+ if (!resultOp->getOperands ().empty ()) {
100
+ rewriter.setInsertionPointToEnd (&dstBlock);
101
+ scf::YieldOp::create (rewriter, resultOp->getLoc (), resultOp->getOperands ());
102
+ }
103
+
104
+ rewriter.eraseOp (srcTerminator);
105
+ }
106
+
107
+ struct IfConversion : public OpRewritePattern <fir::IfOp> {
108
+ using OpRewritePattern<fir::IfOp>::OpRewritePattern;
109
+ LogicalResult matchAndRewrite (fir::IfOp ifOp,
110
+ PatternRewriter &rewriter) const override {
111
+ bool hasElse = !ifOp.getElseRegion ().empty ();
112
+ auto scfIfOp =
113
+ scf::IfOp::create (rewriter, ifOp.getLoc (), ifOp.getResultTypes (),
114
+ ifOp.getCondition (), hasElse);
115
+
116
+ copyBlockAndTransformResult (rewriter, ifOp.getThenRegion ().front (),
117
+ scfIfOp.getThenRegion ().front ());
118
+
119
+ if (hasElse) {
120
+ copyBlockAndTransformResult (rewriter, ifOp.getElseRegion ().front (),
121
+ scfIfOp.getElseRegion ().front ());
122
+ }
123
+
124
+ scfIfOp->setAttrs (ifOp->getAttrs ());
125
+ rewriter.replaceOp (ifOp, scfIfOp);
126
+ return success ();
127
+ }
128
+ };
90
129
} // namespace
91
130
92
131
void FIRToSCFPass::runOnOperation () {
93
132
RewritePatternSet patterns (&getContext ());
94
- patterns.add <DoLoopConversion>(patterns.getContext ());
133
+ patterns.add <DoLoopConversion, IfConversion >(patterns.getContext ());
95
134
ConversionTarget target (getContext ());
96
- target.addIllegalOp <fir::DoLoopOp>();
135
+ target.addIllegalOp <fir::DoLoopOp, fir::IfOp >();
97
136
target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
98
137
if (failed (
99
138
applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments