Skip to content

Commit 1a7f9c9

Browse files
committed
[CIR][ThroughMLIR] Lower simple SwitchOp
1 parent fe8b9ca commit 1a7f9c9

File tree

5 files changed

+286
-17
lines changed

5 files changed

+286
-17
lines changed

clang/include/clang/CIR/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
#include <memory>
1919

2020
namespace cir {
21+
/// Create a pass for transforming CIR operations to more 'scf' dialect-friendly
22+
/// forms. It rewrites operations that aren't supported by 'scf', such as breaks
23+
/// and continues.
24+
std::unique_ptr<mlir::Pass> createMLIRCoreDialectsLoweringPreparePass();
25+
2126
/// Create a pass for lowering from MLIR builtin dialects such as `Affine` and
2227
/// `Std`, to the LLVM dialect for codegen.
2328
std::unique_ptr<mlir::Pass> createConvertMLIRToLLVMPass();

clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_clang_library(clangCIRLoweringThroughMLIR
99
LowerCIRLoopToSCF.cpp
1010
LowerCIRToMLIR.cpp
1111
LowerMLIRToLLVM.cpp
12+
MLIRCoreDialectsLoweringPrepare.cpp
1213

1314
DEPENDS
1415
MLIRCIROpsIncGen

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,28 +1444,119 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
14441444
}
14451445
};
14461446

1447+
class CIRSwitchOpLowering : public mlir::OpConversionPattern<cir::SwitchOp> {
1448+
public:
1449+
using OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1450+
1451+
mlir::LogicalResult
1452+
matchAndRewrite(cir::SwitchOp op, OpAdaptor adaptor,
1453+
mlir::ConversionPatternRewriter &rewriter) const override {
1454+
rewriter.setInsertionPointAfter(op);
1455+
llvm::SmallVector<CaseOp> cases;
1456+
if (!op.isSimpleForm(cases))
1457+
llvm_unreachable("NYI");
1458+
1459+
llvm::SmallVector<int64_t> caseValues;
1460+
// Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1461+
// This is necessary because some CaseOp might carry 0 or multiple values.
1462+
llvm::DenseMap<size_t, unsigned> indexMap;
1463+
caseValues.reserve(cases.size());
1464+
for (auto [i, caseOp] : llvm::enumerate(cases)) {
1465+
switch (caseOp.getKind()) {
1466+
case CaseOpKind::Equal: {
1467+
auto valueAttr = caseOp.getValue()[0];
1468+
auto value = cast<cir::IntAttr>(valueAttr);
1469+
indexMap[i] = caseValues.size();
1470+
caseValues.push_back(value.getUInt());
1471+
break;
1472+
}
1473+
case CaseOpKind::Default:
1474+
break;
1475+
case CaseOpKind::Range:
1476+
case CaseOpKind::Anyof:
1477+
llvm_unreachable("NYI");
1478+
}
1479+
}
1480+
1481+
auto operand = adaptor.getOperands()[0];
1482+
// `scf.index_switch` expects an index of type `index`.
1483+
auto indexType = mlir::IndexType::get(getContext());
1484+
auto indexCast = rewriter.create<mlir::arith::IndexCastOp>(
1485+
op.getLoc(), indexType, operand);
1486+
auto indexSwitch = rewriter.create<mlir::scf::IndexSwitchOp>(
1487+
op.getLoc(), mlir::TypeRange{}, indexCast, caseValues, cases.size());
1488+
1489+
bool metDefault = false;
1490+
for (auto [i, caseOp] : llvm::enumerate(cases)) {
1491+
auto &region = caseOp.getRegion();
1492+
switch (caseOp.getKind()) {
1493+
case CaseOpKind::Equal: {
1494+
auto &caseRegion = indexSwitch.getCaseRegions()[indexMap[i]];
1495+
rewriter.inlineRegionBefore(region, caseRegion, caseRegion.end());
1496+
break;
1497+
}
1498+
case CaseOpKind::Default: {
1499+
auto &defaultRegion = indexSwitch.getDefaultRegion();
1500+
rewriter.inlineRegionBefore(region, defaultRegion, defaultRegion.end());
1501+
metDefault = true;
1502+
break;
1503+
}
1504+
case CaseOpKind::Range:
1505+
case CaseOpKind::Anyof:
1506+
llvm_unreachable("NYI");
1507+
}
1508+
}
1509+
1510+
// `scf.index_switch` expects its default region to contain exactly one
1511+
// block. If we don't have a default region in `cir.switch`, we need to
1512+
// supply it here.
1513+
if (!metDefault) {
1514+
auto &defaultRegion = indexSwitch.getDefaultRegion();
1515+
mlir::Block *block =
1516+
rewriter.createBlock(&defaultRegion, defaultRegion.end());
1517+
rewriter.setInsertionPointToEnd(block);
1518+
rewriter.create<mlir::scf::YieldOp>(op.getLoc());
1519+
}
1520+
1521+
// The final `cir.break` should be replaced to `scf.yield`.
1522+
// After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1523+
for (auto &region : indexSwitch.getCaseRegions()) {
1524+
auto &lastBlock = region.back();
1525+
auto &lastOp = lastBlock.back();
1526+
assert(isa<BreakOp>(lastOp));
1527+
rewriter.setInsertionPointAfter(&lastOp);
1528+
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(&lastOp);
1529+
}
1530+
1531+
rewriter.replaceOp(op, indexSwitch);
1532+
1533+
return mlir::success();
1534+
}
1535+
};
1536+
14471537
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
14481538
mlir::TypeConverter &converter) {
14491539
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
14501540

14511541
patterns
1452-
.add<CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1453-
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1454-
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1455-
CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering,
1456-
CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1457-
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1458-
CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1459-
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1460-
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1461-
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1462-
CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1463-
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1464-
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1465-
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1466-
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1467-
CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1468-
CIRTrapOpLowering>(converter, patterns.getContext());
1542+
.add<CIRSwitchOpLowering, CIRATanOpLowering, CIRCmpOpLowering,
1543+
CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1544+
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1545+
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1546+
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1547+
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1548+
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1549+
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1550+
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1551+
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1552+
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1553+
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1554+
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1555+
CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1556+
CIRVectorInsertLowering, CIRVectorExtractLowering,
1557+
CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering,
1558+
CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(
1559+
converter, patterns.getContext());
14691560
}
14701561

