diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index a6f668b26aa10e..5b38e83536633a 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -148,11 +148,17 @@ std::unique_ptr createBufferLoopHoistingPass(); // Options struct for BufferResultsToOutParams pass. // Note: defined only here, not in tablegen. struct BufferResultsToOutParamsOptions { + /// Memcpy function: Generate a memcpy between two memrefs. + using MemCpyFn = + std::function; + // Filter function; returns true if the function should be converted. // Defaults to true, i.e. all functions are converted. llvm::function_ref filterFn = [](func::FuncOp *func) { return true; }; + + std::optional memCpyFn; }; /// Creates a pass that converts memref function results to out-params. diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index e09c63295515cb..644a6ed2566e5c 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -565,4 +565,36 @@ def EmitC_IfOp : EmitC_Op<"if", let hasCustomAssemblyFormat = 1; } +def EmitC_SubscriptOp : EmitC_Op<"subscript", + [TypesMatchWith<"result type matches element type of 'array'", + "array", "result", + "::llvm::cast($_self).getElementType()">]> { + let summary = "Array subscript operation"; + let description = [{ + With the `subscript` operation the subscript operator `[]` can be applied + to variables or arguments of array type. + + Example: + + ```mlir + %i = index.constant 1 + %j = index.constant 7 + %0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32 + ``` + }]; + let arguments = (ins Arg:$array, + Variadic:$indices); + let results = (outs AnyType:$result); + + let builders = [ + OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{ + build($_builder, $_state, cast(array.getType()).getElementType(), array, indices); + }]> + ]; + + let hasVerifier = 1; + let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)"; +} + + #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 9d7cfa7f840b29..7eedd3a8769ff0 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -54,6 +54,60 @@ bool isSignatureLegal(FunctionType ty) { return isLegal(llvm::concat(ty.getInputs(), ty.getResults())); } +struct ConvertLoad final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(op, operands.getMemref(), + operands.getIndices()); + return success(); + } +}; + +struct ConvertStore final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + auto subscript = rewriter.create( + op.getLoc(), operands.getMemref(), operands.getIndices()); + rewriter.replaceOpWithNewOp(op, subscript, + operands.getValue()); + return success(); + } +}; + +struct ConvertAlloca final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocaOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure( + op.getLoc(), "cannot transform alloca with dynamic shape"); + } + + if (op.getAlignment().value_or(1) > 1) { + // TODO: Allow alignment if it is not more than the natural alignment + // of the C array. + return rewriter.notifyMatchFailure( + op.getLoc(), "cannot transform alloca with alignment requirement"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + auto noInit = emitc::OpaqueAttr::get(getContext(), ""); + rewriter.replaceOpWithNewOp(op, resultTy, noInit); + return success(); + } +}; + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase { void runOnOperation() override { @@ -91,6 +145,7 @@ struct ConvertMemRefToEmitCPass target.addDynamicallyLegalDialect( [](Operation *op) { return isLegal(op); }); target.addIllegalDialect(); + target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -106,6 +161,8 @@ void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); populateReturnOpTypeConversionPattern(patterns, converter); + patterns.add(converter, + patterns.getContext()); } std::unique_ptr> mlir::createConvertMemRefToEmitCPass() { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index dd359c2dcca5dd..930f035339c1d3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -21,6 +21,7 @@ namespace bufferization { } // namespace mlir using namespace mlir; +using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn; /// Return `true` if the given MemRef type has a fully dynamic layout. static bool hasFullyDynamicLayoutMap(MemRefType type) { @@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func, // Updates all ReturnOps in the scope of the given func::FuncOp by either // keeping them as return values or copying the associated buffer contents into // the given out-params. -static void updateReturnOps(func::FuncOp func, - ArrayRef appendedEntryArgs) { - func.walk([&](func::ReturnOp op) { +static LogicalResult updateReturnOps(func::FuncOp func, + ArrayRef appendedEntryArgs, + MemCpyFn memCpyFn) { + auto res = func.walk([&](func::ReturnOp op) { SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; for (Value operand : op.getOperands()) { @@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func, keepAsReturnOperands.push_back(operand); } OpBuilder builder(op); - for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) - builder.create(op.getLoc(), std::get<0>(t), - std::get<1>(t)); + for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { + if (failed( + memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t)))) + return WalkResult::interrupt(); + } builder.create(op.getLoc(), keepAsReturnOperands); op.erase(); + return WalkResult::advance(); }); + return failure(res.wasInterrupted()); } // Updates all CallOps in the scope of the given ModuleOp by allocating @@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( return failure(); if (func.isExternal()) continue; - updateReturnOps(func, appendedEntryArgs); + auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from, + Value to) { + builder.create(loc, from, to); + return success(); + }; + if (failed(updateReturnOps(func, appendedEntryArgs, + options.memCpyFn.value_or(defaultMemCpyFn)))) { + return failure(); + } } if (failed(updateCalls(module, options))) return failure(); diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 776285d842db97..2d578d47aa4a88 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -102,9 +102,10 @@ LogicalResult ApplyOp::verify() { LogicalResult emitc::AssignOp::verify() { Value variable = getVar(); Operation *variableDef = variable.getDefiningOp(); - if (!variableDef || !llvm::isa(variableDef)) + if (!variableDef || + !llvm::isa(variableDef)) return emitOpError() << "requires first operand (" << variable - << ") to be a Variable"; + << ") to be a Variable or subscript"; Value value = getValue(); if (variable.getType() != value.getType()) @@ -530,6 +531,20 @@ LogicalResult emitc::VariableOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// SubscriptOp +//===----------------------------------------------------------------------===// + +LogicalResult emitc::SubscriptOp::verify() { + if (getIndices().size() != (size_t)getArray().getType().getRank()) { + return emitOpError() << "requires number of indices (" + << getIndices().size() + << ") to match the rank of the array type (" + << getArray().getType().getRank() << ")"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index d4dadc12d41de9..1edf679390d7d4 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -122,6 +122,9 @@ struct CppEmitter { /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); + // Returns the textual representation of a subscript operation. + std::string getSubscriptName(emitc::SubscriptOp op); + /// Return the existing or a new label of a Block. StringRef getOrCreateName(Block &block); @@ -251,8 +254,7 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, emitc::AssignOp assignOp) { - auto variableOp = cast(assignOp.getVar().getDefiningOp()); - OpResult result = variableOp->getResult(0); + OpResult result = assignOp.getVar().getDefiningOp()->getResult(0); if (failed(emitter.emitVariableAssignment(result))) return failure(); @@ -262,6 +264,13 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::SubscriptOp subscriptOp) { + // Add name to cache so that `hasValueInScope` works. + emitter.getOrCreateName(subscriptOp.getResult()); + return success(); +} + static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator) { @@ -706,12 +715,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) labelInScopeCount.push(0); } +std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << getOrCreateName(op.getArray()); + for (auto index : op.getIndices()) { + ss << "[" << getOrCreateName(index) << "]"; + } + return out; +} + /// Return the existing or a new name for a Value. StringRef CppEmitter::getOrCreateName(Value val) { if (auto literal = dyn_cast_if_present(val.getDefiningOp())) return literal.getValue(); - if (!valueMapper.count(val)) - valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + if (!valueMapper.count(val)) { + if (auto subscript = + dyn_cast_if_present(val.getDefiningOp())) { + valueMapper.insert(val, getSubscriptName(subscript)); + } else { + valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + } + } return *valueMapper.begin(val); } @@ -891,6 +916,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, bool trailingSemicolon) { + if (isa(result.getDefiningOp())) + return success(); if (hasValueInScope(result)) { return result.getDefiningOp()->emitError( "result variable for the operation already declared"); @@ -957,7 +984,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp, - emitc::VariableOp>( + emitc::SubscriptOp, emitc::VariableOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( @@ -973,7 +1000,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (failed(status)) return failure(); - if (isa(op)) + if (isa(op)) return success(); os << (trailingSemicolon ? ";\n" : "\n"); diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir index 56d48bef3e1b34..2903fbffccb02a 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir @@ -22,3 +22,20 @@ func.func @memref_op(%arg0 : memref<2x4xf32>) { memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> return } + +// ----- + +func.func @alloca_with_dynamic_shape() { + %0 = index.constant 1 + // expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}} + %1 = memref.alloca(%0) : memref<4x?xf32> + return +} + +// ----- + +func.func @alloca_with_alignment() { + // expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}} + %1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32> + return +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 5114d6d91e5046..f3bc5a5124bf02 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -22,3 +22,26 @@ func.func @memref_call(%arg0 : memref<32xf32>) { func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32> func.return } + +// ----- + +// CHECK-LABEL: memref_load_store +// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32> +// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index +func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) { + // CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32> + %0 = memref.load %in[%i, %j] : memref<4x8xf32> + // CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32> + // CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32 + memref.store %0, %out[%i, %j] : memref<3x5xf32> + return +} + +// ----- + +// CHECK-LABEL: alloca +func.func @alloca() { + // CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32> + %0 = memref.alloca() : memref<4x8xf32> + return +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 49efb962dfa257..fd79bbd8a1d308 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -224,3 +224,11 @@ func.func @test_assign_type_mismatch(%arg1: f32) { emitc.assign %arg1 : f32 to %v : i32 return } + +// ----- + +func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) { + // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}} + %0 = emitc.subscript %arg0[%arg2] : <4x8xf32> + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index b3a24c26b96cab..d280f12b78516a 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -149,3 +149,11 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) { } return } + +func.func @test_subscript(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, + %arg2: index, %arg3: index) { + %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32> + %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32> + emitc.assign %0 : f32 to %1 : f32 + return +} diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir new file mode 100644 index 00000000000000..2c771a30bd5c17 --- /dev/null +++ b/mlir/test/Target/Cpp/subscript.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s + +func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) { + %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32> + %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32> + emitc.assign %0 : f32 to %1 : f32 + return +} +// CHECK: void load_store(float [[V1:[^ ]*]][4][8], float [[V2:[^ ]*]][3][5], +// CHECK-SAME: size_t [[V3:[^ ]*]], size_t [[V4:[^ ]*]]) +// CHECK-NEXT: [[V2]][[[V3]]][[[V4]]] = [[V1]][[[V3]]][[[V4]]];