Skip to content

Commit

Permalink
[MLIR] MemRefToEmitC: convert memref.load/store to emitc.subscript
Browse files Browse the repository at this point in the history
Pull Request: #119
  • Loading branch information
mgehre-amd committed Feb 28, 2024
1 parent c938fb8 commit bb4596e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
31 changes: 31 additions & 0 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ bool isSignatureLegal(FunctionType ty) {
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
}

struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
operands.getIndices());
return success();
}
};

struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
}
};

struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
void runOnOperation() override {
Expand Down Expand Up @@ -91,6 +119,7 @@ struct ConvertMemRefToEmitCPass
target.addDynamicallyLegalDialect<func::FuncDialect>(
[](Operation *op) { return isLegal(op); });
target.addIllegalDialect<memref::MemRefDialect>();
target.addLegalDialect<emitc::EmitCDialect>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand All @@ -106,6 +135,8 @@ void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
patterns.add<ConvertLoad>(converter, patterns.getContext());
patterns.add<ConvertStore>(converter, patterns.getContext());
}

std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,17 @@ func.func @memref_call(%arg0 : memref<32xf32>) {
func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
func.return
}

// -----

// CHECK-LABEL: memref_load_store
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
// CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
%0 = memref.load %in[%i, %j] : memref<4x8xf32>
// CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
// CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
memref.store %0, %out[%i, %j] : memref<3x5xf32>
return
}

0 comments on commit bb4596e

Please sign in to comment.