Skip to content

[mlir] Generalize OneShotModuleBufferize to operate on any Operation #148327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct LogicalResult;
} // namespace llvm

namespace mlir {
class ModuleOp;
class Operation;

namespace bufferization {
struct BufferizationStatistics;
Expand All @@ -23,12 +23,13 @@ struct OneShotBufferizationOptions;
class BufferizationState;

/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
/// `state`. This operates on any `SymbolTable` op.
llvm::LogicalResult
analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state,
BufferizationStatistics *statistics = nullptr);

/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
/// Bufferize an `op`s nested ops that implement `BufferizableOpInterface`.
/// This operates on any `SymbolTable` op.
///
/// Note: This function does not run One-Shot Analysis. No buffer copies are
/// inserted except two cases:
Expand All @@ -37,20 +38,20 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter`
/// is not empty. The FuncOps it contains were not analyzed. Buffer copies
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
llvm::LogicalResult bufferizeModuleOp(
Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics = nullptr);

/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
void removeBufferizationAttributesInModule(ModuleOp moduleOp);
/// Remove bufferization attributes on every FuncOp arguments in the SymbolTable
/// op.
void removeBufferizationAttributesInModule(Operation *moduleOp);

/// Run One-Shot Module Bufferization on the given module. Performs a simple
/// function call analysis to determine which function arguments are
/// Run One-Shot Module Bufferization on the given SymbolTable. Performs a
/// simple function call analysis to determine which function arguments are
/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
/// Bufferize.
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
Operation *moduleOp,
const bufferization::OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics = nullptr);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -8,12 +9,13 @@
//
// Module Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
// implementations for FuncOp, CallOp and ReturnOp.
// implementations for FuncOp, CallOp and ReturnOp. Although it is named
// Module Bufferization, it may operate on any SymbolTable.
//
// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
// This function analyzes the given module and determines the order of analysis
// and bufferization: Functions that are called are processed before their
// respective callers.
// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
// ...)`. This function analyzes the given op and determines the order of
// analysis and bufferization: Functions that are called are processed before
// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered and stored in `FuncAnalysisState`.
Expand Down Expand Up @@ -309,34 +311,37 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;

for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
return WalkResult::skip();

callerMap[calledFunction].insert(callOp);
if (calledBy[calledFunction].insert(funcOp).second) {
numberCallOpsContainedInFuncOp[funcOp]++;
for (mlir::Region &region : moduleOp->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature,
// then it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
return WalkResult::skip();

callerMap[calledFunction].insert(callOp);
if (calledBy[calledFunction].insert(funcOp).second) {
numberCallOpsContainedInFuncOp[funcOp]++;
}
return WalkResult::advance();
});
if (res.wasInterrupted())
return failure();
}
return WalkResult::advance();
});
if (res.wasInterrupted())
return failure();
}
}

// Iteratively remove function operations that do not call any of the
Expand Down Expand Up @@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}

LogicalResult
mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
mlir::bufferization::analyzeModuleOp(Operation *moduleOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
Expand Down Expand Up @@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
}

void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
for (auto op : moduleOp.getOps<func::FuncOp>()) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
Operation *moduleOp) {
for (mlir::Region &region : moduleOp->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
for (BlockArgument bbArg : funcOp.getArguments())
removeBufferizationAttributes(bbArg);
}
}
}
}

LogicalResult mlir::bufferization::bufferizeModuleOp(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
IRRewriter rewriter(moduleOp->getContext());

// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
Expand Down Expand Up @@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}

// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
if (failed(bufferizeOp(&op, options, state, statistics)))
return failure();
for (mlir::Region &region : moduleOp->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (mlir::Operation &op :
llvm::make_early_inc_range(block.getOperations())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
if (failed(bufferizeOp(&op, options, state, statistics)))
return failure();
}
}
}

// Post-pass cleanup of function argument attributes.
Expand All @@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}

LogicalResult mlir::bufferization::runOneShotModuleBufferize(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
if (failed(analyzeModuleOp(op, analysisState, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, analysisState, statistics)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ class SparsificationAndBufferizationPass

bufferization::BufferizationState bufferizationState;

if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
updatedOptions,
if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions,
bufferizationState)))
return failure();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s

"test.symbol_scope_isolated"() ({
// CHECK-LABEL: func @inner_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
// CHECK-NOT: copy
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: memref.store %{{.*}}, %[[arg0]]
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
%1 = tensor.extract %0[%c1] : tensor<?xf32>
// CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32
return %0, %1 : tensor<?xf32>, f32
}

// CHECK-LABEL: func @call_func_with_non_tensor_return(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func.func @call_func_with_non_tensor_return(
%t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
// CHECK-NOT: alloc
// CHECK-NOT: copy
// CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
// CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
return %1, %0 : f32, tensor<?xf32>
}
"test.finish" () : () -> ()
}) : () -> ()


1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRBufferizationTestPasses
TestOneShotModuleBufferize.cpp
TestTensorCopyInsertion.cpp
TestTensorLikeAndBufferLike.cpp

Expand Down
57 changes: 57 additions & 0 deletions mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {
struct TestOneShotModuleBufferizePass
: public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)

TestOneShotModuleBufferizePass() = default;
TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass)
: PassWrapper(pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect>();
}
StringRef getArgument() const final {
return "test-one-shot-module-bufferize";
}
StringRef getDescription() const final {
return "Pass to test One Shot Module Bufferization";
}

void runOnOperation() override {

llvm::errs() << "Running TestOneShotModuleBufferize on: "
<< getOperation()->getName() << "\n";
bufferization::OneShotBufferizationOptions opt;

opt.bufferizeFunctionBoundaries = true;
bufferization::BufferizationState bufferizationState;

if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,
bufferizationState)))
signalPassFailure();
}
};
} // namespace

namespace mlir::test {
void registerTestOneShotModuleBufferizePass() {
PassRegistration<TestOneShotModuleBufferizePass>();
}
} // namespace mlir::test
9 changes: 9 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def SymbolScopeOp : TEST_Op<"symbol_scope",
let regions = (region SizedRegion<1>:$region);
}

def SymbolScopeIsolatedOp
: TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable,
SingleBlockImplicitTerminator<
"TerminatorOp">]> {
let summary =
"operation which defines a new symbol table that is IsolatedFromAbove";
let regions = (region SizedRegion<1>:$region);
}

def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
let summary = "operation which defines a new symbol table without a "
"restriction on a terminator";
Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ void registerTestShardSimplificationsPass();
void registerTestMultiBuffering();
void registerTestNextAccessPass();
void registerTestNVGPULowerings();
void registerTestOneShotModuleBufferizePass();
void registerTestOpaqueLoc();
void registerTestOpLoweringPasses();
void registerTestPadFusion();
Expand Down Expand Up @@ -281,6 +282,7 @@ void registerTestPasses() {
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestNVGPULowerings();
mlir::test::registerTestOneShotModuleBufferizePass();
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestOpLoweringPasses();
mlir::test::registerTestPadFusion();
Expand Down