@@ -1444,28 +1444,119 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
1444
1444
}
1445
1445
};
1446
1446
1447
+ class CIRSwitchOpLowering : public mlir ::OpConversionPattern<cir::SwitchOp> {
1448
+ public:
1449
+ using OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1450
+
1451
+ mlir::LogicalResult
1452
+ matchAndRewrite (cir::SwitchOp op, OpAdaptor adaptor,
1453
+ mlir::ConversionPatternRewriter &rewriter) const override {
1454
+ rewriter.setInsertionPointAfter (op);
1455
+ llvm::SmallVector<CaseOp> cases;
1456
+ if (!op.isSimpleForm (cases))
1457
+ llvm_unreachable (" NYI" );
1458
+
1459
+ llvm::SmallVector<int64_t > caseValues;
1460
+ // Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1461
+ // This is necessary because some CaseOp might carry 0 or multiple values.
1462
+ llvm::DenseMap<size_t , unsigned > indexMap;
1463
+ caseValues.reserve (cases.size ());
1464
+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1465
+ switch (caseOp.getKind ()) {
1466
+ case CaseOpKind::Equal: {
1467
+ auto valueAttr = caseOp.getValue ()[0 ];
1468
+ auto value = cast<cir::IntAttr>(valueAttr);
1469
+ indexMap[i] = caseValues.size ();
1470
+ caseValues.push_back (value.getUInt ());
1471
+ break ;
1472
+ }
1473
+ case CaseOpKind::Default:
1474
+ break ;
1475
+ case CaseOpKind::Range:
1476
+ case CaseOpKind::Anyof:
1477
+ llvm_unreachable (" NYI" );
1478
+ }
1479
+ }
1480
+
1481
+ auto operand = adaptor.getOperands ()[0 ];
1482
+ // `scf.index_switch` expects an index of type `index`.
1483
+ auto indexType = mlir::IndexType::get (getContext ());
1484
+ auto indexCast = rewriter.create <mlir::arith::IndexCastOp>(
1485
+ op.getLoc (), indexType, operand);
1486
+ auto indexSwitch = rewriter.create <mlir::scf::IndexSwitchOp>(
1487
+ op.getLoc (), mlir::TypeRange{}, indexCast, caseValues, cases.size ());
1488
+
1489
+ bool metDefault = false ;
1490
+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1491
+ auto ®ion = caseOp.getRegion ();
1492
+ switch (caseOp.getKind ()) {
1493
+ case CaseOpKind::Equal: {
1494
+ auto &caseRegion = indexSwitch.getCaseRegions ()[indexMap[i]];
1495
+ rewriter.inlineRegionBefore (region, caseRegion, caseRegion.end ());
1496
+ break ;
1497
+ }
1498
+ case CaseOpKind::Default: {
1499
+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1500
+ rewriter.inlineRegionBefore (region, defaultRegion, defaultRegion.end ());
1501
+ metDefault = true ;
1502
+ break ;
1503
+ }
1504
+ case CaseOpKind::Range:
1505
+ case CaseOpKind::Anyof:
1506
+ llvm_unreachable (" NYI" );
1507
+ }
1508
+ }
1509
+
1510
+ // `scf.index_switch` expects its default region to contain exactly one
1511
+ // block. If we don't have a default region in `cir.switch`, we need to
1512
+ // supply it here.
1513
+ if (!metDefault) {
1514
+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1515
+ mlir::Block *block =
1516
+ rewriter.createBlock (&defaultRegion, defaultRegion.end ());
1517
+ rewriter.setInsertionPointToEnd (block);
1518
+ rewriter.create <mlir::scf::YieldOp>(op.getLoc ());
1519
+ }
1520
+
1521
+ // The final `cir.break` should be replaced to `scf.yield`.
1522
+ // After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1523
+ for (auto ®ion : indexSwitch.getCaseRegions ()) {
1524
+ auto &lastBlock = region.back ();
1525
+ auto &lastOp = lastBlock.back ();
1526
+ assert (isa<BreakOp>(lastOp));
1527
+ rewriter.setInsertionPointAfter (&lastOp);
1528
+ rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(&lastOp);
1529
+ }
1530
+
1531
+ rewriter.replaceOp (op, indexSwitch);
1532
+
1533
+ return mlir::success ();
1534
+ }
1535
+ };
1536
+
1447
1537
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1448
1538
mlir::TypeConverter &converter) {
1449
1539
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1450
1540
1451
1541
patterns
1452
- .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1453
- CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1454
- CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1455
- CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering,
1456
- CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1457
- CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1458
- CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1459
- CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1460
- CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1461
- CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1462
- CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1463
- CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1464
- CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1465
- CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1466
- CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1467
- CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1468
- CIRTrapOpLowering>(converter, patterns.getContext ());
1542
+ .add <CIRSwitchOpLowering, CIRATanOpLowering, CIRCmpOpLowering,
1543
+ CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1544
+ CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1545
+ CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1546
+ CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1547
+ CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1548
+ CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1549
+ CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1550
+ CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1551
+ CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1552
+ CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1553
+ CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1554
+ CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1555
+ CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1556
+ CIRVectorInsertLowering, CIRVectorExtractLowering,
1557
+ CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering,
1558
+ CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(
1559
+ converter, patterns.getContext ());
1469
1560
}
1470
1561
1471
1562
static mlir::TypeConverter prepareTypeConverter () {
@@ -1571,6 +1662,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
1571
1662
1572
1663
mlir::PassManager pm (mlirCtx);
1573
1664
1665
+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
1574
1666
pm.addPass (createConvertCIRToMLIRPass ());
1575
1667
pm.addPass (createConvertMLIRToLLVMPass ());
1576
1668
@@ -1616,6 +1708,7 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
1616
1708
1617
1709
mlir::PassManager pm (mlirCtx);
1618
1710
1711
+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
1619
1712
pm.addPass (createConvertCIRToMLIRPass ());
1620
1713
1621
1714
auto result = !mlir::failed (pm.run (theModule));
0 commit comments