From c3d04b5257adb122684237dab95f252c56d388d0 Mon Sep 17 00:00:00 2001 From: Robert Esclapez-Garcia Date: Tue, 9 Apr 2024 10:29:21 +0100 Subject: [PATCH 1/4] [mlir][PDLInterp]: Allow constraints with no inputs --- .../mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 2 +- .../PDLToPDLInterp/PredicateTree.cpp | 22 ++++++++++++++----- .../pdl-to-pdl-interp-matcher.mlir | 14 ++++++++++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 27bc493a558a72..fc1048909fa7e0 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -106,7 +106,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { DefaultValuedAttr:$isNegated); let results = (outs Variadic:$results); let assemblyFormat = [{ - $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors + $name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict `->` successors }]; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 07bec9efd9104d..031464abead482 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -265,14 +265,24 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, DenseMap &inputs) { OperandRange arguments = op.getArgs(); + Position *pos = nullptr; std::vector allPositions; - allPositions.reserve(arguments.size()); - for (Value arg : arguments) - allPositions.push_back(inputs.lookup(arg)); - // Push the constraint to the furthest position. - Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), - comparePosDepth); + // If this constraint has no arguments, this means it has no dependencies, and + // the same applies to all results + if (arguments.empty()) { + pos = builder.getRoot(); + } else { + allPositions.reserve(arguments.size()); + for (Value arg : arguments) + allPositions.push_back(inputs.lookup(arg)); + + // Push the constraint to the furthest position. + pos = *std::max_element(allPositions.begin(), allPositions.end(), + comparePosDepth); + } + assert(pos && "Must have a non-null value"); + ResultRange results = op.getResults(); PredicateBuilder::Predicate pred = builder.getConstraint( op.getName(), allPositions, SmallVector(results.getTypes()), diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index 92afb765b5ab4e..bdcc6ad492b8ed 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -130,6 +130,20 @@ module @constraint_with_result_multiple { // ----- +// CHECK-LABEL: module @constraint_with_no_inputs +module @constraint_with_no_inputs { + pdl.pattern : benefit(0) { + %attr = apply_native_constraint "constraint_no_inputs" : !pdl.attribute + %root = operation + rewrite %root with "rewriter" + } +} + +// CHECK: func @matcher(%arg0: !pdl.operation) { +// CHECK: pdl_interp.apply_constraint "constraint_no_inputs" : !pdl.attribute -> ^{{.*}}, ^{{.*}} + +// ----- + // CHECK-LABEL: module @negated_constraint module @negated_constraint { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) From 4a543bdb91742472beaf2af637c9db2f29482e7d Mon Sep 17 00:00:00 2001 From: Robert Esclapez-Garcia Date: Tue, 9 Apr 2024 10:31:40 +0100 Subject: [PATCH 2/4] [mlir][PDL]: Allow ApplyNativeConstraintOp have optional inputs --- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 4 +++- mlir/lib/Dialect/PDL/IR/PDL.cpp | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 199b284372e26c..2b5d2695475622 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -49,7 +49,9 @@ def PDL_ApplyNativeConstraintOp Variadic:$args, DefaultValuedAttr:$isNegated); let results = (outs Variadic:$results); - let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict"; + let assemblyFormat = [{ + $name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict + }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index 6615b142c31b7f..109171744afa35 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -92,8 +92,8 @@ static void visit(Operation *op, DenseSet &visited) { //===----------------------------------------------------------------------===// LogicalResult ApplyNativeConstraintOp::verify() { - if (getNumOperands() == 0) - return emitOpError("expected at least one argument"); + if (getNumOperands() == 0 && getNumResults() == 0) + return emitOpError("expected at least one argument or result"); return success(); } From df673ea7cb8e0440c236a5b7c40676f283de46fb Mon Sep 17 00:00:00 2001 From: Robert Esclapez-Garcia Date: Tue, 9 Apr 2024 10:36:10 +0100 Subject: [PATCH 3/4] [mlir][PDL]: Allow dict builtin function on Constraint and Rewrite --- mlir/include/mlir/Dialect/PDL/IR/Builtins.h | 10 +++-- mlir/lib/Dialect/PDL/IR/Builtins.cpp | 38 ++++++++++++------- mlir/lib/Tools/PDLL/Parser/Parser.cpp | 19 ++++++++-- mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 20 +++++----- mlir/test/mlir-pdll/Parser/expr.pdll | 4 +- mlir/unittests/Dialect/PDL/BuiltinTest.cpp | 41 +++++++++++++-------- 6 files changed, 83 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h index d77710ae889c05..fa9ed5d15c2b3c 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h +++ b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h @@ -26,10 +26,12 @@ namespace pdl { void registerBuiltins(PDLPatternModule &pdlPattern); namespace builtin { -Attribute createDictionaryAttr(PatternRewriter &rewriter); -Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter, - Attribute dictAttr, Attribute attrName, - Attribute attrEntry); +LogicalResult createDictionaryAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args); +LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args); Attribute createArrayAttr(PatternRewriter &rewriter); Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr, Attribute element); diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index 75f0d8ca2ade31..4d77e67cc436ae 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -16,18 +16,22 @@ using namespace mlir; namespace mlir::pdl { namespace builtin { -mlir::Attribute createDictionaryAttr(mlir::PatternRewriter &rewriter) { - return rewriter.getDictionaryAttr({}); + +LogicalResult createDictionaryAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + results.push_back(rewriter.getDictionaryAttr({})); + return success(); } -mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter, - mlir::Attribute dictAttr, - mlir::Attribute attrName, - mlir::Attribute attrEntry) { - assert(isa(dictAttr)); - auto attr = dictAttr.cast(); - auto name = attrName.cast(); - std::vector values = attr.getValue().vec(); +LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + auto dictAttr = cast(args[0].cast()); + auto name = cast(args[1].cast()); + auto attrEntry = args[2].cast(); + + std::vector values = dictAttr.getValue().vec(); // Remove entry if it exists in the dictionary. llvm::erase_if(values, [&](NamedAttribute &namedAttr) { @@ -35,7 +39,8 @@ mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter, }); values.push_back(rewriter.getNamedAttr(name, attrEntry)); - return rewriter.getDictionaryAttr(values); + results.push_back(rewriter.getDictionaryAttr(values)); + return success(); } mlir::Attribute createArrayAttr(mlir::PatternRewriter &rewriter) { @@ -110,14 +115,19 @@ LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results, void registerBuiltins(PDLPatternModule &pdlPattern) { using namespace builtin; // See Parser::defineBuiltins() - pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr", + pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr_rewrite", createDictionaryAttr); - pdlPattern.registerRewriteFunction("__builtin_addEntryToDictionaryAttr", - addEntryToDictionaryAttr); + pdlPattern.registerRewriteFunction( + "__builtin_addEntryToDictionaryAttr_rewrite", addEntryToDictionaryAttr); pdlPattern.registerRewriteFunction("__builtin_createArrayAttr", createArrayAttr); pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr", addElemToArrayAttr); + pdlPattern.registerConstraintFunctionWithResults( + "__builtin_createDictionaryAttr_constraint", createDictionaryAttr); + pdlPattern.registerConstraintFunctionWithResults( + "__builtin_addEntryToDictionaryAttr_constraint", + addEntryToDictionaryAttr); pdlPattern.registerConstraintFunctionWithResults("__builtin_add", add); } } // namespace mlir::pdl diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index fc1dbcfe3c1edf..43221a703ca906 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -640,10 +640,21 @@ T *Parser::declareBuiltin(StringRef name, ArrayRef argNames, } void Parser::declareBuiltins() { - builtins.createDictionaryAttr = declareBuiltin( - "__builtin_createDictionaryAttr", {}, /*returnsAttr=*/true); - builtins.addEntryToDictionaryAttr = declareBuiltin( - "__builtin_addEntryToDictionaryAttr", {"attr", "attrName", "attrEntry"}, + builtins.createDictionaryAttr_Rewrite = declareBuiltin( + "__builtin_createDictionaryAttr_rewrite", {}, /*returnsAttr=*/true); + builtins.createDictionaryAttr_Constraint = + declareBuiltin( + "__builtin_createDictionaryAttr_constraint", {}, + /*returnsAttr=*/true); + builtins.addEntryToDictionaryAttr_Rewrite = + declareBuiltin( + "__builtin_addEntryToDictionaryAttr_rewrite", + {"attr", "attrName", "attrEntry"}, + /*returnsAttr=*/true); + builtins.addEntryToDictionaryAttr_Constraint = + declareBuiltin( + "__builtin_addEntryToDictionaryAttr_constraint", + {"attr", "attrName", "attrEntry"}, /*returnsAttr=*/true); builtins.createArrayAttr = declareBuiltin( "__builtin_createArrayAttr", {}, /*returnsAttr=*/true); diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index aa2e268c81f368..1d715372009c55 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -167,9 +167,9 @@ Pattern TypeExpr => erase op<> -> (type<"i32">); // CHECK: %[[VAL_1:.*]] = operation "test.op" // CHECK: %[[VAL_2:.*]] = attribute = "test" // CHECK: rewrite %[[VAL_1]] { -// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite" // CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" -// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]] +// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]] // CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_5]]} // CHECK: replace %[[VAL_1]] with %[[VAL_6]] Pattern RewriteOneEntryDictionary { @@ -188,14 +188,14 @@ Pattern RewriteOneEntryDictionary { // CHECK: %[[VAL_2:.*]] = attribute = "test2" // CHECK: %[[VAL_3:.*]] = attribute = "test3" // CHECK: rewrite %[[VAL_1]] { -// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" +// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite" // CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_6:.*]] = attribute = "test1" -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = attribute = "secondAttr" -// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]] // CHECK: %[[VAL_10:.*]] = attribute = "thirdAttr" -// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]] +// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]] // CHECK: %[[VAL_12:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_11]]} // CHECK: replace %[[VAL_1]] with %[[VAL_12]] Pattern RewriteMultipleEntriesDictionary { @@ -214,10 +214,10 @@ Pattern RewriteMultipleEntriesDictionary { // CHECK: %[[VAL_1:.*]] = operation "test.op" // CHECK: rewrite %[[VAL_1]] { // CHECK: %[[VAL_2:.*]] = apply_native_rewrite "__builtin_createArrayAttr" -// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite" // CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_5:.*]] = attribute = "test1" -// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] +// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] // CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]} // CHECK: replace %[[VAL_1]] with %[[VAL_8]] @@ -236,10 +236,10 @@ Pattern RewriteOneDictionaryArrayAttr { // CHECK: %[[VAL_2:.*]] = attribute = "test2" // CHECK: rewrite %[[VAL_1]] { // CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createArrayAttr" -// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" +// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr_rewrite" // CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_6:.*]] = attribute = "test1" -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]] // CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]] // CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]} diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 98181e313b7150..018f2be4beacd7 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -38,10 +38,10 @@ Pattern { // CHECK: CallExpr {{.*}} Type // CHECK: UserRewriteDecl {{.*}} Name<__builtin_createArrayAttr> ResultType // CHECK: CallExpr {{.*}} Type -// CHECK: UserRewriteDecl {{.*}} Name<__builtin_addEntryToDictionaryAttr> ResultType +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_addEntryToDictionaryAttr_rewrite> ResultType // CHECK: `Arguments` // CHECK: CallExpr {{.*}} Type -// CHECK: UserRewriteDecl {{.*}} Name<__builtin_createDictionaryAttr> ResultType +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_createDictionaryAttr_rewrite> ResultType // CHECK: AttributeExpr {{.*}} Value<""firstAttr""> Pattern { diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp index dca5967c6112b2..c38bed21134a9e 100644 --- a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -42,26 +42,37 @@ class BuiltinTest : public ::testing::Test { }; TEST_F(BuiltinTest, createDictionaryAttr) { - auto attr = builtin::createDictionaryAttr(rewriter); - auto dict = dyn_cast(attr); - EXPECT_TRUE(dict); + TestPDLResultList results(1); + EXPECT_TRUE(succeeded(builtin::createDictionaryAttr(rewriter, results, {}))); + ASSERT_TRUE(results.getResults().size() == 1); + auto dict = dyn_cast_or_null( + results.getResults().back().cast()); + ASSERT_TRUE(dict); EXPECT_TRUE(dict.empty()); } TEST_F(BuiltinTest, addEntryToDictionaryAttr) { - auto dictAttr = rewriter.getDictionaryAttr({}); - - mlir::Attribute updated = builtin::addEntryToDictionaryAttr( - rewriter, dictAttr, rewriter.getStringAttr("testAttr"), - rewriter.getI16IntegerAttr(0)); + TestPDLResultList results(1); - EXPECT_TRUE(updated.cast().contains("testAttr")); - - auto second = builtin::addEntryToDictionaryAttr( - rewriter, updated, rewriter.getStringAttr("testAttr2"), - rewriter.getI16IntegerAttr(0)); - EXPECT_TRUE(second.cast().contains("testAttr")); - EXPECT_TRUE(second.cast().contains("testAttr2")); + auto dictAttr = rewriter.getDictionaryAttr({}); + EXPECT_TRUE(succeeded(builtin::addEntryToDictionaryAttr( + rewriter, results, + {dictAttr, rewriter.getStringAttr("testAttr"), + rewriter.getI16IntegerAttr(0)}))); + ASSERT_TRUE(results.getResults().size() == 1); + mlir::Attribute updated = results.getResults().front().cast(); + EXPECT_TRUE(cast(updated).contains("testAttr")); + + results = TestPDLResultList(1); + EXPECT_TRUE(succeeded(builtin::addEntryToDictionaryAttr( + rewriter, results, + {updated, rewriter.getStringAttr("testAttr2"), + rewriter.getI16IntegerAttr(0)}))); + ASSERT_TRUE(results.getResults().size() == 1); + mlir::Attribute second = results.getResults().front().cast(); + + EXPECT_TRUE(cast(second).contains("testAttr")); + EXPECT_TRUE(cast(second).contains("testAttr2")); } TEST_F(BuiltinTest, createArrayAttr) { From 3280a80f42bf9d8a7c3417e1f8106156864b320f Mon Sep 17 00:00:00 2001 From: Robert Esclapez-Garcia Date: Tue, 9 Apr 2024 10:37:14 +0100 Subject: [PATCH 4/4] [mlir][mlir-pdll][Parser] Allow parsing dictionaries in Constraint section --- mlir/lib/Tools/PDLL/Parser/Parser.cpp | 37 +++++++++++++----- mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 30 +++++++++++++++ mlir/test/mlir-pdll/Parser/expr.pdll | 44 ++++++++++++++++++++++ 3 files changed, 101 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 43221a703ca906..0a330ce96e1246 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -600,8 +600,10 @@ class Parser { CodeCompleteContext *codeCompleteContext; struct { - ast::UserRewriteDecl *createDictionaryAttr; - ast::UserRewriteDecl *addEntryToDictionaryAttr; + ast::UserRewriteDecl *createDictionaryAttr_Rewrite; + ast::UserConstraintDecl *createDictionaryAttr_Constraint; + ast::UserRewriteDecl *addEntryToDictionaryAttr_Rewrite; + ast::UserConstraintDecl *addEntryToDictionaryAttr_Constraint; ast::UserRewriteDecl *createArrayAttr; ast::UserRewriteDecl *addElemToArrayAttr; ast::UserConstraintDecl *add; @@ -655,7 +657,7 @@ void Parser::declareBuiltins() { declareBuiltin( "__builtin_addEntryToDictionaryAttr_constraint", {"attr", "attrName", "attrEntry"}, - /*returnsAttr=*/true); + /*returnsAttr=*/true); builtins.createArrayAttr = declareBuiltin( "__builtin_createArrayAttr", {}, /*returnsAttr=*/true); builtins.addElemToArrayAttr = declareBuiltin( @@ -2121,14 +2123,22 @@ FailureOr Parser::parseDictAttrExpr() { consumeToken(Token::l_brace); SMRange loc = curToken.getLoc(); - if (parserContext != ParserContext::Rewrite) - return emitError( - "Parsing of dictionary attributes as constraint not supported!"); - - auto dictAttrCall = createBuiltinCall(loc, builtins.createDictionaryAttr, {}); + FailureOr dictAttrCall; + if (parserContext == ParserContext::Rewrite) { + dictAttrCall = + createBuiltinCall(loc, builtins.createDictionaryAttr_Rewrite, {}); + } else { + dictAttrCall = + createBuiltinCall(loc, builtins.createDictionaryAttr_Constraint, {}); + } if (failed(dictAttrCall)) return failure(); + // No key-values inside dictionary + if (consumeIf(Token::r_brace)) { + return dictAttrCall; + } + // Add each nested attribute to the dict do { FailureOr decl = @@ -2159,8 +2169,15 @@ FailureOr Parser::parseDictAttrExpr() { // Create addEntryToDictionaryAttr native call. SmallVector arrayAttrArgs{*dictAttrCall, *stringAttrRef, namedDecl->getValue()}; - auto entryToDictionaryCall = createBuiltinCall( - loc, builtins.addEntryToDictionaryAttr, arrayAttrArgs); + + FailureOr entryToDictionaryCall; + if (parserContext == ParserContext::Rewrite) { + entryToDictionaryCall = createBuiltinCall( + loc, builtins.addEntryToDictionaryAttr_Rewrite, arrayAttrArgs); + } else { + entryToDictionaryCall = createBuiltinCall( + loc, builtins.addEntryToDictionaryAttr_Constraint, arrayAttrArgs); + } if (failed(entryToDictionaryCall)) return failure(); diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 1d715372009c55..49b9add1992539 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -253,6 +253,36 @@ Pattern RewriteMultiplyElementsArrayAttr { }; } +// ----- + +// CHECK-LABEL: pdl.pattern @ConstraintWithEmptyDictionary : benefit(0) { +// CHECK: %[[DICT:.+]] = apply_native_constraint "__builtin_createDictionaryAttr_constraint" : !pdl.attribute +// CHECK: rewrite %{{.*}} { +// CHECK-NEXT: operation "test.success" {"importantAttr" = %[[DICT]]} + +Pattern ConstraintWithEmptyDictionary { + let dict = {}; + replace op with op() { importantAttr = dict }; +} + +// ----- +// CHECK-LABEL: pdl.pattern @ConstraintWithTwoEntriesDictionary : benefit(0) { +// CHECK: %[[DICT:.+]] = apply_native_constraint "__builtin_createDictionaryAttr_constraint" : !pdl.attribute +// CHECK-NEXT: %[[KEY:.+]] = attribute = "hello" +// CHECK-NEXT: %[[VAL:.+]] = attribute = "world" +// CHECK-NEXT: %[[DICT2:.+]] = apply_native_constraint "__builtin_addEntryToDictionaryAttr_constraint"(%[[DICT]], %[[KEY]], %[[VAL]] : !pdl.attribute, !pdl.attribute, !pdl.attribute) : !pdl.attribute +// CHECK-NEXT: %[[KEY2:.+]] = attribute = "bye" +// CHECK-NEXT: %[[VAL2:.+]] = attribute = 100 : ui8 +// CHECK-NEXT: %[[DICT3:.+]] = apply_native_constraint "__builtin_addEntryToDictionaryAttr_constraint"(%[[DICT2]], %[[KEY2]], %[[VAL2]] : !pdl.attribute, !pdl.attribute, !pdl.attribute) : !pdl.attribute +// CHECK: rewrite %{{.*}} { +// CHECK-NEXT: operation "test.success" {"importantAttr" = %[[DICT3]]} + + +Pattern ConstraintWithTwoEntriesDictionary { + let dict = { hello = "world", bye = 100 : ui8 }; + replace op with op() { importantAttr = dict }; +} + // ----- // CHECK-LABEL: pdl.pattern @TestAdd : benefit(0) { // CHECK: %[[VAL_0:.*]] = attribute = 4 : i32 diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 018f2be4beacd7..5f6f86b9729780 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -55,6 +55,50 @@ Pattern { // ----- +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_createDictionaryAttr_constraint> ResultType +//CHECK-NEXT: `Results` +//CHECK-NEXT: `-VariableDecl {{.*}} Name<> Type +//CHECK-NEXT: `Constraints` +//CHECK-NEXT: `-AttrConstraintDecl {{.*}} +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getEmptyDict() -> Attr { + let dictionary = {}; + return dictionary; +} + +// ----- + +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addEntryToDictionaryAttr_constraint> ResultType +// CHECK: `Arguments` +//CHECK-NEXT: |-CallExpr {{.*}} Type +//CHECK-NEXT: | `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: | `-UserConstraintDecl {{.*}} Name<__builtin_createDictionaryAttr_constraint> ResultType +//CHECK-NEXT: | `Results` +//CHECK-NEXT: | `-VariableDecl {{.*}} Name<> Type +//CHECK-NEXT: | `Constraints` +//CHECK-NEXT: | `-AttrConstraintDecl {{.*}} +//CHECK-NEXT: |-DeclRefExpr {{.*}} Type +//CHECK-NEXT: | `-VariableDecl {{.*}} Name Type +//CHECK-NEXT: | `-AttributeExpr {{.*}} Value<""test""> +//CHECK-NEXT: `-AttributeExpr {{.*}} Value<""String""> +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getPopulatedDict() -> Attr { + let dictionary = { test = "String" }; + return dictionary; +} + +// ----- + //===----------------------------------------------------------------------===// // CallExpr //===----------------------------------------------------------------------===//