-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[MLIR][SCF] Actually use conversion interface in scf-to-cf conversion #154075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: William Moses (wsmoses) ChangesFull diff: https://github.com/llvm/llvm-project/pull/154075.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 37cfc9f2c23e6..d9ec932244770 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -100,11 +100,13 @@ struct SCFToControlFlowPass
// | <%init visible by dominance> |
// +--------------------------------+
//
-struct ForLowering : public OpRewritePattern<ForOp> {
- using OpRewritePattern<ForOp>::OpRewritePattern;
+struct ForLowering : public OpConversionPattern<ForOp> {
+ using OpConversionPattern<ForOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(ForOp forOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewrite(ForOp forOp,
+ typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
// Create a CFG subgraph for the scf.if operation (including its "then" and
@@ -193,25 +195,31 @@ struct ForLowering : public OpRewritePattern<ForOp> {
// | <code after the IfOp> |
// +--------------------------------+
//
-struct IfLowering : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
+struct IfLowering : public OpConversionPattern<IfOp> {
+ using OpConversionPattern<IfOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(IfOp ifOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewrite(IfOp ifOp,
+ typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
-struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
- using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+struct ExecuteRegionLowering : public OpConversionPattern<ExecuteRegionOp> {
+ using OpConversionPattern<ExecuteRegionOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(ExecuteRegionOp op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(
+ ExecuteRegionOp op,
+ typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
-struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
- using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
+struct ParallelLowering : public OpConversionPattern<mlir::scf::ParallelOp> {
+ using OpConversionPattern<mlir::scf::ParallelOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(
+ mlir::scf::ParallelOp parallelOp,
+ typename OpConversionPattern<mlir::scf::ParallelOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
/// Create a CFG subgraph for this loop construct. The regions of the loop need
@@ -273,41 +281,49 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
/// the results of the WhileOp are defined in the 'before' region, which is
/// required to have a single existing block, and are therefore accessible in
/// the continuation block due to dominance.
-struct WhileLowering : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern<WhileOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(WhileOp whileOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp,
+ typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
/// Optimized version of the above for the case of the "after" region merely
/// forwarding its arguments back to the "before" region (i.e., a "do-while"
/// loop). This avoid inlining the "after" region completely and branches back
/// to the "before" entry instead.
-struct DoWhileLowering : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
+struct DoWhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern<WhileOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(WhileOp whileOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp,
+ typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
-struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
- using OpRewritePattern::OpRewritePattern;
+struct IndexSwitchLowering : public OpConversionPattern<IndexSwitchOp> {
+ using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(IndexSwitchOp op,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(
+ IndexSwitchOp op,
+ typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
/// has no shared outputs. Ops with shared outputs should be bufferized first.
/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
/// dialects/passes.
-struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
- using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
+struct ForallLowering : public OpConversionPattern<mlir::scf::ForallOp> {
+ using OpConversionPattern<mlir::scf::ForallOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
- PatternRewriter &rewriter) const override;
+ LogicalResult matchAndRewrite(
+ mlir::scf::ForallOp forallOp,
+ typename OpConversionPattern<mlir::scf::ForallOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
};
} // namespace
@@ -325,8 +341,9 @@ static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
brOp->setDiscardableAttrs(llvmAttrs);
}
-LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
- PatternRewriter &rewriter) const {
+LogicalResult ForLowering::matchAndRewrite(
+ ForOp forOp, typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
// Start by splitting the block containing the 'scf.for' into two parts.
@@ -397,8 +414,9 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
return success();
}
-LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
- PatternRewriter &rewriter) const {
+LogicalResult IfLowering::matchAndRewrite(
+ IfOp ifOp, typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
auto loc = ifOp.getLoc();
// Start by splitting the block containing the 'scf.if' into two parts.
@@ -453,9 +471,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
return success();
}
-LogicalResult
-ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
- PatternRewriter &rewriter) const {
+LogicalResult ExecuteRegionLowering::matchAndRewrite(
+ ExecuteRegionOp op,
+ typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto *condBlock = rewriter.getInsertionBlock();
@@ -487,9 +506,10 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
return success();
}
-LogicalResult
-ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
- PatternRewriter &rewriter) const {
+LogicalResult ParallelLowering::matchAndRewrite(
+ ParallelOp parallelOp,
+ typename OpConversionPattern<ParallelOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
Location loc = parallelOp.getLoc();
auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
if (!reductionOp) {
@@ -563,8 +583,9 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
return success();
}
-LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
- PatternRewriter &rewriter) const {
+LogicalResult WhileLowering::matchAndRewrite(
+ WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
OpBuilder::InsertionGuard guard(rewriter);
Location loc = whileOp.getLoc();
@@ -606,9 +627,9 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}
-LogicalResult
-DoWhileLowering::matchAndRewrite(WhileOp whileOp,
- PatternRewriter &rewriter) const {
+LogicalResult DoWhileLowering::matchAndRewrite(
+ WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
Block &afterBlock = *whileOp.getAfterBody();
if (!llvm::hasSingleElement(afterBlock))
return rewriter.notifyMatchFailure(whileOp,
@@ -652,9 +673,10 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}
-LogicalResult
-IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
- PatternRewriter &rewriter) const {
+LogicalResult IndexSwitchLowering::matchAndRewrite(
+ IndexSwitchOp op,
+ typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
// Split the block at the op.
Block *condBlock = rewriter.getInsertionBlock();
Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
@@ -714,8 +736,10 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
return success();
}
-LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
- PatternRewriter &rewriter) const {
+LogicalResult ForallLowering::matchAndRewrite(
+ ForallOp forallOp,
+ typename OpConversionPattern<ForallOp>::OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
return scf::forallToParallelLoop(rewriter, forallOp);
}
|
Please add some context in the description, explaining why this is useful/desirable. |
done! @joker-eph, relatedly mind giving it a review? |
Thanks for elaborating, but I still don't understand: is this something more specific we could describe here? The debug trace shows a crash on null values in the folder, and then shows that this was created by a rollback in the conversion dialect infra? Ping @matthias-springer who may help narrowing down the criteria that makes this needed to be an OpConversion. (also I don't think the long debug trace belongs to the commit message) |
yeah I guess at a high-level it's illegal to combine rewrite patterns and conversion patterns. Nearly every other pass inside of Conversion/ (the other exception MathToLibm.cpp I'm making a PR for concurrently) use conversion patterns. That alone imo merits the change, but re the segfault above -- this means that one cannot combine several conversion patterns in a single pass legally, at least without rewriting. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand what causes the crash here, but this PR makes sense to me. Conversion patterns and rewrite patterns do not compose well when running with "rollback enabled". The name of the populateSCFToControlFlowConversionPatterns
function suggests that these should be conversion patterns.
However, I doubt that whatever crash you're seeing would be fixed by this change. Turning the rewrite patterns into conversion patterns is adding an indicator to the user what kind of things are allowed in the matchAndRewrite
implementation. (E.g., replaceAllUsesWith
is not allowed in a dialect conversion.) But that implementation did not actually change, so I believe this PR is NFC for existing users that populate those patterns in a dialect conversion.
Alternatively, the name of the function could be renamed to populateSCFToControlFlowPatterns
, and the dialect conversion be replaced with the walkPatterns
driver.
Looking at your stack trace:
One possible scenario where the rollback of an operation created could introduce a NULL value:
The problem with this example is that |
@matthias-springer : how is updating the patterns from OpRewritePattern to OpConversionPattern changing anything to what you're describing? Edit: I missed the first message, you acknowledged it:
So we don't have an explanation yet.
That seems like something to consider: keeping compatibility with the walk-and-apply driver means the fastest path is always available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(marking as blocking pending resolution of the current discussion)
Given that the walk-patterns driver has much less overhead, it would make sense to switch to that one and leave the patterns here as rewrite patterns. |
To be clear I didn't confirm that these patches resolve the segfault I am seeing, but considering there is already illegal (and misnamed) use of rewriters in a conversion pattern, I figure it would be a good first place to start. Patching in LLVM to do an integration test unfortunately requires me to rebase xla and a bunch of other big deps to a specific llvm which might be hard -- but I'll try to work it in and test shortly. |
It's a good start. But I think using the walk-pattern driver instead of a dialect conversion for this pass would be even better here. |
assuming I applied your patch correctly, it didn't hit anything that changes the error. also for ease, here's the full debug log from a mwe example we have: EnzymeAD/Enzyme-JAX#1317 (comment) |
Another thing you could try to help narrowing down the issue: Set this variable to |
okay figured out and fixedthe source of the segfault from our end which is indeed unrelated to this PR (essentially blockargument being added during conversion, outside of pattern infra). This still might be useful to do (will leave that discussion to you guys being more familiar with this part of the code) |
This leads to catastrophic failures, especially when using conversion patterns inside the greedy driver that doesn't provide sufficient rewriter infrastructure.
As this is in conversion (and exported in populate functions), this should be a conversion rather than rewriter. And indeed I do see a segfault due to a mismatch of rewrite/conversion in a downstream project.
x/ref #154083 EnzymeAD/Enzyme-JAX#1278