Skip to content

Commit 8e9ca05

Browse files
[flang][fir] Add conversion of fir.if to scf.if. (#149959)
This commmit is a supplement for #140374. RFC:https://discourse.llvm.org/t/rfc-add-fir-affine-optimization-fir-pass-pipeline/86190/6
1 parent 14e6390 commit 8e9ca05

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed

flang/lib/Optimizer/Transforms/FIRToSCF.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,52 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
8787
return success();
8888
}
8989
};
90+
91+
void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
92+
Block &dstBlock) {
93+
Operation *srcTerminator = srcBlock.getTerminator();
94+
auto resultOp = cast<fir::ResultOp>(srcTerminator);
95+
96+
dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(),
97+
srcBlock.begin(), std::prev(srcBlock.end()));
98+
99+
if (!resultOp->getOperands().empty()) {
100+
rewriter.setInsertionPointToEnd(&dstBlock);
101+
scf::YieldOp::create(rewriter, resultOp->getLoc(), resultOp->getOperands());
102+
}
103+
104+
rewriter.eraseOp(srcTerminator);
105+
}
106+
107+
struct IfConversion : public OpRewritePattern<fir::IfOp> {
108+
using OpRewritePattern<fir::IfOp>::OpRewritePattern;
109+
LogicalResult matchAndRewrite(fir::IfOp ifOp,
110+
PatternRewriter &rewriter) const override {
111+
bool hasElse = !ifOp.getElseRegion().empty();
112+
auto scfIfOp =
113+
scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(),
114+
ifOp.getCondition(), hasElse);
115+
116+
copyBlockAndTransformResult(rewriter, ifOp.getThenRegion().front(),
117+
scfIfOp.getThenRegion().front());
118+
119+
if (hasElse) {
120+
copyBlockAndTransformResult(rewriter, ifOp.getElseRegion().front(),
121+
scfIfOp.getElseRegion().front());
122+
}
123+
124+
scfIfOp->setAttrs(ifOp->getAttrs());
125+
rewriter.replaceOp(ifOp, scfIfOp);
126+
return success();
127+
}
128+
};
90129
} // namespace
91130

92131
void FIRToSCFPass::runOnOperation() {
93132
RewritePatternSet patterns(&getContext());
94-
patterns.add<DoLoopConversion>(patterns.getContext());
133+
patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
95134
ConversionTarget target(getContext());
96-
target.addIllegalOp<fir::DoLoopOp>();
135+
target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
97136
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
98137
if (failed(
99138
applyPartialConversion(getOperation(), target, std::move(patterns))))

flang/test/Fir/FirToSCF/if.fir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: fir-opt %s --fir-to-scf | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @test_only(
4+
// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) {
5+
// CHECK: scf.if %[[ARG0]] {
6+
// CHECK: %[[VAL_1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : i32
7+
// CHECK: }
8+
// CHECK: return
9+
// CHECK: }
10+
func.func @test_only(%arg0 : i1, %arg1 : i32) {
11+
fir.if %arg0 {
12+
%0 = arith.addi %arg1, %arg1 : i32
13+
}
14+
return
15+
}
16+
17+
// CHECK-LABEL: func.func @test_else() {
18+
// CHECK: %[[VAL_1:.*]] = arith.constant false
19+
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32
20+
// CHECK: scf.if %[[VAL_1]] {
21+
// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32
22+
// CHECK: } else {
23+
// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32
24+
// CHECK: }
25+
// CHECK: return
26+
// CHECK: }
27+
func.func @test_else() {
28+
%false = arith.constant false
29+
%1 = arith.constant 2 : i32
30+
fir.if %false {
31+
%2 = arith.constant 3 : i32
32+
} else {
33+
%3 = arith.constant 3 : i32
34+
}
35+
return
36+
}
37+
38+
// CHECK-LABEL: func.func @test_two_result() {
39+
// CHECK: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32
40+
// CHECK: %[[VAL_2:.*]] = arith.constant false
41+
// CHECK: %[[RES:[0-9]+]]:2 = scf.if %[[VAL_2]] -> (f32, f32) {
42+
// CHECK: scf.yield %[[VAL_1]], %[[VAL_1]] : f32, f32
43+
// CHECK: } else {
44+
// CHECK: scf.yield %[[VAL_1]], %[[VAL_1]] : f32, f32
45+
// CHECK: }
46+
// CHECK: return
47+
// CHECK: }
48+
func.func @test_two_result() {
49+
%1 = arith.constant 2.0 : f32
50+
%cmp = arith.constant false
51+
%x, %y = fir.if %cmp -> (f32, f32) {
52+
fir.result %1, %1 : f32, f32
53+
} else {
54+
fir.result %1, %1 : f32, f32
55+
}
56+
return
57+
}

0 commit comments

Comments
 (0)