Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 76 additions & 52 deletions mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,13 @@ struct SCFToControlFlowPass
// | <%init visible by dominance> |
// +--------------------------------+
//
struct ForLowering : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;
struct ForLowering : public OpConversionPattern<ForOp> {
using OpConversionPattern<ForOp>::OpConversionPattern;

LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(ForOp forOp,
typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

// Create a CFG subgraph for the scf.if operation (including its "then" and
Expand Down Expand Up @@ -193,25 +195,31 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// | <code after the IfOp> |
// +--------------------------------+
//
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
struct IfLowering : public OpConversionPattern<IfOp> {
using OpConversionPattern<IfOp>::OpConversionPattern;

LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(IfOp ifOp,
typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
struct ExecuteRegionLowering : public OpConversionPattern<ExecuteRegionOp> {
using OpConversionPattern<ExecuteRegionOp>::OpConversionPattern;

LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(
ExecuteRegionOp op,
typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
struct ParallelLowering : public OpConversionPattern<mlir::scf::ParallelOp> {
using OpConversionPattern<mlir::scf::ParallelOp>::OpConversionPattern;

LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(
mlir::scf::ParallelOp parallelOp,
typename OpConversionPattern<mlir::scf::ParallelOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Create a CFG subgraph for this loop construct. The regions of the loop need
Expand Down Expand Up @@ -273,41 +281,49 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// the results of the WhileOp are defined in the 'before' region, which is
/// required to have a single existing block, and are therefore accessible in
/// the continuation block due to dominance.
struct WhileLowering : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
struct WhileLowering : public OpConversionPattern<WhileOp> {
using OpConversionPattern<WhileOp>::OpConversionPattern;

LogicalResult matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(WhileOp whileOp,
typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Optimized version of the above for the case of the "after" region merely
/// forwarding its arguments back to the "before" region (i.e., a "do-while"
/// loop). This avoid inlining the "after" region completely and branches back
/// to the "before" entry instead.
struct DoWhileLowering : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
struct DoWhileLowering : public OpConversionPattern<WhileOp> {
using OpConversionPattern<WhileOp>::OpConversionPattern;

LogicalResult matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(WhileOp whileOp,
typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern::OpRewritePattern;
struct IndexSwitchLowering : public OpConversionPattern<IndexSwitchOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(IndexSwitchOp op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(
IndexSwitchOp op,
typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
/// has no shared outputs. Ops with shared outputs should be bufferized first.
/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
/// dialects/passes.
struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
struct ForallLowering : public OpConversionPattern<mlir::scf::ForallOp> {
using OpConversionPattern<mlir::scf::ForallOp>::OpConversionPattern;

LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(
mlir::scf::ForallOp forallOp,
typename OpConversionPattern<mlir::scf::ForallOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // namespace
Expand All @@ -325,8 +341,9 @@ static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
brOp->setDiscardableAttrs(llvmAttrs);
}

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
LogicalResult ForLowering::matchAndRewrite(
ForOp forOp, typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();

// Start by splitting the block containing the 'scf.for' into two parts.
Expand Down Expand Up @@ -397,8 +414,9 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
return success();
}

LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
LogicalResult IfLowering::matchAndRewrite(
IfOp ifOp, typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = ifOp.getLoc();

// Start by splitting the block containing the 'scf.if' into two parts.
Expand Down Expand Up @@ -453,9 +471,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
return success();
}

LogicalResult
ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const {
LogicalResult ExecuteRegionLowering::matchAndRewrite(
ExecuteRegionOp op,
typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();

auto *condBlock = rewriter.getInsertionBlock();
Expand Down Expand Up @@ -487,9 +506,10 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
return success();
}

LogicalResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
LogicalResult ParallelLowering::matchAndRewrite(
ParallelOp parallelOp,
typename OpConversionPattern<ParallelOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = parallelOp.getLoc();
auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
if (!reductionOp) {
Expand Down Expand Up @@ -563,8 +583,9 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
return success();
}

LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const {
LogicalResult WhileLowering::matchAndRewrite(
WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
OpBuilder::InsertionGuard guard(rewriter);
Location loc = whileOp.getLoc();

Expand Down Expand Up @@ -606,9 +627,9 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}

LogicalResult
DoWhileLowering::matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const {
LogicalResult DoWhileLowering::matchAndRewrite(
WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Block &afterBlock = *whileOp.getAfterBody();
if (!llvm::hasSingleElement(afterBlock))
return rewriter.notifyMatchFailure(whileOp,
Expand Down Expand Up @@ -652,9 +673,10 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}

LogicalResult
IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
PatternRewriter &rewriter) const {
LogicalResult IndexSwitchLowering::matchAndRewrite(
IndexSwitchOp op,
typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Split the block at the op.
Block *condBlock = rewriter.getInsertionBlock();
Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
Expand Down Expand Up @@ -714,8 +736,10 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
return success();
}

LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const {
LogicalResult ForallLowering::matchAndRewrite(
ForallOp forallOp,
typename OpConversionPattern<ForallOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return scf::forallToParallelLoop(rewriter, forallOp);
}

Expand Down