diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 03180a687523bf..af4f13dc09360d 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -168,24 +168,28 @@ def NoneType : Type($_self)">, "none type", BuildableType<"$_builder.getType<::mlir::NoneType>()">; // Any type from the given list -class AnyTypeOf allowedTypes, string summary = "", +class AnyTypeOf allowedTypeList, string summary = "", string cppClassName = "::mlir::Type"> : Type< // Satisfy any of the allowed types' conditions. - Or, + Or, !if(!eq(summary, ""), - !interleave(!foreach(t, allowedTypes, t.summary), " or "), + !interleave(!foreach(t, allowedTypeList, t.summary), " or "), summary), - cppClassName>; + cppClassName> { + list allowedTypes = allowedTypeList; +} // A type that satisfies the constraints of all given types. -class AllOfType allowedTypes, string summary = "", +class AllOfType allowedTypeList, string summary = "", string cppClassName = "::mlir::Type"> : Type< - // Satisfy all of the allowedf types' conditions. - And, + // Satisfy all of the allowed types' conditions. + And, !if(!eq(summary, ""), - !interleave(!foreach(t, allowedTypes, t.summary), " and "), + !interleave(!foreach(t, allowedTypeList, t.summary), " and "), summary), - cppClassName>; + cppClassName> { + list allowedTypes = allowedTypeList; +} // A type that satisfies additional predicates. class ConfinedType predicates, string summary = "", diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td index 57ae8afbba5eeb..5b9e756727cb36 100644 --- a/mlir/test/tblgen-to-irdl/CMathDialect.td +++ b/mlir/test/tblgen-to-irdl/CMathDialect.td @@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> { } // CHECK: irdl.operation @identity { -// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa($_self))" +// CHECK-NEXT: %0 = irdl.base "!cmath.complex" // CHECK-NEXT: irdl.operands() // CHECK-NEXT: irdl.results(%0) // CHECK-NEXT: } @@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> { } // CHECK: irdl.operation @mul { -// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa($_self))" -// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa($_self))" -// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa($_self))" +// CHECK-NEXT: %0 = irdl.base "!cmath.complex" +// CHECK-NEXT: %1 = irdl.base "!cmath.complex" +// CHECK-NEXT: %2 = irdl.base "!cmath.complex" // CHECK-NEXT: irdl.operands(%0, %1) // CHECK-NEXT: irdl.results(%2) // CHECK-NEXT: } @@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> { } // CHECK: irdl.operation @norm { -// CHECK-NEXT: %0 = irdl.c_pred "(true)" -// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa($_self))" +// CHECK-NEXT: %0 = irdl.any +// CHECK-NEXT: %1 = irdl.base "!cmath.complex" // CHECK-NEXT: irdl.operands(%0) // CHECK-NEXT: irdl.results(%1) // CHECK-NEXT: } diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td new file mode 100644 index 00000000000000..fc40da527db00a --- /dev/null +++ b/mlir/test/tblgen-to-irdl/TestDialect.td @@ -0,0 +1,74 @@ +// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" + +// CHECK-LABEL: irdl.dialect @test { +def Test_Dialect : Dialect { + let name = "test"; +} + +class Test_Type traits = []> +: TypeDef { + let mnemonic = typeMnemonic; +} + +class Test_Op traits = []> + : Op; + +def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {} +def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {} +def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {} + + +// Check that AllOfType is converted correctly. +def Test_AndOp : Test_Op<"and"> { + let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in); +} +// CHECK-LABEL: irdl.operation @and { +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a" +// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any +// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) +// CHECK-NEXT: irdl.operands(%[[v2]]) +// CHECK-NEXT: irdl.results() +// CHECK-NEXT: } + + +// Check that AnyType is converted correctly. +def Test_AnyOp : Test_Op<"any"> { + let arguments = (ins AnyType:$in); +} +// CHECK-LABEL: irdl.operation @any { +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any +// CHECK-NEXT: irdl.operands(%[[v0]]) +// CHECK-NEXT: irdl.results() +// CHECK-NEXT: } + + +// Check that AnyTypeOf is converted correctly. +def Test_OrOp : Test_Op<"or"> { + let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in); +} +// CHECK-LABEL: irdl.operation @or { +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a" +// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b" +// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c" +// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) +// CHECK-NEXT: irdl.operands(%[[v3]]) +// CHECK-NEXT: irdl.results() +// CHECK-NEXT: } + + +// Check that variadics and optionals are converted correctly. +def Test_VariadicityOp : Test_Op<"variadicity"> { + let arguments = (ins Variadic:$variadic, + Optional:$optional, + Test_SingletonCType:$required); +} +// CHECK-LABEL: irdl.operation @variadicity { +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a" +// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b" +// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c" +// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]]) +// CHECK-NEXT: irdl.results() +// CHECK-NEXT: } diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp index ba5bf4d9d4abbc..a55f3539f31db0 100644 --- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp +++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp @@ -39,15 +39,49 @@ llvm::cl::opt selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), llvm::cl::cat(dialectGenCat), llvm::cl::Required); -irdl::CPredOp createConstraint(OpBuilder &builder, - NamedTypeConstraint namedConstraint) { +Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) { MLIRContext *ctx = builder.getContext(); - // Build the constraint as a string. - std::string constraint = - namedConstraint.constraint.getPredicate().getCondition(); + const Record &predRec = constraint.getDef(); + + if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional")) + return createConstraint(builder, predRec.getValueAsDef("baseType")); + + if (predRec.getName() == "AnyType") { + auto op = builder.create(UnknownLoc::get(ctx)); + return op.getOutput(); + } + + if (predRec.isSubClassOf("TypeDef")) { + std::string typeName = ("!" + predRec.getValueAsString("typeName")).str(); + auto op = builder.create(UnknownLoc::get(ctx), + StringAttr::get(ctx, typeName)); + return op.getOutput(); + } + + if (predRec.isSubClassOf("AnyTypeOf")) { + std::vector constraints; + for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) { + constraints.push_back( + createConstraint(builder, tblgen::Constraint(child))); + } + auto op = builder.create(UnknownLoc::get(ctx), constraints); + return op.getOutput(); + } + + if (predRec.isSubClassOf("AllOfType")) { + std::vector constraints; + for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) { + constraints.push_back( + createConstraint(builder, tblgen::Constraint(child))); + } + auto op = builder.create(UnknownLoc::get(ctx), constraints); + return op.getOutput(); + } + + std::string condition = constraint.getPredicate().getCondition(); // Build a CPredOp to match the C constraint built. irdl::CPredOp op = builder.create( - UnknownLoc::get(ctx), StringAttr::get(ctx, constraint)); + UnknownLoc::get(ctx), StringAttr::get(ctx, condition)); return op; } @@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder, SmallVector operands; SmallVector variadicity; for (const NamedTypeConstraint &namedCons : namedCons) { - auto operand = createConstraint(consBuilder, namedCons); + auto operand = createConstraint(consBuilder, namedCons.constraint); operands.push_back(operand); irdl::VariadicityAttr var;