diff --git a/clang/include/clang/CIR/Passes.h b/clang/include/clang/CIR/Passes.h index 3f8a174aac0c..ecb25ba856f4 100644 --- a/clang/include/clang/CIR/Passes.h +++ b/clang/include/clang/CIR/Passes.h @@ -18,6 +18,11 @@ #include namespace cir { +/// Create a pass for transforming CIR operations to more 'scf' dialect-friendly +/// forms. It rewrites operations that aren't supported by 'scf', such as breaks +/// and continues. +std::unique_ptr createMLIRCoreDialectsLoweringPreparePass(); + /// Create a pass for lowering from MLIR builtin dialects such as `Affine` and /// `Std`, to the LLVM dialect for codegen. std::unique_ptr createConvertMLIRToLLVMPass(); diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt b/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt index 8c2631ab57d8..9a38b0d4a65d 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt +++ b/clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt @@ -9,6 +9,7 @@ add_clang_library(clangCIRLoweringThroughMLIR LowerCIRLoopToSCF.cpp LowerCIRToMLIR.cpp LowerMLIRToLLVM.cpp + MLIRCoreDialectsLoweringPrepare.cpp DEPENDS MLIRCIROpsIncGen diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 1a3cb5db1fa4..2ee5fe5ed88b 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -1496,29 +1496,118 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern { } }; +class CIRSwitchOpLowering : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::SwitchOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.setInsertionPointAfter(op); + llvm::SmallVector cases; + if (!op.isSimpleForm(cases)) + mlir::emitError(op.getLoc(), "not yet implemented"); + + llvm::SmallVector caseValues; + // Maps the index of a CaseOp in `cases`, to the index in `caseValues`. + // This is necessary because some CaseOp might carry 0 or multiple values. + llvm::DenseMap indexMap; + caseValues.reserve(cases.size()); + for (auto [i, caseOp] : llvm::enumerate(cases)) { + switch (caseOp.getKind()) { + case CaseOpKind::Equal: { + auto valueAttr = caseOp.getValue()[0]; + auto value = cast(valueAttr); + indexMap[i] = caseValues.size(); + caseValues.push_back(value.getUInt()); + break; + } + case CaseOpKind::Default: + break; + case CaseOpKind::Range: + case CaseOpKind::Anyof: + mlir::emitError(op.getLoc(), "not yet implemented"); + } + } + + auto operand = adaptor.getOperands()[0]; + // `scf.index_switch` expects an index of type `index`. + auto indexType = mlir::IndexType::get(getContext()); + auto indexCast = rewriter.create( + op.getLoc(), indexType, operand); + auto indexSwitch = rewriter.create( + op.getLoc(), mlir::TypeRange{}, indexCast, caseValues, cases.size()); + + bool metDefault = false; + for (auto [i, caseOp] : llvm::enumerate(cases)) { + auto ®ion = caseOp.getRegion(); + switch (caseOp.getKind()) { + case CaseOpKind::Equal: { + auto &caseRegion = indexSwitch.getCaseRegions()[indexMap[i]]; + rewriter.inlineRegionBefore(region, caseRegion, caseRegion.end()); + break; + } + case CaseOpKind::Default: { + auto &defaultRegion = indexSwitch.getDefaultRegion(); + rewriter.inlineRegionBefore(region, defaultRegion, defaultRegion.end()); + metDefault = true; + break; + } + case CaseOpKind::Range: + case CaseOpKind::Anyof: + mlir::emitError(op.getLoc(), "not yet implemented"); + } + } + + // `scf.index_switch` expects its default region to contain exactly one + // block. If we don't have a default region in `cir.switch`, we need to + // supply it here. + if (!metDefault) { + auto &defaultRegion = indexSwitch.getDefaultRegion(); + mlir::Block *block = + rewriter.createBlock(&defaultRegion, defaultRegion.end()); + rewriter.setInsertionPointToEnd(block); + rewriter.create(op.getLoc()); + } + + // The final `cir.break` should be replaced to `scf.yield`. + // After MLIRLoweringPrepare pass, every case must end with a `cir.break`. + for (auto ®ion : indexSwitch.getCaseRegions()) { + auto &lastBlock = region.back(); + auto &lastOp = lastBlock.back(); + assert(isa(lastOp)); + rewriter.setInsertionPointAfter(&lastOp); + rewriter.replaceOpWithNewOp(&lastOp); + } + + rewriter.replaceOp(op, indexSwitch); + + return mlir::success(); + } +}; + void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); - patterns - .add( - converter, patterns.getContext()); + patterns.add< + CIRSwitchOpLowering, CIRGetElementOpLowering, CIRATanOpLowering, + CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering, + CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering, + CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering, + CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering, + CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering, + CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering, + CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering, + CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering, + CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, + CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering, + CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering, + CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering, + CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering, + CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering, + CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering, + CIRTrapOpLowering>(converter, patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { @@ -1624,6 +1713,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule, mlir::PassManager pm(mlirCtx); + pm.addPass(createMLIRCoreDialectsLoweringPreparePass()); pm.addPass(createConvertCIRToMLIRPass()); pm.addPass(createConvertMLIRToLLVMPass()); @@ -1669,6 +1759,7 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule, mlir::PassManager pm(mlirCtx); + pm.addPass(createMLIRCoreDialectsLoweringPreparePass()); pm.addPass(createConvertCIRToMLIRPass()); auto result = !mlir::failed(pm.run(theModule)); diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/MLIRCoreDialectsLoweringPrepare.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/MLIRCoreDialectsLoweringPrepare.cpp new file mode 100644 index 000000000000..755ce0f7887f --- /dev/null +++ b/clang/lib/CIR/Lowering/ThroughMLIR/MLIRCoreDialectsLoweringPrepare.cpp @@ -0,0 +1,120 @@ +//===- MLIRCoreDialectsLoweringPrepare.cpp - CIR lowering preparation -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" + +using namespace llvm; +using namespace cir; + +namespace cir { + +struct MLIRLoweringPrepare + : public mlir::PassWrapper> { + // `scf.index_switch` requires that switch branches do not fall through. + // We need to copy the next branch's body when the current `cir.case` does + // not terminate with a break. + void removeFallthrough(llvm::SmallVector &cases); + + void runOnOp(mlir::Operation *op); + void runOnOperation() final; + + StringRef getDescription() const override { + return "Rewrite CIR module to be more 'scf' dialect-friendly"; + } + + StringRef getArgument() const override { return "mlir-lowering-prepare"; } +}; + +// `scf.index_switch` requires that switch branches do not fall through. +// We need to copy the next branch's body when the current `cir.case` does not +// terminate with a break. +void MLIRLoweringPrepare::removeFallthrough(llvm::SmallVector &cases) { + CIRBaseBuilderTy builder(getContext()); + // Note we enumerate in the reverse order, to facilitate the cloning. + for (auto it = cases.rbegin(); it != cases.rend(); it++) { + auto caseOp = *it; + auto ®ion = caseOp.getRegion(); + auto &lastBlock = region.back(); + mlir::Operation &last = lastBlock.back(); + if (isa(last)) + continue; + + // The last op must be a `cir.yield`. As it falls through, we copy the + // previous case's body to this one. + if (!isa(last)) { + caseOp->dump(); + continue; + } + assert(isa(last)); + + // If there's no previous case, we can simply change the yield into a break. + if (it == cases.rbegin()) { + builder.setInsertionPointAfter(&last); + builder.create(last.getLoc()); + last.erase(); + continue; + } + + auto prevIt = it; + --prevIt; + CaseOp &prev = *prevIt; + auto &prevRegion = prev.getRegion(); + mlir::IRMapping mapping; + builder.cloneRegionBefore(prevRegion, region, region.end()); + + // We inline the block to the end. + // This is required because `scf.index_switch` expects that each of its + // region contains a single block. + mlir::Block *cloned = lastBlock.getNextNode(); + for (auto it = cloned->begin(); it != cloned->end();) { + auto next = it; + next++; + it->moveBefore(&last); + it = next; + } + cloned->erase(); + last.erase(); + } +} + +void MLIRLoweringPrepare::runOnOp(mlir::Operation *op) { + if (auto switchOp = dyn_cast(op)) { + llvm::SmallVector cases; + if (!switchOp.isSimpleForm(cases)) + op->emitError("NYI"); + + removeFallthrough(cases); + return; + } + op->emitError("unexpected op type"); +} + +void MLIRLoweringPrepare::runOnOperation() { + auto module = getOperation(); + + llvm::SmallVector opsToTransform; + module->walk([&](mlir::Operation *op) { + if (isa(op)) + opsToTransform.push_back(op); + }); + + for (auto *op : opsToTransform) + runOnOp(op); +} + +std::unique_ptr createMLIRCoreDialectsLoweringPreparePass() { + return std::make_unique(); +} + +} // namespace cir diff --git a/clang/test/CIR/Lowering/ThroughMLIR/switch.c b/clang/test/CIR/Lowering/ThroughMLIR/switch.c new file mode 100644 index 000000000000..e41b44f622b5 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/switch.c @@ -0,0 +1,50 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +void fallthrough() { + int i = 0; + switch (i) { + case 2: + i++; + case 3: + i++; + break; + case 8: + i++; + } + + // This should copy the `i++; break` in case 3 to case 2. + + // CHECK: memref.alloca_scope { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[CASTED:.+]] = arith.index_cast %[[I]] + // CHECK: scf.index_switch %[[CASTED]] + // CHECK: case 2 { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: scf.yield + // CHECK: } + // CHECK: case 3 { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: scf.yield + // CHECK: } + // CHECK: case 8 { + // CHECK: %[[I:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: scf.yield + // CHECK: } + // CHECK: default { + // CHECK: } + // CHECK: } +}