diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 9d7cfa7f840b29..4f9a602233536b 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -54,6 +54,34 @@ bool isSignatureLegal(FunctionType ty) { return isLegal(llvm::concat(ty.getInputs(), ty.getResults())); } +struct ConvertLoad final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(op, operands.getMemref(), + operands.getIndices()); + return success(); + } +}; + +struct ConvertStore final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + + auto subscript = rewriter.create( + op.getLoc(), operands.getMemref(), operands.getIndices()); + rewriter.replaceOpWithNewOp(op, subscript, + operands.getValue()); + return success(); + } +}; + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase { void runOnOperation() override { @@ -91,6 +119,7 @@ struct ConvertMemRefToEmitCPass target.addDynamicallyLegalDialect( [](Operation *op) { return isLegal(op); }); target.addIllegalDialect(); + target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -106,6 +135,8 @@ void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); populateReturnOpTypeConversionPattern(patterns, converter); + patterns.add(converter, patterns.getContext()); + patterns.add(converter, patterns.getContext()); } std::unique_ptr> mlir::createConvertMemRefToEmitCPass() { diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 5114d6d91e5046..b41e27388f2c8e 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -22,3 +22,17 @@ func.func @memref_call(%arg0 : memref<32xf32>) { func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32> func.return } + +// ----- + +// CHECK-LABEL: memref_load_store +// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32> +// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index +func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) { + // CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32> + %0 = memref.load %in[%i, %j] : memref<4x8xf32> + // CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32> + // CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32 + memref.store %0, %out[%i, %j] : memref<3x5xf32> + return +}