Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] BufferResultsToOutParams: Allow to configure memCpyFn #121

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,17 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOptions {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;

// Filter function; returns true if the function should be converted.
// Defaults to true, i.e. all functions are converted.
llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
return true;
};

std::optional<MemCpyFn> memCpyFn;
};

/// Creates a pass that converts memref function results to out-params.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace bufferization {
} // namespace mlir

using namespace mlir;
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;

/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
Expand Down Expand Up @@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
static void updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs) {
func.walk([&](func::ReturnOp op) {
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
MemCpyFn memCpyFn) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
for (Value operand : op.getOperands()) {
Expand All @@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
std::get<1>(t));
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (failed(
memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
return WalkResult::interrupt();
}
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
op.erase();
return WalkResult::advance();
});
return failure(res.wasInterrupted());
}

// Updates all CallOps in the scope of the given ModuleOp by allocating
Expand Down Expand Up @@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
updateReturnOps(func, appendedEntryArgs);
auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
Value to) {
builder.create<memref::CopyOp>(loc, from, to);
return success();
};
if (failed(updateReturnOps(func, appendedEntryArgs,
options.memCpyFn.value_or(defaultMemCpyFn)))) {
return failure();
}
}
if (failed(updateCalls(module, options)))
return failure();
Expand Down
Loading