Skip to content

Commit

Permalink
[FXML-4791] Lower memref expand/collapse to EmitC (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
cferry-AMD authored Sep 9, 2024
1 parent dec1017 commit dfb5921
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 5 deletions.
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {

def EmitC_CastOp : EmitC_Op<"cast",
[CExpression,
DeclareOpInterfaceMethods<CastOpInterface>,
SameOperandsAndResultShape]> {
DeclareOpInterfaceMethods<CastOpInterface>]> {
let summary = "Cast operation";
let description = [{
The `cast` operation performs an explicit type conversion and is emitted
Expand Down
67 changes: 66 additions & 1 deletion mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -166,6 +168,68 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};

struct ConvertCollapseShape final
: public OpConversionPattern<memref::CollapseShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::CollapseShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

// Do not generate casts between arrays with dynamic shapes
if (!arrayValue.getType().hasStaticShape())
return rewriter.notifyMatchFailure(op.getLoc(),
"dynamic shapes not supported");
auto newCastOp = rewriter.create<emitc::CastOp>(op->getLoc(), resultTy,
operands.getSrc());
newCastOp.setReference(true);
rewriter.replaceOp(op, newCastOp);
return success();
}
};

struct ConvertExpandShape final
: public OpConversionPattern<memref::ExpandShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::ExpandShapeOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSrc());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

// Do not generate casts between arrays with dynamic shapes
if (!arrayValue.getType().hasStaticShape())
return rewriter.notifyMatchFailure(op.getLoc(),
"dynamic shapes not supported");

auto newCastOp = rewriter.create<emitc::CastOp>(op->getLoc(), resultTy,
operands.getSrc());
newCastOp.setReference(true);
rewriter.replaceOp(op, newCastOp);
return success();
}
};

} // namespace

void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
Expand All @@ -187,5 +251,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
ConvertStore, ConvertCollapseShape, ConvertExpandShape>(
converter, patterns.getContext());
}
21 changes: 20 additions & 1 deletion mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,15 @@ LogicalResult emitc::AssignOp::verify() {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
Type input = inputs.front(), output = outputs.front();

// Cast to array is only possible from an array
if (isa<emitc::ArrayType>(input) != isa<emitc::ArrayType>(output))
return false;

// Arrays can be casted to arrays by reference.
if (isa<emitc::ArrayType>(input) && isa<emitc::ArrayType>(output))
return true;

// Scalars
return (
(emitc::isIntegerIndexOrOpaqueType(input) ||
emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
Expand All @@ -236,7 +245,15 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}

LogicalResult CastOp::verify() {
if (getReference())
bool isReference = getReference();

if (isa<emitc::ArrayType>(getDest().getType())) {
if (!isReference)
return emitOpError("cast of array must bear a reference");
return success();
}

if (isReference)
return emitOpError("cast of value type must not bear a reference");

return success();
Expand Down Expand Up @@ -954,6 +971,8 @@ LogicalResult emitc::ArrayType::verify(
for (int64_t dim : shape) {
if (dim < 0)
return emitError() << "dimensions must have non-negative size";
if (dim == ShapedType::kDynamic)
return emitError() << "dimensions must have static size";
}

if (!elementType)
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,21 @@ func.func @zero_rank() {

// expected-error@+1 {{failed to legalize operation 'memref.global'}}
memref.global "nested" constant @nested_global : memref<3x7xf32>

// -----

// CHECK-LABEL: memref_expand_dyn_shape
func.func @memref_expand_dyn_shape(%arg: memref<?xi32>, %size: index) -> memref<?x5xi32> {
// expected-error@+1 {{failed to legalize operation 'memref.expand_shape'}}
%0 = memref.expand_shape %arg [[0, 1]] output_shape [%size, 5] : memref<?xi32> into memref<?x5xi32>
return %0 : memref<?x5xi32>
}

// -----

// CHECK-LABEL: memref_collapse_dyn_shape
func.func @memref_collapse_dyn_shape(%arg: memref<?x5xi32>) -> memref<?xi32> {
// expected-error@+1 {{failed to legalize operation 'memref.collapse_shape'}}
%0 = memref.collapse_shape %arg [[0, 1]] : memref<?x5xi32> into memref<?xi32>
return %0 : memref<?xi32>
}
19 changes: 19 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,22 @@ func.func @memref_index_values(%i: index, %j: index) -> index {
// CHECK: return %[[CAST_RET]] : index
return %1 : index
}

// -----

// CHECK-LABEL: memref_expand_shape
func.func @memref_expand_shape(%arg: memref<10xi32>) -> memref<2x5xi32> {
// CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<10xi32> to !emitc.array<2x5xi32> ref
%0 = memref.expand_shape %arg [[0, 1]] output_shape [2, 5] : memref<10xi32> into memref<2x5xi32>
return %0 : memref<2x5xi32>
}


// -----

// CHECK-LABEL: memref_collapse_shape
func.func @memref_collapse_shape(%arg: memref<2x5xi32>) -> memref<10xi32> {
// CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref
%0 = memref.collapse_shape %arg [[0, 1]] : memref<2x5xi32> into memref<10xi32>
return %0 : memref<10xi32>
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func.func @cast_tensor(%arg : tensor<f32>) {
// -----

func.func @cast_array(%arg : !emitc.array<4xf32>) {
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
// expected-error @+1 {{'emitc.cast' op cast of array must bear a reference}}
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
return
}
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) {
return
}

func.func @cast_array(%arg : !emitc.array<4xf32>) {
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref
return
}

func.func @c() {
%1 = "emitc.constant"(){value = 42 : i32} : () -> i32
%2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Target/Cpp/cast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@ func.func @cast_ptr(%arg0 : !emitc.ptr<!emitc.opaque<"void">>) {
%1 = emitc.cast %arg0 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
return
}

// CHECK-LABEL: void cast_array
func.func @cast_array(%arg0: !emitc.array<10xi32>) {
// CHECK-NEXT: int32_t (&[[V1:[^ ]*]])[2][5] = (int32_t (&)[2][5]) [[V0:[^ ]*]]
%1 = emitc.cast %arg0 : !emitc.array<10xi32> to !emitc.array<2x5xi32> ref
// CHECK-NEXT: int32_t (&[[V2:[^ ]*]])[10] = (int32_t (&)[10]) [[V1]]
%2 = emitc.cast %1 : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref
return
}

0 comments on commit dfb5921

Please sign in to comment.