From 9d168220ea912bef7718a62c2754536b82b286cf Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Sat, 27 Jan 2024 02:48:00 +0100 Subject: [PATCH] PDLL: Add equals operator --- mlir/include/mlir/Dialect/PDL/IR/Builtins.h | 3 +++ mlir/lib/Dialect/PDL/IR/Builtins.cpp | 25 +++++++++++++++++++ mlir/lib/Tools/PDLL/Parser/Lexer.cpp | 4 +++ mlir/lib/Tools/PDLL/Parser/Lexer.h | 1 + mlir/lib/Tools/PDLL/Parser/Parser.cpp | 24 ++++++++++++++++-- mlir/test/lib/Tools/PDLL/TestPDLL.pdll | 7 ++++++ mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 14 +++++++++++ .../test/mlir-pdll/Integration/test-pdll.mlir | 9 +++++++ mlir/test/mlir-pdll/Parser/expr-failure.pdll | 20 +++++++++++++++ mlir/test/mlir-pdll/Parser/expr.pdll | 16 ++++++++++++ mlir/unittests/Dialect/PDL/BuiltinTest.cpp | 20 +++++++++++++++ 11 files changed, 141 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h index 72603b7ec100c1..8c56fa1446ea83 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h +++ b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h @@ -13,6 +13,8 @@ #ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_ #define MLIR_DIALECT_PDL_IR_BUILTINS_H_ +#include "mlir/Support/LogicalResult.h" + namespace mlir { class PDLPatternModule; class Attribute; @@ -29,6 +31,7 @@ Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter, Attribute createArrayAttr(PatternRewriter &rewriter); Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr, Attribute element); +LogicalResult equals(PatternRewriter &rewriter, Attribute lhs, Attribute rhs); } // namespace builtin } // namespace pdl } // namespace mlir diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index dfe635e8ab3049..fbe25774bb2329 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -39,6 +39,30 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter, values.push_back(element); return rewriter.getArrayAttr(values); } + +LogicalResult equals(mlir::PatternRewriter &, mlir::Attribute lhs, + mlir::Attribute rhs) { + if (auto lhsAttr = dyn_cast_or_null(lhs)) { + auto rhsAttr = dyn_cast_or_null(rhs); + if (!rhsAttr || lhsAttr.getType() != rhsAttr.getType()) + return failure(); + + APInt lhsVal = lhsAttr.getValue(); + APInt rhsVal = rhsAttr.getValue(); + return success(lhsVal.eq(rhsVal)); + } + + if (auto lhsAttr = dyn_cast_or_null(lhs)) { + auto rhsAttr = dyn_cast_or_null(rhs); + if (!rhsAttr || lhsAttr.getType() != rhsAttr.getType()) + return failure(); + + APFloat lhsVal = lhsAttr.getValue(); + APFloat rhsVal = rhsAttr.getValue(); + return success(lhsVal.compare(rhsVal) == llvm::APFloatBase::cmpEqual); + } + return failure(); +} } // namespace builtin void registerBuiltins(PDLPatternModule &pdlPattern) { @@ -52,5 +76,6 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { createArrayAttr); pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr", addElemToArrayAttr); + pdlPattern.registerConstraintFunction("__builtin_equals", equals); } } // namespace mlir::pdl diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp index eff82dc0211657..f47a5f3115fb63 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -209,6 +209,10 @@ Token Lexer::lexToken() { ++curPtr; return formToken(Token::equal_arrow, tokStart); } + if (*curPtr == '=') { + ++curPtr; + return formToken(Token::equal_equal, tokStart); + } return formToken(Token::equal, tokStart); case ';': return formToken(Token::semicolon, tokStart); diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h index 509bfd6678289d..6e823e1e429ea2 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.h +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -78,6 +78,7 @@ class Token { dot, equal, equal_arrow, + equal_equal, semicolon, exclam, /// Paired punctuation. diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index ef471e5a2edc10..f661b84f11e1ed 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -601,6 +601,7 @@ class Parser { ast::UserRewriteDecl *addEntryToDictionaryAttr; ast::UserRewriteDecl *createArrayAttr; ast::UserRewriteDecl *addElemToArrayAttr; + ast::UserConstraintDecl *equals; } builtins{}; }; } // namespace @@ -629,7 +630,7 @@ T *Parser::declareBuiltin(StringRef name, ArrayRef argNames, popDeclScope(); auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc), - args, results, {}, attrTy); + args, results, {}, createUserConstraintRewriteResultType(results)); curDeclScope->add(constraintDecl); return constraintDecl; } @@ -645,6 +646,10 @@ void Parser::declareBuiltins() { builtins.addElemToArrayAttr = declareBuiltin( "__builtin_addElemToArrayAttr", {"attr", "element"}, /*returnsAttr=*/true); + + builtins.equals = declareBuiltin( + "__builtin_equals", {"lhs", "rhs"}, + /*returnsAttr=*/false); } FailureOr Parser::parseModule() { @@ -1892,7 +1897,22 @@ FailureOr Parser::parseLogicalAndExpr() { } FailureOr Parser::parseEqualityExpr() { - return parseRelationExpr(); + auto lhs = parseRelationExpr(); + if (failed(lhs)) + return failure(); + + switch (curToken.getKind()) { + case Token::equal_equal: { + consumeToken(); + auto rhs = parseRelationExpr(); + if (failed(rhs)) + return failure(); + SmallVector args{*lhs, *rhs}; + return createBuiltinCall(curToken.getLoc(), builtins.equals, args); + } + default: + return lhs; + } } FailureOr Parser::parseRelationExpr() { return parseAddSubExpr(); } diff --git a/mlir/test/lib/Tools/PDLL/TestPDLL.pdll b/mlir/test/lib/Tools/PDLL/TestPDLL.pdll index 9715b556bbe214..3b1359efaa03ee 100644 --- a/mlir/test/lib/Tools/PDLL/TestPDLL.pdll +++ b/mlir/test/lib/Tools/PDLL/TestPDLL.pdll @@ -14,3 +14,10 @@ Pattern TestSimplePattern => replace op with op; // Test the import of interfaces. Pattern TestInterface => replace _: CastOpInterface with op; + +// Test equals builtin +Pattern TestEquals { + let op = op {val = val : Attr}; + val == attr<"4 : i32">; + replace op with op; +} diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 752e3c8268ede1..da1c2ecb7f873c 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -252,3 +252,17 @@ Pattern RewriteMultiplyElementsArrayAttr { replace root with newRoot; }; } + +// ----- +// CHECK-LABEL: pdl.pattern @TestEquals : benefit(0) { +// CHECK: %[[VAL_0:.*]] = operands +// CHECK: %[[VAL_1:.*]] = attribute +// CHECK: %[[VAL_2:.*]] = types +// CHECK: %[[VAL_3:.*]] = operation "test.op"(%[[VAL_0]] : !pdl.range) {"val" = %[[VAL_1]]} -> (%[[VAL_2]] : !pdl.range) +// CHECK: %[[VAL_4:.*]] = attribute = 4 : i32 +// CHECK: apply_native_constraint "__builtin_equals"(%[[VAL_1]], %[[VAL_4]] : !pdl.attribute, !pdl.attribute) +Pattern TestEquals { + let op = op {val = val : Attr}; + val == attr<"4 : i32">; + replace op with op; +} diff --git a/mlir/test/mlir-pdll/Integration/test-pdll.mlir b/mlir/test/mlir-pdll/Integration/test-pdll.mlir index baaffc74bf1f1b..88a26d8cf9767e 100644 --- a/mlir/test/mlir-pdll/Integration/test-pdll.mlir +++ b/mlir/test/mlir-pdll/Integration/test-pdll.mlir @@ -15,3 +15,12 @@ func.func @testImportedInterface() -> i1 { %value = "builtin.unrealized_conversion_cast"() : () -> (i1) return %value : i1 } + +// CHECK-LABEL: func @test_builtin +func.func @test_builtin() { + // CHECK: test.success + // CHECK: test.equals_neg + "test.equals"() { val = 4 : i32 }: () -> () + "test.equals_neg"() { val = 4 : i32 }: () -> () + return +} diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 9c1982a85cd876..9af5fc1ceba6a8 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -412,3 +412,23 @@ Pattern { // CHECK: expected `>` after type literal let foo = type<""; } + +// ----- + +//===----------------------------------------------------------------------===// +// Builtins +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected expression + == + erase _: Op; +} + +// ----- + +Pattern { + // CHECK: expected expression + attr<"4 : i32"> == + erase _: Op; +} diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 2cb3242f39c697..4893cdccbc0a31 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -359,3 +359,19 @@ Pattern { erase _: Op; } + +// ----- + +//===----------------------------------------------------------------------===// +// Builtins +//===----------------------------------------------------------------------===// + +// CHECK: Module {{.*}} +// CHECK: UserConstraintDecl {{.*}} Name<__builtin_equals> ResultType> +// CHECK: UserConstraintDecl {{.*}} Name<__builtin_equals> ResultType> +Pattern { + attr<"4 : i32"> == attr<"5 : i32">; + let a: Attr; + a == a; + erase _: Op; +} diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp index f653477a5ac87b..8118e0e8c4dfdc 100644 --- a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -69,4 +69,24 @@ TEST_F(BuiltinTest, addElemToArrayAttr) { cast(*cast(updatedArrAttr).begin()); EXPECT_EQ(dictInsideArrAttr, dict); } + +TEST_F(BuiltinTest, equals) { + auto onei16 = rewriter.getI16IntegerAttr(1); + auto onei32 = rewriter.getI32IntegerAttr(1); + auto zeroi32 = rewriter.getI32IntegerAttr(0); + + EXPECT_TRUE(builtin::equals(rewriter, onei16, onei16).succeeded()); + EXPECT_TRUE(builtin::equals(rewriter, onei16, onei32).failed()); + EXPECT_TRUE(builtin::equals(rewriter, zeroi32, onei32).failed()); + + auto onef32 = rewriter.getF32FloatAttr(1.0); + auto zerof32 = rewriter.getF32FloatAttr(0.0); + auto negzerof32 = rewriter.getF32FloatAttr(-0.0); + auto zerof64 = rewriter.getF64FloatAttr(0.0); + + EXPECT_TRUE(builtin::equals(rewriter, onef32, onef32).succeeded()); + EXPECT_TRUE(builtin::equals(rewriter, onef32, zerof32).failed()); + EXPECT_TRUE(builtin::equals(rewriter, negzerof32, zerof32).succeeded()); + EXPECT_TRUE(builtin::equals(rewriter, zerof32, zerof64).failed()); +} } // namespace