@@ -458,6 +458,22 @@ struct LinalgDetensorize
458
458
}
459
459
};
460
460
461
+ // / A listener that forwards notifyBlockErased and notifyOperationErased to
462
+ // / the given callbacks.
463
+ struct CallbackListener : public RewriterBase ::Listener {
464
+ CallbackListener (std::function<void (Operation *op)> onOperationErased,
465
+ std::function<void (Block *block)> onBlockErased)
466
+ : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
467
+
468
+ void notifyBlockErased (Block *block) override { onBlockErased (block); }
469
+ void notifyOperationErased (Operation *op) override {
470
+ onOperationErased (op);
471
+ }
472
+
473
+ std::function<void (Operation *op)> onOperationErased;
474
+ std::function<void (Block *block)> onBlockErased;
475
+ };
476
+
461
477
void runOnOperation () override {
462
478
MLIRContext *context = &getContext ();
463
479
DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,22 @@ struct LinalgDetensorize
551
567
populateBranchOpInterfaceTypeConversionPattern (patterns, typeConverter,
552
568
shouldConvertBranchOperand);
553
569
554
- if (failed (
555
- applyFullConversion (getOperation (), target, std::move (patterns))))
570
+ ConversionConfig config;
571
+ auto onOperationErased = [&](Operation *op) {
572
+ opsToDetensor.erase (op);
573
+ detensorableBranchOps.erase (op);
574
+ };
575
+ auto onBlockErased = [&](Block *block) {
576
+ for (BlockArgument arg : block->getArguments ()) {
577
+ blockArgsToDetensor.erase (arg);
578
+ }
579
+ };
580
+ CallbackListener listener (onOperationErased, onBlockErased);
581
+
582
+ config.listener = &listener;
583
+ config.allowPatternRollback = false ;
584
+ if (failed (applyFullConversion (getOperation (), target, std::move (patterns),
585
+ config)))
556
586
signalPassFailure ();
557
587
558
588
RewritePatternSet canonPatterns (context);
0 commit comments