Skip to content

Commit

Permalink
[mlir][irdl] Add support for basic structural constraints in tblgen-t…
Browse files Browse the repository at this point in the history
…o-irdl (llvm#82862)
  • Loading branch information
math-fehr authored Mar 5, 2024
1 parent 1c2b79a commit a64975f
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 22 deletions.
22 changes: 13 additions & 9 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,28 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
BuildableType<"$_builder.getType<::mlir::NoneType>()">;

// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
cppClassName>;
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypes, string summary = "",
class AllOfType<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy all of the allowedf types' conditions.
And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
// Satisfy all of the allowed types' conditions.
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " and "),
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
cppClassName>;
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/tblgen-to-irdl/CMathDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
}

// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
Expand All @@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
}

// CHECK: irdl.operation @mul {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: %2 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0, %1)
// CHECK-NEXT: irdl.results(%2)
// CHECK-NEXT: }
Expand All @@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> {
}

// CHECK: irdl.operation @norm {
// CHECK-NEXT: %0 = irdl.c_pred "(true)"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.any
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0)
// CHECK-NEXT: irdl.results(%1)
// CHECK-NEXT: }
Expand Down
74 changes: 74 additions & 0 deletions mlir/test/tblgen-to-irdl/TestDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s

include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"

// CHECK-LABEL: irdl.dialect @test {
def Test_Dialect : Dialect {
let name = "test";
}

class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Test_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;

def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}


// Check that AllOfType is converted correctly.
def Test_AndOp : Test_Op<"and"> {
let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
}
// CHECK-LABEL: irdl.operation @and {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: irdl.operands(%[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }


// Check that AnyType is converted correctly.
def Test_AnyOp : Test_Op<"any"> {
let arguments = (ins AnyType:$in);
}
// CHECK-LABEL: irdl.operation @any {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }


// Check that AnyTypeOf is converted correctly.
def Test_OrOp : Test_Op<"or"> {
let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
}
// CHECK-LABEL: irdl.operation @or {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.operands(%[[v3]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }


// Check that variadics and optionals are converted correctly.
def Test_VariadicityOp : Test_Op<"variadicity"> {
let arguments = (ins Variadic<Test_SingletonAType>:$variadic,
Optional<Test_SingletonBType>:$optional,
Test_SingletonCType:$required);
}
// CHECK-LABEL: irdl.operation @variadicity {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
48 changes: 41 additions & 7 deletions mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,49 @@ llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);

irdl::CPredOp createConstraint(OpBuilder &builder,
NamedTypeConstraint namedConstraint) {
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
// Build the constraint as a string.
std::string constraint =
namedConstraint.constraint.getPredicate().getCondition();
const Record &predRec = constraint.getDef();

if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
return createConstraint(builder, predRec.getValueAsDef("baseType"));

if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
return op.getOutput();
}

if (predRec.isSubClassOf("TypeDef")) {
std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
StringAttr::get(ctx, typeName));
return op.getOutput();
}

if (predRec.isSubClassOf("AnyTypeOf")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}

if (predRec.isSubClassOf("AllOfType")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}

std::string condition = constraint.getPredicate().getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
return op;
}

Expand All @@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
auto operand = createConstraint(consBuilder, namedCons);
auto operand = createConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);

irdl::VariadicityAttr var;
Expand Down

0 comments on commit a64975f

Please sign in to comment.