Skip to content

Commit fdd7e25

Browse files
[mlir][linalg] Migrate Detensorize pass to new dialect conversion driver
1 parent 0d8aa9d commit fdd7e25

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,22 @@ struct LinalgDetensorize
458458
}
459459
};
460460

461+
/// A listener that forwards notifyBlockErased and notifyOperationErased to
462+
/// the given callbacks.
463+
struct CallbackListener : public RewriterBase::Listener {
464+
CallbackListener(std::function<void(Operation *op)> onOperationErased,
465+
std::function<void(Block *block)> onBlockErased)
466+
: onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
467+
468+
void notifyBlockErased(Block *block) override { onBlockErased(block); }
469+
void notifyOperationErased(Operation *op) override {
470+
onOperationErased(op);
471+
}
472+
473+
std::function<void(Operation *op)> onOperationErased;
474+
std::function<void(Block *block)> onBlockErased;
475+
};
476+
461477
void runOnOperation() override {
462478
MLIRContext *context = &getContext();
463479
DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,22 @@ struct LinalgDetensorize
551567
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
552568
shouldConvertBranchOperand);
553569

554-
if (failed(
555-
applyFullConversion(getOperation(), target, std::move(patterns))))
570+
ConversionConfig config;
571+
auto onOperationErased = [&](Operation *op) {
572+
opsToDetensor.erase(op);
573+
detensorableBranchOps.erase(op);
574+
};
575+
auto onBlockErased = [&](Block *block) {
576+
for (BlockArgument arg : block->getArguments()) {
577+
blockArgsToDetensor.erase(arg);
578+
}
579+
};
580+
CallbackListener listener(onOperationErased, onBlockErased);
581+
582+
config.listener = &listener;
583+
config.allowPatternRollback = false;
584+
if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
585+
config)))
556586
signalPassFailure();
557587

558588
RewritePatternSet canonPatterns(context);

mlir/test/Dialect/Linalg/detensorize_0d.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso
5353
}
5454
// CHECK-LABEL: func @detensor_op_sequence
5555
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
56-
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
57-
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
58-
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
59-
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
56+
// CHECK: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]]
57+
// CHECK: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
58+
// CHECK: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
59+
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
60+
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]]
6061
// CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
6162
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
6263
// CHECK: return %[[new_tensor_res]]

0 commit comments

Comments
 (0)