Skip to content

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented Aug 18, 2025

This leads to catastrophic failures, especially when using conversion patterns inside the greedy driver that doesn't provide sufficient rewriter infrastructure.

As this is in conversion (and exported in populate functions), this should be a conversion rather than rewriter. And indeed I do see a segfault due to a mismatch of rewrite/conversion in a downstream project.

x/ref #154083 EnzymeAD/Enzyme-JAX#1278

@wsmoses wsmoses requested review from ftynse and chelini August 18, 2025 08:34
@llvmbot llvmbot added the mlir label Aug 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 18, 2025

@llvm/pr-subscribers-mlir

Author: William Moses (wsmoses)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/154075.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+76-52)
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 37cfc9f2c23e6..d9ec932244770 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -100,11 +100,13 @@ struct SCFToControlFlowPass
 //      |   <%init visible by dominance> |
 //      +--------------------------------+
 //
-struct ForLowering : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
+struct ForLowering : public OpConversionPattern<ForOp> {
+  using OpConversionPattern<ForOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(ForOp forOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(ForOp forOp,
+                  typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 // Create a CFG subgraph for the scf.if operation (including its "then" and
@@ -193,25 +195,31 @@ struct ForLowering : public OpRewritePattern<ForOp> {
 //      | <code after the IfOp>          |
 //      +--------------------------------+
 //
-struct IfLowering : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
+struct IfLowering : public OpConversionPattern<IfOp> {
+  using OpConversionPattern<IfOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(IfOp ifOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(IfOp ifOp,
+                  typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
-struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
-  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+struct ExecuteRegionLowering : public OpConversionPattern<ExecuteRegionOp> {
+  using OpConversionPattern<ExecuteRegionOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(ExecuteRegionOp op,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      ExecuteRegionOp op,
+      typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
-struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
-  using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
+struct ParallelLowering : public OpConversionPattern<mlir::scf::ParallelOp> {
+  using OpConversionPattern<mlir::scf::ParallelOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      mlir::scf::ParallelOp parallelOp,
+      typename OpConversionPattern<mlir::scf::ParallelOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Create a CFG subgraph for this loop construct. The regions of the loop need
@@ -273,41 +281,49 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
 /// the results of the WhileOp are defined in the 'before' region, which is
 /// required to have a single existing block, and are therefore accessible in
 /// the continuation block due to dominance.
-struct WhileLowering : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+  using OpConversionPattern<WhileOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(WhileOp whileOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(WhileOp whileOp,
+                  typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Optimized version of the above for the case of the "after" region merely
 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
 /// loop). This avoid inlining the "after" region completely and branches back
 /// to the "before" entry instead.
-struct DoWhileLowering : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
+struct DoWhileLowering : public OpConversionPattern<WhileOp> {
+  using OpConversionPattern<WhileOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(WhileOp whileOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(WhileOp whileOp,
+                  typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Lower an `scf.index_switch` operation to a `cf.switch` operation.
-struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct IndexSwitchLowering : public OpConversionPattern<IndexSwitchOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(IndexSwitchOp op,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      IndexSwitchOp op,
+      typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
 /// has no shared outputs. Ops with shared outputs should be bufferized first.
 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
 /// dialects/passes.
-struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
-  using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
+struct ForallLowering : public OpConversionPattern<mlir::scf::ForallOp> {
+  using OpConversionPattern<mlir::scf::ForallOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      mlir::scf::ForallOp forallOp,
+      typename OpConversionPattern<mlir::scf::ForallOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
 } // namespace
@@ -325,8 +341,9 @@ static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
   brOp->setDiscardableAttrs(llvmAttrs);
 }
 
-LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
-                                           PatternRewriter &rewriter) const {
+LogicalResult ForLowering::matchAndRewrite(
+    ForOp forOp, typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   Location loc = forOp.getLoc();
 
   // Start by splitting the block containing the 'scf.for' into two parts.
@@ -397,8 +414,9 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
   return success();
 }
 
-LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
-                                          PatternRewriter &rewriter) const {
+LogicalResult IfLowering::matchAndRewrite(
+    IfOp ifOp, typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   auto loc = ifOp.getLoc();
 
   // Start by splitting the block containing the 'scf.if' into two parts.
@@ -453,9 +471,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
   return success();
 }
 
-LogicalResult
-ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
-                                       PatternRewriter &rewriter) const {
+LogicalResult ExecuteRegionLowering::matchAndRewrite(
+    ExecuteRegionOp op,
+    typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   auto loc = op.getLoc();
 
   auto *condBlock = rewriter.getInsertionBlock();
@@ -487,9 +506,10 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
   return success();
 }
 
-LogicalResult
-ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
-                                  PatternRewriter &rewriter) const {
+LogicalResult ParallelLowering::matchAndRewrite(
+    ParallelOp parallelOp,
+    typename OpConversionPattern<ParallelOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   Location loc = parallelOp.getLoc();
   auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
   if (!reductionOp) {
@@ -563,8 +583,9 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
   return success();
 }
 
-LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
-                                             PatternRewriter &rewriter) const {
+LogicalResult WhileLowering::matchAndRewrite(
+    WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   OpBuilder::InsertionGuard guard(rewriter);
   Location loc = whileOp.getLoc();
 
@@ -606,9 +627,9 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
   return success();
 }
 
-LogicalResult
-DoWhileLowering::matchAndRewrite(WhileOp whileOp,
-                                 PatternRewriter &rewriter) const {
+LogicalResult DoWhileLowering::matchAndRewrite(
+    WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   Block &afterBlock = *whileOp.getAfterBody();
   if (!llvm::hasSingleElement(afterBlock))
     return rewriter.notifyMatchFailure(whileOp,
@@ -652,9 +673,10 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
   return success();
 }
 
-LogicalResult
-IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
-                                     PatternRewriter &rewriter) const {
+LogicalResult IndexSwitchLowering::matchAndRewrite(
+    IndexSwitchOp op,
+    typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   // Split the block at the op.
   Block *condBlock = rewriter.getInsertionBlock();
   Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
@@ -714,8 +736,10 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
   return success();
 }
 
-LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ForallLowering::matchAndRewrite(
+    ForallOp forallOp,
+    typename OpConversionPattern<ForallOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   return scf::forallToParallelLoop(rewriter, forallOp);
 }
 

@joker-eph
Copy link
Collaborator

Please add some context in the description, explaining why this is useful/desirable.

@wsmoses wsmoses requested a review from joker-eph August 18, 2025 09:08
@wsmoses
Copy link
Member Author

wsmoses commented Aug 18, 2025

done! @joker-eph, relatedly mind giving it a review?

@joker-eph
Copy link
Collaborator

Thanks for elaborating, but I still don't understand: is this something more specific we could describe here?

The debug trace shows a crash on null values in the folder, and then shows that this was created by a rollback in the conversion dialect infra?
You wrote "when using conversion patterns inside the greedy driver that doesn't provide sufficient rewriter infrastructure" while the issue would come from the dialect conversion here, so I don't quite follow.

Ping @matthias-springer who may help narrowing down the criteria that makes this needed to be an OpConversion.

(also I don't think the long debug trace belongs to the commit message)

@wsmoses
Copy link
Member Author

wsmoses commented Aug 18, 2025

yeah I guess at a high-level it's illegal to combine rewrite patterns and conversion patterns. Nearly every other pass inside of Conversion/ (the other exception MathToLibm.cpp I'm making a PR for concurrently) use conversion patterns. That alone imo merits the change, but re the segfault above -- this means that one cannot combine several conversion patterns in a single pass legally, at least without rewriting.

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand what causes the crash here, but this PR makes sense to me. Conversion patterns and rewrite patterns do not compose well when running with "rollback enabled". The name of the populateSCFToControlFlowConversionPatterns function suggests that these should be conversion patterns.

However, I doubt that whatever crash you're seeing would be fixed by this change. Turning the rewrite patterns into conversion patterns is adding an indicator to the user what kind of things are allowed in the matchAndRewrite implementation. (E.g., replaceAllUsesWith is not allowed in a dialect conversion.) But that implementation did not actually change, so I believe this PR is NFC for existing users that populate those patterns in a dialect conversion.

Alternatively, the name of the function could be renamed to populateSCFToControlFlowPatterns, and the dialect conversion be replaced with the walkPatterns driver.

@matthias-springer
Copy link
Member

matthias-springer commented Aug 18, 2025

Looking at your stack trace:

#2  0x000060fc32746bb7 in mlir::Value::dropAllUses (this=0x7fff2968b3b0) at external/llvm-project/mlir/include/mlir/IR/Value.h:144
#3  0x000060fc32746d84 in mlir::Operation::dropAllUses (this=0x60fc473b1c30) at external/llvm-project/mlir/include/mlir/IR/Operation.h:836
#4  0x000060fc35db9083 in (anonymous namespace)::CreateOperationRewrite::rollback (this=0x60fc473bc590)

One possible scenario where the rollback of an operation created could introduce a NULL value:

  1. Pattern A creates a new operation X.
  2. Pattern (possibly another one) calls rewriter.replaceAllUsesWith(..., X.getResult()).
  3. A is rolled back.

The problem with this example is that replaceAllUsesWith is not supported in a dialect conversion. (But the API surface still exposes it.) The conversion driver (allowPatternRollback = false) adds some preliminary support, but more work is still needed...

@joker-eph
Copy link
Collaborator

joker-eph commented Aug 18, 2025

@matthias-springer : how is updating the patterns from OpRewritePattern to OpConversionPattern changing anything to what you're describing?

Edit: I missed the first message, you acknowledged it:

However, I doubt that whatever crash you're seeing would be fixed by this change. [...] so I believe this PR is NFC for existing users that populate those patterns in a dialect conversion.

So we don't have an explanation yet.

Alternatively, the name of the function could be renamed to populateSCFToControlFlowPatterns, and the dialect conversion be replaced with the walkPatterns driver.

That seems like something to consider: keeping compatibility with the walk-and-apply driver means the fastest path is always available.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(marking as blocking pending resolution of the current discussion)

@matthias-springer
Copy link
Member

That seems like something to consider: keeping compatibility with the walk-and-apply driver means the fastest path is always available.

Given that the walk-patterns driver has much less overhead, it would make sense to switch to that one and leave the patterns here as rewrite patterns.

@matthias-springer
Copy link
Member

@wsmoses Can you see what happens if you patch this in?

@wsmoses
Copy link
Member Author

wsmoses commented Aug 18, 2025

To be clear I didn't confirm that these patches resolve the segfault I am seeing, but considering there is already illegal (and misnamed) use of rewriters in a conversion pattern, I figure it would be a good first place to start.

Patching in LLVM to do an integration test unfortunately requires me to rebase xla and a bunch of other big deps to a specific llvm which might be hard -- but I'll try to work it in and test shortly.

@matthias-springer
Copy link
Member

I figure it would be a good first place to start.

It's a good start. But I think using the walk-pattern driver instead of a dialect conversion for this pass would be even better here.

@wsmoses
Copy link
Member Author

wsmoses commented Aug 19, 2025

assuming I applied your patch correctly, it didn't hit anything that changes the error.

also for ease, here's the full debug log from a mwe example we have: EnzymeAD/Enzyme-JAX#1317 (comment)

@matthias-springer
Copy link
Member

Another thing you could try to help narrowing down the issue: Set this variable to Never. Folding is known to be expensive and cause problems sometimes.

@wsmoses
Copy link
Member Author

wsmoses commented Aug 20, 2025

okay figured out and fixedthe source of the segfault from our end which is indeed unrelated to this PR (essentially blockargument being added during conversion, outside of pattern infra). This still might be useful to do (will leave that discussion to you guys being more familiar with this part of the code)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants