diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 725a1bcb4e6cb1..fd9525dd1d2858 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -30,6 +30,8 @@ namespace mlir { namespace emitc { void buildTerminatedBody(OpBuilder &builder, Location loc); +/// Determines whether \p type is valid in EmitC. +bool isSupportedEmitCType(mlir::Type type); /// Determines whether \p type is a valid integer type in EmitC. bool isSupportedIntegerType(mlir::Type type); /// Determines whether \p type is a valid floating-point type in EmitC. diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index ee5fc0b09a1611..e5fe70af95c33f 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -34,16 +34,16 @@ class EmitC_Op traits = []> // Base class for unary operations. class EmitC_UnaryOp traits = []> : EmitC_Op { - let arguments = (ins AnyType); - let results = (outs AnyType); + let arguments = (ins EmitCType); + let results = (outs EmitCType); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } // Base class for binary operations. class EmitC_BinaryOp traits = []> : EmitC_Op { - let arguments = (ins AnyType:$lhs, AnyType:$rhs); - let results = (outs AnyType); + let arguments = (ins EmitCType:$lhs, EmitCType:$rhs); + let results = (outs EmitCType); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> { }]; let arguments = (ins Arg:$applicableOperator, - AnyType:$operand + EmitCType:$operand ); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let assemblyFormat = [{ $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results) }]; @@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { Arg:$callee, Arg, "the order of operands and further attributes">:$args, Arg, "template arguments">:$template_args, - Variadic:$operands + Variadic:$operands ); - let results = (outs Variadic); + let results = (outs Variadic); let builders = [ OpBuilder<(ins "::mlir::TypeRange":$resultTypes, @@ -284,8 +284,8 @@ def EmitC_CastOp : EmitC_Op<"cast", ``` }]; - let arguments = (ins AnyType:$source); - let results = (outs AnyType:$dest); + let arguments = (ins EmitCType:$source); + let results = (outs EmitCType:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; } @@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> { }]; let arguments = (ins EmitC_CmpPredicateAttr:$predicate, - AnyType:$lhs, - AnyType:$rhs); - let results = (outs AnyType); + EmitCType:$lhs, + EmitCType:$rhs); + let results = (outs EmitCType); let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)"; } @@ -354,7 +354,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { }]; let arguments = (ins EmitC_OpaqueOrTypedAttr:$value); - let results = (outs AnyType); + let results = (outs EmitCType); let hasFolder = 1; let hasVerifier = 1; @@ -424,7 +424,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression", }]; let arguments = (ins UnitAttr:$do_not_inline); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let regions = (region SizedRegion<1>:$region); let hasVerifier = 1; @@ -532,8 +532,8 @@ def EmitC_CallOp : EmitC_Op<"call", %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); - let results = (outs Variadic); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); let builders = [ OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ @@ -723,7 +723,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, } ``` }]; - let arguments = (ins Optional:$operand); + let arguments = (ins Optional:$operand); let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?"; let hasVerifier = 1; @@ -767,7 +767,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { }]; let arguments = (ins StrAttr:$value); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let hasVerifier = 1; let assemblyFormat = "$value attr-dict `:` type($result)"; @@ -933,8 +933,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional", int32_t v6 = v3 ? v4 : v5; ``` }]; - let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value); - let results = (outs AnyType:$result); + let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value); + let results = (outs EmitCType:$result); let assemblyFormat = "operands attr-dict `:` type($result)"; } @@ -1011,7 +1011,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> { }]; let arguments = (ins EmitC_OpaqueOrTypedAttr:$value); - let results = (outs AnyType); + let results = (outs EmitCType); let hasVerifier = 1; } @@ -1081,7 +1081,7 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global", }]; let arguments = (ins FlatSymbolRefAttr:$name); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let assemblyFormat = "$name `:` type($result) attr-dict"; } @@ -1139,7 +1139,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> { ``` }]; - let arguments = (ins AnyType:$var, AnyType:$value); + let arguments = (ins EmitCType:$var, EmitCType:$value); let results = (outs); let hasVerifier = 1; @@ -1160,7 +1160,7 @@ def EmitC_YieldOp : EmitC_Op<"yield", value is yielded. }]; - let arguments = (ins Optional:$result); + let arguments = (ins Optional:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; let hasVerifier = 1; @@ -1243,7 +1243,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", }]; let arguments = (ins Arg:$array, Variadic:$indices); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let builders = [ OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{ diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index bce5807230ce49..444395b915e250 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td" // EmitC type definitions //===----------------------------------------------------------------------===// +def EmitCType : Type, + "type supported by EmitC">; + def EmitCIntegerType : Type, "integer type supported by EmitC">; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 41e290397e3cfc..d07a4e7a245139 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -10,11 +10,15 @@ #include "mlir/Dialect/EmitC/IR/EmitCTraits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Types.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::emitc; @@ -54,6 +58,40 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { builder.create(loc); } +bool mlir::emitc::isSupportedEmitCType(Type type) { + if (llvm::isa(type)) + return true; + if (auto ptrType = llvm::dyn_cast(type)) + return isSupportedEmitCType(ptrType.getPointee()); + if (auto arrayType = llvm::dyn_cast(type)) { + auto elemType = arrayType.getElementType(); + return !llvm::isa(elemType) && + isSupportedEmitCType(elemType); + } + if (type.isIndex()) + return true; + if (llvm::isa(type)) + return isSupportedIntegerType(type); + if (llvm::isa(type)) + return isSupportedFloatType(type); + if (auto tensorType = llvm::dyn_cast(type)) { + if (!tensorType.hasStaticShape()) { + return false; + } + auto elemType = tensorType.getElementType(); + if (llvm::isa(elemType)) { + return false; + } + return isSupportedEmitCType(elemType); + } + if (auto tupleType = llvm::dyn_cast(type)) { + return llvm::all_of(tupleType.getTypes(), [](Type type) { + return !llvm::isa(type) && isSupportedEmitCType(type); + }); + } + return false; +} + bool mlir::emitc::isSupportedIntegerType(Type type) { if (auto intType = llvm::dyn_cast(type)) { switch (intType.getWidth()) { diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index f9d517bf689b95..0ad8d4eabe6b8b 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -97,3 +97,51 @@ func.func @illegal_float_type(%arg0: f80, %arg1: f80) { %mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80 return } + +// ----- + +func.func @illegal_pointee_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr + return +} + +// ----- + +func.func @illegal_non_static_tensor_shape_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor + return +} + +// ----- + +func.func @illegal_tensor_array_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor> + return +} + +// ----- + +func.func @illegal_tensor_integer_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<9xi11>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<9xi11> + return +} + +// ----- + +func.func @illegal_tuple_array_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple, f32>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple, f32> + return +} + +// ----- + +func.func @illegal_tuple_float_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple + return +}