diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 9cd2ef34e15ea..4749a45e51c1f 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -26,6 +26,7 @@ namespace mlir { class GreedyRewriteConfig; +class RuntimeVerifiableOpInterface; //===----------------------------------------------------------------------===// // Passes @@ -77,6 +78,13 @@ std::unique_ptr createPrintIRPass(const PrintIRPassOptions & = {}); /// Creates a pass that generates IR to verify ops at runtime. std::unique_ptr createGenerateRuntimeVerificationPass(); +/// Create an instance of the generate runtime verification pass, and +/// use the provided filter function to skip certain verifiable ops. +/// The default implementation does not filter any ops. +std::unique_ptr createGenerateRuntimeVerificationPass( + std::function + shouldHandleVerifiableOpFn); + /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. std::unique_ptr createLoopInvariantCodeMotionPass(); diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp index a40bc2b3272fc..214510ca8ccd4 100644 --- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp +++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp @@ -17,16 +17,46 @@ namespace mlir { #include "mlir/Transforms/Passes.h.inc" } // namespace mlir +#define DEBUG_TYPE "generate-runtime-verification" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + using namespace mlir; +static bool defaultShouldHandleVerifiableOpFn(RuntimeVerifiableOpInterface op) { + // By default, all verifiable ops are considered + return true; +} + namespace { struct GenerateRuntimeVerificationPass : public impl::GenerateRuntimeVerificationBase< GenerateRuntimeVerificationPass> { + + GenerateRuntimeVerificationPass(); + GenerateRuntimeVerificationPass(const GenerateRuntimeVerificationPass &) = + default; + GenerateRuntimeVerificationPass( + std::function + shouldHandleVerifiableOpFn); + void runOnOperation() override; + +private: + // A filter function to select verifiable ops to generate verification for. + // If empty, all verifiable ops are considered. + std::function shouldHandleVerifiableOpFn; }; } // namespace +GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass() + : shouldHandleVerifiableOpFn(defaultShouldHandleVerifiableOpFn) {} + +GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass( + std::function + shouldHandleVerifiableOpFn) + : shouldHandleVerifiableOpFn(std::move(shouldHandleVerifiableOpFn)) {} + void GenerateRuntimeVerificationPass::runOnOperation() { // The implementation of the RuntimeVerifiableOpInterface may create ops that // can be verified. We don't want to generate verification for IR that @@ -38,11 +68,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() { OpBuilder builder(getOperation()->getContext()); for (RuntimeVerifiableOpInterface verifiableOp : ops) { - builder.setInsertionPoint(verifiableOp); - verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc()); - }; + if (shouldHandleVerifiableOpFn(verifiableOp)) { + builder.setInsertionPoint(verifiableOp); + verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc()); + } else { + LDBG("Skipping operation: " << verifiableOp.getOperation()); + } + } } std::unique_ptr mlir::createGenerateRuntimeVerificationPass() { return std::make_unique(); } + +std::unique_ptr mlir::createGenerateRuntimeVerificationPass( + std::function + shouldHandleVerifiableOpFn) { + return std::make_unique( + std::move(shouldHandleVerifiableOpFn)); +}