Skip to content

Commit

Permalink
[𝘀𝗽𝗿] initial version
Browse files Browse the repository at this point in the history
Created using spr 1.3.6-beta.1
  • Loading branch information
mgehre-amd committed Feb 28, 2024
2 parents d3c90a4 + fcc2ba7 commit 91441e5
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 15 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,17 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOptions {
/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;

// Filter function; returns true if the function should be converted.
// Defaults to true, i.e. all functions are converted.
llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
return true;
};

std::optional<MemCpyFn> memCpyFn;
};

/// Creates a pass that converts memref function results to out-params.
Expand Down
32 changes: 32 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -565,4 +565,36 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}

def EmitC_SubscriptOp : EmitC_Op<"subscript",
[TypesMatchWith<"result type matches element type of 'array'",
"array", "result",
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
let summary = "Array subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
to variables or arguments of array type.

Example:

```mlir
%i = index.constant 1
%j = index.constant 7
%0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
```
}];
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
Variadic<Index>:$indices);
let results = (outs AnyType:$result);

let builders = [
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
}]>
];

let hasVerifier = 1;
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
}


#endif // MLIR_DIALECT_EMITC_IR_EMITC
57 changes: 57 additions & 0 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,60 @@ bool isSignatureLegal(FunctionType ty) {
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
}

struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
operands.getIndices());
return success();
}
};

struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
}
};

struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with dynamic shape");
}

if (op.getAlignment().value_or(1) > 1) {
// TODO: Allow alignment if it is not more than the natural alignment
// of the C array.
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with alignment requirement");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
return success();
}
};

struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
void runOnOperation() override {
Expand Down Expand Up @@ -91,6 +145,7 @@ struct ConvertMemRefToEmitCPass
target.addDynamicallyLegalDialect<func::FuncDialect>(
[](Operation *op) { return isLegal(op); });
target.addIllegalDialect<memref::MemRefDialect>();
target.addLegalDialect<emitc::EmitCDialect>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand All @@ -106,6 +161,8 @@ void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
patterns.add<ConvertLoad, ConvertStore, ConvertAlloca>(converter,
patterns.getContext());
}

std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace bufferization {
} // namespace mlir

using namespace mlir;
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;

/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
Expand Down Expand Up @@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
static void updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs) {
func.walk([&](func::ReturnOp op) {
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
MemCpyFn memCpyFn) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
for (Value operand : op.getOperands()) {
Expand All @@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
std::get<1>(t));
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (failed(
memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
return WalkResult::interrupt();
}
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
op.erase();
return WalkResult::advance();
});
return failure(res.wasInterrupted());
}

// Updates all CallOps in the scope of the given ModuleOp by allocating
Expand Down Expand Up @@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
updateReturnOps(func, appendedEntryArgs);
auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
Value to) {
builder.create<memref::CopyOp>(loc, from, to);
return success();
};
if (failed(updateReturnOps(func, appendedEntryArgs,
options.memCpyFn.value_or(defaultMemCpyFn)))) {
return failure();
}
}
if (failed(updateCalls(module, options)))
return failure();
Expand Down
19 changes: 17 additions & 2 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ LogicalResult ApplyOp::verify() {
LogicalResult emitc::AssignOp::verify() {
Value variable = getVar();
Operation *variableDef = variable.getDefiningOp();
if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
if (!variableDef ||
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
return emitOpError() << "requires first operand (" << variable
<< ") to be a Variable";
<< ") to be a Variable or subscript";

Value value = getValue();
if (variable.getType() != value.getType())
Expand Down Expand Up @@ -530,6 +531,20 @@ LogicalResult emitc::VariableOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// SubscriptOp
//===----------------------------------------------------------------------===//

LogicalResult emitc::SubscriptOp::verify() {
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
return emitOpError() << "requires number of indices ("
<< getIndices().size()
<< ") to match the rank of the array type ("
<< getArray().getType().getRank() << ")";
}
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 33 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct CppEmitter {
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);

// Returns the textual representation of a subscript operation.
std::string getSubscriptName(emitc::SubscriptOp op);

/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);

Expand Down Expand Up @@ -251,8 +254,7 @@ static LogicalResult printOperation(CppEmitter &emitter,

static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
OpResult result = variableOp->getResult(0);
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);

if (failed(emitter.emitVariableAssignment(result)))
return failure();
Expand All @@ -262,6 +264,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::SubscriptOp subscriptOp) {
// Add name to cache so that `hasValueInScope` works.
emitter.getOrCreateName(subscriptOp.getResult());
return success();
}

static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
Expand Down Expand Up @@ -706,12 +715,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
labelInScopeCount.push(0);
}

std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getArray());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
return out;
}

/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
return literal.getValue();
if (!valueMapper.count(val))
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
if (!valueMapper.count(val)) {
if (auto subscript =
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
valueMapper.insert(val, getSubscriptName(subscript));
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
}
return *valueMapper.begin(val);
}

Expand Down Expand Up @@ -891,6 +916,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {

LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
bool trailingSemicolon) {
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
return success();
if (hasValueInScope(result)) {
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
Expand Down Expand Up @@ -957,7 +984,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConstantOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
emitc::VariableOp>(
emitc::SubscriptOp, emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
Expand All @@ -973,7 +1000,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();

if (isa<emitc::LiteralOp>(op))
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
return success();

os << (trailingSemicolon ? ";\n" : "\n");
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,20 @@ func.func @memref_op(%arg0 : memref<2x4xf32>) {
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
return
}

// -----

func.func @alloca_with_dynamic_shape() {
%0 = index.constant 1
// expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
%1 = memref.alloca(%0) : memref<4x?xf32>
return
}

// -----

func.func @alloca_with_alignment() {
// expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
%1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
return
}
23 changes: 23 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,26 @@ func.func @memref_call(%arg0 : memref<32xf32>) {
func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
func.return
}

// -----

// CHECK-LABEL: memref_load_store
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
// CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
%0 = memref.load %in[%i, %j] : memref<4x8xf32>
// CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
// CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
memref.store %0, %out[%i, %j] : memref<3x5xf32>
return
}

// -----

// CHECK-LABEL: alloca
func.func @alloca() {
// CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
return
}
8 changes: 8 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,11 @@ func.func @test_assign_type_mismatch(%arg1: f32) {
emitc.assign %arg1 : f32 to %v : i32
return
}

// -----

func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
return
}
8 changes: 8 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,11 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
}
return
}

func.func @test_subscript(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>,
%arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
emitc.assign %0 : f32 to %1 : f32
return
}
Loading

0 comments on commit 91441e5

Please sign in to comment.