Skip to content

Commit

Permalink
[mlir][emitc] Restrict types in EmitC
Browse files Browse the repository at this point in the history
Use what is currently supported by the emitter to restrict the valid types of EmitC operations. Define three utility functions for valid types, such that they can be used to restrict the operations in the table gen as well as being available for reuse in dialect conversions.
  • Loading branch information
TinaAMD committed Apr 10, 2024
1 parent d288ea3 commit d7b8edc
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 28 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
/// Determines whether \p type is valid in EmitC.
bool isValidEmitCType(mlir::Type type);
/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);
/// Determines whether \p type is a valid floating-point type in EmitC.
Expand Down
56 changes: 28 additions & 28 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ class EmitC_Op<string mnemonic, list<Trait> traits = []>
// Base class for unary operations.
class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
let arguments = (ins AnyType);
let results = (outs AnyType);
let arguments = (ins Valid_EmitC_Type);
let results = (outs Valid_EmitC_Type);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

// Base class for binary operations.
class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
let results = (outs AnyType);
let arguments = (ins Valid_EmitC_Type:$lhs, Valid_EmitC_Type:$rhs);
let results = (outs Valid_EmitC_Type);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

Expand Down Expand Up @@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
}];
let arguments = (ins
Arg<StrAttr, "the operator to apply">:$applicableOperator,
AnyType:$operand
Valid_EmitC_Type:$operand
);
let results = (outs AnyType:$result);
let results = (outs Valid_EmitC_Type:$result);
let assemblyFormat = [{
$applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
}];
Expand Down Expand Up @@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
Arg<StrAttr, "the C++ function to call">:$callee,
Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
Variadic<AnyType>:$operands
Variadic<Valid_EmitC_Type>:$operands
);
let results = (outs Variadic<AnyType>);
let results = (outs Variadic<Valid_EmitC_Type>);
let builders = [
OpBuilder<(ins
"::mlir::TypeRange":$resultTypes,
Expand Down Expand Up @@ -284,15 +284,15 @@ def EmitC_CastOp : EmitC_Op<"cast",
```
}];

let arguments = (ins AnyType:$source);
let results = (outs AnyType:$dest);
let arguments = (ins Valid_EmitC_Type:$source);
let results = (outs Valid_EmitC_Type:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
}

def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
let summary = "Comparison operation";
let description = [{
With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=>
With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=>
can be applied.

Its first argument is an attribute that defines the comparison operator:
Expand All @@ -309,7 +309,7 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
```mlir
// Custom form of the cmp operation.
%0 = emitc.cmp eq, %arg0, %arg1 : (i32, i32) -> i1
%1 = emitc.cmp lt, %arg2, %arg3 :
%1 = emitc.cmp lt, %arg2, %arg3 :
(
!emitc.opaque<"std::valarray<float>">,
!emitc.opaque<"std::valarray<float>">
Expand All @@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
}];

let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
AnyType:$lhs,
AnyType:$rhs);
let results = (outs AnyType);
Valid_EmitC_Type:$lhs,
Valid_EmitC_Type:$rhs);
let results = (outs Valid_EmitC_Type);

let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
}
Expand Down Expand Up @@ -354,7 +354,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
}];

let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
let results = (outs AnyType);
let results = (outs Valid_EmitC_Type);

let hasFolder = 1;
let hasVerifier = 1;
Expand Down Expand Up @@ -424,7 +424,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
}];

let arguments = (ins UnitAttr:$do_not_inline);
let results = (outs AnyType:$result);
let results = (outs Valid_EmitC_Type:$result);
let regions = (region SizedRegion<1>:$region);

let hasVerifier = 1;
Expand Down Expand Up @@ -532,8 +532,8 @@ def EmitC_CallOp : EmitC_Op<"call",
%2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
```
}];
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<Valid_EmitC_Type>:$operands);
let results = (outs Variadic<Valid_EmitC_Type>);

let builders = [
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
Expand Down Expand Up @@ -723,7 +723,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">,
}
```
}];
let arguments = (ins Optional<AnyType>:$operand);
let arguments = (ins Optional<Valid_EmitC_Type>:$operand);

let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?";
let hasVerifier = 1;
Expand Down Expand Up @@ -767,7 +767,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
}];

let arguments = (ins StrAttr:$value);
let results = (outs AnyType:$result);
let results = (outs Valid_EmitC_Type:$result);

let hasVerifier = 1;
let assemblyFormat = "$value attr-dict `:` type($result)";
Expand Down Expand Up @@ -933,8 +933,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional",
int32_t v6 = v3 ? v4 : v5;
```
}];
let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
let results = (outs AnyType:$result);
let arguments = (ins I1:$condition, Valid_EmitC_Type:$true_value, Valid_EmitC_Type:$false_value);
let results = (outs Valid_EmitC_Type:$result);
let assemblyFormat = "operands attr-dict `:` type($result)";
}

Expand Down Expand Up @@ -1011,7 +1011,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
}];

let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
let results = (outs AnyType);
let results = (outs Valid_EmitC_Type);

let hasVerifier = 1;
}
Expand Down Expand Up @@ -1108,7 +1108,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
#endif

...

#ifdef __cplusplus
}
#endif
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
```
}];

let arguments = (ins AnyType:$var, AnyType:$value);
let arguments = (ins Valid_EmitC_Type:$var, Valid_EmitC_Type:$value);
let results = (outs);

let hasVerifier = 1;
Expand All @@ -1160,7 +1160,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
value is yielded.
}];

let arguments = (ins Optional<AnyType>:$result);
let arguments = (ins Optional<Valid_EmitC_Type>:$result);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];

let hasVerifier = 1;
Expand Down Expand Up @@ -1243,7 +1243,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
}];
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
Variadic<IntegerIndexOrOpaqueType>:$indices);
let results = (outs AnyType:$result);
let results = (outs Valid_EmitC_Type:$result);

let builders = [
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// EmitC type definitions
//===----------------------------------------------------------------------===//

def Valid_EmitC_Type : Type<CPred<"emitc::isValidEmitCType($_self)">,
"EmitC dialect type">;

def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
"integer type supported by EmitC">;

Expand Down
42 changes: 42 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"

using namespace mlir;
using namespace mlir::emitc;
Expand Down Expand Up @@ -54,6 +58,44 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<emitc::YieldOp>(loc);
}

bool mlir::emitc::isValidEmitCType(Type type) {
if (isa<emitc::OpaqueType>(type)) {
return true;
}
if (auto ptrType = dyn_cast<emitc::PointerType>(type)) {
return isValidEmitCType(ptrType.getPointee());
}
if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
auto elemType = arrayType.getElementType();
return !isa<emitc::ArrayType>(elemType) && isValidEmitCType(elemType);
}
if (type.isIndex()) {
return true;
}
if (isa<IntegerType>(type)) {
return isSupportedIntegerType(type);
}
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
return isSupportedFloatType(type);
}
if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
if (!tensorType.hasStaticShape()) {
return false;
}
auto elemType = tensorType.getElementType();
if (isa<emitc::ArrayType>(elemType)) {
return false;
}
return isValidEmitCType(elemType);
}
if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
return llvm::all_of(tupleType.getTypes(), [](Type type) {
return !isa<emitc::ArrayType>(type) && isValidEmitCType(type);
});
}
return false;
}

bool mlir::emitc::isSupportedIntegerType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
switch (intType.getWidth()) {
Expand Down

0 comments on commit d7b8edc

Please sign in to comment.