diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index f6ce553dd899a0..da896d03cd961d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -230,6 +230,33 @@ struct ConvertExpandShape final } }; +struct ConvertReinterpretCast final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = + dyn_cast>(operands.getSource()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + + auto newCastOp = rewriter.create(op->getLoc(), resultTy, + operands.getSource()); + newCastOp.setReference(true); + rewriter.replaceOp(op, newCastOp); + return success(); + } +}; + } // namespace void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { @@ -251,6 +278,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter) { patterns.add( - converter, patterns.getContext()); + ConvertStore, ConvertCollapseShape, ConvertExpandShape, + ConvertReinterpretCast>(converter, patterns.getContext()); } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index 4df7bac0b55806..87ed7a63b9b1c5 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -61,3 +61,48 @@ func.func @memref_collapse_dyn_shape(%arg: memref) -> memref { %0 = memref.collapse_shape %arg [[0, 1]] : memref into memref return %0 : memref } + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_dyn_shape +func.func @memref_reinterpret_cast_dyn_shape(%arg: memref<2x5xi32>, %size: index) -> memref { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<2x5xi32> to memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_dyn_offset +func.func @memref_reinterpret_cast_dyn_offset(%arg: memref<2x5xi32>, %offset: index) -> memref<10xi32, strided<[1], offset: ?>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [%offset], sizes: [10], strides: [1] : memref<2x5xi32> to memref<10xi32, strided<[1], offset: ?>> + return %0 : memref<10xi32, strided<[1], offset:? >> +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_static_offset +func.func @memref_reinterpret_cast_static_offset(%arg: memref<2x5xi32>) -> memref<10xi32, strided<[1], offset: 10>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [10], sizes: [10], strides: [1] : memref<2x5xi32> to memref<10xi32, strided<[1], offset: 10>> + return %0 : memref<10xi32, strided<[1], offset: 10>> +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_static_strides +func.func @memref_reinterpret_cast_offset(%arg: memref<2x5xi32>) -> memref<10xi32, strided<[2], offset: 0>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [2] : memref<2x5xi32> to memref<10xi32, strided<[2], offset: 0>> + return %0 : memref<10xi32, strided<[2], offset: 0>> +} + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast_dyn_strides +func.func @memref_reinterpret_cast_offset(%arg: memref<2x5xi32>, %stride: index) -> memref<10xi32, strided<[?], offset: 0>> { + // expected-error@+1 {{failed to legalize operation 'memref.reinterpret_cast'}} + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [%stride] : memref<2x5xi32> to memref<10xi32, strided<[?], offset: 0>> + return %0 : memref<10xi32, strided<[?], offset: 0>> +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 1effcb66cd62b3..c6206165266590 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -94,3 +94,12 @@ func.func @memref_collapse_shape(%arg: memref<2x5xi32>) -> memref<10xi32> { %0 = memref.collapse_shape %arg [[0, 1]] : memref<2x5xi32> into memref<10xi32> return %0 : memref<10xi32> } + +// ----- + +// CHECK-LABEL: memref_reinterpret_cast +func.func @memref_reinterpret_cast(%arg: memref<2x5xi32>) -> memref<10xi32> { + // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref<2x5xi32> to memref<10xi32> + return %0 : memref<10xi32> +}