14711562
static mlir::TypeConverter prepareTypeConverter() {
@@ -1571,6 +1662,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
15711662

15721663
mlir::PassManager pm(mlirCtx);
15731664

1665+
pm.addPass(createMLIRCoreDialectsLoweringPreparePass());
15741666
pm.addPass(createConvertCIRToMLIRPass());
15751667
pm.addPass(createConvertMLIRToLLVMPass());
15761668

@@ -1616,6 +1708,7 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
16161708

16171709
mlir::PassManager pm(mlirCtx);
16181710

1711+
pm.addPass(createMLIRCoreDialectsLoweringPreparePass());
16191712
pm.addPass(createConvertCIRToMLIRPass());
16201713

16211714
auto result = !mlir::failed(pm.run(theModule));
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//===- MLIRCoreDialectsLoweringPrepare.cpp - CIR lowering preparation -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/IR/BuiltinOps.h"
10+
#include "mlir/IR/IRMapping.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/DialectConversion.h"
13+
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
14+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
15+
16+
using namespace llvm;
17+
using namespace cir;
18+
19+
namespace cir {
20+
21+
struct MLIRLoweringPrepare
22+
: public mlir::PassWrapper<MLIRLoweringPrepare,
23+
mlir::OperationPass<mlir::ModuleOp>> {
24+
// `scf.index_switch` requires that switch branches do not fall through.
25+
// We need to copy the next branch's body when the current `cir.case` does not
26+
// terminate with a break.
27+
void removeFallthrough(llvm::SmallVector<CaseOp> &cases);
28+
29+
void runOnOp(mlir::Operation *op);
30+
void runOnOperation() final;
31+
32+
StringRef getDescription() const override {
33+
return "Rewrite CIR module to be more 'scf' dialect-friendly";
34+
}
35+
36+
StringRef getArgument() const override { return "mlir-lowering-prepare"; }
37+
};
38+
39+
// `scf.index_switch` requires that switch branches do not fall through.
40+
// We need to copy the next branch's body when the current `cir.case` does not
41+
// terminate with a break.
42+
void MLIRLoweringPrepare::removeFallthrough(llvm::SmallVector<CaseOp> &cases) {
43+
CIRBaseBuilderTy builder(getContext());
44+
// Note we enumerate in the reverse order, to facilitate the cloning.
45+
for (auto it = cases.rbegin(); it != cases.rend(); it++) {
46+
auto caseOp = *it;
47+
auto &region = caseOp.getRegion();
48+
auto &lastBlock = region.back();
49+
mlir::Operation &last = lastBlock.back();
50+
if (isa<BreakOp>(last))
51+
continue;
52+
53+
// The last op must be a `cir.yield`. As it falls through, we copy the
54+
// previous case's body to this one.
55+
if (!isa<YieldOp>(last)) {
56+
caseOp->dump();
57+
continue;
58+
}
59+
assert(isa<YieldOp>(last));
60+
61+
// If there's no previous case, we can simply change the yield into a break.
62+
if (it == cases.rbegin()) {
63+
builder.setInsertionPointAfter(&last);
64+
builder.create<BreakOp>(last.getLoc());
65+
last.erase();
66+
continue;
67+
}
68+
69+
auto prevIt = it;
70+
--prevIt;
71+
CaseOp &prev = *prevIt;
72+
auto &prevRegion = prev.getRegion();
73+
mlir::IRMapping mapping;
74+
builder.cloneRegionBefore(prevRegion, region, region.end());
75+
76+
// We inline the block to the end.
77+
// This is required because `scf.index_switch` expects that each of its
78+
// region contains a single block.
79+
mlir::Block *cloned = lastBlock.getNextNode();
80+
for (auto it = cloned->begin(); it != cloned->end();) {
81+
auto next = it;
82+
next++;
83+
it->moveBefore(&last);
84+
it = next;
85+
}
86+
cloned->erase();
87+
last.erase();
88+
}
89+
}
90+
91+
void MLIRLoweringPrepare::runOnOp(mlir::Operation *op) {
92+
if (auto switchOp = dyn_cast<SwitchOp>(op)) {
93+
llvm::SmallVector<CaseOp> cases;
94+
if (!switchOp.isSimpleForm(cases))
95+
op->emitError("NYI");
96+
97+
removeFallthrough(cases);
98+
return;
99+
}
100+
op->emitError("unexpected op type");
101+
}
102+
103+
void MLIRLoweringPrepare::runOnOperation() {
104+
auto module = getOperation();
105+
106+
llvm::SmallVector<mlir::Operation *> opsToTransform;
107+
module->walk([&](mlir::Operation *op) {
108+
if (isa<SwitchOp>(op))
109+
opsToTransform.push_back(op);
110+
});
111+
112+
for (auto *op : opsToTransform)
113+
runOnOp(op);
114+
}
115+
116+
std::unique_ptr<mlir::Pass> createMLIRCoreDialectsLoweringPreparePass() {
117+
return std::make_unique<MLIRLoweringPrepare>();
118+
}
119+
120+
} // namespace cir
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void fallthrough() {
5+
int i = 0;
6+
switch (i) {
7+
case 2:
8+
i++;
9+
case 3:
10+
i++;
11+
break;
12+
case 8:
13+
i++;
14+
}
15+
16+
// This should copy the `i++; break` in case 3 to case 2.
17+
18+
// CHECK: memref.alloca_scope {
19+
// CHECK: %[[I:.+]] = memref.load %alloca[]
20+
// CHECK: %[[CASTED:.+]] = arith.index_cast %[[I]]
21+
// CHECK: scf.index_switch %[[CASTED]]
22+
// CHECK: case 2 {
23+
// CHECK: %[[I:.+]] = memref.load %alloca[]
24+
// CHECK: %[[ONE:.+]] = arith.constant 1
25+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
26+
// CHECK: memref.store %[[ADD]], %alloca[]
27+
// CHECK: %[[I:.+]] = memref.load %alloca[]
28+
// CHECK: %[[ONE:.+]] = arith.constant 1
29+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
30+
// CHECK: memref.store %[[ADD]], %alloca[]
31+
// CHECK: scf.yield
32+
// CHECK: }
33+
// CHECK: case 3 {
34+
// CHECK: %[[I:.+]] = memref.load %alloca[]
35+
// CHECK: %[[ONE:.+]] = arith.constant 1
36+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
37+
// CHECK: memref.store %[[ADD]], %alloca[]
38+
// CHECK: scf.yield
39+
// CHECK: }
40+
// CHECK: case 8 {
41+
// CHECK: %[[I:.+]] = memref.load %alloca[]
42+
// CHECK: %[[ONE:.+]] = arith.constant 1
43+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
44+
// CHECK: memref.store %[[ADD]], %alloca[]
45+
// CHECK: scf.yield
46+
// CHECK: }
47+
// CHECK: default {
48+
// CHECK: }
49+
// CHECK: }
50+
}

0 commit comments

Comments
 (0)