From fdd7e25a1840eff4f33bb9a0369cd4bc36c95dde Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 10 Aug 2025 11:41:51 +0000 Subject: [PATCH] [mlir][linalg] Migrate Detensorize pass to new dialect conversion driver --- .../Dialect/Linalg/Transforms/Detensorize.cpp | 34 +++++++++++++++++-- mlir/test/Dialect/Linalg/detensorize_0d.mlir | 9 ++--- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 830905495e759..221f95a8d8f33 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -458,6 +458,22 @@ struct LinalgDetensorize } }; + /// A listener that forwards notifyBlockErased and notifyOperationErased to + /// the given callbacks. + struct CallbackListener : public RewriterBase::Listener { + CallbackListener(std::function onOperationErased, + std::function onBlockErased) + : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {} + + void notifyBlockErased(Block *block) override { onBlockErased(block); } + void notifyOperationErased(Operation *op) override { + onOperationErased(op); + } + + std::function onOperationErased; + std::function onBlockErased; + }; + void runOnOperation() override { MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; @@ -551,8 +567,22 @@ struct LinalgDetensorize populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, shouldConvertBranchOperand); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + auto onOperationErased = [&](Operation *op) { + opsToDetensor.erase(op); + detensorableBranchOps.erase(op); + }; + auto onBlockErased = [&](Block *block) { + for (BlockArgument arg : block->getArguments()) { + blockArgsToDetensor.erase(arg); + } + }; + CallbackListener listener(onOperationErased, onBlockErased); + + config.listener = &listener; + config.allowPatternRollback = false; + if (failed(applyFullConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); RewritePatternSet canonPatterns(context); diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir index 74931cb0830bc..76e8c7e8daba9 100644 --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tenso } // CHECK-LABEL: func @detensor_op_sequence // CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) -// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] -// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] -// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] -// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] +// CHECK: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]] +// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] // CHECK: return %[[new_tensor_res]]